@@ -244,7 +244,7 @@ def forward(self, key_pos = None, query_pos = None, key_segment = None, query_se
244244 # b*q*k
245245 if self .absolute_inner_segment :
246246 absolute_position_bucket = self ._absolute_position_bucket (
247- query_pos - key_pos ,
247+ key_pos - query_pos ,
248248 bidirectional = self .bidirectional ,
249249 num_buckets = self .num_buckets ,
250250 max_distance = self .max_distance
@@ -260,23 +260,6 @@ def forward(self, key_pos = None, query_pos = None, key_segment = None, query_se
260260 embeds = embeds .permute (0 , 3 , 1 , 2 ).contiguous ()
261261 return embeds
262262
263- def _relative_position_bucket (self , relative_position , bidirectional = True , num_buckets = 32 , max_exact = 0.125 , max_distance = 0.5 ):
264- relative_buckets = 0
265- if bidirectional :
266- num_buckets //= 2
267- max_exact /= 2
268- relative_buckets = (relative_position > 0 ).to (torch .int32 ) * num_buckets
269- relative_position = torch .abs (relative_position )
270- else :
271- relative_position = - torch .min (relative_position , torch .zeros_like (relative_position ))
272- max_exact /= 2
273- is_small = relative_position < max_exact
274- half_num_buckets = num_buckets // 2
275- relative_postion_if_large = half_num_buckets + (torch .log (relative_position / max_exact ) / math .log (max_distance / max_exact ) * (num_buckets - half_num_buckets )).to (torch .int32 )
276- relative_postion_if_large = torch .min (relative_postion_if_large , torch .full_like (relative_postion_if_large , num_buckets - 1 ))
277- relative_buckets += torch .where (is_small , (relative_position / max_exact * half_num_buckets ).to (torch .int32 ), relative_postion_if_large )
278- return relative_buckets
279-
280263 def _segment_relative_position_bucket (self , query_segment , key_segment ):
281264 return query_segment * self .num_segments + key_segment
282265
0 commit comments