Skip to content

Commit b330630

Browse files
committed
Fix alignment
1 parent 28f1cef commit b330630

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

alf/algorithms/sac_algorithm.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __init__(self,
273273
critic_network_cls, q_network_cls)
274274

275275
self._use_entropy_reward = use_entropy_reward
276-
self._munchausen_reward_weight = min(0, munchausen_reward_weight)
276+
self._munchausen_reward_weight = max(0, munchausen_reward_weight)
277277
if munchausen_reward_weight > 0:
278278
assert not normalize_entropy_reward, (
279279
"should not normalize entropy "
@@ -853,10 +853,23 @@ def _calc_critic_loss(self, info: SacInfo):
853853
(There is an issue in their implementation: their "terminals" can't
854854
differentiate between discount=0 (NormalEnd) and discount=1 (TimeOut).
855855
In the latter case, masking should not be performed.)
856-
857-
When the reward is multi-dim, the entropy reward will be added to *all*
858-
dims.
859856
"""
857+
if self._use_entropy_reward:
858+
with torch.no_grad():
859+
log_pi = info.log_pi
860+
if self._entropy_normalizer is not None:
861+
log_pi = self._entropy_normalizer.normalize(log_pi)
862+
entropy_reward = nest.map_structure(
863+
lambda la, lp: -torch.exp(la) * lp, self._log_alpha,
864+
log_pi)
865+
entropy_reward = sum(nest.flatten(entropy_reward))
866+
discount = self._critic_losses[0].gamma * info.discount
867+
# When the reward is multi-dim, the entropy reward will be
868+
# added to *all* dims.
869+
info = info._replace(
870+
reward=(info.reward + common.expand_dims_as(
871+
entropy_reward * discount, info.reward)))
872+
860873
if self._munchausen_reward_weight > 0:
861874
with torch.no_grad():
862875
# calculate the log probability of the rollout action
@@ -875,26 +888,22 @@ def _calc_critic_loss(self, info: SacInfo):
875888
munchausen_reward = nest.map_structure(
876889
lambda la, lp: torch.exp(la) * lp, self._log_alpha,
877890
log_pi_rollout_a)
891+
# [T, B]
878892
munchausen_reward = sum(nest.flatten(munchausen_reward))
893+
# forward shift the munchausen reward one-step temporally,
894+
# with zero-padding for the first step. This dummy reward
895+
# for the first step does not impact training as it is not
896+
# used in TD-learning.
897+
munchausen_reward = torch.cat((torch.zeros_like(
898+
munchausen_reward[0:1]), munchausen_reward[:-1]),
899+
dim=0)
900+
# When the reward is multi-dim, the munchausen reward will be
901+
# added to *all* dims.
879902
info = info._replace(
880903
reward=(
881904
info.reward + self._munchausen_reward_weight *
882905
common.expand_dims_as(munchausen_reward, info.reward)))
883906

884-
if self._use_entropy_reward:
885-
with torch.no_grad():
886-
log_pi = info.log_pi
887-
if self._entropy_normalizer is not None:
888-
log_pi = self._entropy_normalizer.normalize(log_pi)
889-
entropy_reward = nest.map_structure(
890-
lambda la, lp: -torch.exp(la) * lp, self._log_alpha,
891-
log_pi)
892-
entropy_reward = sum(nest.flatten(entropy_reward))
893-
discount = self._critic_losses[0].gamma * info.discount
894-
info = info._replace(
895-
reward=(info.reward + common.expand_dims_as(
896-
entropy_reward * discount, info.reward)))
897-
898907
critic_info = info.critic
899908
critic_losses = []
900909
for i, l in enumerate(self._critic_losses):

0 commit comments

Comments
 (0)