diff --git a/source/embed.py b/source/embed.py index 62bdc334..5f288031 100644 --- a/source/embed.py +++ b/source/embed.py @@ -157,20 +157,22 @@ def batch(tokens, lengths, indices): ) batch_tokens, batch_lengths, batch_indices = [], [], [] - ntokens = nsentences = 0 + nsentences = 0 + num_tokens_padded = -1 for i in indices: - if nsentences > 0 and ( - (self.max_tokens is not None and ntokens + lengths[i] > self.max_tokens) - or (self.max_sentences is not None and nsentences == self.max_sentences) - ): - yield batch(batch_tokens, batch_lengths, batch_indices) - ntokens = nsentences = 0 - batch_tokens, batch_lengths, batch_indices = [], [], [] + if num_tokens_padded == -1: + num_tokens_padded = tokens[i].shape[0] batch_tokens.append(tokens[i]) batch_lengths.append(lengths[i]) batch_indices.append(i) - ntokens += tokens[i].shape[0] nsentences += 1 + if ((self.max_tokens is not None and (nsentences + 1) * num_tokens_padded > self.max_tokens) + or (self.max_sentences is not None and nsentences + 1 == self.max_sentences)): + yield batch(batch_tokens, batch_lengths, batch_indices) + nsentences = 0 + batch_tokens, batch_lengths, batch_indices = [], [], [] + num_tokens_padded = -1 + if nsentences > 0: yield batch(batch_tokens, batch_lengths, batch_indices)