Skip to content

M2 Mac: Runtime error in training of model after call to torchinfo.summary() #371

@DrMicrobit

Description

@DrMicrobit

Describe the bug
For torchinfo 1.8 on a Mac with M2 chip, the following code resulted in a runtime error:

device = torch.accelerator.current_accelerator()
model = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10)).to(device)
summary(model, input_size=(batch_size, 3, 32, 32))
...
out = model(data)

with the error message

RuntimeError: Tensor for argument weight is on cpu but expected on mps

The same code ran fine on Linux with a Nvidia card.

Expected behavior
No runtime error.

Desktop

  • OS: Apple Sequoia 15.5
  • CPU: M2 chip ("Apple silicon")
  • torchinfo version: 1.8

Quick workarounds

  1. (preferred) add "device=" to the call of torchinfo.summary(). E.g. summary(model, input_size=(batch_size, 3, 32, 32), device=device)
  2. (also works) push the model back to the device (model = model.to(device)) after a call to torchinfo.summary()

Cause of bug
In torchinfo.py, the function get_device() seems to be focused on recognising only CUDA as
accelerator, whereas other platforms may have different accelerators. E.g., M-chip Macs
have "mps".
This apparently leads to torchinfo pushing the model to the "cpu" when "device=" was
not given in the call to summary(), which then leads to a runtime error during model training
(or evaluation) when the data is on the accelerator and the model (or parts of it) are
on the CPU.

Bug fix
I have create a PR that should fix the bug for any accelerator recognised by PyTorch.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions