diff --git a/.gitignore b/.gitignore index 30eab60..a0088fc 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,6 @@ venv.bak/ # mypy .mypy_cache/ + +# .DS_Store file from MacOS +**/.DS_Store diff --git a/pytorchtools.py b/pytorchtools.py index 9644e4b..e9803fc 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -1,21 +1,34 @@ import numpy as np import torch +from torch import nn + class EarlyStopping: - """Early stops the training if validation loss doesn't improve after a given patience.""" - def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): + """ + EarlyStopping can be used to monitor the validation loss during training and stop the training process early + if the validation loss does not improve after a certain number of epochs. It can handle both KFold and + non-KFold cases. + """ + + def __init__( + self, + patience: int = 7, + verbose: bool = False, + delta: float = 0, + path: str = "checkpoint.pt", + use_kfold: bool = False, + trace_func=print, + ): """ + Initializes the EarlyStopping object with the given parameters. + Args: - patience (int): How long to wait after last time validation loss improved. - Default: 7 - verbose (bool): If True, prints a message for each validation loss improvement. - Default: False - delta (float): Minimum change in the monitored quantity to qualify as an improvement. - Default: 0 - path (str): Path for the checkpoint to be saved to. - Default: 'checkpoint.pt' - trace_func (function): trace print function. - Default: print + patience: How long to wait after last time validation loss improved. + verbose: If True, prints a message for each validation loss improvement. + delta: Minimum change in the monitored quantity to qualify as an improvement. + path: Path for the checkpoint to be saved to. + use_kfold: If True, saves the model with the lowest loss metric for each fold. + trace_func: trace print function. """ self.patience = patience self.verbose = verbose @@ -25,27 +38,76 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra self.val_loss_min = np.Inf self.delta = delta self.path = path + self.use_kfold = use_kfold self.trace_func = trace_func - def __call__(self, val_loss, model): + self.fold = None + self.filename = None + + def __call__(self, val_loss: float, model: nn.Module, fold: int = None): + """ + This method is called during the training process to monitor the validation loss and decide whether to stop + the training process early or not. + + Args: + val_loss: Validation loss of the model at the current epoch. + model: The PyTorch model being trained. + fold: The current fold of the KFold cross-validation. Required if use_kfold is True. + """ + if self.use_kfold: + assert fold is not None, "Fold must be provided when use_kfold is True" + + # If it's a new fold, resets the early stopping object and sets the filename to save the model + if fold != self.fold: + self.fold = fold + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.filename = self.path.replace(".pt", f"_fold_{fold}.pt") + # Calculating the score by negating the validation loss score = -val_loss + # If the best score is None, sets it to the current score and saves the checkpoint if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) + + # If the score is less than the best score plus delta, increments the counter + # and checks if the patience has been reached elif score < self.best_score + self.delta: self.counter += 1 - self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') + self.trace_func( + f"EarlyStopping counter: {self.counter} out of {self.patience}" + ) if self.counter >= self.patience: self.early_stop = True + + # If the score is better than the best score plus delta, saves the checkpoint and resets the counter else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 - def save_checkpoint(self, val_loss, model): - '''Saves model when validation loss decrease.''' + def save_checkpoint(self, val_loss: float, model: nn.Module): + """ + Saves the model when validation loss decreases. + + Args: + val_loss: The current validation loss. + model: The PyTorch model being trained. + """ + # If verbose mode is on, print a message about the validation loss decreasing and saving the model if self.verbose: - self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') - torch.save(model.state_dict(), self.path) - self.val_loss_min = val_loss + self.trace_func( + f"Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}). Saving model ..." + ) + + # Save the state of the model to the appropriate filename based on whether KFold is used or not + if self.use_kfold: + torch.save(model.state_dict(), self.filename) + else: + torch.save(model.state_dict(), self.path) + + # Update the minimum validation loss seen so far to the current validation loss + self.val_loss_min = val_loss \ No newline at end of file