@@ -39,27 +39,17 @@ def forward(self, x):
3939 return F .log_softmax (x , dim = 1 )
4040
4141
42- def train (rank , model , device , dataloader_kwargs ):
42+ def train (rank , model , device , data_set , dataloader_kwargs ):
4343 # Training Settings
44+
4445 batch_size = 64
4546 epochs = 1
4647 lr = 0.01
4748 momentum = 0.5
4849
4950 torch .manual_seed (1 + rank )
5051 train_loader = torch .utils .data .DataLoader (
51- datasets .MNIST (
52- data_dir ,
53- train = True ,
54- download = True ,
55- transform = transforms .Compose (
56- [transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))]
57- ),
58- ),
59- batch_size = batch_size ,
60- shuffle = True ,
61- num_workers = 1 ,
62- ** dataloader_kwargs
52+ data_set , batch_size = batch_size , shuffle = True , num_workers = 1 , ** dataloader_kwargs
6353 )
6454
6555 optimizer = optim .SGD (model .parameters (), lr = lr , momentum = momentum )
@@ -85,6 +75,17 @@ def test_no_failure_with_torch_mp(out_dir):
8575 path = str (path )
8676 os .environ ["SMDEBUG_CONFIG_FILE_PATH" ] = path
8777 device = "cpu"
78+
79+ # clear data_dir before saving to it
80+ shutil .rmtree (data_dir , ignore_errors = True )
81+ data_set = datasets .MNIST (
82+ data_dir ,
83+ train = True ,
84+ download = True ,
85+ transform = transforms .Compose (
86+ [transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))]
87+ ),
88+ )
8889 dataloader_kwargs = {}
8990 cpu_count = 2 if mp .cpu_count () > 2 else mp .cpu_count ()
9091
@@ -95,7 +96,7 @@ def test_no_failure_with_torch_mp(out_dir):
9596
9697 processes = []
9798 for rank in range (cpu_count ):
98- p = mp .Process (target = train , args = (rank , model , device , dataloader_kwargs ))
99+ p = mp .Process (target = train , args = (rank , model , device , data_set , dataloader_kwargs ))
99100 # We first train the model across `num_processes` processes
100101 p .start ()
101102 processes .append (p )
0 commit comments