|
13 | 13 |
|
14 | 14 | class TestJobGenerator(object): |
15 | 15 | pbs_flags = ['-lfeature=k80', '-lwalltime=42:42', '-lnodes=6:gpus=66', '-m', '-A123-asd-11', '-t10,20,30'] |
| 16 | + sbatch_flags = ['--qos=high', '--output=file.out', '-Cminmemory'] |
16 | 17 |
|
17 | 18 | def setUp(self): |
18 | 19 | self.testing_dir = tempfile.mkdtemp() |
@@ -129,6 +130,32 @@ def test_add_pbs_flags_invalid(self): |
129 | 130 | def test_add_pbs_flags_invalid_resource(self): |
130 | 131 | assert_raises(ValueError, self._test_add_pbs_flags, '-l weeee') |
131 | 132 |
|
| 133 | + def _test_add_sbatch_flags(self, flags): |
| 134 | + job_generator = JobGenerator(self.queue, self.commands) |
| 135 | + job_generator.add_sbatch_flags(flags) |
| 136 | + options = [] |
| 137 | + |
| 138 | + for flag in flags: |
| 139 | + if flag.startswith('--'): |
| 140 | + options += [flag] |
| 141 | + elif flag.startswith('-'): |
| 142 | + options += [(flag[:2] + ' ' + flag[2:]).strip()] |
| 143 | + |
| 144 | + for pbs in job_generator.pbs_list: |
| 145 | + pbs_str = pbs.__str__() |
| 146 | + for flag in options: |
| 147 | + assert_equal(pbs_str.count(flag), 1) |
| 148 | + |
| 149 | + def test_add_sbatch_flags(self): |
| 150 | + for flag in self.sbatch_flags: |
| 151 | + yield self._test_add_sbatch_flags, [flag] |
| 152 | + |
| 153 | + yield self._test_add_sbatch_flags, [flag] |
| 154 | + |
| 155 | + def test_add_sbatch_flag_invalid(self): |
| 156 | + invalid_flags = ["--qos high", "gpu", "-lfeature=k80"] |
| 157 | + for flag in invalid_flags: |
| 158 | + assert_raises(ValueError, self._test_add_sbatch_flags, "--qos high") |
132 | 159 |
|
133 | 160 | class TestGuilliminQueue(object): |
134 | 161 |
|
@@ -244,6 +271,28 @@ def test_pbs_split_2_job_nb_commands(self): |
244 | 271 | assert_true("ppn=6" in str(self.pbs8[0])) |
245 | 272 | assert_true("ppn=2" in str(self.pbs8[1])) |
246 | 273 |
|
| 274 | +class TestSlurmQueue(object): |
| 275 | + |
| 276 | + def setUp(self): |
| 277 | + self.walltime = "10:00" |
| 278 | + self.cores = 42 |
| 279 | + self.mem_per_node = 32 |
| 280 | + self.nb_cores_per_node = 1 |
| 281 | + self.nb_gpus_per_node = 2 |
| 282 | + self.queue = Queue("slurm", "mila", self.walltime, self.nb_cores_per_node, self.nb_gpus_per_node, self.mem_per_node) |
| 283 | + |
| 284 | + self.commands = ["echo 1", "echo 2", "echo 3", "echo 4"] |
| 285 | + job_generator = SlurmJobGenerator(self.queue, self.commands) |
| 286 | + self.pbs = job_generator.pbs_list |
| 287 | + |
| 288 | + def test_ppn_ncpus(self): |
| 289 | + assert_true("ppn" not in str(self.pbs[0])) |
| 290 | + assert_true("ncpus" in str(self.pbs[0])) |
| 291 | + |
| 292 | + def test_gpus_naccelerators(self): |
| 293 | + assert_true("gpus" not in str(self.pbs[0])) |
| 294 | + assert_true("naccelerators" in str(self.pbs[0])) |
| 295 | + |
247 | 296 | class TestJobGeneratorFactory(object): |
248 | 297 |
|
249 | 298 | def setUp(self): |
|
0 commit comments