Skip to content

Commit 19fc0d7

Browse files
committed
prepare dataset outside of child processes
1 parent bb8f4b9 commit 19fc0d7

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

tests/zero_code_change/test_pytorch_multiprocessing.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,17 @@ def forward(self, x):
3939
return F.log_softmax(x, dim=1)
4040

4141

42-
def train(rank, model, device, dataloader_kwargs):
42+
def train(rank, model, device, data_set, dataloader_kwargs):
4343
# Training Settings
44+
4445
batch_size = 64
4546
epochs = 1
4647
lr = 0.01
4748
momentum = 0.5
4849

4950
torch.manual_seed(1 + rank)
5051
train_loader = torch.utils.data.DataLoader(
51-
datasets.MNIST(
52-
data_dir,
53-
train=True,
54-
download=True,
55-
transform=transforms.Compose(
56-
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
57-
),
58-
),
59-
batch_size=batch_size,
60-
shuffle=True,
61-
num_workers=1,
62-
**dataloader_kwargs
52+
data_set, batch_size=batch_size, shuffle=True, num_workers=1, **dataloader_kwargs
6353
)
6454

6555
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
@@ -85,6 +75,17 @@ def test_no_failure_with_torch_mp(out_dir):
8575
path = str(path)
8676
os.environ["SMDEBUG_CONFIG_FILE_PATH"] = path
8777
device = "cpu"
78+
79+
# clear data_dir before saving to it
80+
shutil.rmtree(data_dir, ignore_errors=True)
81+
data_set = datasets.MNIST(
82+
data_dir,
83+
train=True,
84+
download=True,
85+
transform=transforms.Compose(
86+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
87+
),
88+
)
8889
dataloader_kwargs = {}
8990
cpu_count = 2 if mp.cpu_count() > 2 else mp.cpu_count()
9091

@@ -95,7 +96,7 @@ def test_no_failure_with_torch_mp(out_dir):
9596

9697
processes = []
9798
for rank in range(cpu_count):
98-
p = mp.Process(target=train, args=(rank, model, device, dataloader_kwargs))
99+
p = mp.Process(target=train, args=(rank, model, device, data_set, dataloader_kwargs))
99100
# We first train the model across `num_processes` processes
100101
p.start()
101102
processes.append(p)

0 commit comments

Comments
 (0)