-
Notifications
You must be signed in to change notification settings - Fork 59
Open
Description
Here's the error I get:
[---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-15-37f467c4f834> in <module>
1 for epoch in range(1, epochs + 1):
----> 2 train(epoch)
3 test(epoch)
4 with torch.no_grad():
5 sample = torch.randn(2, 2048).to(device)
<ipython-input-13-8f191bde6513> in train(epoch)
6 optimizer.zero_grad()
7 recon_batch, mu, logvar = model(data)
----> 8 loss = loss_mse(recon_batch, data, mu, logvar)
9 loss.backward()
10 train_loss += loss.item()
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
<ipython-input-9-6c49edf3f96a> in forward(self, x_recon, x, mu, logvar)
5
6 def forward(self, x_recon, x, mu, logvar):
----> 7 loss_MSE = self.mse_loss(x_recon, x)
8 loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
9
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
443
444 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 445 return F.mse_loss(input, target, reduction=self.reduction)
446
447
~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in mse_loss(input, target, size_average, reduce, reduction)
2645 ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
2646 else:
-> 2647 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
2648 ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
2649 return ret
~/anaconda3/lib/python3.7/site-packages/torch/functional.py in broadcast_tensors(*tensors)
63 if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
64 return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 65 return _VF.broadcast_tensors(tensors)
66
67
RuntimeError: The size of tensor a (100) must match the size of tensor b (800) at non-singleton dimension 3
My images are of dimension 600x800.
Metadata
Metadata
Assignees
Labels
No labels