-
Couldn't load subscription status.
- Fork 132
Description
Describe the bug
For torchinfo 1.8 on a Mac with M2 chip, the following code resulted in a runtime error:
device = torch.accelerator.current_accelerator()
model = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10)).to(device)
summary(model, input_size=(batch_size, 3, 32, 32))
...
out = model(data)
with the error message
RuntimeError: Tensor for argument weight is on cpu but expected on mps
The same code ran fine on Linux with a Nvidia card.
Expected behavior
No runtime error.
Desktop
- OS: Apple Sequoia 15.5
- CPU: M2 chip ("Apple silicon")
- torchinfo version: 1.8
Quick workarounds
- (preferred) add "device=" to the call of torchinfo.summary(). E.g.
summary(model, input_size=(batch_size, 3, 32, 32), device=device) - (also works) push the model back to the device (
model = model.to(device)) after a call totorchinfo.summary()
Cause of bug
In torchinfo.py, the function get_device() seems to be focused on recognising only CUDA as
accelerator, whereas other platforms may have different accelerators. E.g., M-chip Macs
have "mps".
This apparently leads to torchinfo pushing the model to the "cpu" when "device=" was
not given in the call to summary(), which then leads to a runtime error during model training
(or evaluation) when the data is on the accelerator and the model (or parts of it) are
on the CPU.
Bug fix
I have create a PR that should fix the bug for any accelerator recognised by PyTorch.