Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 17 additions & 53 deletions sam3/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def forward(
tgt_query_pos_o2o = torch.cat(
[torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0
)
tgt_query_pos = torch.cat(
[torch.zeros_like(presence_token), tgt_query_pos], dim=0
)
tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0)

q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0]
Expand Down Expand Up @@ -274,9 +272,7 @@ def __init__(

if resolution is not None and stride is not None:
feat_size = resolution // stride
coords_h, coords_w = self._get_coords(
feat_size, feat_size, device="cuda"
)
coords_h, coords_w = self._get_coords(feat_size, feat_size, device="cuda")
self.compilable_cord_cache = (coords_h, coords_w)
self.compilable_stored_size = (feat_size, feat_size)

Expand Down Expand Up @@ -329,27 +325,11 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
H, W = feat_size
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1)
bs, num_queries, _ = boxes_xyxy.shape
if self.compilable_cord_cache is None:
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
self.compilable_stored_size = (H, W)

if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
H,
W,
):
# good, hitting the cache, will be compilable
coords_h, coords_w = self.compilable_cord_cache
else:
# cache miss, will create compilation issue
# In case we're not compiling, we'll still rely on the dict-based cache
if feat_size not in self.coord_cache:
self.coord_cache[feat_size] = self._get_coords(
H, W, reference_boxes.device
)
coords_h, coords_w = self.coord_cache[feat_size]

assert coords_h.shape == (H,)
assert coords_w.shape == (W,)
# NOTE: Cache removed for torch.export dynamic shape support.
# With symbolic H/W (SymInt), cache keys are unhashable.
# Always compute - cheap operation (just torch.arange).
coords_h, coords_w = self._get_coords(H, W, reference_boxes.device)

deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
Expand All @@ -359,16 +339,12 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
if self.boxRPB in ["log", "both"]:
deltas_x_log = deltas_x * 8 # normalize to -8, 8
deltas_x_log = (
torch.sign(deltas_x_log)
* torch.log2(torch.abs(deltas_x_log) + 1.0)
/ np.log2(8)
torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8)
)

deltas_y_log = deltas_y * 8 # normalize to -8, 8
deltas_y_log = (
torch.sign(deltas_y_log)
* torch.log2(torch.abs(deltas_y_log) + 1.0)
/ np.log2(8)
torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8)
)
if self.boxRPB == "log":
deltas_x = deltas_x_log
Expand All @@ -388,19 +364,17 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
act_ckpt_enable=self.training and self.use_act_checkpoint,
) # bs, num_queries, H, n_heads

if not torch.compiler.is_dynamo_compiling():
if not torch.compiler.is_compiling():
assert deltas_x.shape[:3] == (bs, num_queries, W)
assert deltas_y.shape[:3] == (bs, num_queries, H)

B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
2
) # bs, num_queries, H, W, n_heads
if not torch.compiler.is_dynamo_compiling():
B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2) # bs, num_queries, H, W, n_heads
if not torch.compiler.is_compiling():
assert B.shape[:4] == (bs, num_queries, H, W)
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
B = B.contiguous() # memeff attn likes ordered strides
if not torch.compiler.is_dynamo_compiling():
if not torch.compiler.is_compiling():
assert B.shape[2:] == (num_queries, H * W)
return B

Expand Down Expand Up @@ -456,10 +430,7 @@ def forward(
if reference_boxes is not None:
assert (reference_boxes.shape[0] == self.num_queries) or (
self.use_instance_query
and (
reference_boxes.shape[0]
== self.instance_query_embed.num_embeddings
)
and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings)
)
reference_boxes = reference_boxes.repeat(2, 1, 1)

Expand Down Expand Up @@ -499,8 +470,7 @@ def forward(

for layer_idx, layer in enumerate(self.layers):
reference_points_input = (
reference_boxes[:, :, None]
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4

query_sine_embed = gen_sineembed_for_position(
Expand All @@ -511,9 +481,7 @@ def forward(
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model

if self.boxRPB != "none" and reference_boxes is not None:
assert spatial_shapes.shape[0] == 1, (
"only single scale support implemented"
)
assert spatial_shapes.shape[0] == 1, "only single scale support implemented"
memory_mask = self._get_rpb_matrix(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
Expand Down Expand Up @@ -591,9 +559,7 @@ def forward(
presence_feats = presence_out.clone()

if not self.compiled and self.compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=self.compile_mode, fullgraph=True
)
self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True)
self.compiled = True

return (
Expand Down Expand Up @@ -671,9 +637,7 @@ def forward(
src_pos[0],
)

assert src.shape[1] == prompt.shape[1], (
"Batch size must be the same for src and prompt"
)
assert src.shape[1] == prompt.shape[1], "Batch size must be the same for src and prompt"

output = src

Expand Down
6 changes: 3 additions & 3 deletions sam3/model/geometry_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,9 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
# boxes are [Num_boxes, bs, 4], normalized in [0, 1]
# We need to denormalize, and convert to [x, y, x, y]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
scale = scale.view(1, 1, 4)
scale = torch.tensor(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removes memory pinning for cpu export. see #1

[W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device
).view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
sampled = torchvision.ops.roi_align(
img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size
Expand Down
Loading