diff --git a/.gitignore b/.gitignore index c5bc1ee8..24e9c03d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ poetry.lock dist clients/python/moondream/torch wandb/ +bitblas_cache/ moondream_finetune.safetensors diff --git a/moondream/torch/config.py b/moondream/torch/config.py index fa4ae05c..e7d18ce7 100644 --- a/moondream/torch/config.py +++ b/moondream/torch/config.py @@ -12,6 +12,8 @@ class TextConfig: n_heads: int = 32 n_kv_heads: int = 32 prefix_attn: int = 730 + group_size: int = 128 + cache_dir: str = "./bitblas_cache" @dataclass(frozen=True) diff --git a/moondream/torch/layers.py b/moondream/torch/layers.py index 4140c18f..6949c7ca 100644 --- a/moondream/torch/layers.py +++ b/moondream/torch/layers.py @@ -1,8 +1,12 @@ from dataclasses import dataclass from typing import Literal +import bitblas +from bitblas.cache import OperatorCache + import torch from torch.nn import functional as F +import torch.nn as nn def gelu_approx(x): @@ -15,6 +19,66 @@ class LinearWeights: bias: torch.Tensor +class Linear(nn.Module): + """ + Linear layer with support for bitblas quantization. + If dtype is torch.int8, it uses bitblas for quantization. + Otherwise, it uses a standard nn.Linear layer. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + operator_cache: OperatorCache = None, + cache_dir: str = None, + group_size: int = 128, + ): + super().__init__() + + if dtype == torch.int8: + self.linear = bitblas.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + with_zeros=True, + zeros_mode="original", + with_scaling=True, + A_dtype="float16", + W_dtype="uint4", + accum_dtype="float16", + out_dtype="float16", + fast_decoding=True, + enable_tuning=True, + operator_cache=operator_cache, + database_path=cache_dir, + group_size=group_size, + ) + else: + self.linear = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + dtype=torch.float16, + ) + + def forward(self, x): + return self.linear(x) + + @property + def weight(self) -> torch.Tensor: + try: + return self.linear.weight + except AttributeError: + return self.linear.qweight + + @property + def bias(self) -> torch.Tensor: + return self.linear.bias + + def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: return F.linear(x, w.weight, w.bias) @@ -37,6 +101,7 @@ class MLPWeights: def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: + x = w.fc1(x) x = gelu_approx(x) x = w.fc2(x) diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index ec9014ae..1190642a 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -66,12 +66,16 @@ class MoondreamModel(nn.Module): def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True): super().__init__() self.config = config + self.dtype = dtype + self.setup_caches_flag = setup_caches self.tokenizer = Tokenizer.from_pretrained( "vikhyatk/moondream2", revision="2025-01-09" ) + self.vision = build_vision_model(config.vision, dtype) - self.text = build_text_model(config.text, dtype) + + self.text = None # Region Model self.region = nn.ModuleDict( @@ -125,11 +129,11 @@ def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=Tr attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1 self.register_buffer("attn_mask", attn_mask, persistent=False) - # Initialize KV caches. - if setup_caches: - self._setup_caches() - def _setup_caches(self): + """Setup KV caches for the text model""" + if self.text is None: + return # Can't set up caches without text model + c = self.config.text for b in self.text.blocks: b.kv_cache = KVCache( @@ -163,15 +167,14 @@ def _decode_one_tok( def compile(self): # TODO: vision_projection is not being compiled - self._vis_enc = torch.compile(self._vis_enc, fullgraph=True) - self._prefill = torch.compile(self._prefill, fullgraph=True) - self._decode_one_tok = torch.compile( - self._decode_one_tok, fullgraph=True, mode="reduce-overhead" + self._vis_enc = torch.compile( + self._vis_enc, fullgraph=False, mode="reduce-overhead" ) + # self._prefill = torch.compile(self._prefill) + # self._decode_one_tok = torch.compile(self._decode_one_tok) def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor: all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device) - torch._dynamo.mark_dynamic(all_crops, 0) outputs = self._vis_enc(all_crops) @@ -201,6 +204,7 @@ def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage: # Run through text model in addition to the vision encoder, to minimize # re-computation if multiple queries are performed on this image. + with torch.inference_mode(): img_emb = self._run_vision_encoder(image) bos_emb = text_encoder( @@ -236,10 +240,10 @@ def _apply_top_p(self, probs: torch.Tensor, top_p: float): def _prefill_prompt( self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float ): + with torch.inference_mode(): prompt_emb = text_encoder(prompt_tokens, self.text) torch._dynamo.mark_dynamic(prompt_emb, 1) - mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :] pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long) hidden = self._prefill(prompt_emb, mask, pos_ids) diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index 20c72be2..8c5189af 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -5,11 +5,17 @@ from PIL import Image, ImageDraw from tqdm import tqdm +import logging +import bitblas + +bitblas.logger.setLevel("FATAL") from .weights import load_weights_into_model from .moondream import MoondreamModel, MoondreamConfig +import time if __name__ == "__main__": + start = time.time() parser = argparse.ArgumentParser() parser.add_argument("--image", "-i", type=str, required=True) parser.add_argument("--prompt", "-p", type=str, required=True) @@ -32,17 +38,20 @@ config = MoondreamConfig.from_dict(config) else: config = MoondreamConfig() + model = MoondreamModel(config) load_weights_into_model(args.model, model) - model = model.to(device) # Encode image. image_path = args.image if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found at {image_path}") image = Image.open(image_path) + model = model.to(device) if not args.benchmark: + + # model.compile() encoded_image = model.encode_image(image) # Short caption @@ -142,6 +151,5 @@ print("\nQuery Speed (tokens/sec):") print(f" Mean: {sum(query_speeds)/len(query_speeds):.2f}") print(f" Min: {min(query_speeds):.2f}") - print(f" Max: {max(query_speeds):.2f}") else: raise ValueError("To run benchmarks, make sure you are on a CUDA device") diff --git a/moondream/torch/text.py b/moondream/torch/text.py index de75fb58..ee97a1ba 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -2,8 +2,9 @@ import torch.nn as nn from torch.nn import functional as F +from bitblas.cache import OperatorCache -from .layers import layer_norm, mlp +from .layers import layer_norm, mlp, Linear from .rope import apply_rotary_emb, precompute_freqs_cis from .config import TextConfig @@ -26,6 +27,7 @@ def attn( head_dim = d_model // n_heads qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim) + q_dim = n_heads * head_dim kv_dim = n_kv_heads * head_dim @@ -139,6 +141,7 @@ def text_decoder( n_kv_heads=config.n_kv_heads, position_ids=position_ids, ) + l_mlp = mlp(l_in, block.mlp) x = x + l_attn + l_mlp @@ -158,32 +161,52 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module): return logits -def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: +def build_text_model( + config: TextConfig, + linear_dtype: torch.dtype = torch.float16, + layernorm_dtype: torch.dtype = torch.float16, +) -> ( + nn.Module +): # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads)) + operator_cache = None + cache_dir = None + group_size = None + if linear_dtype == torch.int8: + + operator_cache = OperatorCache() + cache_dir = config.cache_dir + group_size = config.group_size + + def create_linear(in_features, out_features, dtype=linear_dtype): + # factory function for creating Linear layers so we dont have to pass everything again and again + return Linear( + in_features=in_features, + out_features=out_features, + dtype=dtype, + operator_cache=operator_cache, + cache_dir=cache_dir, + group_size=group_size, + ) + text = nn.ModuleDict( { "blocks": nn.ModuleList( [ nn.ModuleDict( { - "ln": nn.LayerNorm(config.dim, dtype=dtype), + "ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype), "attn": nn.ModuleDict( { - "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype), - "proj": nn.Linear( - config.dim, config.dim, dtype=dtype - ), + "qkv": create_linear(config.dim, qkv_dim), + "proj": create_linear(config.dim, config.dim), } ), "mlp": nn.ModuleDict( { - "fc1": nn.Linear( - config.dim, config.ff_dim, dtype=dtype - ), - "fc2": nn.Linear( - config.ff_dim, config.dim, dtype=dtype - ), + "fc1": create_linear(config.dim, config.ff_dim), + "fc2": create_linear(config.ff_dim, config.dim), } ), } @@ -191,11 +214,13 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: for _ in range(config.n_layers) ] ), - "post_ln": nn.LayerNorm(config.dim, dtype=dtype), - "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype), + "post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype), + "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=layernorm_dtype), } ) - text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype)) + text.wte = nn.Parameter( + torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype) + ) text.register_buffer( "freqs_cis", precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context), diff --git a/moondream/torch/weights.py b/moondream/torch/weights.py index f9634b5c..e1193443 100644 --- a/moondream/torch/weights.py +++ b/moondream/torch/weights.py @@ -1,10 +1,26 @@ import safetensors import torch import torch.nn as nn +import re from contextlib import contextmanager from typing import Callable, List +from .text import build_text_model +from .config import TextConfig + + +# Our custom linear has an module named linear, so we add linear to the name +def add_linear_to_key(k: str) -> str: + k = k.replace("model.", "") + if k.startswith("text.") and ".linear." not in k: + k = re.sub( + r"(attn\.(?:qkv|proj)|mlp\.fc[12])\.(weight|bias)$", + r"\1.linear.\2", + k, + ) + return k + @contextmanager def safetensors_open(safetensors_file: str): @@ -27,12 +43,17 @@ def get_keys() -> List[str]: yield get_tensor -def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None: +def _load_weights( + get_tensor: Callable[[str], torch.Tensor], + model: nn.Module, + is_quantized: bool = False, +) -> None: """Internal function to load weights using a tensor getter function.""" model = model.to(dtype=torch.float16) vision = model.vision region = model.region + weight_map = { "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[ "patch_emb" @@ -90,23 +111,42 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - } ) - for i in range(len(model.text["blocks"])): - prefix = f"text_model.transformer.h.{i}" - blk = model.text["blocks"][i] - weight_map.update( - { - f"{prefix}.ln.weight": blk["ln"].weight, - f"{prefix}.ln.bias": blk["ln"].bias, - f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight, - f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias, - f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight, - f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias, - f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight, - f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, - f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight, - f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, - } - ) + if not is_quantized: + for i in range(len(model.text["blocks"])): + prefix = f"text_model.transformer.h.{i}" + blk = model.text["blocks"][i] + weight_map.update( + { + f"{prefix}.ln.weight": blk["ln"].weight, + f"{prefix}.ln.bias": blk["ln"].bias, + f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight, + f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias, + f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight, + f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias, + f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight, + f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, + f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight, + f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, + } + ) + else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight) + for i in range(len(model.text["blocks"])): + prefix = f"text_model.transformer.h.{i}" + blk = model.text["blocks"][i] + weight_map.update( + { + f"{prefix}.ln.qweight": blk["ln"].weight, + f"{prefix}.ln.bias": blk["ln"].bias, + f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight, + f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias, + f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight, + f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias, + f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight, + f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, + f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight, + f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, + } + ) for key, tensor in weight_map.items(): tensor.data.copy_(get_tensor(key)) @@ -120,35 +160,75 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: """Load weights from a safetensors file into a MoondreamModel instance.""" with safetensors_open(weights_file) as get_tensor: + all_keys = get_tensor.keys() + + is_quantized = any( + ".qweight" in key or "_quantized" in key or "quant." in key + for key in all_keys + ) + + if "text_model.transformer.h.0.ln.weight" in all_keys: + layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype + else: + layernorm_dtype = torch.float16 + + linear_dtype = torch.int8 if is_quantized else torch.float16 + + model.text = build_text_model( + TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype + ) + if model.setup_caches_flag: + model._setup_caches() + if ( - "vision.blocks.0.attn.proj.bias" in get_tensor.keys() - or "model.vision.blocks.0.attn.proj.bias" in get_tensor.keys() + "vision.blocks.0.attn.proj.bias" in all_keys + or "model.vision.blocks.0.attn.proj.bias" in all_keys ): with safetensors_open(weights_file) as get_tensor: - tensors = { - k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys() - } + tensors = {add_linear_to_key(k): get_tensor(k) for k in all_keys} model.load_state_dict(tensors, strict=False) else: # Wrap the get_tensor function to handle key normalization - name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()} + name_map = {k.replace("._orig_mod", ""): k for k in all_keys} _load_weights( - lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model + lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), + model, + is_quantized, ) def load_weights_from_pt(weights_file: str, model: nn.Module) -> None: """Load weights from a PyTorch file into a MoondreamModel instance.""" - device = str(torch.empty(0).device) - tensors = torch.load(weights_file, map_location=device, weights_only=True) - if "vision.blocks.0.attn.proj.bias" in tensors.keys(): + tensors = torch.load(weights_file, map_location="cpu", weights_only=True) + all_keys = tensors.keys() + is_quantized = any( + ".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys + ) + + if "text.blocks.0.ln.weight" in all_keys: + layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype + else: + layernorm_dtype = torch.float16 + + linear_dtype = torch.int8 if is_quantized else torch.float16 + model.text = build_text_model( + TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype + ) + if model.setup_caches_flag: + model._setup_caches() + + if ( + "vision.blocks.0.attn.proj.bias" in all_keys + or "model.vision.blocks.0.attn.proj.bias" in all_keys + ): + tensors = {add_linear_to_key(k): v for k, v in tensors.items()} model.load_state_dict(tensors, strict=False) else: tensors = { k.replace("._orig_mod", ""): v.to(dtype=torch.float16) for k, v in tensors.items() } - _load_weights(lambda x: tensors[x], model) + _load_weights(lambda x: tensors[x], model, is_quantized) def load_weights_into_model(weights_file: str, model: nn.Module) -> None: