-
Couldn't load subscription status.
- Fork 132
Description
Is your feature request related to a problem? Please describe.
Weight sharing is a common technique for modern LLM to reduce the number of parameters. For example, in transformers models, you can set tie_word_embeddings to make the embeddings layers in encoder, decoder and lm_head layers sharing the same weights.
Currently the torchinfo.summary cannot distinguish the parameters are shared or not. To reproduce for example:
import torch
import torchinfo
from transformers.models.mbart.modeling_mbart import MBartConfig, MBartForCausalLM
DECODER_CONFIG = MBartConfig(
vocab_size=50000,
max_position_embeddings=1024,
d_model=384,
decoder_layers=2,
decoder_attention_heads=16,
decoder_ffn_dim=1536,
decoder_start_token_id=0,
layer_norm_eps=1e-05,
is_decoder=True,
scale_embedding=True,
tie_word_embeddings=True, #switch to False to see
)
decoder = MBartForCausalLM(DECODER_CONFIG)
input_data = {
"input_ids":torch.randint(0, 50000, (16, 256), dtype=torch.int32),
"attention_mask": torch.ones(16, 256, dtype=torch.int32),
"encoder_hidden_states": torch.randn(16, 144, 384)
}
torchinfo.summary(decoder, input_data=input_data)which produces whenever the tie_word_embeddings is True or False
=========================================================================================================
Layer (type:depth-idx) Output Shape Param #
=========================================================================================================
MBartForCausalLM [16, 16, 256, 24] --
├─MBartDecoderWrapper: 1-1 -- --
│ └─MBartDecoder: 2-1 [16, 16, 256, 24] --
│ │ └─Embedding: 3-1 [16, 256, 384] 19,200,000
│ │ └─MBartLearnedPositionalEmbedding: 3-2 [16, 256, 384] 393,984
│ │ └─LayerNorm: 3-3 [16, 256, 384] 768
│ │ └─ModuleList: 3-4 -- 4,733,184
│ │ └─LayerNorm: 3-5 [16, 256, 384] 768
├─Linear: 1-2 [16, 256, 50000] 19,200,000
=========================================================================================================
Total params: 43,528,704
Trainable params: 43,528,704
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 696.46
=========================================================================================================
Input size (MB): 3.57
Forward/backward pass size (MB): 2069.36
Params size (MB): 174.11
Estimated Total Size (MB): 2247.05
=========================================================================================================
Also see #358 #322 #303, the new feature would also close them.
Describe the solution you'd like
To better estimate the number of params, one possibility is to keep tracking of the id of model params to remove duplication and report counts with/without duplication. We can introduce an argument like no_duplication to enable it.
Describe alternatives you've considered
No
Additional context
Version info
torchinfo==1.8.0
transformers==4.40.0