Skip to content

Commit c7bb250

Browse files
committed
Updated tests
1 parent d7d0300 commit c7bb250

File tree

5 files changed

+61
-6
lines changed

5 files changed

+61
-6
lines changed

scripts/smart-dispatch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def parse_arguments():
209209

210210
parser.add_argument('-p', '--pool', type=int, help="Number of workers that will be consuming commands. Default: Nb commands")
211211
parser.add_argument('--pbsFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of PBS flags. Ex:--pbsFlags="-lfeature=k80 -t0-4"')
212-
parser.add_argument('--sbatchFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of SBATCH flags. Ex:--sbatchFlags="-qos=high --output=file.out"')
212+
parser.add_argument('--sbatchFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of SBATCH flags. Ex:--sbatchFlags="--qos=high --ofile.out"')
213213
subparsers = parser.add_subparsers(dest="mode")
214214

215215
launch_parser = subparsers.add_parser('launch', help="Launch jobs.")

smartdispatch/job_generator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,13 @@ def add_sbatch_flags(self, flags):
7979
for flag in flags:
8080
split = flag.find('=')
8181
if flag.startswith('--'):
82-
options[flag[2:split]] = flag[split+1:]
83-
elif flag.startswith('-'):
84-
options[flag[1:split]] = flag[split+1:]
82+
if split == -1:
83+
raise ValueError("Invalid SBATCH flag ({})".format(flag))
84+
options[flag[:split].lstrip("-")] = flag[split+1:]
85+
elif flag.startswith('-') and split == -1:
86+
options[flag[1:2]] = flag[2:]
8587
else:
86-
raise ValueError("Invalid SBATCH flag ({})".format(flag))
88+
raise ValueError("Invalid SBATCH flag ({}, is it a PBS flag?)".format(flag))
8789

8890
for pbs in self.pbs_list:
8991
pbs.add_sbatch_options(**options)

smartdispatch/pbs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ def __str__(self):
178178
pbs += ["#PBS -l {0}={1}".format(resource_name, resource_value)]
179179

180180
for option_name, option_value in self.sbatch_options.items():
181-
pbs += ["#SBATCH {0}={1}".format(option_name, option_value)]
181+
if option_name.startswith('--'):
182+
pbs += ["#SBATCH {0}={1}".format(option_name, option_value)]
183+
else:
184+
pbs += ["#SBATCH {0} {1}".format(option_name, option_value)]
182185

183186
pbs += ["\n# Modules #"]
184187
for module in self.modules:

smartdispatch/tests/test_job_generator.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
class TestJobGenerator(object):
1515
pbs_flags = ['-lfeature=k80', '-lwalltime=42:42', '-lnodes=6:gpus=66', '-m', '-A123-asd-11', '-t10,20,30']
16+
sbatch_flags = ['--qos=high', '--output=file.out', '-Cminmemory']
1617

1718
def setUp(self):
1819
self.testing_dir = tempfile.mkdtemp()
@@ -129,6 +130,32 @@ def test_add_pbs_flags_invalid(self):
129130
def test_add_pbs_flags_invalid_resource(self):
130131
assert_raises(ValueError, self._test_add_pbs_flags, '-l weeee')
131132

133+
def _test_add_sbatch_flags(self, flags):
134+
job_generator = JobGenerator(self.queue, self.commands)
135+
job_generator.add_sbatch_flags(flags)
136+
options = []
137+
138+
for flag in flags:
139+
if flag.startswith('--'):
140+
options += [flag]
141+
elif flag.startswith('-'):
142+
options += [(flag[:2] + ' ' + flag[2:]).strip()]
143+
144+
for pbs in job_generator.pbs_list:
145+
pbs_str = pbs.__str__()
146+
for flag in options:
147+
assert_equal(pbs_str.count(flag), 1)
148+
149+
def test_add_sbatch_flags(self):
150+
for flag in self.sbatch_flags:
151+
yield self._test_add_sbatch_flags, [flag]
152+
153+
yield self._test_add_sbatch_flags, [flag]
154+
155+
def test_add_sbatch_flag_invalid(self):
156+
invalid_flags = ["--qos high", "gpu", "-lfeature=k80"]
157+
for flag in invalid_flags:
158+
assert_raises(ValueError, self._test_add_sbatch_flags, "--qos high")
132159

133160
class TestGuilliminQueue(object):
134161

@@ -244,6 +271,28 @@ def test_pbs_split_2_job_nb_commands(self):
244271
assert_true("ppn=6" in str(self.pbs8[0]))
245272
assert_true("ppn=2" in str(self.pbs8[1]))
246273

274+
class TestSlurmQueue(object):
275+
276+
def setUp(self):
277+
self.walltime = "10:00"
278+
self.cores = 42
279+
self.mem_per_node = 32
280+
self.nb_cores_per_node = 1
281+
self.nb_gpus_per_node = 2
282+
self.queue = Queue("slurm", "mila", self.walltime, self.nb_cores_per_node, self.nb_gpus_per_node, self.mem_per_node)
283+
284+
self.commands = ["echo 1", "echo 2", "echo 3", "echo 4"]
285+
job_generator = SlurmJobGenerator(self.queue, self.commands)
286+
self.pbs = job_generator.pbs_list
287+
288+
def test_ppn_ncpus(self):
289+
assert_true("ppn" not in str(self.pbs[0]))
290+
assert_true("ncpus" in str(self.pbs[0]))
291+
292+
def test_gpus_naccelerators(self):
293+
assert_true("gpus" not in str(self.pbs[0]))
294+
assert_true("naccelerators" in str(self.pbs[0]))
295+
247296
class TestJobGeneratorFactory(object):
248297

249298
def setUp(self):

smartdispatch/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def detect_cluster():
115115
output = Popen(["qstat", "-B"], stdout=PIPE).communicate()[0]
116116
except OSError:
117117
# If qstat is not available we assume that the cluster is unknown.
118+
# TODO: handle MILA + CEDAR + GRAHAM
118119
cluster_name = get_slurm_cluster_name()
119120
return None
120121
# Get server name from status

0 commit comments

Comments
 (0)