File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change 1+ # -*- coding: utf-8 -*-
2+ from __future__ import print_function , division
3+
4+ import torch
5+ import torch .nn as nn
6+
7+ class CombinedLoss (nn .Module ):
8+ def __init__ (self , params , loss_dict ):
9+ super (CombinedLoss , self ).__init__ ()
10+ loss_names = params ['loss_type' ]
11+ self .loss_weight = params ['loss_weight' ]
12+ assert (len (loss_names ) == len (self .loss_weight ))
13+ self .loss_list = []
14+ for loss_name in loss_names :
15+ if (loss_name in loss_dict ):
16+ one_loss = loss_dict [loss_name ](params )
17+ self .loss_list .append (one_loss )
18+ else :
19+ raise ValueError ("{0:} is not defined, or has not been added to the \
20+ loss dictionary" .format (loss_name ))
21+
22+ def forward (self , loss_input_dict ):
23+ loss_value = 0.0
24+ for i in range (len (self .loss_list )):
25+ loss_value = self .loss_weight [i ] + self .loss_list [i ](loss_input_dict )
26+ return loss_value
You can’t perform that action at this time.
0 commit comments