Skip to content

6. Loss Functions

EthanTreg edited this page Jun 12, 2025 · 1 revision

Custom Loss Functions

In the section 2. Network Architectures and Training, some network architectures can have their loss functions replaced with custom ones; however, to enable the save saving of the architecture state, custom loss functions should be created using BaseLoss from netloader.loss_funcs.

Currently, the supported loss functions are:

  • MSELoss
  • CrossEntropyLoss

BaseLoss

The BaseLoss is the parent class of all loss functions; therefore, the methods of the class will be what all loss functions build off.

Private Attributes:

  • _args: tuple[Any, ...], arguments to pass to the initialisation of a loss function
  • _kwargs: dict[str, Any], optional keyword arguments to pass to the initialisation of a loss function
  • _loss_func: Callable, loss function object

Methods of BaseLoss

Public Methods:

  • forward: Forward pass of the loss function
    • output: Tensor, output from the network with shape (N,...), where N is the number of elements
    • target: Tensor, target values with shape (N,...)
    • return: Tensor, loss value with shape (1)

Magic Methods:

  • __repr__: Returns a string representation of the loss function
    • return: str`, string representation of the loss function
  • __getstate__: Returns a dictionary containing the state of the loss function for pickling
    • _return_: dict[str, Any], dictionary containing the state of the loss function
  • __setstate__: Sets the state of the loss function for pickling
    • state: dict[str, Any], dictionary containing the state of the loss function

Initialisation Arguments:

  • loss_func: type, loss function class to be used
  • *args: Any, optional arguments to be passed to loss_func
  • **kwargs: Any, optional keyword arguments to be passed to loss_func

MSELoss

The mean squared error loss function is identical to MSELoss from PyTorch, but safe for saving.

Initialisation Arguments:

  • *args: Any, optional arguments to be passed to MSELoss
  • **kwargs: Any, optional keyword arguments to be passed to MSELoss

CrossEntropyLoss

The cross entropy loss function is identical to CrossEntropyLoss from PyTorch, but safe for saving.

Initialisation Arguments:

  • *args: Any, optional arguments to be passed to CrossEntropyLoss
  • **kwargs: Any, optional keyword arguments to be passed to CrossEntropyLoss

Creating Custom Loss Functions

If one of the current loss functions is not suitable for your problem, you can extend the BaseLoss class to make custom loss functions.

Creating MSELoss from BaseLoss

First, the new class must inherit BaseLoss, then define the initialisation method. The __init__ method is defined as:

from typing import Any

from torch import nn
from netloader.loss_funcs import BaseLoss


class MSELoss(BaseLoss):
    def __init__(self, *args: Any, **kwargs: Any):
        super().__init__(nn.MSELoss, *args, **kwargs)

Finally, for save saving, the loss function is not saved; therefore, when loading the loss function from a saved state, we need to reinitialise the loss function; therefore, we define __setstate__:

def __setstate__(self, state: dict[str, Any]) -> None:
    super().__setstate__(state)
    self.loss_func = nn.MSELoss(*self._args, **self._kwargs)

Therefore, the full class is:

from typing import Any

from torch import nn
from netloader.loss_funcs import BaseLoss


class MSELoss(BaseLoss):
    def __init__(self, *args: Any, **kwargs: Any):
        super().__init__(nn.MSELoss, *args, **kwargs)

    def __setstate__(self, state: dict[str, Any]) -> None:
        super().__setstate__(state)
        self.loss_func = nn.MSELoss(*self._args, **self._kwargs)

The forward method by default passes the network output and target into the loss function; however, if the loss function is more complex and requires manipulation of the data, the forward method can be redefined in your child class.

Clone this wiki locally