-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathEmbedding.py
More file actions
65 lines (55 loc) · 2.54 KB
/
Embedding.py
File metadata and controls
65 lines (55 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch.nn as nn
import torch
import math
class PositionalEncoding(nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
#>>> pos_encoder = PositionalEncoding(d_model)
"""
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model) # [max_len, d_model]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # [d_model/2]
pe[:, 0::2] = torch.sin(position * div_term) # [max_len, d_model/2]
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1) # [max_len, 1, d_model]
self.register_buffer('pe', pe)
def forward(self, x): # [x_len, batch_size, d_model]
"""
:param x: [x_len, batch_size, emb_size]
:return: [x_len, batch_size, emb_size]
"""
x = x + self.pe[:x.size(0), :] # [batch_size, max_len, d_model]
return self.dropout(x)
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
"""
:param tokens: shape : [len, batch_size]
:return: shape: [len, batch_size, emb_size]
"""
def forward(self, tokens):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
if __name__ == '__main__':
x = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], dtype=torch.long)
x = x.reshape(5, 2) # [src_len, batch_size]
token_embedding = TokenEmbedding(vocab_size=11, emb_size=512)
x = token_embedding(tokens=x)
pos_embedding = PositionalEncoding(d_model=512)
x = pos_embedding(x=x)
print(x.shape)