From d71cc23907c9b6ec5944472689c408b1bdf1fd38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Reyes=20Hern=C3=A1ndez?= Date: Tue, 4 Apr 2023 17:09:23 +0200 Subject: [PATCH 1/2] Created a EarlyStoppingKFold class, black formatted the main script and update the .gitignore file --- .gitignore | 3 + pytorchtools.py | 163 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 149 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 30eab60..a0088fc 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,6 @@ venv.bak/ # mypy .mypy_cache/ + +# .DS_Store file from MacOS +**/.DS_Store diff --git a/pytorchtools.py b/pytorchtools.py index 9644e4b..6b43c44 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -1,22 +1,31 @@ import numpy as np import torch +from torch import nn + class EarlyStopping: - """Early stops the training if validation loss doesn't improve after a given patience.""" - def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): + """The EarlyStopping class monitors the validation loss during training and stop the training process early + if the validation loss does not improve after a certain number of epochs""" + + def __init__( + self, + patience: int = 7, + verbose: bool = False, + delta: float = 0, + path: str = "checkpoint.pt", + trace_func=print, + ): """ + Initializes the EarlyStopping object with the given parameters. + Args: - patience (int): How long to wait after last time validation loss improved. - Default: 7 - verbose (bool): If True, prints a message for each validation loss improvement. - Default: False - delta (float): Minimum change in the monitored quantity to qualify as an improvement. - Default: 0 - path (str): Path for the checkpoint to be saved to. - Default: 'checkpoint.pt' - trace_func (function): trace print function. - Default: print + patience: How long to wait after last time validation loss improved. + verbose: If True, prints a message for each validation loss improvement. + delta: Minimum change in the monitored quantity to qualify as an improvement. + path: Path for the checkpoint to be saved to. + trace_func: trace print function. """ + # Setting the instance variables to the values passed to the constructor self.patience = patience self.verbose = verbose self.counter = 0 @@ -26,26 +35,146 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra self.delta = delta self.path = path self.trace_func = trace_func - def __call__(self, val_loss, model): + def __call__(self, val_loss: float, model: nn.Module): + """ + This method is called during the training process to monitor the validation loss and decide whether to stop + the training process early or not. + + Args: + val_loss: Validation loss of the model at the current epoch. + model: The PyTorch model being trained. + """ + + # Calculating the score by negating the validation loss score = -val_loss + # If the best score is None, sets it to the current score and saves the checkpoint if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) + + # If the score is less than the best score plus delta, increments the counter + # and checks if the patience has been reached elif score < self.best_score + self.delta: self.counter += 1 - self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') + self.trace_func( + f"EarlyStopping counter: {self.counter} out of {self.patience}" + ) if self.counter >= self.patience: self.early_stop = True + + # If the score is better than the best score plus delta, saves the checkpoint and resets the counter else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 - def save_checkpoint(self, val_loss, model): - '''Saves model when validation loss decrease.''' + def save_checkpoint(self, val_loss: float, model: nn.Module): + """Saves model when validation loss decrease""" + # If verbose mode is on, print a message about the validation loss decreasing and saving the model if self.verbose: - self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + self.trace_func( + f"Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}). Saving model ..." + ) + + # Save the state of the model to a file specified by `self.path` torch.save(model.state_dict(), self.path) + + # Update the minimum validation loss seen so far to the current validation loss + self.val_loss_min = val_loss + + +# Implementing an Early Stopping class to save the best performing model on a given fold +class EarlyStoppingKFold: + """The EarlyStoppingKFold class monitors the validation loss during training and stop the training process early + if the validation loss does not improve after a certain number of epochs for a given fold + """ + + def __init__( + self, + patience: int = 7, + verbose: bool = False, + delta: float = 0, + path: str = "checkpoint.pt", + trace_func=print, + ): + """ + Initializes the EarlyStoppingKFold object with the given parameters. + + Args: + patience: How long to wait after last time validation loss improved. + verbose: If True, prints a message for each validation loss improvement. + delta: Minimum change in the monitored quantity to qualify as an improvement. + path: Path for the checkpoint to be saved to. + trace_func: trace print function. + """ + # Setting the instance variables to the values passed to the constructor + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + self.path = path + self.trace_func = trace_func + self.fold = None + self.filename = None + + def __call__(self, val_loss: float, model: nn.Module, fold: int): + """ + This method is called during the training process to monitor the validation loss and decide whether to stop + the training process early or not. + + Args: + val_loss: Validation loss of the model at the current epoch. + model: The PyTorch model being trained. + fold: The current fold of the KFold cross-validation. + """ + # If it's a new fold, resets the early stopping object and sets the filename to save the model + if fold != self.fold: + self.fold = fold + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.filename = self.path.replace(".pt", f"_fold_{fold}.pt") + + # Calculating the score by negating the validation loss + score = -val_loss + + # If the best score is None, sets it to the current score and saves the checkpoint + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model) + + # If the score is less than the best score plus delta, increments the counter + # and checks if the patience has been reached + elif score < self.best_score + self.delta: + self.counter += 1 + self.trace_func( + f"EarlyStopping counter: {self.counter} out of {self.patience}" + ) + if self.counter >= self.patience: + self.early_stop = True + + # If the score is better than the best score plus delta, saves the checkpoint and resets the counter + else: + self.best_score = score + self.save_checkpoint(val_loss, model) + self.counter = 0 + + def save_checkpoint(self, val_loss: float, model: nn.Module): + """Saves model when validation loss decrease""" + # If verbose mode is on, print a message about the validation loss decreasing and saving the model + if self.verbose: + self.trace_func( + f"Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}). Saving model ..." + ) + + # Save the state of the model to the filename specified for the current fold + torch.save(model.state_dict(), self.filename) + + # Update the minimum validation loss seen so far to the current validation loss self.val_loss_min = val_loss From ad0a7c6811309a55b0cc078133e8306376735e65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Reyes=20Hern=C3=A1ndez?= Date: Sun, 7 May 2023 19:41:17 +0200 Subject: [PATCH 2/2] Merged EarlyStopping and EarlyStoppingKFold into one single class --- pytorchtools.py | 129 ++++++++++++------------------------------------ 1 file changed, 31 insertions(+), 98 deletions(-) diff --git a/pytorchtools.py b/pytorchtools.py index 6b43c44..e9803fc 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -4,8 +4,11 @@ class EarlyStopping: - """The EarlyStopping class monitors the validation loss during training and stop the training process early - if the validation loss does not improve after a certain number of epochs""" + """ + EarlyStopping can be used to monitor the validation loss during training and stop the training process early + if the validation loss does not improve after a certain number of epochs. It can handle both KFold and + non-KFold cases. + """ def __init__( self, @@ -13,6 +16,7 @@ def __init__( verbose: bool = False, delta: float = 0, path: str = "checkpoint.pt", + use_kfold: bool = False, trace_func=print, ): """ @@ -23,9 +27,9 @@ def __init__( verbose: If True, prints a message for each validation loss improvement. delta: Minimum change in the monitored quantity to qualify as an improvement. path: Path for the checkpoint to be saved to. + use_kfold: If True, saves the model with the lowest loss metric for each fold. trace_func: trace print function. """ - # Setting the instance variables to the values passed to the constructor self.patience = patience self.verbose = verbose self.counter = 0 @@ -34,9 +38,12 @@ def __init__( self.val_loss_min = np.Inf self.delta = delta self.path = path + self.use_kfold = use_kfold self.trace_func = trace_func + self.fold = None + self.filename = None - def __call__(self, val_loss: float, model: nn.Module): + def __call__(self, val_loss: float, model: nn.Module, fold: int = None): """ This method is called during the training process to monitor the validation loss and decide whether to stop the training process early or not. @@ -44,7 +51,19 @@ def __call__(self, val_loss: float, model: nn.Module): Args: val_loss: Validation loss of the model at the current epoch. model: The PyTorch model being trained. + fold: The current fold of the KFold cross-validation. Required if use_kfold is True. """ + if self.use_kfold: + assert fold is not None, "Fold must be provided when use_kfold is True" + + # If it's a new fold, resets the early stopping object and sets the filename to save the model + if fold != self.fold: + self.fold = fold + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.filename = self.path.replace(".pt", f"_fold_{fold}.pt") # Calculating the score by negating the validation loss score = -val_loss @@ -71,110 +90,24 @@ def __call__(self, val_loss: float, model: nn.Module): self.counter = 0 def save_checkpoint(self, val_loss: float, model: nn.Module): - """Saves model when validation loss decrease""" - # If verbose mode is on, print a message about the validation loss decreasing and saving the model - if self.verbose: - self.trace_func( - f"Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}). Saving model ..." - ) - - # Save the state of the model to a file specified by `self.path` - torch.save(model.state_dict(), self.path) - - # Update the minimum validation loss seen so far to the current validation loss - self.val_loss_min = val_loss - - -# Implementing an Early Stopping class to save the best performing model on a given fold -class EarlyStoppingKFold: - """The EarlyStoppingKFold class monitors the validation loss during training and stop the training process early - if the validation loss does not improve after a certain number of epochs for a given fold - """ - - def __init__( - self, - patience: int = 7, - verbose: bool = False, - delta: float = 0, - path: str = "checkpoint.pt", - trace_func=print, - ): - """ - Initializes the EarlyStoppingKFold object with the given parameters. - - Args: - patience: How long to wait after last time validation loss improved. - verbose: If True, prints a message for each validation loss improvement. - delta: Minimum change in the monitored quantity to qualify as an improvement. - path: Path for the checkpoint to be saved to. - trace_func: trace print function. - """ - # Setting the instance variables to the values passed to the constructor - self.patience = patience - self.verbose = verbose - self.counter = 0 - self.best_score = None - self.early_stop = False - self.val_loss_min = np.Inf - self.delta = delta - self.path = path - self.trace_func = trace_func - self.fold = None - self.filename = None - - def __call__(self, val_loss: float, model: nn.Module, fold: int): """ - This method is called during the training process to monitor the validation loss and decide whether to stop - the training process early or not. + Saves the model when validation loss decreases. Args: - val_loss: Validation loss of the model at the current epoch. + val_loss: The current validation loss. model: The PyTorch model being trained. - fold: The current fold of the KFold cross-validation. """ - # If it's a new fold, resets the early stopping object and sets the filename to save the model - if fold != self.fold: - self.fold = fold - self.counter = 0 - self.best_score = None - self.early_stop = False - self.val_loss_min = np.Inf - self.filename = self.path.replace(".pt", f"_fold_{fold}.pt") - - # Calculating the score by negating the validation loss - score = -val_loss - - # If the best score is None, sets it to the current score and saves the checkpoint - if self.best_score is None: - self.best_score = score - self.save_checkpoint(val_loss, model) - - # If the score is less than the best score plus delta, increments the counter - # and checks if the patience has been reached - elif score < self.best_score + self.delta: - self.counter += 1 - self.trace_func( - f"EarlyStopping counter: {self.counter} out of {self.patience}" - ) - if self.counter >= self.patience: - self.early_stop = True - - # If the score is better than the best score plus delta, saves the checkpoint and resets the counter - else: - self.best_score = score - self.save_checkpoint(val_loss, model) - self.counter = 0 - - def save_checkpoint(self, val_loss: float, model: nn.Module): - """Saves model when validation loss decrease""" # If verbose mode is on, print a message about the validation loss decreasing and saving the model if self.verbose: self.trace_func( f"Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}). Saving model ..." ) - # Save the state of the model to the filename specified for the current fold - torch.save(model.state_dict(), self.filename) + # Save the state of the model to the appropriate filename based on whether KFold is used or not + if self.use_kfold: + torch.save(model.state_dict(), self.filename) + else: + torch.save(model.state_dict(), self.path) # Update the minimum validation loss seen so far to the current validation loss - self.val_loss_min = val_loss + self.val_loss_min = val_loss \ No newline at end of file