From f771e625fa2eea1898ef0970cd7fb4882a8ef2d7 Mon Sep 17 00:00:00 2001 From: mingruimingrui Date: Sun, 5 Apr 2020 16:33:12 +0800 Subject: [PATCH] fix combine_bidir --- source/embed.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/source/embed.py b/source/embed.py index c775ffce..5f15018a 100644 --- a/source/embed.py +++ b/source/embed.py @@ -228,10 +228,13 @@ def forward(self, src_tokens, src_lengths): if self.bidirectional: def combine_bidir(outs): - return torch.cat([ - torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units) - for i in range(self.num_layers) - ], dim=0) + # [num_layers * num_dir, bsz, hidden_size] + # -> [num_layers, num_dir, bsz, hidden_size] + # -> [num_layers, bsz, num_dir, hidden_size] + # -> [num_layers, bsz, num_dir * hidden_size] + outs = outs.reshape(self.num_layers, 2, bsz, self.hidden_size) + outs = outs.transpose(1, 2) + return outs.reshape(self.num_layers, bsz, self.output_units) final_hiddens = combine_bidir(final_hiddens) final_cells = combine_bidir(final_cells)