From 8bb54df7168ac83a789841276920999337169c3f Mon Sep 17 00:00:00 2001 From: Jiajun Bao Date: Mon, 21 Nov 2022 20:16:56 -0500 Subject: [PATCH 1/3] add padded tokens in token counts to compare to max_token in embed.py --- source/embed.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/source/embed.py b/source/embed.py index 62bdc334..977480be 100644 --- a/source/embed.py +++ b/source/embed.py @@ -158,18 +158,22 @@ def batch(tokens, lengths, indices): batch_tokens, batch_lengths, batch_indices = [], [], [] ntokens = nsentences = 0 + num_tokens_padded = -1 for i in indices: + if num_tokens_padded == -1: + num_tokens_padded = tokens[i].shape[0] if nsentences > 0 and ( - (self.max_tokens is not None and ntokens + lengths[i] > self.max_tokens) + (self.max_tokens is not None and ntokens + num_tokens_padded > self.max_tokens) or (self.max_sentences is not None and nsentences == self.max_sentences) ): yield batch(batch_tokens, batch_lengths, batch_indices) ntokens = nsentences = 0 batch_tokens, batch_lengths, batch_indices = [], [], [] + num_tokens_padded = -1 batch_tokens.append(tokens[i]) batch_lengths.append(lengths[i]) batch_indices.append(i) - ntokens += tokens[i].shape[0] + ntokens += num_tokens_padded nsentences += 1 if nsentences > 0: yield batch(batch_tokens, batch_lengths, batch_indices) From e14ebdaf47da708928309d0b9cb7edaa65e215d8 Mon Sep 17 00:00:00 2001 From: Jiajun Bao Date: Mon, 21 Nov 2022 21:53:02 -0500 Subject: [PATCH 2/3] fixed default value to 0 --- source/embed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/embed.py b/source/embed.py index 977480be..8c29d187 100644 --- a/source/embed.py +++ b/source/embed.py @@ -158,9 +158,9 @@ def batch(tokens, lengths, indices): batch_tokens, batch_lengths, batch_indices = [], [], [] ntokens = nsentences = 0 - num_tokens_padded = -1 + num_tokens_padded = 0 for i in indices: - if num_tokens_padded == -1: + if num_tokens_padded == 0: num_tokens_padded = tokens[i].shape[0] if nsentences > 0 and ( (self.max_tokens is not None and ntokens + num_tokens_padded > self.max_tokens) @@ -169,7 +169,7 @@ def batch(tokens, lengths, indices): yield batch(batch_tokens, batch_lengths, batch_indices) ntokens = nsentences = 0 batch_tokens, batch_lengths, batch_indices = [], [], [] - num_tokens_padded = -1 + num_tokens_padded = 0 batch_tokens.append(tokens[i]) batch_lengths.append(lengths[i]) batch_indices.append(i) From ba5d8fcee9a34e26c7195a93469c62bc5904dd72 Mon Sep 17 00:00:00 2001 From: Jiajun Bao Date: Mon, 21 Nov 2022 22:25:38 -0500 Subject: [PATCH 3/3] fixed batching --- source/embed.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/source/embed.py b/source/embed.py index 8c29d187..5f288031 100644 --- a/source/embed.py +++ b/source/embed.py @@ -157,24 +157,22 @@ def batch(tokens, lengths, indices): ) batch_tokens, batch_lengths, batch_indices = [], [], [] - ntokens = nsentences = 0 - num_tokens_padded = 0 + nsentences = 0 + num_tokens_padded = -1 for i in indices: - if num_tokens_padded == 0: + if num_tokens_padded == -1: num_tokens_padded = tokens[i].shape[0] - if nsentences > 0 and ( - (self.max_tokens is not None and ntokens + num_tokens_padded > self.max_tokens) - or (self.max_sentences is not None and nsentences == self.max_sentences) - ): - yield batch(batch_tokens, batch_lengths, batch_indices) - ntokens = nsentences = 0 - batch_tokens, batch_lengths, batch_indices = [], [], [] - num_tokens_padded = 0 batch_tokens.append(tokens[i]) batch_lengths.append(lengths[i]) batch_indices.append(i) - ntokens += num_tokens_padded nsentences += 1 + if ((self.max_tokens is not None and (nsentences + 1) * num_tokens_padded > self.max_tokens) + or (self.max_sentences is not None and nsentences + 1 == self.max_sentences)): + yield batch(batch_tokens, batch_lengths, batch_indices) + nsentences = 0 + batch_tokens, batch_lengths, batch_indices = [], [], [] + num_tokens_padded = -1 + if nsentences > 0: yield batch(batch_tokens, batch_lengths, batch_indices)