-
Notifications
You must be signed in to change notification settings - Fork 132
Description
Hi,
First off thanks for this great contribution!
There seems to be an issue with the handling of then encoder_outputs in the pooler level when passing output_all_encoded_layers = True.
examples/examples/benchmarks/bert/src/bert_layers.py
Lines 676 to 689 in daddaef
| encoder_outputs = self.encoder( | |
| embedding_output, | |
| attention_mask, | |
| output_all_encoded_layers=output_all_encoded_layers, | |
| subset_mask=subset_mask) | |
| if masked_tokens_mask is None: | |
| sequence_output = encoder_outputs[-1] | |
| pooled_output = self.pooler( | |
| sequence_output) if self.pooler is not None else None | |
| else: | |
| # TD [2022-03-01]: the indexing here is very tricky. | |
| attention_mask_bool = attention_mask.bool() |
because when doing that, I'm getting:
File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/PatientTrajectoryForecasting/utils/bert_layers_mosa.py:567, in BertPooler.forward(self, hidden_states, pool)
561 def forward(self,
562 hidden_states: torch.Tensor,
563 pool: Optional[bool] = True) -> torch.Tensor:
564 # We "pool" the model by simply taking the hidden state corresponding
565 # to the first token.
566 first_token_tensor = hidden_states[:, 0] if pool else hidden_states
--> 567 pooled_output = self.dense(first_token_tensor)
568 pooled_output = self.activation(pooled_output)
569 return pooled_output
File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x54784 and 768x768)
I believe the issue is due to the padding function not being applied to the hidden layens before appending to the list in the bert encoder level:
examples/examples/benchmarks/bert/src/bert_layers.py
Lines 511 to 530 in daddaef
| all_encoder_layers = [] | |
| if subset_mask is None: | |
| for layer_module in self.layer: | |
| hidden_states = layer_module(hidden_states, | |
| cu_seqlens, | |
| seqlen, | |
| None, | |
| indices, | |
| attn_mask=attention_mask, | |
| bias=alibi_attn_mask) | |
| if output_all_encoded_layers: | |
| all_encoder_layers.append(hidden_states) | |
| # Pad inputs and mask. It will insert back zero-padded tokens. | |
| # Assume ntokens is total number of tokens (padded and non-padded) | |
| # and ntokens_unpad is total number of non-padded tokens. | |
| # Then padding performs the following de-compression: | |
| # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden] | |
| hidden_states = bert_padding_module.pad_input( | |
| hidden_states, indices, batch, seqlen) | |
| else: |
(Edit: yep this works, but not haven't checked for deps)
all_encoder_layers.append(bert_padding_module.pad_input(
hidden_states, indices, batch, seqlen))
The same thing should probably be done when the subset_mask is not None...
Thanks again for your contribution to the comunity!