@@ -62,13 +62,15 @@ def num_shards(self):
6262 return 10
6363
6464 def generate_data (self , data_dir , _ , task_id = - 1 ):
65+
6566 def generator_eos (nbr_symbols , max_length , nbr_cases ):
6667 """Shift by NUM_RESERVED_IDS and append EOS token."""
6768 for case in self .generator (nbr_symbols , max_length , nbr_cases ):
6869 new_case = {}
6970 for feature in case :
70- new_case [feature ] = [i + text_encoder .NUM_RESERVED_TOKENS
71- for i in case [feature ]] + [text_encoder .EOS_ID ]
71+ new_case [feature ] = [
72+ i + text_encoder .NUM_RESERVED_TOKENS for i in case [feature ]
73+ ] + [text_encoder .EOS_ID ]
7274 yield new_case
7375
7476 utils .generate_dataset_and_shuffle (
@@ -154,10 +156,7 @@ def generator(self, nbr_symbols, max_length, nbr_cases):
154156 for _ in xrange (nbr_cases ):
155157 l = np .random .randint (max_length ) + 1
156158 inputs = [np .random .randint (nbr_symbols - shift ) for _ in xrange (l )]
157- yield {
158- "inputs" : inputs ,
159- "targets" : [i + shift for i in inputs ]
160- }
159+ yield {"inputs" : inputs , "targets" : [i + shift for i in inputs ]}
161160
162161 @property
163162 def dev_length (self ):
@@ -191,10 +190,7 @@ def generator(self, nbr_symbols, max_length, nbr_cases):
191190 for _ in xrange (nbr_cases ):
192191 l = np .random .randint (max_length ) + 1
193192 inputs = [np .random .randint (nbr_symbols ) for _ in xrange (l )]
194- yield {
195- "inputs" : inputs ,
196- "targets" : list (reversed (inputs ))
197- }
193+ yield {"inputs" : inputs , "targets" : list (reversed (inputs ))}
198194
199195
200196@registry .register_problem
@@ -272,10 +268,7 @@ def reverse_generator_nlplike(nbr_symbols,
272268 for _ in xrange (nbr_cases ):
273269 l = int (abs (np .random .normal (loc = max_length / 2 , scale = std_dev )) + 1 )
274270 inputs = zipf_random_sample (distr_map , l )
275- yield {
276- "inputs" : inputs ,
277- "targets" : list (reversed (inputs ))
278- }
271+ yield {"inputs" : inputs , "targets" : list (reversed (inputs ))}
279272
280273
281274@registry .register_problem
@@ -287,8 +280,8 @@ def num_symbols(self):
287280 return 8000
288281
289282 def generator (self , nbr_symbols , max_length , nbr_cases ):
290- return reverse_generator_nlplike (
291- nbr_symbols , max_length , nbr_cases , 10 , 1.300 )
283+ return reverse_generator_nlplike (nbr_symbols , max_length , nbr_cases , 10 ,
284+ 1.300 )
292285
293286 @property
294287 def train_length (self ):
@@ -308,8 +301,8 @@ def num_symbols(self):
308301 return 32000
309302
310303 def generator (self , nbr_symbols , max_length , nbr_cases ):
311- return reverse_generator_nlplike (
312- nbr_symbols , max_length , nbr_cases , 10 , 1.050 )
304+ return reverse_generator_nlplike (nbr_symbols , max_length , nbr_cases , 10 ,
305+ 1.050 )
313306
314307
315308def lower_endian_to_number (l , base ):
@@ -431,3 +424,28 @@ class AlgorithmicMultiplicationDecimal40(AlgorithmicMultiplicationBinary40):
431424 @property
432425 def num_symbols (self ):
433426 return 10
427+
428+
429+ @registry .register_problem
430+ class AlgorithmicReverseBinary40Test (AlgorithmicReverseBinary40 ):
431+ """Test Problem with tiny dataset."""
432+
433+ @property
434+ def train_length (self ):
435+ return 10
436+
437+ @property
438+ def dev_length (self ):
439+ return 10
440+
441+ @property
442+ def train_size (self ):
443+ return 1000
444+
445+ @property
446+ def dev_size (self ):
447+ return 100
448+
449+ @property
450+ def num_shards (self ):
451+ return 1
0 commit comments