Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 37465a1

Browse files
authored
Merge pull request #296 from rsepassi/push
Added an assert in common_attention.
2 parents 60cebfb + ff231f0 commit 37465a1

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,7 @@ def multihead_attention_2d(query_antecedent,
10801080
x = local_attention_2d(
10811081
q, k, v, query_shape=query_shape, memory_flange=memory_flange)
10821082
else:
1083+
assert attention_type == "masked_local_attention_2d"
10831084
x = masked_local_attention_2d(q, k, v, query_shape=query_shape,
10841085
memory_flange=memory_flange)
10851086
x = combine_heads_2d(x)

0 commit comments

Comments
 (0)