Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ed7862c

Browse files
Ashish VaswaniRyan Sepassi
authored andcommitted
1d Dilated masked and unmasked self-attention. Added spaces between tokens for logging during inference.
PiperOrigin-RevId: 170552095
1 parent 464f9ad commit ed7862c

File tree

2 files changed

+294
-5
lines changed

2 files changed

+294
-5
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 292 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,280 @@ def pad_l_and_r(x, pad_length):
10901090
return output
10911091

10921092

1093+
def reshape_by_blocks(x, x_shape, memory_block_size):
1094+
x = tf.reshape(x, [
1095+
x_shape[0], x_shape[1], x_shape[2] // memory_block_size,
1096+
memory_block_size, x_shape[3]
1097+
])
1098+
return x
1099+
1100+
1101+
def dilated_self_attention_1d(q,
1102+
k,
1103+
v,
1104+
query_block_size=128,
1105+
memory_block_size=128,
1106+
gap_size=2,
1107+
num_memory_blocks=2,
1108+
name=None):
1109+
"""dilated self-attention.
1110+
1111+
Args:
1112+
q: a Tensor with shape [batch, heads, length, depth_k]
1113+
k: a Tensor with shape [batch, heads, length, depth_k]
1114+
v: a Tensor with shape [batch, heads, length, depth_v]
1115+
query_block_size: an integer indicating size of query block
1116+
memory_block_size: an integer indicating the size of a memory block.
1117+
gap_size: an integer indicating the gap size
1118+
num_memory_blocks: how many memory blocks to look at to the left and right.
1119+
Each will be separated by gap_size.
1120+
name: an optional string
1121+
1122+
Returns:
1123+
a Tensor of shape [batch, heads, length, depth_v]
1124+
"""
1125+
with tf.variable_scope(
1126+
name, default_name="dilated_self_attention_1d", values=[q, k, v]):
1127+
v_list_shape = v.get_shape().as_list()
1128+
v_shape = tf.shape(v)
1129+
depth_v = v_shape[3]
1130+
batch_size = v_shape[0]
1131+
num_heads = v_shape[1]
1132+
original_length = tf.shape(q)[2]
1133+
# making sure q is a multiple of query block size
1134+
def pad_to_multiple(x, pad_length):
1135+
x_length = tf.shape(x)[2]
1136+
return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]])
1137+
1138+
def pad_l_and_r(x, pad_length):
1139+
return tf.pad(x, [[0, 0], [0, 0], [pad_length, pad_length], [0, 0]])
1140+
1141+
q = pad_to_multiple(q, query_block_size)
1142+
v = pad_to_multiple(v, query_block_size)
1143+
k = pad_to_multiple(k, query_block_size)
1144+
1145+
q.set_shape(v_list_shape)
1146+
v.set_shape(v_list_shape)
1147+
k.set_shape(v_list_shape)
1148+
# Setting up q blocks
1149+
new_q_shape = tf.shape(q)
1150+
# Setting up q blocks
1151+
q = reshape_by_blocks(q, new_q_shape, query_block_size)
1152+
self_k_part = reshape_by_blocks(k, new_q_shape, query_block_size)
1153+
self_v_part = reshape_by_blocks(v, new_q_shape, query_block_size)
1154+
1155+
# Setting up k and v windows
1156+
k_v_padding = (gap_size + memory_block_size) * num_memory_blocks
1157+
k = pad_l_and_r(k, k_v_padding)
1158+
v = pad_l_and_r(v, k_v_padding)
1159+
# getting gather indices
1160+
index_length = (new_q_shape[2] - query_block_size + memory_block_size)
1161+
indices = tf.range(0, index_length, delta=1, name="index_range")
1162+
# making indices [1, length, 1] to appy convs
1163+
indices = tf.reshape(indices, [1, -1, 1])
1164+
kernel = tf.expand_dims(tf.eye(memory_block_size), axis=1)
1165+
gather_indices = tf.nn.conv1d(
1166+
tf.cast(indices, tf.float32),
1167+
kernel,
1168+
query_block_size,
1169+
padding="VALID",
1170+
name="gather_conv")
1171+
1172+
gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0)
1173+
1174+
# get left and right memory blocks for each query
1175+
# [length, batch, heads, dim]
1176+
k_t = tf.transpose(k, [2, 0, 1, 3])
1177+
v_t = tf.transpose(v, [2, 0, 1, 3])
1178+
left_k = gather_dilated_memory_blocks(k_t[:-k_v_padding, :, :, :],
1179+
num_memory_blocks, gap_size,
1180+
query_block_size, memory_block_size,
1181+
gather_indices)
1182+
left_v = gather_dilated_memory_blocks(v_t[:-k_v_padding, :, :, :],
1183+
num_memory_blocks, gap_size,
1184+
query_block_size, memory_block_size,
1185+
gather_indices)
1186+
1187+
right_k = gather_dilated_memory_blocks(k_t[k_v_padding:, :, :, :],
1188+
num_memory_blocks, gap_size,
1189+
query_block_size, memory_block_size,
1190+
gather_indices, direction="right")
1191+
right_v = gather_dilated_memory_blocks(v_t[k_v_padding:, :, :, :],
1192+
num_memory_blocks, gap_size,
1193+
query_block_size, memory_block_size,
1194+
gather_indices, direction="right")
1195+
1196+
k_windows = tf.concat([left_k, self_k_part, right_k], axis=3)
1197+
v_windows = tf.concat([left_v, self_v_part, right_v], axis=3)
1198+
attention_bias = tf.expand_dims(
1199+
embedding_to_padding(k_windows) * -1e9, axis=-2)
1200+
1201+
output = dot_product_attention(
1202+
q, k_windows, v_windows, attention_bias, dropout_rate=0.,
1203+
name="dilated_1d", make_image_summary=False)
1204+
output = tf.reshape(output, [batch_size, num_heads, -1, depth_v])
1205+
# Remove the padding if introduced
1206+
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
1207+
output.set_shape(v_list_shape)
1208+
return output
1209+
1210+
1211+
def gather_dilated_memory_blocks(x, num_memory_blocks, gap_size,
1212+
query_block_size, memory_block_size,
1213+
gather_indices, direction="left"):
1214+
"""Gathers blocks with gaps in between.
1215+
1216+
Args:
1217+
x: A tensor of shape [length, batch, heads, depth]
1218+
num_memory_blocks: num_memory_blocks: how many memory blocks to look
1219+
in "direction". Each will be separated by gap_size.
1220+
gap_size: an integer indicating the gap size
1221+
query_block_size: an integer indicating size of query block
1222+
memory_block_size: an integer indicating the size of a memory block.
1223+
gather_indices: The indices to gather from.
1224+
direction: left or right
1225+
Returns:
1226+
a tensor of shape [batch, heads, blocks, block_length, depth]
1227+
"""
1228+
1229+
gathered_blocks = []
1230+
# gathering memory blocks
1231+
for block_id in range(num_memory_blocks):
1232+
block_end_index = -(query_block_size +
1233+
gap_size * (block_id+1) + memory_block_size *
1234+
block_id) - 1
1235+
block_start_index = (
1236+
(memory_block_size + gap_size) *
1237+
(num_memory_blocks - (block_id + 1))
1238+
)
1239+
if direction != "left":
1240+
[block_end_index, block_start_index] = [
1241+
-block_start_index - 1, -block_end_index + 1
1242+
]
1243+
def gather_dilated_1d_blocks(x, gather_indices):
1244+
x_new = tf.gather(x, gather_indices)
1245+
# [batch, heads, blocks, block_length, dim]
1246+
return tf.transpose(x_new, [2, 3, 0, 1, 4])
1247+
1248+
gathered_blocks.append(
1249+
gather_dilated_1d_blocks(x[block_start_index:block_end_index],
1250+
gather_indices))
1251+
return tf.concat(gathered_blocks, 3)
1252+
1253+
1254+
def masked_dilated_self_attention_1d(q,
1255+
k,
1256+
v,
1257+
query_block_size=64,
1258+
memory_block_size=64,
1259+
gap_size=2,
1260+
num_memory_blocks=2,
1261+
name=None):
1262+
"""dilated self-attention.
1263+
1264+
Args:
1265+
q: a Tensor with shape [batch, heads, length, depth_k]
1266+
k: a Tensor with shape [batch, heads, length, depth_k]
1267+
v: a Tensor with shape [batch, heads, length, depth_v]
1268+
query_block_size: an integer
1269+
memory_block_size: an integer indicating how much to look left.
1270+
gap_size: an integer indicating the gap size
1271+
num_memory_blocks: how many memory blocks to look at to the left. Each will
1272+
be separated by gap_size.
1273+
name: an optional string
1274+
1275+
Returns:
1276+
a Tensor of shape [batch, heads, length, depth_v]
1277+
"""
1278+
with tf.variable_scope(
1279+
name, default_name="masked_dilated_self_attention_1d", values=[q, k, v]):
1280+
v_list_shape = v.get_shape().as_list()
1281+
v_shape = tf.shape(v)
1282+
depth_v = v_shape[3]
1283+
batch_size = v_shape[0]
1284+
num_heads = v_shape[1]
1285+
original_length = tf.shape(q)[2]
1286+
# making sure q is a multiple of query block size
1287+
def pad_to_multiple(x, pad_length):
1288+
x_length = tf.shape(x)[2]
1289+
return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]])
1290+
1291+
def pad_l(x, left_pad_length):
1292+
return tf.pad(x, [[0, 0], [0, 0], [left_pad_length, 0], [0, 0]])
1293+
1294+
q = pad_to_multiple(q, query_block_size)
1295+
v = pad_to_multiple(v, query_block_size)
1296+
k = pad_to_multiple(k, query_block_size)
1297+
q.set_shape(v_list_shape)
1298+
v.set_shape(v_list_shape)
1299+
k.set_shape(v_list_shape)
1300+
# Setting up q blocks
1301+
new_q_shape = tf.shape(q)
1302+
1303+
# Setting up q blocks
1304+
q = reshape_by_blocks(q, new_q_shape, query_block_size)
1305+
self_k_part = reshape_by_blocks(k, new_q_shape, query_block_size)
1306+
self_v_part = reshape_by_blocks(v, new_q_shape, query_block_size)
1307+
# Setting up k and v windows
1308+
k_v_padding = (gap_size + memory_block_size) * num_memory_blocks
1309+
k = pad_l(k, k_v_padding)
1310+
v = pad_l(v, k_v_padding)
1311+
# getting gather indices
1312+
index_length = (new_q_shape[2] - query_block_size + memory_block_size)
1313+
1314+
indices = tf.range(0, index_length, delta=1, name="index_range")
1315+
# making indices [1, length, 1] to appy convs
1316+
indices = tf.reshape(indices, [1, -1, 1])
1317+
kernel = tf.expand_dims(tf.eye(memory_block_size), axis=1)
1318+
gather_indices = tf.nn.conv1d(
1319+
tf.cast(indices, tf.float32),
1320+
kernel,
1321+
query_block_size,
1322+
padding="VALID",
1323+
name="gather_conv")
1324+
gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0)
1325+
1326+
# get left and right memory blocks for each query
1327+
# [length, batch, heads, dim]
1328+
k_t = tf.transpose(k, [2, 0, 1, 3])
1329+
v_t = tf.transpose(v, [2, 0, 1, 3])
1330+
1331+
k_unmasked_windows = gather_dilated_memory_blocks(k_t, num_memory_blocks,
1332+
gap_size,
1333+
query_block_size,
1334+
memory_block_size,
1335+
gather_indices)
1336+
v_unmasked_windows = gather_dilated_memory_blocks(v_t, num_memory_blocks,
1337+
gap_size,
1338+
query_block_size,
1339+
memory_block_size,
1340+
gather_indices)
1341+
1342+
# combine memory windows
1343+
block_q_shape = tf.shape(q)
1344+
masked_attention_bias = tf.tile(tf.expand_dims(
1345+
attention_bias_lower_triangle(query_block_size), axis=0),
1346+
[block_q_shape[0], block_q_shape[1],
1347+
block_q_shape[2], 1, 1])
1348+
padding_attention_bias = tf.expand_dims(
1349+
embedding_to_padding(k_unmasked_windows) * -1e9, axis=-2)
1350+
padding_attention_bias = tf.tile(padding_attention_bias,
1351+
[1, 1, 1, query_block_size, 1])
1352+
attention_bias = tf.concat([masked_attention_bias, padding_attention_bias],
1353+
axis=-1)
1354+
# combine memory windows
1355+
k_windows = tf.concat([self_k_part, k_unmasked_windows], 3)
1356+
v_windows = tf.concat([self_v_part, v_unmasked_windows], 3)
1357+
output = dot_product_attention(
1358+
q, k_windows, v_windows, attention_bias, dropout_rate=0.,
1359+
name="dilated_1d", make_image_summary=False)
1360+
output = tf.reshape(output, [batch_size, num_heads, -1, depth_v])
1361+
# Remove the padding if introduced
1362+
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
1363+
output.set_shape(v_list_shape)
1364+
return output
1365+
1366+
10931367
def local_attention_2d(q,
10941368
k,
10951369
v,
@@ -1441,6 +1715,8 @@ def multihead_attention(query_antecedent,
14411715
q_padding="VALID",
14421716
kv_padding="VALID",
14431717
cache=None,
1718+
gap_size=0,
1719+
num_memory_blocks=2,
14441720
name=None,
14451721
**kwargs):
14461722
"""Multihead scaled-dot-product attention with input/output transformations.
@@ -1475,6 +1751,10 @@ def multihead_attention(query_antecedent,
14751751
be empty Tensors of the appropriate shape.
14761752
'k' [batch_size, 0, key_channels]
14771753
'v' [batch_size, 0, value_channels]
1754+
gap_size: Integer option for dilated attention to indicate spacing between
1755+
memory blocks.
1756+
num_memory_blocks: Integer option to indicate how many memory blocks to look
1757+
at.
14781758
name: an optional string
14791759
**kwargs (dict): Params for the attention function
14801760
@@ -1542,13 +1822,22 @@ def multihead_attention(query_antecedent,
15421822
dropout_rate, image_shapes)
15431823
elif attention_type == "local_mask_right":
15441824
x = masked_local_attention_1d(q, k, v, block_length=block_length)
1545-
else:
1546-
assert attention_type == "local_unmasked"
1825+
elif attention_type == "local_unmasked":
15471826
x = local_attention_1d(
15481827
q, k, v, block_length=block_length, filter_width=block_width)
1828+
elif attention_type == "masked_dilated_1d":
1829+
x = masked_dilated_self_attention_1d(q, k, v, block_length,
1830+
block_width,
1831+
gap_size,
1832+
num_memory_blocks)
1833+
else:
1834+
assert attention_type == "unmasked_dilated_1d"
1835+
x = dilated_self_attention_1d(q, k, v, block_length,
1836+
block_width,
1837+
gap_size,
1838+
num_memory_blocks)
15491839
x = combine_heads(x)
15501840
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
1551-
15521841
if additional_returned_value is not None:
15531842
return x, additional_returned_value
15541843
return x

tensor2tensor/utils/decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ def log_decode_results(inputs,
8686
if targets is not None:
8787
decoded_targets = " ".join(map(str, targets.flatten()))
8888
else:
89-
decoded_outputs = "".join(
89+
decoded_outputs = " ".join(
9090
map(str, targets_vocab.decode(_save_until_eos(outputs.flatten()))))
9191
if targets is not None:
92-
decoded_targets = "".join(
92+
decoded_targets = " ".join(
9393
map(str, targets_vocab.decode(_save_until_eos(targets.flatten()))))
9494

9595
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)

0 commit comments

Comments
 (0)