Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm

from utils.dice_score import multiclass_dice_coeff, dice_coeff


@torch.inference_mode()
def evaluate(net, dataloader, device, amp):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
dice_score = 0.0

# iterate over the validation set
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
Expand All @@ -24,17 +23,17 @@ def evaluate(net, dataloader, device, amp):
mask_pred = net(image)

if net.n_classes == 1:
# binary segmentation
assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
# compute the Dice score
mask_pred = (torch.sigmoid(mask_pred) > 0.5).float()
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
else:
# multi-class segmentation: compute Dice directly on class indices
assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
# convert to one-hot format
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
mask_pred_indices = mask_pred.argmax(dim=1)
# ignore background class (0) if desired
dice_score += multiclass_dice_coeff(mask_pred_indices, mask_true, ignore_index=0)

net.train()
return dice_score / max(num_val_batches, 1)