Skip to content

Commit d7d0300

Browse files
committed
Cleaned code with PR feedback
1 parent 04f0ff1 commit d7d0300

File tree

6 files changed

+38
-35
lines changed

6 files changed

+38
-35
lines changed

smartdispatch/job_generator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ def add_sbatch_flags(self, flags):
7878

7979
for flag in flags:
8080
split = flag.find('=')
81-
options[flag[:split]] = flag[split+1:]
81+
if flag.startswith('--'):
82+
options[flag[2:split]] = flag[split+1:]
83+
elif flag.startswith('-'):
84+
options[flag[1:split]] = flag[split+1:]
85+
else:
86+
raise ValueError("Invalid SBATCH flag ({})".format(flag))
8287

8388
for pbs in self.pbs_list:
8489
pbs.add_sbatch_options(**options)
@@ -182,12 +187,13 @@ def _add_cluster_specific_rules(self):
182187
# Remove forbidden ppn option. Default is 2 cores per gpu.
183188
pbs.resources['nodes'] = re.sub(":ppn=[0-9]+", "", pbs.resources['nodes'])
184189

185-
class SlurmClusterGenerator(JobGenerator):
190+
class SlurmJobGenerator(JobGenerator):
186191

187192
def _add_cluster_specific_rules(self):
188193
for pbs in self.pbs_list:
189-
node_resource = pbs.resources.pop('nodes')
190-
gpus = re.match(".*gpus=([0-9]+)", node_resource).group(1)
191-
ppn = re.match(".*ppn=([0-9]+)", node_resource).group(1)
194+
gpus = re.match(".*gpus=([0-9]+)", pbs.resources['nodes']).group(1)
195+
ppn = re.match(".*ppn=([0-9]+)", pbs.resources['nodes']).group(1)
196+
pbs.resources['nodes'] = re.sub("ppn=[0-9]+", "", pbs.resources['nodes'])
197+
pbs.resources['nodes'] = re.sub(":gpus=[0-9]+", "", pbs.resources['nodes'])
192198
pbs.add_resources(naccelerators=gpus)
193199
pbs.add_resources(ncpus=ppn)

