-
Couldn't load subscription status.
- Fork 1.4k
Description
import torch
tgt = torch.tensor([[1, 2, 3, 0, 0]])
pad = 0
padding_mask = (tgt != pad).unsqueeze(-2) # shape: [1, 1, 5]
sub_mask = torch.triu(torch.ones((1, 5, 5)), diagonal=1).bool()
sub_mask = ~sub_mask
final_mask = padding_mask & sub_mask # shape: [1, 5, 5]
print(final_mask[0].int())
Assume the input is [1, 2, 3, 0, 0]
Output is
tensor([[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0]], dtype=torch.int32)
When changed to padding_mask = (tgt != pad).unsqueeze(-1)
Obtain the correct output
tensor([[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]], dtype=torch.int32)