@@ -273,7 +273,7 @@ def __init__(self,
273273 critic_network_cls , q_network_cls )
274274
275275 self ._use_entropy_reward = use_entropy_reward
276- self ._munchausen_reward_weight = min (0 , munchausen_reward_weight )
276+ self ._munchausen_reward_weight = max (0 , munchausen_reward_weight )
277277 if munchausen_reward_weight > 0 :
278278 assert not normalize_entropy_reward , (
279279 "should not normalize entropy "
@@ -853,10 +853,23 @@ def _calc_critic_loss(self, info: SacInfo):
853853 (There is an issue in their implementation: their "terminals" can't
854854 differentiate between discount=0 (NormalEnd) and discount=1 (TimeOut).
855855 In the latter case, masking should not be performed.)
856-
857- When the reward is multi-dim, the entropy reward will be added to *all*
858- dims.
859856 """
857+ if self ._use_entropy_reward :
858+ with torch .no_grad ():
859+ log_pi = info .log_pi
860+ if self ._entropy_normalizer is not None :
861+ log_pi = self ._entropy_normalizer .normalize (log_pi )
862+ entropy_reward = nest .map_structure (
863+ lambda la , lp : - torch .exp (la ) * lp , self ._log_alpha ,
864+ log_pi )
865+ entropy_reward = sum (nest .flatten (entropy_reward ))
866+ discount = self ._critic_losses [0 ].gamma * info .discount
867+ # When the reward is multi-dim, the entropy reward will be
868+ # added to *all* dims.
869+ info = info ._replace (
870+ reward = (info .reward + common .expand_dims_as (
871+ entropy_reward * discount , info .reward )))
872+
860873 if self ._munchausen_reward_weight > 0 :
861874 with torch .no_grad ():
862875 # calculate the log probability of the rollout action
@@ -875,26 +888,22 @@ def _calc_critic_loss(self, info: SacInfo):
875888 munchausen_reward = nest .map_structure (
876889 lambda la , lp : torch .exp (la ) * lp , self ._log_alpha ,
877890 log_pi_rollout_a )
891+ # [T, B]
878892 munchausen_reward = sum (nest .flatten (munchausen_reward ))
893+ # forward shift the munchausen reward one-step temporally,
894+ # with zero-padding for the first step. This dummy reward
895+ # for the first step does not impact training as it is not
896+ # used in TD-learning.
897+ munchausen_reward = torch .cat ((torch .zeros_like (
898+ munchausen_reward [0 :1 ]), munchausen_reward [:- 1 ]),
899+ dim = 0 )
900+ # When the reward is multi-dim, the munchausen reward will be
901+ # added to *all* dims.
879902 info = info ._replace (
880903 reward = (
881904 info .reward + self ._munchausen_reward_weight *
882905 common .expand_dims_as (munchausen_reward , info .reward )))
883906
884- if self ._use_entropy_reward :
885- with torch .no_grad ():
886- log_pi = info .log_pi
887- if self ._entropy_normalizer is not None :
888- log_pi = self ._entropy_normalizer .normalize (log_pi )
889- entropy_reward = nest .map_structure (
890- lambda la , lp : - torch .exp (la ) * lp , self ._log_alpha ,
891- log_pi )
892- entropy_reward = sum (nest .flatten (entropy_reward ))
893- discount = self ._critic_losses [0 ].gamma * info .discount
894- info = info ._replace (
895- reward = (info .reward + common .expand_dims_as (
896- entropy_reward * discount , info .reward )))
897-
898907 critic_info = info .critic
899908 critic_losses = []
900909 for i , l in enumerate (self ._critic_losses ):
0 commit comments