Skip to content

Commit a20405b

Browse files
committed
Updated tests using mock
1 parent c7bb250 commit a20405b

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

smartdispatch/job_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def add_sbatch_flags(self, flags):
8080
split = flag.find('=')
8181
if flag.startswith('--'):
8282
if split == -1:
83-
raise ValueError("Invalid SBATCH flag ({})".format(flag))
83+
raise ValueError("Invalid SBATCH flag ({}), no '=' character found' ".format(flag))
8484
options[flag[:split].lstrip("-")] = flag[split+1:]
8585
elif flag.startswith('-') and split == -1:
8686
options[flag[1:2]] = flag[2:]

smartdispatch/tests/test_job_generator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from nose.tools import assert_true, assert_false, assert_equal, assert_raises
2-
import unittest
3-
42
import os
5-
import tempfile
63
import shutil
4+
import tempfile
5+
import unittest
6+
7+
try:
8+
from mock import patch
9+
except ImportError:
10+
from unittest.mock import patch
11+
712
from smartdispatch.queue import Queue
813
from smartdispatch.job_generator import JobGenerator, job_generator_factory
914
from smartdispatch.job_generator import HeliosJobGenerator, HadesJobGenerator
@@ -155,7 +160,7 @@ def test_add_sbatch_flags(self):
155160
def test_add_sbatch_flag_invalid(self):
156161
invalid_flags = ["--qos high", "gpu", "-lfeature=k80"]
157162
for flag in invalid_flags:
158-
assert_raises(ValueError, self._test_add_sbatch_flags, "--qos high")
163+
assert_raises(ValueError, self._test_add_sbatch_flags, flag)
159164

160165
class TestGuilliminQueue(object):
161166

@@ -285,9 +290,15 @@ def setUp(self):
285290
job_generator = SlurmJobGenerator(self.queue, self.commands)
286291
self.pbs = job_generator.pbs_list
287292

293+
with patch.object(SlurmJobGenerator,'_add_cluster_specific_rules', side_effect=lambda: None):
294+
dummy_generator = SlurmJobGenerator(self.queue, self.commands)
295+
self.dummy_pbs = dummy_generator.pbs_list
296+
288297
def test_ppn_ncpus(self):
289298
assert_true("ppn" not in str(self.pbs[0]))
290299
assert_true("ncpus" in str(self.pbs[0]))
300+
assert_true("ppn" in str(self.dummy_pbs[0]))
301+
assert_true("ncpus" not in str(self.dummy_pbs[0]))
291302

292303
def test_gpus_naccelerators(self):
293304
assert_true("gpus" not in str(self.pbs[0]))

smartdispatch/tests/test_utils.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
# -*- coding: utf-8 -*-
22
import unittest
3-
4-
from smartdispatch import utils
5-
3+
try:
4+
from mock import patch
5+
import mock
6+
except ImportError:
7+
from unittest.mock import patch
8+
import unittest.mock
69
from nose.tools import assert_equal, assert_true
710
from numpy.testing import assert_array_equal
11+
import subprocess
812

13+
from smartdispatch import utils
914

1015
class PrintBoxedTests(unittest.TestCase):
1116

@@ -49,3 +54,40 @@ def test_slugify():
4954

5055
for arg, expected in testing_arguments:
5156
assert_equal(utils.slugify(arg), expected)
57+
58+
command_output = """\
59+
Server Max Tot Que Run Hld Wat Trn Ext Com Status
60+
---------------- --- --- --- --- --- --- --- --- --- ----------
61+
gpu-srv1.{} 0 1674 524 121 47 0 0 22 960 Idle
62+
"""
63+
64+
slurm_command = """\
65+
Cluster ControlHost ControlPort RPC Share GrpJobs GrpTRES GrpSubmit MaxJobs MaxTRES MaxSubmit MaxWall QOS Def QOS
66+
---------- --------------- ------------ ----- --------- ------- ------------- --------- ------- ------------- --------- ----------- -------------------- ---------
67+
{} 132.204.24.224 6817 7680 1 normal
68+
"""
69+
70+
71+
class ClusterIdentificationTest(unittest.TestCase):
72+
73+
def test_detect_cluster(self):
74+
server_name = ["hades", "m", "guil", "helios", "hades"]
75+
clusters = ["hades", "mammouth", "guillimin", "helios"]
76+
77+
for name, cluster in zip(server_name, clusters):
78+
with patch('smartdispatch.utils.Popen') as mock_communicate:
79+
mock_communicate.return_value.communicate.return_value = (command_output.format(name),)
80+
self.assertEquals(utils.detect_cluster(), cluster)
81+
82+
# def test_detect_mila_cluster(self):
83+
# with patch('smartdispatch.utils.Popen') as mock_communicate:
84+
# mock_communicate.return_value.communicate.side_effect = OSError
85+
# self.assertIsNone(utils.detect_cluster())
86+
87+
def test_get_slurm_cluster_name(self):
88+
clusters = ["graham", "cedar", "mila"]
89+
90+
for cluster in clusters:
91+
with patch('smartdispatch.utils.Popen') as mock_communicate:
92+
mock_communicate.return_value.communicate.return_value = (slurm_command.format(cluster),)
93+
self.assertEquals(utils.get_slurm_cluster_name(), cluster)

0 commit comments

Comments
 (0)