Skip to content

Commit 07fcf52

Browse files
committed
Merge branch 'main' of github.com:OpenBMB/ModelCenter
2 parents e472393 + 096e657 commit 07fcf52

File tree

1 file changed

+1
-18
lines changed

1 file changed

+1
-18
lines changed

model_center/layer/position_embedding.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def forward(self, key_pos = None, query_pos = None, key_segment = None, query_se
244244
# b*q*k
245245
if self.absolute_inner_segment:
246246
absolute_position_bucket = self._absolute_position_bucket(
247-
query_pos - key_pos,
247+
key_pos - query_pos,
248248
bidirectional=self.bidirectional,
249249
num_buckets=self.num_buckets,
250250
max_distance = self.max_distance
@@ -260,23 +260,6 @@ def forward(self, key_pos = None, query_pos = None, key_segment = None, query_se
260260
embeds = embeds.permute(0, 3, 1, 2).contiguous()
261261
return embeds
262262

263-
def _relative_position_bucket(self, relative_position, bidirectional=True, num_buckets = 32, max_exact=0.125, max_distance=0.5):
264-
relative_buckets = 0
265-
if bidirectional:
266-
num_buckets //= 2
267-
max_exact /= 2
268-
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
269-
relative_position = torch.abs(relative_position)
270-
else:
271-
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
272-
max_exact /= 2
273-
is_small = relative_position < max_exact
274-
half_num_buckets = num_buckets // 2
275-
relative_postion_if_large = half_num_buckets + (torch.log(relative_position / max_exact) / math.log(max_distance / max_exact) * (num_buckets - half_num_buckets)).to(torch.int32)
276-
relative_postion_if_large = torch.min(relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1))
277-
relative_buckets += torch.where(is_small, (relative_position / max_exact * half_num_buckets).to(torch.int32), relative_postion_if_large)
278-
return relative_buckets
279-
280263
def _segment_relative_position_bucket(self, query_segment, key_segment):
281264
return query_segment * self.num_segments + key_segment
282265

0 commit comments

Comments
 (0)