smartdispatch/pbs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,14 @@ def add_sbatch_options(self, **options):
7070
Parameters
7171
----------
7272
**options : dict
73-
each key is the name of a SBATCH option (see `Options`)
73+
each key is the name of a SBATCH option
7474
"""
7575

7676
for option_name, option_value in options.items():
77+
if len(option_name) == 1:
78+
self.sbatch_options["-" + option_name] = option_value
79+
else:
80+
self.sbatch_options["--" + option_name] = option_value
7781
self.sbatch_options[option_name] = option_value
7882

7983
def add_resources(self, **resources):

smartdispatch/tests/test_job_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from nose.tools import assert_true, assert_false, assert_equal, assert_raises
2+
import unittest
23

34
import os
45
import tempfile
@@ -7,6 +8,7 @@
78
from smartdispatch.job_generator import JobGenerator, job_generator_factory
89
from smartdispatch.job_generator import HeliosJobGenerator, HadesJobGenerator
910
from smartdispatch.job_generator import GuilliminJobGenerator, MammouthJobGenerator
11+
from smartdispatch.job_generator import SlurmJobGenerator
1012

1113

1214
class TestJobGenerator(object):
@@ -242,7 +244,6 @@ def test_pbs_split_2_job_nb_commands(self):
242244
assert_true("ppn=6" in str(self.pbs8[0]))
243245
assert_true("ppn=2" in str(self.pbs8[1]))
244246

245-
246247
class TestJobGeneratorFactory(object):
247248

248249
def setUp(self):

smartdispatch/tests/test_pbs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from nose.tools import assert_true, assert_equal, assert_raises
22
from numpy.testing import assert_array_equal
33

4-
54
from smartdispatch.pbs import PBS
65
import unittest
76
import tempfile
@@ -38,6 +37,13 @@ def test_add_options(self):
3837
assert_equal(self.pbs.options["-A"], "option2")
3938
assert_equal(self.pbs.options["-B"], "option3")
4039

40+
def test_add_sbatch_options(self):
41+
self.pbs.add_sbatch_options(a="value1")
42+
assert_equal(self.pbs.sbatch_options["-a"], "value1")
43+
self.pbs.add_sbatch_options(option1="value2", option2="value3")
44+
assert_equal(self.pbs.sbatch_options["--option1"], "value2")
45+
assert_equal(self.pbs.sbatch_options["--option2"], "value3")
46+
4147
def test_add_resources(self):
4248
assert_equal(len(self.pbs.resources), 1)
4349
assert_equal(self.pbs.resources["walltime"], self.walltime)

smartdispatch/tests/pbs_slurm_test.py renamed to smartdispatch/tests/verify_slurms_pbs_wrapper.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import os
33
import time
44
import unittest
5-
65
from subprocess import Popen, PIPE
76

7+
from smartdispatch.utils import get_slurm_cluster_name
8+
89
pbs_string = """\
910
#!/usr/bin/env /bin/bash
1011
@@ -22,28 +23,8 @@
2223
nvidia-smi
2324
"""
2425

25-
sbatch_string = """\
26-
#!/usr/bin/env -i /bin/zsh
27-
28-
#SBATCH --job-name=arrayJob
29-
#SBATCH --output=arrayJob_%A_%a.out
30-
#SBATCH --time=01:00:00
31-
{}
32-
33-
######################
34-
# Begin work section #
35-
######################
36-
37-
echo "My SLURM_ARRAY_JOB_ID:" $SLURM_ARRAY_JOB_ID
38-
echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
39-
nvidia-smi
40-
"""
41-
4226
# Checking which cluster is running the tests first
43-
process = Popen("sacctmgr list cluster", stdout=PIPE, stderr=PIPE, shell=True)
44-
stdout, _ = process.communicate()
45-
stdout = stdout.decode()
46-
cluster = stdout.splitlines()[2].strip().split(' ')[0]
27+
cluster = get_slurm_cluster_name()
4728
to_skip = cluster in ['graham', 'cedar']
4829
message = "Test does not run on cluster {}".format(cluster)
4930

@@ -53,14 +34,14 @@ def tearDown(self):
5334
for file_name in (glob('*.out') + ["test.pbs"]):
5435
os.remove(file_name)
5536

56-
def _test_param(self, param_array, command, flag, string=pbs_string, output_array=None):
37+
def _test_param(self, param_array, command_template, flag, string=pbs_string, output_array=None):
5738
output_array = output_array or param_array
5839
for param, output in zip(param_array, output_array):
59-
com = pbs_string.format(
60-
string.format(command.format(param))
40+
param_command = pbs_string.format(
41+
string.format(command_template.format(param))
6142
)
6243
with open("test.pbs", "w") as text_file:
63-
text_file.write(com)
44+
text_file.write(param_command)
6445
process = Popen("sbatch test.pbs", stdout=PIPE, stderr=PIPE, shell=True)
6546
stdout, _ = process.communicate()
6647
stdout = stdout.decode()

smartdispatch/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def detect_cluster():
115115
output = Popen(["qstat", "-B"], stdout=PIPE).communicate()[0]
116116
except OSError:
117117
# If qstat is not available we assume that the cluster is unknown.
118-
# TODO: handle MILA + CEDAR + GRAHAM
118+
cluster_name = get_slurm_cluster_name()
119119
return None
120120
# Get server name from status
121121
server_name = output.split('\n')[2].split(' ')[0]
@@ -131,6 +131,11 @@ def detect_cluster():
131131
cluster_name = "hades"
132132
return cluster_name
133133

134+
def get_slurm_cluster_name():
135+
stdout = Popen("sacctmgr list cluster", stdout=PIPE, shell=True).communicate()[0]
136+
stdout = stdout.decode()
137+
cluster_name = stdout.splitlines()[2].strip().split(' ')[0]
138+
return cluster_name
134139

135140
def get_launcher(cluster_name):
136141
if cluster_name == "helios":

0 commit comments

Comments
 (0)