Skip to content
Open

Int4 #281

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ poetry.lock
dist
clients/python/moondream/torch
wandb/
bitblas_cache/
moondream_finetune.safetensors
2 changes: 2 additions & 0 deletions moondream/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class TextConfig:
n_heads: int = 32
n_kv_heads: int = 32
prefix_attn: int = 730
group_size: int = 128
cache_dir: str = "./bitblas_cache"


@dataclass(frozen=True)
Expand Down
65 changes: 65 additions & 0 deletions moondream/torch/layers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from dataclasses import dataclass
from typing import Literal

import bitblas
from bitblas.cache import OperatorCache

import torch
from torch.nn import functional as F
import torch.nn as nn


def gelu_approx(x):
Expand All @@ -15,6 +19,66 @@ class LinearWeights:
bias: torch.Tensor


class Linear(nn.Module):
"""
Linear layer with support for bitblas quantization.
If dtype is torch.int8, it uses bitblas for quantization.
Otherwise, it uses a standard nn.Linear layer.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
operator_cache: OperatorCache = None,
cache_dir: str = None,
group_size: int = 128,
):
super().__init__()

if dtype == torch.int8:
self.linear = bitblas.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
with_zeros=True,
zeros_mode="original",
with_scaling=True,
A_dtype="float16",
W_dtype="uint4",
accum_dtype="float16",
out_dtype="float16",
fast_decoding=True,
enable_tuning=True,
operator_cache=operator_cache,
database_path=cache_dir,
group_size=group_size,
)
else:
self.linear = nn.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
dtype=torch.float16,
)

def forward(self, x):
return self.linear(x)

@property
def weight(self) -> torch.Tensor:
try:
return self.linear.weight
except AttributeError:
return self.linear.qweight

@property
def bias(self) -> torch.Tensor:
return self.linear.bias


def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
return F.linear(x, w.weight, w.bias)

Expand All @@ -37,6 +101,7 @@ class MLPWeights:


def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:

x = w.fc1(x)
x = gelu_approx(x)
x = w.fc2(x)
Expand Down
26 changes: 15 additions & 11 deletions moondream/torch/moondream.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ class MoondreamModel(nn.Module):
def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
super().__init__()
self.config = config
self.dtype = dtype
self.setup_caches_flag = setup_caches

self.tokenizer = Tokenizer.from_pretrained(
"vikhyatk/moondream2", revision="2025-01-09"
)

self.vision = build_vision_model(config.vision, dtype)
self.text = build_text_model(config.text, dtype)

self.text = None

# Region Model
self.region = nn.ModuleDict(
Expand Down Expand Up @@ -125,11 +129,11 @@ def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=Tr
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
self.register_buffer("attn_mask", attn_mask, persistent=False)

# Initialize KV caches.
if setup_caches:
self._setup_caches()

def _setup_caches(self):
"""Setup KV caches for the text model"""
if self.text is None:
return # Can't set up caches without text model

c = self.config.text
for b in self.text.blocks:
b.kv_cache = KVCache(
Expand Down Expand Up @@ -163,15 +167,14 @@ def _decode_one_tok(

def compile(self):
# TODO: vision_projection is not being compiled
self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
self._prefill = torch.compile(self._prefill, fullgraph=True)
self._decode_one_tok = torch.compile(
self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
self._vis_enc = torch.compile(
self._vis_enc, fullgraph=False, mode="reduce-overhead"
)
# self._prefill = torch.compile(self._prefill)
# self._decode_one_tok = torch.compile(self._decode_one_tok)

def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)

torch._dynamo.mark_dynamic(all_crops, 0)

outputs = self._vis_enc(all_crops)
Expand Down Expand Up @@ -201,6 +204,7 @@ def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:

# Run through text model in addition to the vision encoder, to minimize
# re-computation if multiple queries are performed on this image.

with torch.inference_mode():
img_emb = self._run_vision_encoder(image)
bos_emb = text_encoder(
Expand Down Expand Up @@ -236,10 +240,10 @@ def _apply_top_p(self, probs: torch.Tensor, top_p: float):
def _prefill_prompt(
self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
):

with torch.inference_mode():
prompt_emb = text_encoder(prompt_tokens, self.text)
torch._dynamo.mark_dynamic(prompt_emb, 1)

mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
hidden = self._prefill(prompt_emb, mask, pos_ids)
Expand Down
12 changes: 10 additions & 2 deletions moondream/torch/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@

from PIL import Image, ImageDraw
from tqdm import tqdm
import logging
import bitblas

bitblas.logger.setLevel("FATAL")

from .weights import load_weights_into_model
from .moondream import MoondreamModel, MoondreamConfig
import time

if __name__ == "__main__":
start = time.time()
parser = argparse.ArgumentParser()
parser.add_argument("--image", "-i", type=str, required=True)
parser.add_argument("--prompt", "-p", type=str, required=True)
Expand All @@ -32,17 +38,20 @@
config = MoondreamConfig.from_dict(config)
else:
config = MoondreamConfig()

model = MoondreamModel(config)
load_weights_into_model(args.model, model)
model = model.to(device)

# Encode image.
image_path = args.image
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at {image_path}")
image = Image.open(image_path)
model = model.to(device)

if not args.benchmark:

# model.compile()
encoded_image = model.encode_image(image)

# Short caption
Expand Down Expand Up @@ -142,6 +151,5 @@
print("\nQuery Speed (tokens/sec):")
print(f" Mean: {sum(query_speeds)/len(query_speeds):.2f}")
print(f" Min: {min(query_speeds):.2f}")
print(f" Max: {max(query_speeds):.2f}")
else:
raise ValueError("To run benchmarks, make sure you are on a CUDA device")
57 changes: 41 additions & 16 deletions moondream/torch/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import torch.nn as nn

from torch.nn import functional as F
from bitblas.cache import OperatorCache

from .layers import layer_norm, mlp
from .layers import layer_norm, mlp, Linear
from .rope import apply_rotary_emb, precompute_freqs_cis
from .config import TextConfig

Expand All @@ -26,6 +27,7 @@ def attn(
head_dim = d_model // n_heads

qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)

q_dim = n_heads * head_dim
kv_dim = n_kv_heads * head_dim

Expand Down Expand Up @@ -139,6 +141,7 @@ def text_decoder(
n_kv_heads=config.n_kv_heads,
position_ids=position_ids,
)

l_mlp = mlp(l_in, block.mlp)
x = x + l_attn + l_mlp

Expand All @@ -158,44 +161,66 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
return logits


def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
def build_text_model(
config: TextConfig,
linear_dtype: torch.dtype = torch.float16,
layernorm_dtype: torch.dtype = torch.float16,
) -> (
nn.Module
): # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))

operator_cache = None
cache_dir = None
group_size = None
if linear_dtype == torch.int8:

operator_cache = OperatorCache()
cache_dir = config.cache_dir
group_size = config.group_size

def create_linear(in_features, out_features, dtype=linear_dtype):
# factory function for creating Linear layers so we dont have to pass everything again and again
return Linear(
in_features=in_features,
out_features=out_features,
dtype=dtype,
operator_cache=operator_cache,
cache_dir=cache_dir,
group_size=group_size,
)

text = nn.ModuleDict(
{
"blocks": nn.ModuleList(
[
nn.ModuleDict(
{
"ln": nn.LayerNorm(config.dim, dtype=dtype),
"ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
"attn": nn.ModuleDict(
{
"qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
"proj": nn.Linear(
config.dim, config.dim, dtype=dtype
),
"qkv": create_linear(config.dim, qkv_dim),
"proj": create_linear(config.dim, config.dim),
}
),
"mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.dim, config.ff_dim, dtype=dtype
),
"fc2": nn.Linear(
config.ff_dim, config.dim, dtype=dtype
),
"fc1": create_linear(config.dim, config.ff_dim),
"fc2": create_linear(config.ff_dim, config.dim),
}
),
}
)
for _ in range(config.n_layers)
]
),
"post_ln": nn.LayerNorm(config.dim, dtype=dtype),
"lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
"post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
"lm_head": nn.Linear(config.dim, config.vocab_size, dtype=layernorm_dtype),
}
)
text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
text.wte = nn.Parameter(
torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)
)
text.register_buffer(
"freqs_cis",
precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
Expand Down
Loading
Loading