-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenizer.py
More file actions
163 lines (130 loc) · 4.91 KB
/
tokenizer.py
File metadata and controls
163 lines (130 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from typing import Iterator, List
import torch
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
class Tokenizer:
"""Tokenizer class for tokenizing and detokenizing text data."""
special_symbols = ["<unk>", "<pad>", "<bos>", "<eos>"]
def __init__(
self,
src_model: str,
tgt_model: str,
vocab_size: int,
dataset: torchtext.datasets,
) -> None:
"""Constructor for Tokenizer class.
:param src_model: the name of the source language model
:type src_model: str
:param tgt_model: the name of the target language model
:type tgt_model: str
:param dataset: the vocab size
:type dataset: int
:param dataset: the dataset to build the vocabulary from
:type dataset: torchtext.datasets
:return: None
"""
self.__src = get_tokenizer("spacy", language=src_model)
self.__tgt = get_tokenizer("spacy", language=tgt_model)
self.__vocab_src = build_vocab_from_iterator(
self.__yield_tokens(dataset),
min_freq=1,
specials=self.special_symbols,
special_first=True,
max_tokens=vocab_size,
)
self.__vocab_tgt = build_vocab_from_iterator(
self.__yield_tokens(dataset, src=False),
min_freq=1,
specials=self.special_symbols,
special_first=True,
max_tokens=vocab_size,
)
self.__vocab_src.set_default_index(0)
self.__vocab_tgt.set_default_index(0)
def __yield_tokens(
self, data_iter: torchtext.datasets, src: bool = True
) -> Iterator:
"""Function to yield tokens from a dataset.
:param data_iter: the dataset to yield tokens from
:type data_iter: torchtext.datasets
:param src: the vocab model to use, defaults to True
:type src: bool, optional
:yield: the tokens from the dataset
:rtype: Iterator
"""
for data_sample in data_iter:
if src:
yield self.tokenize(data_sample[0], src)
else:
yield self.tokenize(data_sample[1], src)
def tokenize(self, text: str, src: bool) -> str:
"""Function to tokenize text data.
:param text: the text data to tokenize
:type text: str
:param src: the vocab model to use, defaults to True
:type src: bool
:return: the tokenized text data
:rtype: str
"""
if src:
return self.__src(text)
return self.__tgt(text)
def detokenize(self, tokens: List[str]) -> str:
"""Function to detokenize text data.
:param tokens: the tokenized text data
:type tokens: List[str]
:return: the detokenized text data
:rtype: str
"""
return " ".join(tokens)
def create_label(self, target: torch.Tensor) -> torch.Tensor:
"""Function to create label data for the training process
:param target: the target text (encoded)
:type target: torch.Tensor
:return: the label
:rtype: torch.Tensor
"""
return torch.cat(
(
target[1:],
torch.tensor([self.special_symbols.index("<eos>")])
)
)
def string_to_vocab(self, text: str, src: bool = True) -> torch.Tensor:
"""Function to convert text data to vocabulary.
:param text: the text data to convert
:type text: str
:param src: the model to use to convert the text, defaults to True
:type src: bool, optional
:return: the vocabulary representation of the text data
:rtype: torch.Tensor
"""
if not src:
tokens = self.tokenize(text, src)
return torch.cat(
(
torch.tensor([self.special_symbols.index("<bos>")]),
torch.tensor(self.__vocab_tgt(tokens)),
)
)
tokens = self.tokenize(text, src)
return torch.cat(
(
torch.tensor([self.special_symbols.index("<bos>")]),
torch.tensor(self.__vocab_src(tokens)),
torch.tensor([self.special_symbols.index("<eos>")]),
)
)
def vocab_to_string(self, vocab: torch.Tensor, src: bool) -> str:
"""Function to convert vocabulary to text data.
:param vocab: the vocabulary to convert
:type vocab: torch.Tensor
:param src: the model to use to convert the vocabulary, defaults to True
:type src: bool
:return: the text data representation of the vocabulary
:rtype: str
"""
if not src:
return self.detokenize(self.__vocab_tgt.lookup_tokens(vocab.tolist()))
return self.detokenize(self.__vocab_src.lookup_tokens(vocab.tolist()))