From 1874a33b1151bacafbc611acffa2d6dce3245c15 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 10 Nov 2020 19:25:11 -0500 Subject: [PATCH] make Encoder scriptable. --- source/embed.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/source/embed.py b/source/embed.py index c775ffce..790651cb 100644 --- a/source/embed.py +++ b/source/embed.py @@ -50,17 +50,15 @@ def buffered_read(fp, buffer_size): if len(buffer) > 0: yield buffer - def buffered_arange(max): - if not hasattr(buffered_arange, 'buf'): - buffered_arange.buf = torch.LongTensor() if max > buffered_arange.buf.numel(): torch.arange(max, out=buffered_arange.buf) return buffered_arange.buf[:max] # TODO Do proper padding from the beginning -def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False): +@torch.jit.script +def convert_padding_direction(src_tokens, padding_idx: int, right_to_left: bool=False, left_to_right: bool=False): assert right_to_left ^ left_to_right pad_mask = src_tokens.eq(padding_idx) if not pad_mask.any(): @@ -73,7 +71,7 @@ def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left # already left padded return src_tokens max_len = src_tokens.size(1) - range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens) + range = torch.arange(max_len).type_as(src_tokens).expand_as(src_tokens) num_pads = pad_mask.long().sum(dim=1, keepdim=True) if right_to_left: index = torch.remainder(range - num_pads, max_len) @@ -193,6 +191,13 @@ def __init__( if bidirectional: self.output_units *= 2 + def combine_bidir(self, outs, bsz: int): + return torch.cat([ + torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units) + for i in range(self.num_layers) + ], dim=0) + + def forward(self, src_tokens, src_lengths): if self.left_pad: # convert left-padding to right-padding @@ -211,15 +216,16 @@ def forward(self, src_tokens, src_lengths): x = x.transpose(0, 1) # pack embedded source tokens into a PackedSequence - packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths) # apply LSTM if self.bidirectional: state_size = 2 * self.num_layers, bsz, self.hidden_size else: state_size = self.num_layers, bsz, self.hidden_size - h0 = x.data.new(*state_size).zero_() - c0 = x.data.new(*state_size).zero_() + h0 = torch.zeros(*state_size) + c0 = torch.zeros(*state_size) packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) # unpack outputs and apply dropout @@ -227,14 +233,8 @@ def forward(self, src_tokens, src_lengths): assert list(x.size()) == [seqlen, bsz, self.output_units] if self.bidirectional: - def combine_bidir(outs): - return torch.cat([ - torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units) - for i in range(self.num_layers) - ], dim=0) - - final_hiddens = combine_bidir(final_hiddens) - final_cells = combine_bidir(final_cells) + final_hiddens = self.combine_bidir(final_hiddens, bsz) + final_cells = self.combine_bidir(final_cells, bsz) encoder_padding_mask = src_tokens.eq(self.padding_idx).t() @@ -248,8 +248,10 @@ def combine_bidir(outs): return { 'sentemb': sentemb, - 'encoder_out': (x, final_hiddens, final_cells), - 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None + 'encoder_out': x, + 'final_hiddens': final_hiddens, + 'final_cells': final_cells, + 'encoder_padding_mask': encoder_padding_mask }