diff --git a/train.py b/train.py index 5c3805208c..3ec4246618 100644 --- a/train.py +++ b/train.py @@ -112,37 +112,40 @@ def train_net(net, # Evaluation round division_step = (n_train // (10 * batch_size)) - if division_step > 0: - if global_step % division_step == 0: - histograms = {} - for tag, value in net.named_parameters(): - tag = tag.replace('/', '.') - histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) - histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) - - val_score = evaluate(net, val_loader, device) - scheduler.step(val_score) - - logging.info('Validation Dice score: {}'.format(val_score)) - experiment.log({ - 'learning rate': optimizer.param_groups[0]['lr'], - 'validation Dice': val_score, - 'images': wandb.Image(images[0].cpu()), - 'masks': { - 'true': wandb.Image(true_masks[0].float().cpu()), - 'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()), - }, - 'step': global_step, - 'epoch': epoch, - **histograms - }) + + if division_step > 0 and global_step % division_step == 0: + histograms = {} + for tag, value in net.named_parameters(): + tag = tag.replace('/', '.') + histograms[f'Weights/{tag}'] = wandb.Histogram(value.data.cpu()) + histograms[f'Gradients/{tag}'] = wandb.Histogram(value.grad.data.cpu()) + + val_score = evaluate(net, val_loader, device) + scheduler.step(val_score) + + logging.info(f'Validation Dice score: {val_score}') + experiment.log({ + 'learning rate': optimizer.param_groups[0]['lr'], + 'validation Dice': val_score, + 'images': wandb.Image(images[0].cpu()), + 'masks': { + 'true': wandb.Image(true_masks[0].float().cpu()), + 'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()), + }, + 'step': global_step, + 'epoch': epoch, + **histograms + }) + if save_checkpoint: Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch))) logging.info(f'Checkpoint {epoch} saved!') + def get_args(): parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs') diff --git a/unet/unet_model.py b/unet/unet_model.py index 20c35b52cc..c9ed21d34f 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -32,5 +32,4 @@ def forward(self, x): x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) - logits = self.outc(x) - return logits + return self.outc(x) diff --git a/utils/data_loading.py b/utils/data_loading.py index 8bb4f9252c..331c165cbe 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -56,7 +56,7 @@ def load(filename): def __getitem__(self, idx): name = self.ids[idx] mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*')) - img_file = list(self.images_dir.glob(name + '.*')) + img_file = list(self.images_dir.glob(f'{name}.*')) assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' diff --git a/utils/dice_score.py b/utils/dice_score.py index c07f0d0fbe..26873572af 100644 --- a/utils/dice_score.py +++ b/utils/dice_score.py @@ -17,18 +17,26 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, return (2 * inter + epsilon) / (sets_sum + epsilon) else: # compute and average metric for each batch element - dice = 0 - for i in range(input.shape[0]): - dice += dice_coeff(input[i, ...], target[i, ...]) + dice = sum( + dice_coeff(input[i, ...], target[i, ...]) + for i in range(input.shape[0]) + ) + return dice / input.shape[0] def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): # Average of Dice coefficient for all classes assert input.size() == target.size() - dice = 0 - for channel in range(input.shape[1]): - dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) + dice = sum( + dice_coeff( + input[:, channel, ...], + target[:, channel, ...], + reduce_batch_first, + epsilon, + ) + for channel in range(input.shape[1]) + ) return dice / input.shape[1] diff --git a/utils/utils.py b/utils/utils.py index 859e887dfb..86a07daac4 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -11,7 +11,7 @@ def plot_img_and_mask(img, mask): ax[i + 1].set_title(f'Output mask (class {i + 1})') ax[i + 1].imshow(mask[i, :, :]) else: - ax[1].set_title(f'Output mask') + ax[1].set_title('Output mask') ax[1].imshow(mask) plt.xticks([]), plt.yticks([]) plt.show()