3232class BasicFcRelu (t2t_model .T2TModel ):
3333
3434 def body (self , features ):
35- hparams = self ._hparams
35+ hparams = self .hparams
3636 x = features ["inputs" ]
3737 shape = common_layers .shape_list (x )
3838 x = tf .reshape (x , [- 1 , shape [1 ] * shape [2 ] * shape [3 ]])
@@ -53,7 +53,7 @@ def __init__(self, *args, **kwargs):
5353
5454 def bottleneck (self , x ):
5555 with tf .variable_scope ("bottleneck" ):
56- hparams = self ._hparams
56+ hparams = self .hparams
5757 x = tf .layers .dense (x , hparams .bottleneck_size , name = "bottleneck" )
5858 if hparams .mode == tf .estimator .ModeKeys .TRAIN :
5959 noise = 2.0 * tf .random_uniform (common_layers .shape_list (x )) - 1.0
@@ -68,12 +68,27 @@ def unbottleneck(self, x, res_size):
6868 def bottleneck_loss (self , b ):
6969 return 0.0
7070
71+ def make_even_size (self , x ):
72+ shape = [dim if dim is not None else - 1 for dim in x .get_shape ().as_list ()]
73+ if shape [1 ] % 2 == 0 and shape [2 ] % 2 == 0 :
74+ return x
75+ if shape [1 ] % 2 == 0 and self .is1d :
76+ return x
77+ x , _ = common_layers .pad_to_same_length (
78+ x , x , final_length_divisible_by = 2 , axis = 1 )
79+ if self .is1d :
80+ return x
81+ x , _ = common_layers .pad_to_same_length (
82+ x , x , final_length_divisible_by = 2 , axis = 2 )
83+ return x
84+
7185 def encoder (self , x ):
7286 with tf .variable_scope ("encoder" ):
73- hparams = self ._hparams
87+ hparams = self .hparams
7488 kernel , strides = self ._get_kernel_and_strides ()
7589 # Down-convolutions.
7690 for i in range (hparams .num_hidden_layers ):
91+ x = self .make_even_size (x )
7792 x = tf .layers .conv2d (
7893 x , hparams .hidden_size * 2 ** (i + 1 ), kernel , strides = strides ,
7994 padding = "SAME" , activation = common_layers .belu , name = "conv_%d" % i )
@@ -82,7 +97,7 @@ def encoder(self, x):
8297
8398 def decoder (self , x ):
8499 with tf .variable_scope ("decoder" ):
85- hparams = self ._hparams
100+ hparams = self .hparams
86101 kernel , strides = self ._get_kernel_and_strides ()
87102 # Up-convolutions.
88103 for i in range (hparams .num_hidden_layers ):
@@ -94,19 +109,13 @@ def decoder(self, x):
94109 return x
95110
96111 def body (self , features ):
97- hparams = self ._hparams
112+ hparams = self .hparams
98113 is_training = hparams .mode == tf .estimator .ModeKeys .TRAIN
99114 if hparams .mode != tf .estimator .ModeKeys .PREDICT :
100115 x = features ["targets" ]
101116 shape = common_layers .shape_list (x )
102117 is1d = shape [2 ] == 1
103118 self .is1d = is1d
104- x , _ = common_layers .pad_to_same_length (
105- x , x , final_length_divisible_by = 2 ** hparams .num_hidden_layers , axis = 1 )
106- if not is1d :
107- x , _ = common_layers .pad_to_same_length (
108- x , x , final_length_divisible_by = 2 ** hparams .num_hidden_layers ,
109- axis = 2 )
110119 # Run encoder.
111120 x = self .encoder (x )
112121 # Bottleneck (mix during early training, not too important but stable).
@@ -122,21 +131,21 @@ def body(self, features):
122131 x = b
123132 else :
124133 b = self .sample ()
125- res_size = self ._hparams .hidden_size * 2 ** self ._hparams .num_hidden_layers
134+ res_size = self .hparams .hidden_size * 2 ** self .hparams .num_hidden_layers
126135 res_size = min (res_size , hparams .max_hidden_size )
127136 x = self .unbottleneck (b , res_size )
128137 # Run decoder.
129138 x = self .decoder (x )
130139 if hparams .mode == tf .estimator .ModeKeys .PREDICT :
131- return x
140+ return x , { "bottleneck_loss" : 0.0 }
132141 # Cut to the right size and mix before returning.
133142 res = x [:, :shape [1 ], :shape [2 ], :]
134143 res = common_layers .mix (res , features ["targets" ],
135144 hparams .bottleneck_warmup_steps // 2 , is_training )
136145 return res , {"bottleneck_loss" : b_loss }
137146
138147 def sample (self ):
139- hp = self ._hparams
148+ hp = self .hparams
140149 div_x = 2 ** hp .num_hidden_layers
141150 div_y = 1 if self .is1d else 2 ** hp .num_hidden_layers
142151 size = [hp .batch_size , hp .sample_height // div_x , hp .sample_width // div_y ,
@@ -158,11 +167,11 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
158167 # Sample and decode.
159168 # TODO(lukaszkaiser): is this a universal enough way to get channels?
160169 try :
161- num_channels = self ._hparams .problem .num_channels
170+ num_channels = self .hparams .problem .num_channels
162171 except AttributeError :
163172 num_channels = 1
164173 features ["targets" ] = tf .zeros (
165- [self ._hparams .batch_size , 1 , 1 , num_channels ],
174+ [self .hparams .batch_size , 1 , 1 , num_channels ],
166175 dtype = tf .int32 )
167176 logits , _ = self (features ) # pylint: disable=not-callable
168177 samples = tf .argmax (logits , axis = - 1 )
@@ -175,7 +184,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
175184 return samples
176185
177186 def _get_kernel_and_strides (self ):
178- hparams = self ._hparams
187+ hparams = self .hparams
179188 kernel = (hparams .kernel_height , hparams .kernel_width )
180189 kernel = (hparams .kernel_height , 1 ) if self .is1d else kernel
181190 strides = (2 , 1 ) if self .is1d else (2 , 2 )
0 commit comments