-
Notifications
You must be signed in to change notification settings - Fork 0
6. Loss Functions
In the section 2. Network Architectures and Training, some network architectures can have their loss
functions replaced with custom ones; however, to enable the save saving of the architecture state, custom loss functions
should be created using BaseLoss from netloader.loss_funcs.
Currently, the supported loss functions are:
MSELossCrossEntropyLoss
The BaseLoss is the parent class of all loss functions; therefore, the methods of the class will be what all loss
functions build off.
-
_args:tuple[Any, ...], arguments to pass to the initialisation of a loss function -
_kwargs:dict[str, Any], optional keyword arguments to pass to the initialisation of a loss function -
_loss_func:Callable, loss function object
-
forward: Forward pass of the loss function-
output:Tensor, output from the network with shape (N,...), where N is the number of elements -
target:Tensor, target values with shape (N,...) -
return:
Tensor, loss value with shape (1)
-
-
__repr__: Returns a string representation of the loss function-
return
:str`, string representation of the loss function
-
return
-
__getstate__: Returns a dictionary containing the state of the loss function for pickling-
_return_:dict[str, Any], dictionary containing the state of the loss function
-
-
__setstate__: Sets the state of the loss function for pickling-
state:dict[str, Any], dictionary containing the state of the loss function
-
-
loss_func:type, loss function class to be used -
*args:Any, optional arguments to be passed toloss_func -
**kwargs:Any, optional keyword arguments to be passed toloss_func
The mean squared error loss function is identical to MSELoss from PyTorch, but safe for saving.
-
*args:Any, optional arguments to be passed toMSELoss -
**kwargs:Any, optional keyword arguments to be passed toMSELoss
The cross entropy loss function is identical to CrossEntropyLoss from PyTorch, but safe for saving.
-
*args:Any, optional arguments to be passed toCrossEntropyLoss -
**kwargs:Any, optional keyword arguments to be passed toCrossEntropyLoss
If one of the current loss functions is not suitable for your problem, you can extend the BaseLoss class to make
custom loss functions.
First, the new class must inherit BaseLoss, then define the initialisation method.
The __init__ method is defined as:
from typing import Any
from torch import nn
from netloader.loss_funcs import BaseLoss
class MSELoss(BaseLoss):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(nn.MSELoss, *args, **kwargs)Finally, for save saving, the loss function is not saved; therefore, when loading the loss function from a saved state,
we need to reinitialise the loss function; therefore, we define __setstate__:
def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
self.loss_func = nn.MSELoss(*self._args, **self._kwargs)Therefore, the full class is:
from typing import Any
from torch import nn
from netloader.loss_funcs import BaseLoss
class MSELoss(BaseLoss):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(nn.MSELoss, *args, **kwargs)
def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
self.loss_func = nn.MSELoss(*self._args, **self._kwargs)The forward method by default passes the network output and target into the loss function; however, if the loss function is more complex and requires manipulation of the data, the forward method can be redefined in your child class.