From 7dde3ec0bb239d43ac0339c89dc742e41ede10c0 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 21 Jan 2026 20:59:22 -0800 Subject: [PATCH 1/3] fix: enable CPU export and inference by removing pin_memory on scale tensor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Create the box scale tensor directly on the target device instead of using pin_memory().to(device, non_blocking=True). This enables: - CPU-only inference (pin_memory requires CUDA) - Apple MPS inference (pin_memory not supported) - PT2 export without runtime patching The scale tensor is always exactly 4 floats (16-32 bytes). For such a small tensor, the pin_memory overhead likely exceeds any async transfer benefit. Creating the tensor directly on device avoids the CPU→GPU transfer entirely. --- sam3/model/geometry_encoders.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sam3/model/geometry_encoders.py b/sam3/model/geometry_encoders.py index d60ee54..b1b2314 100644 --- a/sam3/model/geometry_encoders.py +++ b/sam3/model/geometry_encoders.py @@ -644,9 +644,9 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats): # boxes are [Num_boxes, bs, 4], normalized in [0, 1] # We need to denormalize, and convert to [x, y, x, y] boxes_xyxy = box_cxcywh_to_xyxy(boxes) - scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) - scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) - scale = scale.view(1, 1, 4) + scale = torch.tensor( + [W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device + ).view(1, 1, 4) boxes_xyxy = boxes_xyxy * scale sampled = torchvision.ops.roi_align( img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size From 8a1f8c2c2d8d3ab854a85513dbe0d8d19d8ac381 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 21 Jan 2026 22:35:18 -0800 Subject: [PATCH 2/3] Add ExportFriendlyMultiheadAttention for dynamic shape torch.export This adds a custom MultiheadAttention implementation that bypasses F.multi_head_attention_forward to enable torch.export with dynamic shapes (e.g., variable image H/W). The problem: nn.MultiheadAttention uses F.multi_head_attention_forward which has internal guards on sequence length (e.g., Eq(seq_len, 5184)) that fail during torch.export because the sequence length is symbolic. The solution: ExportFriendlyMultiheadAttention: - Manually projects Q, K, V using the same combined in_proj_weight - Calls F.scaled_dot_product_attention directly - Avoids all shape validation guards in F.multi_head_attention_forward Also adds replace_mha_with_export_friendly() utility function to recursively replace all nn.MultiheadAttention modules in a model. Related PyTorch issues: - https://github.com/pytorch/pytorch/issues/170127 - https://github.com/pytorch/pytorch/issues/124502 --- sam3/model/model_misc.py | 280 ++++++++++++++++++++++++++++++++++----- 1 file changed, 249 insertions(+), 31 deletions(-) diff --git a/sam3/model/model_misc.py b/sam3/model/model_misc.py index d961461..b0846a7 100644 --- a/sam3/model/model_misc.py +++ b/sam3/model/model_misc.py @@ -36,6 +36,244 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) +class ExportFriendlyMultiheadAttention(nn.Module): + """MultiheadAttention using F.scaled_dot_product_attention for torch.export compatibility. + + The standard nn.MultiheadAttention uses F.multi_head_attention_forward which has + internal guards on sequence length that fail with dynamic shapes in torch.export. + This implementation uses F.scaled_dot_product_attention directly to avoid those guards. + + Why this is needed: + ------------------- + When exporting with dynamic shapes (e.g., variable image H/W), PyTorch's + F.multi_head_attention_forward creates guards like `Eq(seq_len, 5184)` which + fail during torch.export because the sequence length is symbolic. + + The guard happens in shape validation code (checking attn_mask dimensions) + BEFORE scaled_dot_product_attention is even called, so using + `sdpa_kernel([SDPBackend.MATH])` alone doesn't help. + + This class bypasses F.multi_head_attention_forward entirely by: + 1. Manually projecting Q, K, V + 2. Calling F.scaled_dot_product_attention directly + 3. Avoiding all shape validation guards + + Related PyTorch issues: + - https://github.com/pytorch/pytorch/issues/170127 + - https://github.com/pytorch/pytorch/issues/124502 + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + batch_first: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.batch_first = batch_first + self.dropout = dropout + + assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + + # Combined QKV projection for efficiency + self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim)) + if bias: + self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter("in_proj_bias", None) + + # Output projection + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.xavier_uniform_(self.in_proj_weight) + if self.in_proj_bias is not None: + nn.init.zeros_(self.in_proj_bias) + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.zeros_(self.out_proj.bias) + + @classmethod + def from_nn_mha(cls, mha: nn.MultiheadAttention) -> "ExportFriendlyMultiheadAttention": + """Create an ExportFriendlyMultiheadAttention from an nn.MultiheadAttention. + + Copies all weights from the source module. + + Args: + mha: Source nn.MultiheadAttention module + + Returns: + New ExportFriendlyMultiheadAttention with copied weights + """ + # Create new instance with same configuration + new_mha = cls( + embed_dim=mha.embed_dim, + num_heads=mha.num_heads, + dropout=mha.dropout, + bias=mha.in_proj_bias is not None, + batch_first=mha.batch_first, + ) + + # Copy weights + with torch.no_grad(): + new_mha.in_proj_weight.copy_(mha.in_proj_weight) + if mha.in_proj_bias is not None: + new_mha.in_proj_bias.copy_(mha.in_proj_bias) + new_mha.out_proj.weight.copy_(mha.out_proj.weight) + if mha.out_proj.bias is not None: + new_mha.out_proj.bias.copy_(mha.out_proj.bias) + + return new_mha + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + need_weights: bool = False, + ) -> tuple[Tensor, None]: + """Forward pass using scaled dot product attention. + + Args: + query: (L, N, E) or (N, L, E) if batch_first + key: (S, N, E) or (N, S, E) if batch_first + value: (S, N, E) or (N, S, E) if batch_first + key_padding_mask: (N, S) where True means ignore + attn_mask: (L, S) or (N*num_heads, L, S) + need_weights: Ignored, always returns None for weights + + Returns: + attn_output: (L, N, E) or (N, L, E) if batch_first + attn_weights: Always None (for compatibility) + """ + # Convert to batch_first format for SDPA + if not self.batch_first: + query = query.transpose(0, 1) # (N, L, E) + key = key.transpose(0, 1) # (N, S, E) + value = value.transpose(0, 1) # (N, S, E) + + batch_size, tgt_len, _ = query.shape + src_len = key.shape[1] + + # Project Q, K, V using combined weight + # For cross-attention, we project Q from query and K,V from key/value + q = F.linear( + query, + self.in_proj_weight[: self.embed_dim], + self.in_proj_bias[: self.embed_dim] if self.in_proj_bias is not None else None, + ) + k = F.linear( + key, + self.in_proj_weight[self.embed_dim : 2 * self.embed_dim], + self.in_proj_bias[self.embed_dim : 2 * self.embed_dim] + if self.in_proj_bias is not None + else None, + ) + v = F.linear( + value, + self.in_proj_weight[2 * self.embed_dim :], + self.in_proj_bias[2 * self.embed_dim :] if self.in_proj_bias is not None else None, + ) + + # Reshape to (N, num_heads, L/S, head_dim) + q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1, 2) + + # Handle attention mask + # SDPA expects: (N, num_heads, L, S) or (L, S) broadcastable + # Input attn_mask is (L, S) or (N*num_heads, L, S) + if attn_mask is not None: + if attn_mask.dim() == 2: + # (L, S) -> expand for SDPA + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) # (1, 1, L, S) + elif attn_mask.dim() == 3: + # (N*num_heads, L, S) -> (N, num_heads, L, S) + attn_mask = attn_mask.view(batch_size, self.num_heads, tgt_len, src_len) + + # Handle key_padding_mask + # SDPA expects additive mask, True = -inf + if key_padding_mask is not None: + # (N, S) -> (N, 1, 1, S) + key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) + # Convert bool mask to additive mask + key_padding_mask = key_padding_mask.to(dtype=q.dtype) + key_padding_mask = key_padding_mask.masked_fill(key_padding_mask.bool(), float("-inf")) + + if attn_mask is not None: + attn_mask = attn_mask + key_padding_mask + else: + attn_mask = key_padding_mask + + # Use scaled dot product attention + dropout_p = self.dropout if self.training else 0.0 + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p + ) + + # Reshape back: (N, num_heads, L, head_dim) -> (N, L, E) + attn_output = ( + attn_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, self.embed_dim) + ) + + # Output projection + attn_output = self.out_proj(attn_output) + + # Convert back to seq_first if needed + if not self.batch_first: + attn_output = attn_output.transpose(0, 1) # (L, N, E) + + return attn_output, None + + +def replace_mha_with_export_friendly( + module: nn.Module, + verbose: bool = False, +) -> int: + """Replace all nn.MultiheadAttention modules with ExportFriendlyMultiheadAttention. + + This enables torch.export with dynamic shapes by bypassing + F.multi_head_attention_forward which has guards on sequence length. + + Recursively traverses the module tree and replaces: + - nn.MultiheadAttention + - MultiheadAttentionWrapper (subclass of nn.MultiheadAttention) + + Args: + module: Root module to process (modified in-place) + verbose: If True, print each replacement + + Returns: + Number of modules replaced + """ + count = 0 + + for name, child in list(module.named_children()): + if isinstance(child, nn.MultiheadAttention): + # Replace with export-friendly version + new_child = ExportFriendlyMultiheadAttention.from_nn_mha(child) + setattr(module, name, new_child) + count += 1 + if verbose: + print( + f"Replaced {name}: {type(child).__name__} -> ExportFriendlyMultiheadAttention" + ) + else: + # Recurse into children + count += replace_mha_with_export_friendly(child, verbose=verbose) + + return count + + class DotProductScoring(torch.nn.Module): def __init__( self, @@ -151,11 +389,7 @@ def __init__( def _reset_parameters(self): for n, p in self.named_parameters(): if p.dim() > 1: - if ( - "box_embed" not in n - and "query_embed" not in n - and "reference_points" not in n - ): + if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n: nn.init.xavier_uniform_(p) @@ -249,26 +483,18 @@ def gen_sineembed_for_position(pos_tensor, num_feats=256): y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 - ).flatten(2) - pos_y = torch.stack( - (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 - ).flatten(2) + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t - pos_w = torch.stack( - (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 - ).flatten(2) + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t - pos_h = torch.stack( - (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 - ).flatten(2) + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: @@ -322,11 +548,9 @@ def __init__( loss_stages: Optional[List[int]] = None, ): if output is not None: - assert ( - isinstance(output, list) - and len(output) > 0 - and isinstance(output[0], list) - ), "Expected output to be a list of lists" + assert isinstance(output, list) and len(output) > 0 and isinstance(output[0], list), ( + "Expected output to be a list of lists" + ) self.output = output else: self.output = [] @@ -379,9 +603,7 @@ class _IterationMode(AbstractContextManager): This class is used internally by the SAM3Output.iteration_mode method. """ - def __init__( - self, model_output: "SAM3Output", iter_mode: "SAM3Output.IterMode" - ): + def __init__(self, model_output: "SAM3Output", iter_mode: "SAM3Output.IterMode"): self._model_output = model_output self._orig_iter_mode = model_output.iter_mode self._new_iter_mode = iter_mode @@ -397,9 +619,7 @@ def __exit__(self, exc_type, exc_value, traceback): return super().__exit__(exc_type, exc_value, traceback) @staticmethod - def iteration_mode( - model_output: "SAM3Output", iter_mode: IterMode - ) -> _IterationMode: + def iteration_mode(model_output: "SAM3Output", iter_mode: IterMode) -> _IterationMode: """ Returns a context manager that allows you to temporarily change the iteration mode of the SAM3Output object. Args: @@ -411,9 +631,7 @@ def iteration_mode( return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode) def append(self, item: list): - assert isinstance(item, list), ( - f"Only list items are supported. Got {type(item)}" - ) + assert isinstance(item, list), f"Only list items are supported. Got {type(item)}" self.output.append(item) def __repr__(self): From c956a9af011a72f8ede35272a08ce9e3fc6ec9d0 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 21 Jan 2026 23:18:50 -0800 Subject: [PATCH 3/3] Remove H/W keyed caches for torch.export dynamic shape support During torch.export with dynamic H/W dimensions, SymInt values cannot be used as dict keys. These caches prevented dynamic shape export. Changes: - position_encoding.py: Remove (H, W) keyed cache in forward() - decoder.py: Remove coord_cache dict lookup in _get_rpb_matrix() The computation is cheap (just torch.arange) so always computing is acceptable for export use cases. --- sam3/model/decoder.py | 70 ++++++++------------------------- sam3/model/position_encoding.py | 18 +++------ 2 files changed, 23 insertions(+), 65 deletions(-) diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index 7a204be..cf52288 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -124,9 +124,7 @@ def forward( tgt_query_pos_o2o = torch.cat( [torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0 ) - tgt_query_pos = torch.cat( - [torch.zeros_like(presence_token), tgt_query_pos], dim=0 - ) + tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0) q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o) tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0] @@ -274,9 +272,7 @@ def __init__( if resolution is not None and stride is not None: feat_size = resolution // stride - coords_h, coords_w = self._get_coords( - feat_size, feat_size, device="cuda" - ) + coords_h, coords_w = self._get_coords(feat_size, feat_size, device="cuda") self.compilable_cord_cache = (coords_h, coords_w) self.compilable_stored_size = (feat_size, feat_size) @@ -329,27 +325,11 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): H, W = feat_size boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1) bs, num_queries, _ = boxes_xyxy.shape - if self.compilable_cord_cache is None: - self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device) - self.compilable_stored_size = (H, W) - - if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == ( - H, - W, - ): - # good, hitting the cache, will be compilable - coords_h, coords_w = self.compilable_cord_cache - else: - # cache miss, will create compilation issue - # In case we're not compiling, we'll still rely on the dict-based cache - if feat_size not in self.coord_cache: - self.coord_cache[feat_size] = self._get_coords( - H, W, reference_boxes.device - ) - coords_h, coords_w = self.coord_cache[feat_size] - assert coords_h.shape == (H,) - assert coords_w.shape == (W,) + # NOTE: Cache removed for torch.export dynamic shape support. + # With symbolic H/W (SymInt), cache keys are unhashable. + # Always compute - cheap operation (just torch.arange). + coords_h, coords_w = self._get_coords(H, W, reference_boxes.device) deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] deltas_y = deltas_y.view(bs, num_queries, -1, 2) @@ -359,16 +339,12 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): if self.boxRPB in ["log", "both"]: deltas_x_log = deltas_x * 8 # normalize to -8, 8 deltas_x_log = ( - torch.sign(deltas_x_log) - * torch.log2(torch.abs(deltas_x_log) + 1.0) - / np.log2(8) + torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8) ) deltas_y_log = deltas_y * 8 # normalize to -8, 8 deltas_y_log = ( - torch.sign(deltas_y_log) - * torch.log2(torch.abs(deltas_y_log) + 1.0) - / np.log2(8) + torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8) ) if self.boxRPB == "log": deltas_x = deltas_x_log @@ -388,19 +364,17 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): act_ckpt_enable=self.training and self.use_act_checkpoint, ) # bs, num_queries, H, n_heads - if not torch.compiler.is_dynamo_compiling(): + if not torch.compiler.is_compiling(): assert deltas_x.shape[:3] == (bs, num_queries, W) assert deltas_y.shape[:3] == (bs, num_queries, H) - B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( - 2 - ) # bs, num_queries, H, W, n_heads - if not torch.compiler.is_dynamo_compiling(): + B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2) # bs, num_queries, H, W, n_heads + if not torch.compiler.is_compiling(): assert B.shape[:4] == (bs, num_queries, H, W) B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W B = B.contiguous() # memeff attn likes ordered strides - if not torch.compiler.is_dynamo_compiling(): + if not torch.compiler.is_compiling(): assert B.shape[2:] == (num_queries, H * W) return B @@ -456,10 +430,7 @@ def forward( if reference_boxes is not None: assert (reference_boxes.shape[0] == self.num_queries) or ( self.use_instance_query - and ( - reference_boxes.shape[0] - == self.instance_query_embed.num_embeddings - ) + and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings) ) reference_boxes = reference_boxes.repeat(2, 1, 1) @@ -499,8 +470,7 @@ def forward( for layer_idx, layer in enumerate(self.layers): reference_points_input = ( - reference_boxes[:, :, None] - * torch.cat([valid_ratios, valid_ratios], -1)[None, :] + reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :] ) # nq, bs, nlevel, 4 query_sine_embed = gen_sineembed_for_position( @@ -511,9 +481,7 @@ def forward( query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model if self.boxRPB != "none" and reference_boxes is not None: - assert spatial_shapes.shape[0] == 1, ( - "only single scale support implemented" - ) + assert spatial_shapes.shape[0] == 1, "only single scale support implemented" memory_mask = self._get_rpb_matrix( reference_boxes, (spatial_shapes[0, 0], spatial_shapes[0, 1]), @@ -591,9 +559,7 @@ def forward( presence_feats = presence_out.clone() if not self.compiled and self.compile_mode is not None: - self.forward = torch.compile( - self.forward, mode=self.compile_mode, fullgraph=True - ) + self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True) self.compiled = True return ( @@ -671,9 +637,7 @@ def forward( src_pos[0], ) - assert src.shape[1] == prompt.shape[1], ( - "Batch size must be the same for src and prompt" - ) + assert src.shape[1] == prompt.shape[1], "Batch size must be the same for src and prompt" output = src diff --git a/sam3/model/position_encoding.py b/sam3/model/position_encoding.py index a6a1266..c42c759 100644 --- a/sam3/model/position_encoding.py +++ b/sam3/model/position_encoding.py @@ -62,12 +62,8 @@ def _encode_xy(self, x, y): pos_x = x_embed[:, None] / dim_t pos_y = y_embed[:, None] / dim_t - pos_x = torch.stack( - (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 - ).flatten(1) - pos_y = torch.stack( - (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 - ).flatten(1) + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) return pos_x, pos_y @torch.no_grad() @@ -89,10 +85,10 @@ def encode_points(self, x, y, labels): @torch.no_grad() def forward(self, x): - cache_key = None - cache_key = (x.shape[-2], x.shape[-1]) - if cache_key in self.cache: - return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + # NOTE: Cache removed for torch.export dynamic shape support. + # With symbolic H/W, cache_key = (x.shape[-2], x.shape[-1]) contains SymInt + # which cannot be used as dict keys. Always compute instead. + # The computation is cheap (creates 1D aranges and broadcasts). y_embed = ( torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) .view(1, -1, 1) @@ -121,6 +117,4 @@ def forward(self, x): (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - if cache_key is not None: - self.cache[cache_key] = pos[0] return pos