diff --git a/utils/misc.py b/utils/misc.py index 1be3165..40a0a61 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -156,9 +156,15 @@ def get_action_and_transition(self, obs, hidden): - action: 1D np array - next_hidden (1 x 256) torch tensor """ - _, latent_mu, _ = self.vae(obs) - action = self.controller(latent_mu, hidden[0]) - _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden) + _, mu, logsigma = self.vae(obs) + + # Get latent variable z + sigma = logsigma.exp() + eps = torch.randn_like(sigma) + z = eps.mul(sigma).add_(mu) + + action = self.controller(z, hidden[0]) + _, _, _, _, _, next_hidden = self.mdrnn(action, z, hidden) return action.squeeze().cpu().numpy(), next_hidden def rollout(self, params, render=False):