@@ -226,8 +226,8 @@ def train_one_epoch(
226226 torch.Tensor: Training loss for the epoch.
227227 """
228228 include_body_region = unet .include_top_region_index_input
229- include_modality = ( unet .num_class_embeds is not None )
230-
229+ include_modality = unet .num_class_embeds is not None
230+
231231 if local_rank == 0 :
232232 current_lr = optimizer .param_groups [0 ]["lr" ]
233233 logger .info (f"Epoch { epoch + 1 } , lr { current_lr } ." )
@@ -248,7 +248,7 @@ def train_one_epoch(
248248 bottom_region_index_tensor = train_data ["bottom_region_index" ].to (device )
249249 # We trained with only CT in this version
250250 if include_modality :
251- modality_tensor = torch .ones ((len (images ),),dtype = torch .long ).to (device )
251+ modality_tensor = torch .ones ((len (images ),), dtype = torch .long ).to (device )
252252 spacing_tensor = train_data ["spacing" ].to (device )
253253
254254 optimizer .zero_grad (set_to_none = True )
@@ -268,18 +268,22 @@ def train_one_epoch(
268268 "x" : noisy_latent ,
269269 "timesteps" : timesteps ,
270270 "spacing_tensor" : spacing_tensor ,
271- }
271+ }
272272 # Add extra arguments if include_body_region is True
273273 if include_body_region :
274- unet_inputs .update ({
275- "top_region_index_tensor" : top_region_index_tensor ,
276- "bottom_region_index_tensor" : bottom_region_index_tensor
277- })
274+ unet_inputs .update (
275+ {
276+ "top_region_index_tensor" : top_region_index_tensor ,
277+ "bottom_region_index_tensor" : bottom_region_index_tensor ,
278+ }
279+ )
278280 if include_modality :
279- unet_inputs .update ({
280- "class_labels" : modality_tensor ,
281- })
282- model_output = unet (** unet_inputs )
281+ unet_inputs .update (
282+ {
283+ "class_labels" : modality_tensor ,
284+ }
285+ )
286+ model_output = unet (** unet_inputs )
283287
284288 if noise_scheduler .prediction_type == DDPMPredictionType .EPSILON :
285289 # predict noise
@@ -359,11 +363,7 @@ def save_checkpoint(
359363
360364
361365def diff_model_train (
362- env_config_path : str ,
363- model_config_path : str ,
364- model_def_path : str ,
365- num_gpus : int ,
366- amp : bool = True
366+ env_config_path : str , model_config_path : str , model_def_path : str , num_gpus : int , amp : bool = True
367367) -> None :
368368 """
369369 Main function to train a diffusion model.
@@ -424,8 +424,6 @@ def diff_model_train(
424424 include_body_region = include_body_region ,
425425 )
426426
427-
428-
429427 scale_factor = calculate_scale_factor (train_loader , device , logger )
430428 optimizer = create_optimizer (unet , args .diffusion_unet_train ["lr" ])
431429
@@ -455,7 +453,7 @@ def diff_model_train(
455453 device ,
456454 logger ,
457455 local_rank ,
458- amp = amp
456+ amp = amp ,
459457 )
460458
461459 loss_torch = loss_torch .tolist ()
@@ -498,6 +496,4 @@ def diff_model_train(
498496 parser .add_argument ("--no_amp" , dest = "amp" , action = "store_false" , help = "Disable automatic mixed precision training" )
499497
500498 args = parser .parse_args ()
501- diff_model_train (
502- args .env_config , args .model_config , args .model_def , args .num_gpus , args .amp
503- )
499+ diff_model_train (args .env_config , args .model_config , args .model_def , args .num_gpus , args .amp )
0 commit comments