@@ -1090,6 +1090,280 @@ def pad_l_and_r(x, pad_length):
10901090 return output
10911091
10921092
1093+ def reshape_by_blocks (x , x_shape , memory_block_size ):
1094+ x = tf .reshape (x , [
1095+ x_shape [0 ], x_shape [1 ], x_shape [2 ] // memory_block_size ,
1096+ memory_block_size , x_shape [3 ]
1097+ ])
1098+ return x
1099+
1100+
1101+ def dilated_self_attention_1d (q ,
1102+ k ,
1103+ v ,
1104+ query_block_size = 128 ,
1105+ memory_block_size = 128 ,
1106+ gap_size = 2 ,
1107+ num_memory_blocks = 2 ,
1108+ name = None ):
1109+ """dilated self-attention.
1110+
1111+ Args:
1112+ q: a Tensor with shape [batch, heads, length, depth_k]
1113+ k: a Tensor with shape [batch, heads, length, depth_k]
1114+ v: a Tensor with shape [batch, heads, length, depth_v]
1115+ query_block_size: an integer indicating size of query block
1116+ memory_block_size: an integer indicating the size of a memory block.
1117+ gap_size: an integer indicating the gap size
1118+ num_memory_blocks: how many memory blocks to look at to the left and right.
1119+ Each will be separated by gap_size.
1120+ name: an optional string
1121+
1122+ Returns:
1123+ a Tensor of shape [batch, heads, length, depth_v]
1124+ """
1125+ with tf .variable_scope (
1126+ name , default_name = "dilated_self_attention_1d" , values = [q , k , v ]):
1127+ v_list_shape = v .get_shape ().as_list ()
1128+ v_shape = tf .shape (v )
1129+ depth_v = v_shape [3 ]
1130+ batch_size = v_shape [0 ]
1131+ num_heads = v_shape [1 ]
1132+ original_length = tf .shape (q )[2 ]
1133+ # making sure q is a multiple of query block size
1134+ def pad_to_multiple (x , pad_length ):
1135+ x_length = tf .shape (x )[2 ]
1136+ return tf .pad (x , [[0 , 0 ], [0 , 0 ], [0 , - x_length % pad_length ], [0 , 0 ]])
1137+
1138+ def pad_l_and_r (x , pad_length ):
1139+ return tf .pad (x , [[0 , 0 ], [0 , 0 ], [pad_length , pad_length ], [0 , 0 ]])
1140+
1141+ q = pad_to_multiple (q , query_block_size )
1142+ v = pad_to_multiple (v , query_block_size )
1143+ k = pad_to_multiple (k , query_block_size )
1144+
1145+ q .set_shape (v_list_shape )
1146+ v .set_shape (v_list_shape )
1147+ k .set_shape (v_list_shape )
1148+ # Setting up q blocks
1149+ new_q_shape = tf .shape (q )
1150+ # Setting up q blocks
1151+ q = reshape_by_blocks (q , new_q_shape , query_block_size )
1152+ self_k_part = reshape_by_blocks (k , new_q_shape , query_block_size )
1153+ self_v_part = reshape_by_blocks (v , new_q_shape , query_block_size )
1154+
1155+ # Setting up k and v windows
1156+ k_v_padding = (gap_size + memory_block_size ) * num_memory_blocks
1157+ k = pad_l_and_r (k , k_v_padding )
1158+ v = pad_l_and_r (v , k_v_padding )
1159+ # getting gather indices
1160+ index_length = (new_q_shape [2 ] - query_block_size + memory_block_size )
1161+ indices = tf .range (0 , index_length , delta = 1 , name = "index_range" )
1162+ # making indices [1, length, 1] to appy convs
1163+ indices = tf .reshape (indices , [1 , - 1 , 1 ])
1164+ kernel = tf .expand_dims (tf .eye (memory_block_size ), axis = 1 )
1165+ gather_indices = tf .nn .conv1d (
1166+ tf .cast (indices , tf .float32 ),
1167+ kernel ,
1168+ query_block_size ,
1169+ padding = "VALID" ,
1170+ name = "gather_conv" )
1171+
1172+ gather_indices = tf .squeeze (tf .cast (gather_indices , tf .int32 ), axis = 0 )
1173+
1174+ # get left and right memory blocks for each query
1175+ # [length, batch, heads, dim]
1176+ k_t = tf .transpose (k , [2 , 0 , 1 , 3 ])
1177+ v_t = tf .transpose (v , [2 , 0 , 1 , 3 ])
1178+ left_k = gather_dilated_memory_blocks (k_t [:- k_v_padding , :, :, :],
1179+ num_memory_blocks , gap_size ,
1180+ query_block_size , memory_block_size ,
1181+ gather_indices )
1182+ left_v = gather_dilated_memory_blocks (v_t [:- k_v_padding , :, :, :],
1183+ num_memory_blocks , gap_size ,
1184+ query_block_size , memory_block_size ,
1185+ gather_indices )
1186+
1187+ right_k = gather_dilated_memory_blocks (k_t [k_v_padding :, :, :, :],
1188+ num_memory_blocks , gap_size ,
1189+ query_block_size , memory_block_size ,
1190+ gather_indices , direction = "right" )
1191+ right_v = gather_dilated_memory_blocks (v_t [k_v_padding :, :, :, :],
1192+ num_memory_blocks , gap_size ,
1193+ query_block_size , memory_block_size ,
1194+ gather_indices , direction = "right" )
1195+
1196+ k_windows = tf .concat ([left_k , self_k_part , right_k ], axis = 3 )
1197+ v_windows = tf .concat ([left_v , self_v_part , right_v ], axis = 3 )
1198+ attention_bias = tf .expand_dims (
1199+ embedding_to_padding (k_windows ) * - 1e9 , axis = - 2 )
1200+
1201+ output = dot_product_attention (
1202+ q , k_windows , v_windows , attention_bias , dropout_rate = 0. ,
1203+ name = "dilated_1d" , make_image_summary = False )
1204+ output = tf .reshape (output , [batch_size , num_heads , - 1 , depth_v ])
1205+ # Remove the padding if introduced
1206+ output = tf .slice (output , [0 , 0 , 0 , 0 ], [- 1 , - 1 , original_length , - 1 ])
1207+ output .set_shape (v_list_shape )
1208+ return output
1209+
1210+
1211+ def gather_dilated_memory_blocks (x , num_memory_blocks , gap_size ,
1212+ query_block_size , memory_block_size ,
1213+ gather_indices , direction = "left" ):
1214+ """Gathers blocks with gaps in between.
1215+
1216+ Args:
1217+ x: A tensor of shape [length, batch, heads, depth]
1218+ num_memory_blocks: num_memory_blocks: how many memory blocks to look
1219+ in "direction". Each will be separated by gap_size.
1220+ gap_size: an integer indicating the gap size
1221+ query_block_size: an integer indicating size of query block
1222+ memory_block_size: an integer indicating the size of a memory block.
1223+ gather_indices: The indices to gather from.
1224+ direction: left or right
1225+ Returns:
1226+ a tensor of shape [batch, heads, blocks, block_length, depth]
1227+ """
1228+
1229+ gathered_blocks = []
1230+ # gathering memory blocks
1231+ for block_id in range (num_memory_blocks ):
1232+ block_end_index = - (query_block_size +
1233+ gap_size * (block_id + 1 ) + memory_block_size *
1234+ block_id ) - 1
1235+ block_start_index = (
1236+ (memory_block_size + gap_size ) *
1237+ (num_memory_blocks - (block_id + 1 ))
1238+ )
1239+ if direction != "left" :
1240+ [block_end_index , block_start_index ] = [
1241+ - block_start_index - 1 , - block_end_index + 1
1242+ ]
1243+ def gather_dilated_1d_blocks (x , gather_indices ):
1244+ x_new = tf .gather (x , gather_indices )
1245+ # [batch, heads, blocks, block_length, dim]
1246+ return tf .transpose (x_new , [2 , 3 , 0 , 1 , 4 ])
1247+
1248+ gathered_blocks .append (
1249+ gather_dilated_1d_blocks (x [block_start_index :block_end_index ],
1250+ gather_indices ))
1251+ return tf .concat (gathered_blocks , 3 )
1252+
1253+
1254+ def masked_dilated_self_attention_1d (q ,
1255+ k ,
1256+ v ,
1257+ query_block_size = 64 ,
1258+ memory_block_size = 64 ,
1259+ gap_size = 2 ,
1260+ num_memory_blocks = 2 ,
1261+ name = None ):
1262+ """dilated self-attention.
1263+
1264+ Args:
1265+ q: a Tensor with shape [batch, heads, length, depth_k]
1266+ k: a Tensor with shape [batch, heads, length, depth_k]
1267+ v: a Tensor with shape [batch, heads, length, depth_v]
1268+ query_block_size: an integer
1269+ memory_block_size: an integer indicating how much to look left.
1270+ gap_size: an integer indicating the gap size
1271+ num_memory_blocks: how many memory blocks to look at to the left. Each will
1272+ be separated by gap_size.
1273+ name: an optional string
1274+
1275+ Returns:
1276+ a Tensor of shape [batch, heads, length, depth_v]
1277+ """
1278+ with tf .variable_scope (
1279+ name , default_name = "masked_dilated_self_attention_1d" , values = [q , k , v ]):
1280+ v_list_shape = v .get_shape ().as_list ()
1281+ v_shape = tf .shape (v )
1282+ depth_v = v_shape [3 ]
1283+ batch_size = v_shape [0 ]
1284+ num_heads = v_shape [1 ]
1285+ original_length = tf .shape (q )[2 ]
1286+ # making sure q is a multiple of query block size
1287+ def pad_to_multiple (x , pad_length ):
1288+ x_length = tf .shape (x )[2 ]
1289+ return tf .pad (x , [[0 , 0 ], [0 , 0 ], [0 , - x_length % pad_length ], [0 , 0 ]])
1290+
1291+ def pad_l (x , left_pad_length ):
1292+ return tf .pad (x , [[0 , 0 ], [0 , 0 ], [left_pad_length , 0 ], [0 , 0 ]])
1293+
1294+ q = pad_to_multiple (q , query_block_size )
1295+ v = pad_to_multiple (v , query_block_size )
1296+ k = pad_to_multiple (k , query_block_size )
1297+ q .set_shape (v_list_shape )
1298+ v .set_shape (v_list_shape )
1299+ k .set_shape (v_list_shape )
1300+ # Setting up q blocks
1301+ new_q_shape = tf .shape (q )
1302+
1303+ # Setting up q blocks
1304+ q = reshape_by_blocks (q , new_q_shape , query_block_size )
1305+ self_k_part = reshape_by_blocks (k , new_q_shape , query_block_size )
1306+ self_v_part = reshape_by_blocks (v , new_q_shape , query_block_size )
1307+ # Setting up k and v windows
1308+ k_v_padding = (gap_size + memory_block_size ) * num_memory_blocks
1309+ k = pad_l (k , k_v_padding )
1310+ v = pad_l (v , k_v_padding )
1311+ # getting gather indices
1312+ index_length = (new_q_shape [2 ] - query_block_size + memory_block_size )
1313+
1314+ indices = tf .range (0 , index_length , delta = 1 , name = "index_range" )
1315+ # making indices [1, length, 1] to appy convs
1316+ indices = tf .reshape (indices , [1 , - 1 , 1 ])
1317+ kernel = tf .expand_dims (tf .eye (memory_block_size ), axis = 1 )
1318+ gather_indices = tf .nn .conv1d (
1319+ tf .cast (indices , tf .float32 ),
1320+ kernel ,
1321+ query_block_size ,
1322+ padding = "VALID" ,
1323+ name = "gather_conv" )
1324+ gather_indices = tf .squeeze (tf .cast (gather_indices , tf .int32 ), axis = 0 )
1325+
1326+ # get left and right memory blocks for each query
1327+ # [length, batch, heads, dim]
1328+ k_t = tf .transpose (k , [2 , 0 , 1 , 3 ])
1329+ v_t = tf .transpose (v , [2 , 0 , 1 , 3 ])
1330+
1331+ k_unmasked_windows = gather_dilated_memory_blocks (k_t , num_memory_blocks ,
1332+ gap_size ,
1333+ query_block_size ,
1334+ memory_block_size ,
1335+ gather_indices )
1336+ v_unmasked_windows = gather_dilated_memory_blocks (v_t , num_memory_blocks ,
1337+ gap_size ,
1338+ query_block_size ,
1339+ memory_block_size ,
1340+ gather_indices )
1341+
1342+ # combine memory windows
1343+ block_q_shape = tf .shape (q )
1344+ masked_attention_bias = tf .tile (tf .expand_dims (
1345+ attention_bias_lower_triangle (query_block_size ), axis = 0 ),
1346+ [block_q_shape [0 ], block_q_shape [1 ],
1347+ block_q_shape [2 ], 1 , 1 ])
1348+ padding_attention_bias = tf .expand_dims (
1349+ embedding_to_padding (k_unmasked_windows ) * - 1e9 , axis = - 2 )
1350+ padding_attention_bias = tf .tile (padding_attention_bias ,
1351+ [1 , 1 , 1 , query_block_size , 1 ])
1352+ attention_bias = tf .concat ([masked_attention_bias , padding_attention_bias ],
1353+ axis = - 1 )
1354+ # combine memory windows
1355+ k_windows = tf .concat ([self_k_part , k_unmasked_windows ], 3 )
1356+ v_windows = tf .concat ([self_v_part , v_unmasked_windows ], 3 )
1357+ output = dot_product_attention (
1358+ q , k_windows , v_windows , attention_bias , dropout_rate = 0. ,
1359+ name = "dilated_1d" , make_image_summary = False )
1360+ output = tf .reshape (output , [batch_size , num_heads , - 1 , depth_v ])
1361+ # Remove the padding if introduced
1362+ output = tf .slice (output , [0 , 0 , 0 , 0 ], [- 1 , - 1 , original_length , - 1 ])
1363+ output .set_shape (v_list_shape )
1364+ return output
1365+
1366+
10931367def local_attention_2d (q ,
10941368 k ,
10951369 v ,
@@ -1441,6 +1715,8 @@ def multihead_attention(query_antecedent,
14411715 q_padding = "VALID" ,
14421716 kv_padding = "VALID" ,
14431717 cache = None ,
1718+ gap_size = 0 ,
1719+ num_memory_blocks = 2 ,
14441720 name = None ,
14451721 ** kwargs ):
14461722 """Multihead scaled-dot-product attention with input/output transformations.
@@ -1475,6 +1751,10 @@ def multihead_attention(query_antecedent,
14751751 be empty Tensors of the appropriate shape.
14761752 'k' [batch_size, 0, key_channels]
14771753 'v' [batch_size, 0, value_channels]
1754+ gap_size: Integer option for dilated attention to indicate spacing between
1755+ memory blocks.
1756+ num_memory_blocks: Integer option to indicate how many memory blocks to look
1757+ at.
14781758 name: an optional string
14791759 **kwargs (dict): Params for the attention function
14801760
@@ -1542,13 +1822,22 @@ def multihead_attention(query_antecedent,
15421822 dropout_rate , image_shapes )
15431823 elif attention_type == "local_mask_right" :
15441824 x = masked_local_attention_1d (q , k , v , block_length = block_length )
1545- else :
1546- assert attention_type == "local_unmasked"
1825+ elif attention_type == "local_unmasked" :
15471826 x = local_attention_1d (
15481827 q , k , v , block_length = block_length , filter_width = block_width )
1828+ elif attention_type == "masked_dilated_1d" :
1829+ x = masked_dilated_self_attention_1d (q , k , v , block_length ,
1830+ block_width ,
1831+ gap_size ,
1832+ num_memory_blocks )
1833+ else :
1834+ assert attention_type == "unmasked_dilated_1d"
1835+ x = dilated_self_attention_1d (q , k , v , block_length ,
1836+ block_width ,
1837+ gap_size ,
1838+ num_memory_blocks )
15491839 x = combine_heads (x )
15501840 x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
1551-
15521841 if additional_returned_value is not None :
15531842 return x , additional_returned_value
15541843 return x
0 commit comments