Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,15 @@ dmypy.json
# Model weights and checkpoints
*.pth
*.pt
*.pt2
*.bin
*.ckpt
*.safetensors
weights/
checkpoints/
sam3_logs/
artifacts/
tests/export/export_logs/

# Data files
*.h5
Expand Down
Binary file added assets/images/cat_dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 21 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "sam3"
dynamic = ["version"]
description = "SAM3 (Segment Anything Model 3) implementation"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9,<3.13"
license = {file = "LICENSE"}
authors = [
{name = "Meta AI Research"}
Expand All @@ -33,6 +33,8 @@ dependencies = [
"iopath>=0.1.10",
"typing_extensions",
"huggingface_hub",
"einops",
"psutil",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -92,6 +94,23 @@ sam3 = ["assets/*.txt.gz"]
[tool.setuptools.dynamic]
version = {attr = "sam3.__version__"}

[dependency-groups]
dev = [
"pytest",
"pytest-cov",
"black==24.2.0",
"ufmt==2.8.0",
"ruff-api==0.1.0",
"usort==1.0.2",
"gitpython==3.1.31",
"yt-dlp",
"pandas",
"opencv-python",
"pycocotools",
"numba",
"python-rapidjson",
]

[tool.black]
line-length = 88
target-version = ['py38', 'py39', 'py310', 'py311', 'py312']
Expand Down Expand Up @@ -133,3 +152,4 @@ testpaths = ["tests"]
python_files = "test_*.py"
python_classes = "Test*"
python_functions = "test_*"
markers = ["slow: long-running export and artifact tests"]
103 changes: 73 additions & 30 deletions sam3/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import torch
import torch.nn.functional as F
from sam3.sam.transformer import RoPEAttention
from torch import nn, Tensor
from torchvision.ops.roi_align import RoIAlign
Expand Down Expand Up @@ -151,24 +152,40 @@ def forward(
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)

if presence_token is not None:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
cross_attn_mask = torch.cat(
[presence_token_mask, cross_attn_mask], dim=1
) # (bs*nheads, 1+nq, hw)
if presence_token is not None and cross_attn_mask is not None:
if cross_attn_mask.dim() == 4:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :, :1, :])
cross_attn_mask = torch.cat(
[presence_token_mask, cross_attn_mask], dim=2
) # (bs, nheads, 1+nq, hw)
else:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
cross_attn_mask = torch.cat(
[presence_token_mask, cross_attn_mask], dim=1
) # (bs*nheads, 1+nq, hw)

# Cross attention to image
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=cross_attn_mask,
key_padding_mask=(
memory_key_padding_mask.transpose(0, 1)
if memory_key_padding_mask is not None
else None
),
)[0]
key_padding_mask = (
memory_key_padding_mask.transpose(0, 1)
if memory_key_padding_mask is not None
else None
)
if cross_attn_mask is not None and cross_attn_mask.dim() == 4:
tgt2 = self._cross_attn_with_rpb(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_bias=cross_attn_mask,
key_padding_mask=key_padding_mask,
)
else:
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=cross_attn_mask,
key_padding_mask=key_padding_mask,
)[0]

tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
Expand All @@ -183,6 +200,44 @@ def forward(

return tgt, presence_token_out

def _cross_attn_with_rpb(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attn_bias: Tensor,
key_padding_mask: Optional[Tensor],
) -> Tensor:
mha = self.cross_attn
assert isinstance(mha, nn.MultiheadAttention)
q, k, v = F._in_projection_packed(
query, key, value, mha.in_proj_weight, mha.in_proj_bias
)
tgt_len, bsz, _ = q.shape
num_heads = mha.num_heads
head_dim = mha.head_dim
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
k = k.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
v = v.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
src_len = k.shape[2]
bias = attn_bias
if bias.dim() == 3:
bias = bias.view(bsz, num_heads, tgt_len, src_len)
if key_padding_mask is not None:
pad = key_padding_mask[:, None, None, :].to(dtype=q.dtype)
pad = pad.masked_fill(pad > 0, float("-inf"))
bias = pad if bias is None else bias + pad
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=bias,
dropout_p=mha.dropout if self.training else 0.0,
is_causal=False,
)
attn_output = attn_output.permute(2, 0, 1, 3).reshape(tgt_len, bsz, -1)
return F.linear(attn_output, mha.out_proj.weight, mha.out_proj.bias)


