diff --git a/trainvae.py b/trainvae.py index 1da4403..3364cd3 100644 --- a/trainvae.py +++ b/trainvae.py @@ -91,8 +91,8 @@ def train(epoch): for batch_idx, data in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() - recon_batch, mu, logvar = model(data) - loss = loss_function(recon_batch, data, mu, logvar) + recon_batch, mu, logsigma = model(data) + loss = loss_function(recon_batch, data, mu, logsigma) loss.backward() train_loss += loss.item() optimizer.step() @@ -114,8 +114,8 @@ def test(): with torch.no_grad(): for data in test_loader: data = data.to(device) - recon_batch, mu, logvar = model(data) - test_loss += loss_function(recon_batch, data, mu, logvar).item() + recon_batch, mu, logsigma = model(data) + test_loss += loss_function(recon_batch, data, mu, logsigma).item() test_loss /= len(test_loader.dataset) print('====> Test set loss: {:.4f}'.format(test_loss))