Skip to content

Commit 246f68c

Browse files
Fixtures depending on comm fixtures: fix bug in sequential scheduler + added test
1 parent 52f3ead commit 246f68c

File tree

5 files changed

+60
-7
lines changed

5 files changed

+60
-7
lines changed

pytest_parallel/mpi_reporter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_funct
7272

7373
class SequentialScheduler:
7474
def __init__(self, global_comm):
75-
self.global_comm = global_comm.Dup() # ensure that all communications within the framework are private to the framework
75+
self.global_comm = global_comm.Dup() # ensure that all communications within the framework are private to the framework
7676

7777
# These parameters are not accessible through the API, but are left here for tweaking and experimenting
7878
self.test_comm_creation = 'by_rank' # possible values : 'by_rank' | 'by_test'
@@ -83,6 +83,17 @@ def __init__(self, global_comm):
8383
self.mpi_comm_creation_function = 'MPI_Comm_split' # because 'MPI_Comm_create' uses `Create_group`,
8484
# that is not implemented in mpi4py for Windows
8585

86+
# Note: We don't need to do this for the static and dynamic scheduler
87+
# So we should investigate on what is the difference here
88+
@pytest.hookimpl(tryfirst=True)
89+
def pytest_fixture_setup(self, fixturedef, request):
90+
if hasattr(request.node, 'sub_comm'): # for some fixtures (notably pytest built-in ones), the communicator is not set yet
91+
comm = request.node.sub_comm
92+
if comm == MPI.COMM_NULL: # our process does not participate in the test: do not execute fixtures
93+
# pytest needs `cached_result` to be non-None, so put something there, hoping it will never be used as a valid value
94+
fixturedef.cached_result = (None, 'this_key_should_not_be_matched', RuntimeError('Pytest internal error from pytest_parallel::pytest_fixture_setup'))
95+
return True
96+
8697
@pytest.hookimpl(trylast=True)
8798
def pytest_collection_modifyitems(self, config, items):
8899
add_sub_comm(items, self.global_comm, self.test_comm_creation, self.mpi_comm_creation_function)
@@ -97,7 +108,6 @@ def pytest_runtest_protocol(self, item, nextitem):
97108

98109
@pytest.hookimpl(tryfirst=True)
99110
def pytest_pyfunc_call(self, pyfuncitem):
100-
#print(f'pytest_pyfunc_call {MPI.COMM_WORLD.rank=}')
101111
# This is where the test is normally run.
102112
# Only run the test for the ranks that do participate in the test
103113
if pyfuncitem.sub_comm == MPI.COMM_NULL:

pytest_parallel/plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,15 @@ def __init__(self, comm):
245245

246246
def __enter__(self):
247247
from mpi4py import MPI
248-
if self.comm != MPI.COMM_NULL: # TODO DEL once non-participating rank do not participate in fixtures either
248+
if self.comm != MPI.COMM_NULL: # TODO 2025-10: should not be needed anymore > try deletion
249249
rank = self.comm.rank
250250
self.tmp_dir = tempfile.TemporaryDirectory() if rank == 0 else None
251251
self.tmp_path = Path(self.tmp_dir.name) if rank == 0 else None
252252
return self.comm.bcast(self.tmp_path, root=0)
253253

254254
def __exit__(self, ex_type, ex_value, traceback):
255255
from mpi4py import MPI
256-
if self.comm != MPI.COMM_NULL: # TODO DEL once non-participating rank do not participate in fixtures either
256+
if self.comm != MPI.COMM_NULL: # TODO 2025-10: should not be needed anymore > try deletion
257257
self.comm.barrier()
258258
if self.comm.rank == 0:
259259
self.tmp_dir.cleanup()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
[=]+ test session starts [=]+
2+
platform [^\n]*
3+
cachedir: [^\n]*
4+
?(?:metadata: [^\n]*)?
5+
rootdir: [^\n]*
6+
?(?:configfile: [^\n]*)?
7+
?(?:plugins: [^\n]*)?
8+
collecting ... [\s]*collected 1 item[\s]*
9+
?(?:Submitting tests to SLURM...)?
10+
?(?:SLURM job [^\n]* has been submitted)?
11+
12+
[^\n]*test_fixture.py::test_fixture\[2\] FAILED
13+
14+
[=]+ FAILURES [=]+
15+
[_]+ test_fixture\[2\] [_]+
16+
17+
[-]+ On rank 1 of 2 [-]+
18+
my_fixture = 1
19+
20+
@pytest_parallel.mark.parallel\(2\)
21+
def test_fixture\(my_fixture\):
22+
> assert my_fixture == 0 # should fail on proc 1
23+
E assert 1 == 0
24+
25+
[^\n]*test_fixture.py:12: AssertionError
26+
[=]+ short test summary info [=]+
27+
FAILED [^\n]*test_fixture.py::test_fixture\[2\][^\n]*
28+
[=]+ 1 failed in [^\n]*s [=]+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import pytest
2+
import pytest_parallel
3+
4+
5+
@pytest.fixture
6+
def my_fixture(comm):
7+
return comm.rank
8+
9+
10+
@pytest_parallel.mark.parallel(2)
11+
def test_fixture(my_fixture):
12+
assert my_fixture == 0 # should fail on proc 1

test/test_pytest_parallel.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,13 @@ def test_16(self, scheduler, capfd): run_pytest_parallel_test('two_success_fail_
9999

100100
def test_17(self, scheduler, capfd): run_pytest_parallel_test('fixture_error' , 1, scheduler, capfd) # check that fixture errors are correctly reported
101101

102-
def test_18(self, scheduler, capfd): run_pytest_parallel_test('parametrize' , 2, scheduler, capfd) # check the parametrize API
102+
def test_18(self, scheduler, capfd): run_pytest_parallel_test('fixture' , 2, scheduler, capfd) # check that fixtures can depend on comm
103+
def test_19(self, scheduler, capfd): run_pytest_parallel_test('fixture' , 4, scheduler, capfd) # check that fixtures can depend on comm, more procs
103104

104-
def test_19(self, scheduler, capfd): run_pytest_parallel_test('scheduling' , 4, scheduler, capfd) # check 'real' case
105-
def test_20(self, scheduler, capfd): run_pytest_parallel_test('fail_complex_assert_two_procs' , 2, scheduler, capfd) # check 'complex' error message
105+
def test_20(self, scheduler, capfd): run_pytest_parallel_test('parametrize' , 2, scheduler, capfd) # check the parametrize API
106+
107+
def test_21(self, scheduler, capfd): run_pytest_parallel_test('scheduling' , 4, scheduler, capfd) # check 'real' case
108+
def test_22(self, scheduler, capfd): run_pytest_parallel_test('fail_complex_assert_two_procs' , 2, scheduler, capfd) # check 'complex' error message
106109
# fmt: on
107110

108111
## If one test fail, it may be useful to debug regex matching along the following lines

0 commit comments

Comments
 (0)