Skip to content

Enhancing output with memory estimations for every layer #373

@take2rohit

Description

@take2rohit

I’m always frustrated when analyzing large models where memory bottlenecks occur, especially during training. While Torchview provides a clear view of parameter sizes and layer hierarchy, it lacks estimated memory usage for forward pass, backward pass, and activations per layer, which are crucial for optimizing memory on multi-GPU setups or when applying checkpointing.

Please add an option to include estimated memory usage per layer in the tabular output and/or visualization. For each layer, it would be very helpful to show:

  • Forward pass memory (input × weight, if applicable)
  • Backward pass memory (including gradients and intermediates)
  • Activation memory (based on output size and dtype)
  • Total estimated memory per layer

This would allow developers to identify layers that contribute disproportionately to memory use — and help them decide where to apply checkpointing, tensor parallelism, or other memory-saving strategies.

  • Manually estimating memory using tensor sizes and dtype assumptions — but this is tedious and error-prone.
  • Using torch.profiler, but it doesn’t attribute memory to layers in the same intuitive per-layer hierarchical format as Torchview.
| Layer                  | Output Shape     | Params    | Activation (MB) | Forward (MB) | Backward (MB) | Total (MB) |
|------------------------|------------------|-----------|------------------|---------------|----------------|-------------|
| Conv2d                 | [64, 128, 56, 56] | 73.8K     | 102.76           | 205.52        | 308.28         | 616.56      |
| ReLU                   | [64, 128, 56, 56] | 0         | 102.76           | negligible    | negligible     | 102.76      |
  • This would help memory debugging especially for:
  • Large attention models (e.g., Transformers, ViTs)
  • High-resolution image models
  • Multi-GPU strategies where memory must be estimated per rank
  • A toggle option like show_memory=True in torchview.draw_graph() or torchview.summary() would be perfect.
  • Could assume dtype=torch.float32 unless overridden

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions