@@ -39,10 +39,16 @@ class TransformerFeedables(NamedTuple(
3939 ("input_mask" , tf .Tensor )])):
4040 """Additional feedables used only by the Transformer-based decoder.
4141
42+ Follows the shape pattern of having batch_sized first dimension
43+ shape(batch_size, ...)
44+
4245 Attributes:
4346 input_sequence: The whole input sequence (embedded) that is fed into
4447 the decoder in each decoding step.
45- input_mask: Mask for masking finished sequences.
48+ shape(batch, len, emb)
49+ input_mask: Mask for masking finished sequences. The last dimension
50+ is required for compatibility with the beam_search_decoder.
51+ shape(batch, len, 1)
4652 """
4753
4854
@@ -392,14 +398,14 @@ def train_loop_result(self) -> LoopState:
392398 decoder_ls = AutoregressiveDecoder .get_initial_loop_state (self )
393399
394400 input_sequence = self .embed_input_symbols (self .train_input_symbols )
401+ input_mask = tf .transpose (self .train_mask )
402+
395403 last_layer = self .layer (
396- self .depth , input_sequence , tf . transpose ( self . train_mask ) )
404+ self .depth , input_sequence , input_mask )
397405
398- # We transpose input sequence and mask only to convey to
399- # the defined shapes
400406 tr_feedables = TransformerFeedables (
401- input_sequence = tf . transpose ( input_sequence ) ,
402- input_mask = self . train_mask )
407+ input_sequence = input_sequence ,
408+ input_mask = tf . expand_dims ( input_mask , - 1 ) )
403409
404410 # t_states shape: (batch, time, channels)
405411 # dec_w shape: (channels, vocab)
@@ -453,11 +459,11 @@ def get_initial_loop_state(self) -> LoopState:
453459
454460 tr_feedables = TransformerFeedables (
455461 input_sequence = tf .zeros (
456- shape = [0 , self .batch_size , self .dimension ],
462+ shape = [self .batch_size , 0 , self .dimension ],
457463 dtype = tf .float32 ,
458464 name = "input_sequence" ),
459465 input_mask = tf .zeros (
460- shape = [0 , self .batch_size ],
466+ shape = [self .batch_size , 0 , 1 ],
461467 dtype = tf .float32 ,
462468 name = "input_mask" ))
463469
@@ -486,16 +492,16 @@ def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
486492 with tf .variable_scope (self ._variable_scope , reuse = tf .AUTO_REUSE ):
487493 # shape (time, batch)
488494 input_sequence = append_tensor (
489- tr_feedables .input_sequence , feedables .embedded_input )
495+ tr_feedables .input_sequence , feedables .embedded_input , 1 )
490496
491497 unfinished_mask = tf .to_float (tf .logical_not (feedables .finished ))
492498 input_mask = append_tensor (
493- tr_feedables .input_mask , unfinished_mask )
499+ tr_feedables .input_mask ,
500+ tf .expand_dims (unfinished_mask , - 1 ),
501+ axis = 1 )
494502
495503 last_layer = self .layer (
496- self .depth ,
497- tf .transpose (input_sequence , [1 , 0 , 2 ]),
498- tf .transpose (input_mask ))
504+ self .depth , input_sequence , tf .squeeze (input_mask , - 1 ))
499505
500506 # (batch, state_size)
501507 output_state = last_layer .temporal_states [:, - 1 , :]
0 commit comments