Skip to content

Conversation

@aditya0by0
Copy link
Member

For inference, PyTorch Lightning automatically calls model.eval() and wraps the forward pass in torch.inference_mode().

When performing inference outside Lightning, using only model.eval() is not sufficient.

  • model.eval() does not disable gradients. It only switches layers like dropout and batchnorm into evaluation mode.
  • Gradients are still tracked by autograd unless explicitly disabled.

Disabling gradient tracking is important because it:

  1. Reduces memory usage by avoiding storing intermediate activations for backpropagation.
  2. Speeds up inference by skipping autograd bookkeeping.

For this purpose, torch.inference_mode() is recommended. It is newer, faster, and more restrictive than torch.no_grad(), making it ideal for inference.

References:

@aditya0by0 aditya0by0 self-assigned this Nov 22, 2025
@aditya0by0
Copy link
Member Author

@aditya0by0
Copy link
Member Author

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return torch.relu(self.linear(x))

# Original model
model = MyModel()

# Compile for faster execution
model = torch.compile(model, mode="default")

x = torch.randn(32, 10)
y = model(x)  # runs faster

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants