3636
3737import tensorflow as tf
3838
39-
4039# End-of-sentence marker (should correspond to the position of EOS in the
4140# RESERVED_TOKENS list in text_encoder.py)
4241EOS = 1
@@ -59,9 +58,10 @@ def _original_vocab(tmp_dir):
5958 vocab_filepath = os .path .join (tmp_dir , vocab_filename )
6059 if not os .path .exists (vocab_filepath ):
6160 generator_utils .maybe_download (tmp_dir , vocab_filename , vocab_url )
62- return set (
63- [text_encoder .native_to_unicode (l .strip ()) for l in
64- tf .gfile .Open (vocab_filepath )])
61+ return set ([
62+ text_encoder .native_to_unicode (l .strip ())
63+ for l in tf .gfile .Open (vocab_filepath )
64+ ])
6565
6666
6767def _replace_oov (original_vocab , line ):
@@ -81,19 +81,19 @@ def _replace_oov(original_vocab, line):
8181
8282
8383def _train_data_filenames (tmp_dir ):
84- return [os .path .join (
85- tmp_dir ,
86- "1-billion-word-language-modeling-benchmark-r13output" ,
87- "training-monolingual.tokenized.shuffled" ,
88- "news.en-%05d-of-00100" % i ) for i in xrange (1 , 100 )]
84+ return [
85+ os .path .join (tmp_dir ,
86+ "1-billion-word-language-modeling-benchmark-r13output" ,
87+ "training-monolingual.tokenized.shuffled" ,
88+ "news.en-%05d-of-00100" % i ) for i in xrange (1 , 100 )
89+ ]
8990
9091
9192def _dev_data_filename (tmp_dir ):
92- return os .path .join (
93- tmp_dir ,
94- "1-billion-word-language-modeling-benchmark-r13output" ,
95- "heldout-monolingual.tokenized.shuffled" ,
96- "news.en.heldout-00000-of-00050" )
93+ return os .path .join (tmp_dir ,
94+ "1-billion-word-language-modeling-benchmark-r13output" ,
95+ "heldout-monolingual.tokenized.shuffled" ,
96+ "news.en.heldout-00000-of-00050" )
9797
9898
9999def _maybe_download_corpus (tmp_dir ):
@@ -112,17 +112,18 @@ def _maybe_download_corpus(tmp_dir):
112112 corpus_tar .extractall (tmp_dir )
113113
114114
115- def _get_or_build_subword_text_encoder (tmp_dir ):
115+ def _get_or_build_subword_text_encoder (tmp_dir , vocab_filepath ):
116116 """Builds a SubwordTextEncoder based on the corpus.
117117
118118 Args:
119119 tmp_dir: directory containing dataset.
120+ vocab_filepath: path to store (or load) vocab.
121+
120122 Returns:
121123 a SubwordTextEncoder.
122124 """
123- filepath = os .path .join (tmp_dir , "lm1b_32k.subword_text_encoder" )
124- if tf .gfile .Exists (filepath ):
125- return text_encoder .SubwordTextEncoder (filepath )
125+ if tf .gfile .Exists (vocab_filepath ):
126+ return text_encoder .SubwordTextEncoder (vocab_filepath )
126127 _maybe_download_corpus (tmp_dir )
127128 original_vocab = _original_vocab (tmp_dir )
128129 token_counts = defaultdict (int )
@@ -138,7 +139,7 @@ def _get_or_build_subword_text_encoder(tmp_dir):
138139 break
139140 ret = text_encoder .SubwordTextEncoder ()
140141 ret .build_from_token_counts (token_counts , min_count = 5 )
141- ret .store_to_file (filepath )
142+ ret .store_to_file (vocab_filepath )
142143 return ret
143144
144145
@@ -152,7 +153,7 @@ def is_character_level(self):
152153
153154 @property
154155 def has_inputs (self ):
155- return True
156+ return False
156157
157158 @property
158159 def input_space_id (self ):
@@ -184,25 +185,26 @@ def targeted_vocab_size(self):
184185 def use_train_shards_for_dev (self ):
185186 return True
186187
187- def generator (self , tmp_dir , train , characters = False ):
188+ def generator (self , data_dir , tmp_dir , is_training ):
188189 """Generator for lm1b sentences.
189190
190191 Args:
191- tmp_dir: a string .
192- train: a boolean .
193- characters : a boolean
192+ data_dir: data dir .
193+ tmp_dir: tmp dir .
194+ is_training : a boolean.
194195
195196 Yields:
196197 A dictionary {"inputs": [0], "targets": [<subword ids>]}
197198 """
198199 _maybe_download_corpus (tmp_dir )
199200 original_vocab = _original_vocab (tmp_dir )
200- files = (_train_data_filenames (tmp_dir ) if train
201- else [_dev_data_filename (tmp_dir )])
202- if characters :
201+ files = (_train_data_filenames (tmp_dir )
202+ if is_training else [_dev_data_filename (tmp_dir )])
203+ if self . is_character_level :
203204 encoder = text_encoder .ByteTextEncoder ()
204205 else :
205- encoder = _get_or_build_subword_text_encoder (tmp_dir )
206+ vocab_filepath = os .path .join (data_dir , self .vocab_file )
207+ encoder = _get_or_build_subword_text_encoder (tmp_dir , vocab_filepath )
206208 for filepath in files :
207209 tf .logging .info ("filepath = %s" , filepath )
208210 for line in tf .gfile .Open (filepath ):
0 commit comments