class TransformerDecoder(nn.Module):
def __init__(
Expand Down Expand Up @@ -333,10 +388,7 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
self.compilable_stored_size = (H, W)

if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
H,
W,
):
if torch.compiler.is_dynamo_compiling():
# good, hitting the cache, will be compilable
coords_h, coords_w = self.compilable_cord_cache
else:
Expand All @@ -348,8 +400,6 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
)
coords_h, coords_w = self.coord_cache[feat_size]

assert coords_h.shape == (H,)
assert coords_w.shape == (W,)

deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
Expand Down Expand Up @@ -388,20 +438,13 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
act_ckpt_enable=self.training and self.use_act_checkpoint,
) # bs, num_queries, H, n_heads

if not torch.compiler.is_dynamo_compiling():
assert deltas_x.shape[:3] == (bs, num_queries, W)
assert deltas_y.shape[:3] == (bs, num_queries, H)

B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
2
) # bs, num_queries, H, W, n_heads
if not torch.compiler.is_dynamo_compiling():
assert B.shape[:4] == (bs, num_queries, H, W)
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
B = B.contiguous() # memeff attn likes ordered strides
if not torch.compiler.is_dynamo_compiling():
assert B.shape[2:] == (num_queries, H * W)
return B

def forward(
Expand Down Expand Up @@ -510,6 +553,7 @@ def forward(
# conditional query
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model

memory_mask = None
if self.boxRPB != "none" and reference_boxes is not None:
assert spatial_shapes.shape[0] == 1, (
"only single scale support implemented"
Expand All @@ -518,7 +562,6 @@ def forward(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
)
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training:
assert self.use_act_checkpoint, (
"Activation checkpointing not enabled in the decoder"
Expand Down
2 changes: 1 addition & 1 deletion sam3/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def forward(
else None
)
else:
assert all(x.dim == 4 for x in src), (
assert all(x.dim() == 4 for x in src), (
"expected list of (bs, c, h, w) tensors"
)

Expand Down
21 changes: 19 additions & 2 deletions sam3/model/geometry_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,11 +645,28 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
# We need to denormalize, and convert to [x, y, x, y]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
if (
torch.is_tensor(scale)
and scale.device.type == "cpu"
and boxes_xyxy.device.type != "cpu"
and not torch._dynamo.is_compiling()
):
scale = scale.pin_memory()
scale = scale.to(
device=boxes_xyxy.device,
non_blocking=boxes_xyxy.device.type != "cpu" and not torch._dynamo.is_compiling(),
)
scale = scale.view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
boxes_xyxy = boxes_xyxy.transpose(0, 1)
batch_idx = torch.arange(
bs, device=boxes_xyxy.device, dtype=boxes_xyxy.dtype
)
batch_idx = batch_idx.view(bs, 1, 1).expand(bs, n_boxes, 1)
boxes_for_roi = torch.cat([batch_idx, boxes_xyxy], dim=-1)
boxes_for_roi = boxes_for_roi.reshape(-1, 5)
sampled = torchvision.ops.roi_align(
img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size
img_feats, boxes_for_roi.float(), self.roi_size
)
assert list(sampled.shape) == [
bs * n_boxes,
Expand Down
16 changes: 6 additions & 10 deletions sam3/model/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(

def _encode_xy(self, x, y):
# The positions are expected to be normalized
assert len(x) == len(y) and x.ndim == y.ndim == 1
# torch._check(len(x) == len(y) and x.ndim == y.ndim == 1)
x_embed = x * self.scale
y_embed = y * self.scale

Expand All @@ -62,12 +62,8 @@ def _encode_xy(self, x, y):

pos_x = x_embed[:, None] / dim_t
pos_y = y_embed[:, None] / dim_t
pos_x = torch.stack(
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
).flatten(1)
pos_y = torch.stack(
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
).flatten(1)
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y

@torch.no_grad()
Expand All @@ -89,9 +85,9 @@ def encode_points(self, x, y, labels):

@torch.no_grad()
def forward(self, x):
cache_key = None
cache_key = (x.shape[-2], x.shape[-1])
if cache_key in self.cache:
use_cache = all(isinstance(dim, int) for dim in cache_key)
if use_cache and cache_key in self.cache:
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
y_embed = (
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
Expand Down Expand Up @@ -121,6 +117,6 @@ def forward(self, x):
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
if cache_key is not None:
if use_cache:
self.cache[cache_key] = pos[0]
return pos
Loading