Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ venv.bak/

# mypy
.mypy_cache/

# .DS_Store file from MacOS
**/.DS_Store
100 changes: 81 additions & 19 deletions pytorchtools.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,34 @@
import numpy as np
import torch
from torch import nn


class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
"""
EarlyStopping can be used to monitor the validation loss during training and stop the training process early
if the validation loss does not improve after a certain number of epochs. It can handle both KFold and
non-KFold cases.
"""

def __init__(
self,
patience: int = 7,
verbose: bool = False,
delta: float = 0,
path: str = "checkpoint.pt",
use_kfold: bool = False,
trace_func=print,
):
"""
Initializes the EarlyStopping object with the given parameters.

Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint.pt'
trace_func (function): trace print function.
Default: print
patience: How long to wait after last time validation loss improved.
verbose: If True, prints a message for each validation loss improvement.
delta: Minimum change in the monitored quantity to qualify as an improvement.
path: Path for the checkpoint to be saved to.
use_kfold: If True, saves the model with the lowest loss metric for each fold.
trace_func: trace print function.
"""
self.patience = patience
self.verbose = verbose
Expand All @@ -25,27 +38,76 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
self.use_kfold = use_kfold
self.trace_func = trace_func
def __call__(self, val_loss, model):
self.fold = None
self.filename = None

def __call__(self, val_loss: float, model: nn.Module, fold: int = None):
"""
This method is called during the training process to monitor the validation loss and decide whether to stop
the training process early or not.

Args:
val_loss: Validation loss of the model at the current epoch.
model: The PyTorch model being trained.
fold: The current fold of the KFold cross-validation. Required if use_kfold is True.
"""
if self.use_kfold:
assert fold is not None, "Fold must be provided when use_kfold is True"

# If it's a new fold, resets the early stopping object and sets the filename to save the model
if fold != self.fold:
self.fold = fold
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.filename = self.path.replace(".pt", f"_fold_{fold}.pt")

# Calculating the score by negating the validation loss
score = -val_loss

# If the best score is None, sets it to the current score and saves the checkpoint
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)

# If the score is less than the best score plus delta, increments the counter
# and checks if the patience has been reached
elif score < self.best_score + self.delta:
self.counter += 1
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
self.trace_func(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience:
self.early_stop = True

# If the score is better than the best score plus delta, saves the checkpoint and resets the counter
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0

def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
def save_checkpoint(self, val_loss: float, model: nn.Module):
"""
Saves the model when validation loss decreases.

Args:
val_loss: The current validation loss.
model: The PyTorch model being trained.
"""
# If verbose mode is on, print a message about the validation loss decreasing and saving the model
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
self.trace_func(
f"Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}). Saving model ..."
)

# Save the state of the model to the appropriate filename based on whether KFold is used or not
if self.use_kfold:
torch.save(model.state_dict(), self.filename)
else:
torch.save(model.state_dict(), self.path)

# Update the minimum validation loss seen so far to the current validation loss
self.val_loss_min = val_loss
Loading