From b1cbaed886b41c48ed4d748c23d5d8efe27c1ecf Mon Sep 17 00:00:00 2001 From: Alejandro Date: Tue, 3 Sep 2019 11:01:11 +0200 Subject: [PATCH] Fixed controller misuse of latent_mu --- utils/misc.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/utils/misc.py b/utils/misc.py index 1be3165..40a0a61 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -156,9 +156,15 @@ def get_action_and_transition(self, obs, hidden): - action: 1D np array - next_hidden (1 x 256) torch tensor """ - _, latent_mu, _ = self.vae(obs) - action = self.controller(latent_mu, hidden[0]) - _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden) + _, mu, logsigma = self.vae(obs) + + # Get latent variable z + sigma = logsigma.exp() + eps = torch.randn_like(sigma) + z = eps.mul(sigma).add_(mu) + + action = self.controller(z, hidden[0]) + _, _, _, _, _, next_hidden = self.mdrnn(action, z, hidden) return action.squeeze().cpu().numpy(), next_hidden def rollout(self, params, render=False):