From d493e10735874aee084f0c3bce05ee29936d381e Mon Sep 17 00:00:00 2001 From: snowclipsed Date: Mon, 19 May 2025 13:46:15 -0700 Subject: [PATCH 1/8] Add 4 bit inference capability --- .gitignore | 1 + moondream/torch/config.py | 4 +- moondream/torch/layers.py | 57 +++++++++- moondream/torch/moondream.py | 30 ++--- moondream/torch/sample.py | 21 ++-- moondream/torch/text.py | 59 +++++++--- moondream/torch/weights.py | 214 ++++++++++++++++++++++++++++------- 7 files changed, 302 insertions(+), 84 deletions(-) diff --git a/.gitignore b/.gitignore index c5bc1ee8..24e9c03d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ poetry.lock dist clients/python/moondream/torch wandb/ +bitblas_cache/ moondream_finetune.safetensors diff --git a/moondream/torch/config.py b/moondream/torch/config.py index fa4ae05c..01570010 100644 --- a/moondream/torch/config.py +++ b/moondream/torch/config.py @@ -12,6 +12,8 @@ class TextConfig: n_heads: int = 32 n_kv_heads: int = 32 prefix_attn: int = 730 + group_size: int = 128 + cache_dir: str = "./bitblas_cache" @dataclass(frozen=True) @@ -83,4 +85,4 @@ def to_dict(self): "vision": self.vision.__dict__, "region": self.region.__dict__, "tokenizer": self.tokenizer.__dict__, - } + } \ No newline at end of file diff --git a/moondream/torch/layers.py b/moondream/torch/layers.py index 4140c18f..9dfdaf56 100644 --- a/moondream/torch/layers.py +++ b/moondream/torch/layers.py @@ -1,8 +1,12 @@ from dataclasses import dataclass from typing import Literal +import bitblas +from bitblas.cache import OperatorCache + import torch from torch.nn import functional as F +import torch.nn as nn def gelu_approx(x): @@ -14,6 +18,56 @@ class LinearWeights: weight: torch.Tensor bias: torch.Tensor +class Linear(nn.Module): + """ + Linear layer with support for bitblas quantization. + If dtype is torch.int8, it uses bitblas for quantization. + Otherwise, it uses a standard nn.Linear layer. + """ + def __init__(self, in_features:int, out_features:int, bias: bool = True, + dtype:torch.dtype=None, operator_cache:OperatorCache=None, cache_dir:str=None, group_size:int=128): + super().__init__() + + if dtype == torch.int8: + self.linear = bitblas.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + with_zeros=True, + zeros_mode="original", + with_scaling=True, + A_dtype="float16", + W_dtype="uint4", + accum_dtype="float16", + out_dtype="float16", + fast_decoding=True, + enable_tuning=True, + operator_cache=operator_cache, + database_path=cache_dir, + group_size=group_size, + ) + else: + self.linear = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + dtype=torch.float16 + ) + def forward(self, x): + return self.linear(x) + + @property + def weight(self) -> torch.Tensor: + try: + return self.linear.weight + except AttributeError: + return self.linear.qweight + + + @property + def bias(self) -> torch.Tensor: + return self.linear.bias + def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: return F.linear(x, w.weight, w.bias) @@ -37,6 +91,7 @@ class MLPWeights: def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: + x = w.fc1(x) x = gelu_approx(x) x = w.fc2(x) @@ -60,4 +115,4 @@ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = linear(out, w.proj) - return out + return out \ No newline at end of file diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index ec9014ae..91d07811 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -66,12 +66,16 @@ class MoondreamModel(nn.Module): def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True): super().__init__() self.config = config + self.dtype = dtype + self.setup_caches_flag = setup_caches self.tokenizer = Tokenizer.from_pretrained( "vikhyatk/moondream2", revision="2025-01-09" ) + self.vision = build_vision_model(config.vision, dtype) - self.text = build_text_model(config.text, dtype) + + self.text = None # Region Model self.region = nn.ModuleDict( @@ -125,11 +129,11 @@ def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=Tr attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1 self.register_buffer("attn_mask", attn_mask, persistent=False) - # Initialize KV caches. - if setup_caches: - self._setup_caches() - def _setup_caches(self): + """Setup KV caches for the text model""" + if self.text is None: + return # Can't set up caches without text model + c = self.config.text for b in self.text.blocks: b.kv_cache = KVCache( @@ -163,15 +167,12 @@ def _decode_one_tok( def compile(self): # TODO: vision_projection is not being compiled - self._vis_enc = torch.compile(self._vis_enc, fullgraph=True) - self._prefill = torch.compile(self._prefill, fullgraph=True) - self._decode_one_tok = torch.compile( - self._decode_one_tok, fullgraph=True, mode="reduce-overhead" - ) + self._vis_enc = torch.compile(self._vis_enc, fullgraph=False, mode="reduce-overhead") + # self._prefill = torch.compile(self._prefill) + # self._decode_one_tok = torch.compile(self._decode_one_tok) def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor: all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device) - torch._dynamo.mark_dynamic(all_crops, 0) outputs = self._vis_enc(all_crops) @@ -201,6 +202,7 @@ def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage: # Run through text model in addition to the vision encoder, to minimize # re-computation if multiple queries are performed on this image. + with torch.inference_mode(): img_emb = self._run_vision_encoder(image) bos_emb = text_encoder( @@ -211,7 +213,7 @@ def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage: mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :] pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long) self._prefill(inputs_embeds, mask, pos_ids) - + return EncodedImage( pos=inputs_embeds.size(1), caches=[ @@ -235,11 +237,11 @@ def _apply_top_p(self, probs: torch.Tensor, top_p: float): def _prefill_prompt( self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float - ): + ): + with torch.inference_mode(): prompt_emb = text_encoder(prompt_tokens, self.text) torch._dynamo.mark_dynamic(prompt_emb, 1) - mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :] pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long) hidden = self._prefill(prompt_emb, mask, pos_ids) diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index 20c72be2..ab4026f2 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -5,25 +5,30 @@ from PIL import Image, ImageDraw from tqdm import tqdm +import logging +import bitblas +bitblas.logger.setLevel('FATAL') from .weights import load_weights_into_model from .moondream import MoondreamModel, MoondreamConfig +import time if __name__ == "__main__": + start = time.time() parser = argparse.ArgumentParser() parser.add_argument("--image", "-i", type=str, required=True) parser.add_argument("--prompt", "-p", type=str, required=True) parser.add_argument("--model", "-m", type=str, required=True) parser.add_argument("--config", "-c", type=str, default=None) - parser.add_argument("--max-tokens", "-t", type=int, default=200) + parser.add_argument("--max-tokens", "-t", type=int, default=100) parser.add_argument("--sampler", "-s", type=str, default="greedy") parser.add_argument("--benchmark", "-b", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): - device = "cuda" + torch.set_default_device("cuda") elif torch.backends.mps.is_available(): - device = "mps" + torch.set_default_device("mps") # Load model. if args.config is not None: @@ -32,9 +37,9 @@ config = MoondreamConfig.from_dict(config) else: config = MoondreamConfig() + model = MoondreamModel(config) load_weights_into_model(args.model, model) - model = model.to(device) # Encode image. image_path = args.image @@ -43,6 +48,8 @@ image = Image.open(image_path) if not args.benchmark: + + model.compile() encoded_image = model.encode_image(image) # Short caption @@ -96,7 +103,7 @@ # Detect gaze model.detect_gaze(encoded_image, (0.5, 0.5)) - elif model.device.type != "mps": + else: torch._dynamo.reset() model.compile() @@ -142,6 +149,4 @@ print("\nQuery Speed (tokens/sec):") print(f" Mean: {sum(query_speeds)/len(query_speeds):.2f}") print(f" Min: {min(query_speeds):.2f}") - print(f" Max: {max(query_speeds):.2f}") - else: - raise ValueError("To run benchmarks, make sure you are on a CUDA device") + print(f" Max: {max(query_speeds):.2f}") \ No newline at end of file diff --git a/moondream/torch/text.py b/moondream/torch/text.py index de75fb58..4fb31021 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -2,8 +2,9 @@ import torch.nn as nn from torch.nn import functional as F +from bitblas.cache import OperatorCache -from .layers import layer_norm, mlp +from .layers import layer_norm, mlp, Linear from .rope import apply_rotary_emb, precompute_freqs_cis from .config import TextConfig @@ -26,6 +27,7 @@ def attn( head_dim = d_model // n_heads qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim) + q_dim = n_heads * head_dim kv_dim = n_kv_heads * head_dim @@ -52,6 +54,7 @@ def attn( ) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = w.proj(out) + # print("out", out[:5, :5, :5]) return out @@ -139,6 +142,7 @@ def text_decoder( n_kv_heads=config.n_kv_heads, position_ids=position_ids, ) + l_mlp = mlp(l_in, block.mlp) x = x + l_attn + l_mlp @@ -158,32 +162,51 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module): return logits -def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: +def build_text_model(config: TextConfig, dtype: torch.dtype, group_size: int = 128) -> nn.Module: qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads)) + operator_cache = None + cache_dir = None + layernorm_dtype = torch.float16 + if dtype == torch.int8: + + print("INITIALIZING QUANTIZED MODEL") + operator_cache = OperatorCache() + cache_dir = config.cache_dir + + + def create_linear(in_features, out_features, dtype=dtype): + # factory function for creating Linear layers so we dont have to pass everything again and again + return Linear( + in_features=in_features, + out_features=out_features, + dtype=dtype, + operator_cache=operator_cache, + cache_dir=cache_dir, + group_size=group_size, + ) + + text = nn.ModuleDict( { "blocks": nn.ModuleList( [ nn.ModuleDict( { - "ln": nn.LayerNorm(config.dim, dtype=dtype), + "ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype), "attn": nn.ModuleDict( { - "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype), - "proj": nn.Linear( - config.dim, config.dim, dtype=dtype - ), + "qkv": create_linear(config.dim, qkv_dim), + "proj": create_linear( + config.dim, config.dim) } ), "mlp": nn.ModuleDict( { - "fc1": nn.Linear( - config.dim, config.ff_dim, dtype=dtype - ), - "fc2": nn.Linear( - config.ff_dim, config.dim, dtype=dtype - ), + "fc1": create_linear( + config.dim, config.ff_dim), + "fc2": create_linear( + config.ff_dim, config.dim) } ), } @@ -191,15 +214,17 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: for _ in range(config.n_layers) ] ), - "post_ln": nn.LayerNorm(config.dim, dtype=dtype), - "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype), + "post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype), + "lm_head": nn.Linear( + config.dim, config.vocab_size, dtype=layernorm_dtype + ), } ) - text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype)) + text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)) text.register_buffer( "freqs_cis", precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context), persistent=False, ) - return text + return text \ No newline at end of file diff --git a/moondream/torch/weights.py b/moondream/torch/weights.py index f9634b5c..b7bf662b 100644 --- a/moondream/torch/weights.py +++ b/moondream/torch/weights.py @@ -4,7 +4,11 @@ from contextlib import contextmanager from typing import Callable, List +from .text import build_text_model +# from .vision import build_vision_model # Not used +import gc +# import time # Not used @contextmanager def safetensors_open(safetensors_file: str): @@ -12,9 +16,10 @@ def safetensors_open(safetensors_file: str): Simplify interfacing with safetensors files. Eliminates the need to ignore type errors when using the `safe_open` function. """ + # Open the safetensors file for reading in PyTorch framework with safetensors.safe_open( safetensors_file, framework="pt" - ) as st: # pyright: ignore + ) as st: def get_tensor(name: str) -> torch.Tensor: return st.get_tensor(name) @@ -29,10 +34,12 @@ def get_keys() -> List[str]: def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None: """Internal function to load weights using a tensor getter function.""" - model = model.to(dtype=torch.float16) + + model.to(dtype=torch.float16) vision = model.vision region = model.region + # Define a mapping from expected weight names in the file to model parameters weight_map = { "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[ "patch_emb" @@ -70,6 +77,7 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - "region_model.size_decoder.fc2.bias": region["size_decoder"]["fc2"].bias, } + # Dynamically add weights for vision transformer blocks for i in range(len(model.vision["blocks"])): prefix = f"vision_encoder.encoder.model.visual.blocks.{i}" blk = model.vision["blocks"][i] @@ -90,6 +98,7 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - } ) + # Dynamically add weights for text transformer blocks for i in range(len(model.text["blocks"])): prefix = f"text_model.transformer.h.{i}" blk = model.text["blocks"][i] @@ -108,62 +117,181 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - } ) - for key, tensor in weight_map.items(): - tensor.data.copy_(get_tensor(key)) + # Copy data from loaded tensors to model parameters + for key, tensor_target in weight_map.items(): + source_tensor = get_tensor(key) # get_tensor is expected to provide fp16 tensors if model is fp16 + # because the lambda passed to _load_weights will do .to(fp16) + tensor_target.data.copy_(source_tensor) + + + # Special handling for transposed weights + coord_features_weight = get_tensor("region_model.coordinate_features.weight") + region.coord_features.data.copy_(coord_features_weight.T) - region.coord_features.data.copy_( - get_tensor("region_model.coordinate_features.weight").T - ) - region.size_features.data.copy_(get_tensor("region_model.size_features.weight").T) + size_features_weight = get_tensor("region_model.size_features.weight") + region.size_features.data.copy_(size_features_weight.T) def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: - """Load weights from a safetensors file into a MoondreamModel instance.""" + """Load weights from a safetensors file into the model instance, + with support for quantized models.""" + device = next(model.parameters()).device + print(f"Loading .safetensors file to CPU first (target device: {device})...") + with safetensors_open(weights_file) as get_tensor: - if ( - "vision.blocks.0.attn.proj.bias" in get_tensor.keys() - or "model.vision.blocks.0.attn.proj.bias" in get_tensor.keys() - ): - with safetensors_open(weights_file) as get_tensor: - tensors = { - k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys() - } - model.load_state_dict(tensors, strict=False) + all_keys = get_tensor.keys() + + # Detect if the model is quantized by inspecting tensor keys + is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in all_keys) + + if is_quantized: + print("Quantized model detected from safetensors file keys.") + else: + print("Non-quantized model detected from safetensors file keys.") + + if hasattr(model, 'text') and model.text is None: + # Determine the dtype for the text model based on quantization status + text_dtype = torch.int8 if is_quantized else torch.float16 + print(f"Building text model with dtype: {text_dtype} for safetensors loading.") + # Assuming build_text_model can handle the specified dtype (e.g., int8 for quantized) + model.text = build_text_model(model.config.text, dtype=text_dtype) + model.text.to(device) # Move the newly built text model to the target device + + if hasattr(model, 'setup_caches_flag') and model.setup_caches_flag: + model._setup_caches() + print("Caches set up for text model.") + + elif hasattr(model, 'text') and model.text is not None: + # Text model already exists, check if its dtype needs adjustment + current_text_dtype = next(model.text.parameters()).dtype + print(f"Text model already exists. Current text model dtype: {current_text_dtype}") + if not is_quantized and current_text_dtype != torch.float16: + # If loading non-quantized weights and existing text model is not fp16, convert it + print(f"Converting existing text model to float16 for non-quantized safetensors.") + model.text.to(dtype=torch.float16) + + + # --- Weight loading logic --- + if is_quantized: + print("Loading state_dict for quantized safetensors model.") + + tensors_processed = {} + for k in all_keys: + cleaned_key = k + if cleaned_key.startswith("model."): + cleaned_key = cleaned_key[len("model."):] + + if "._orig_mod" in cleaned_key: # A more careful replacement might be needed + cleaned_key = cleaned_key.replace("._orig_mod", "") + + tensor = get_tensor(k).to(device) # Move tensor to target device + tensors_processed[cleaned_key] = tensor + + model.load_state_dict(tensors_processed, strict=False) else: - # Wrap the get_tensor function to handle key normalization - name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()} - _load_weights( - lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model - ) + + is_direct_load_style = False + + if any(key_name in all_keys for key_name in ["vision.blocks.0.attn.proj.bias", "model.vision.blocks.0.attn.proj.bias"]): + is_direct_load_style = True + + if is_direct_load_style: + print("Using load_state_dict for non-quantized safetensors (direct state_dict loading path).") + model.to(dtype=torch.float16, device=device) + + tensors_processed = {} + for k in all_keys: + cleaned_key = k + if cleaned_key.startswith("model."): + cleaned_key = cleaned_key[len("model."):] + + tensor = get_tensor(k).to(dtype=torch.float16, device=device) + tensors_processed[cleaned_key] = tensor + model.load_state_dict(tensors_processed, strict=False) + else: + print("Using _load_weights for non-quantized safetensors (custom mapping path).") + + name_map = {key.replace("._orig_mod", ""): key for key in all_keys} + + def get_tensor_for_load_weights(name: str) -> torch.Tensor: + source_key = name_map.get(name) + if source_key is None: + # This case should ideally not happen if _load_weights uses correct keys its a fallback + raise KeyError(f"Key {name} not found in safetensor after mapping. Available mapped keys: {list(name_map.keys())}") + return get_tensor(source_key).to(dtype=torch.float16, device=device) + + _load_weights(get_tensor_for_load_weights, model) + + model.to(device) + print("✓ Successfully loaded weights from safetensors file!") def load_weights_from_pt(weights_file: str, model: nn.Module) -> None: - """Load weights from a PyTorch file into a MoondreamModel instance.""" - device = str(torch.empty(0).device) - tensors = torch.load(weights_file, map_location=device, weights_only=True) - if "vision.blocks.0.attn.proj.bias" in tensors.keys(): - model.load_state_dict(tensors, strict=False) - else: - tensors = { + """Load weights from a PyTorch (.pt) file into the model instance.""" + device = next(model.parameters()).device + print(f"Loading .pt file to CPU first to conserve GPU memory (target device: {device})...") + + state_dict_on_cpu = torch.load(weights_file, map_location='cpu', weights_only=True) + + is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in state_dict_on_cpu.keys()) + if is_quantized: + print("Quantized model detected from .pt file keys.") + + if hasattr(model, 'text') and model.text is None: + text_dtype = torch.int8 if is_quantized else torch.float16 + print(f"Building text model with dtype: {text_dtype} for .pt loading.") + model.text = build_text_model(model.config.text, dtype=text_dtype) + model.text.to(device) + + if hasattr(model, 'setup_caches_flag') and model.setup_caches_flag: + model._setup_caches() + print("Caches set up for text model.") + elif hasattr(model, 'text') and model.text is not None: + current_text_dtype = next(model.text.parameters()).dtype + print(f"Text model already exists. Current text model dtype: {current_text_dtype}") + if not is_quantized and current_text_dtype != torch.float16: + print(f"Converting existing text model to float16.") + model.text.to(dtype=torch.float16) + + if not is_quantized: + print( + "Model is not quantized. Loading weights from PyTorch file using _load_weights. This may take a while, please be patient." + ) + + processed_tensors = { k.replace("._orig_mod", ""): v.to(dtype=torch.float16) - for k, v in tensors.items() + for k, v in state_dict_on_cpu.items() } - _load_weights(lambda x: tensors[x], model) + _load_weights(lambda x: processed_tensors[x], model) + del processed_tensors # Clean up intermediate dictionary + else: # Quantized path + print("Loading state_dict (from CPU) into model (on device) for quantized model...") -def load_weights_into_model(weights_file: str, model: nn.Module) -> None: - """ - Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance. + model.load_state_dict(state_dict_on_cpu, strict=False) + + del state_dict_on_cpu + gc.collect() + print("Cleaned up original CPU state_dict from .pt loading.") - Args: - weights_file: Path to weights file (either .safetensors or .pt) - model: MoondreamModel instance to load weights into + +def load_weights_into_model(weights_file: str, model: nn.Module) -> nn.Module: + """ + Main function to load weights into a model. + Determines file type and calls the appropriate loading function. """ - if weights_file.endswith(".safetensors"): + device = next(iter(model.parameters())).device + print(f"Starting weight loading process for model on {device}...") + + if weights_file.endswith('.pt'): + load_weights_from_pt(weights_file, model) + elif weights_file.endswith('.safetensors'): + print("Loading .safetensors file...") load_weights_from_safetensors(weights_file, model) else: - load_weights_from_pt(weights_file, model) + print(f"Unsupported weights file format: {weights_file}. Please use .pt or .safetensors.") + return model + + print("✓✓ Overall weight loading process complete!") + return model - # Make all parameters contiguous - for param in model.parameters(): - param.data = param.data.contiguous() From e341dd9c38863a39002cc883f25d7b4421dce67e Mon Sep 17 00:00:00 2001 From: snowclipsed Date: Mon, 19 May 2025 18:37:24 -0700 Subject: [PATCH 2/8] update weights.py --- moondream/torch/weights.py | 288 +++++++++++++------------------------ 1 file changed, 102 insertions(+), 186 deletions(-) diff --git a/moondream/torch/weights.py b/moondream/torch/weights.py index b7bf662b..d32a3ec6 100644 --- a/moondream/torch/weights.py +++ b/moondream/torch/weights.py @@ -4,11 +4,9 @@ from contextlib import contextmanager from typing import Callable, List -from .text import build_text_model -# from .vision import build_vision_model # Not used -import gc -# import time # Not used +from .text import build_text_model +from .config import TextConfig @contextmanager def safetensors_open(safetensors_file: str): @@ -16,10 +14,9 @@ def safetensors_open(safetensors_file: str): Simplify interfacing with safetensors files. Eliminates the need to ignore type errors when using the `safe_open` function. """ - # Open the safetensors file for reading in PyTorch framework with safetensors.safe_open( safetensors_file, framework="pt" - ) as st: + ) as st: # pyright: ignore def get_tensor(name: str) -> torch.Tensor: return st.get_tensor(name) @@ -32,14 +29,13 @@ def get_keys() -> List[str]: yield get_tensor -def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None: +def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module, is_quantized:bool=False) -> None: """Internal function to load weights using a tensor getter function.""" - - model.to(dtype=torch.float16) + model = model.to(dtype=torch.float16) vision = model.vision region = model.region - # Define a mapping from expected weight names in the file to model parameters + weight_map = { "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[ "patch_emb" @@ -77,7 +73,6 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - "region_model.size_decoder.fc2.bias": region["size_decoder"]["fc2"].bias, } - # Dynamically add weights for vision transformer blocks for i in range(len(model.vision["blocks"])): prefix = f"vision_encoder.encoder.model.visual.blocks.{i}" blk = model.vision["blocks"][i] @@ -98,200 +93,121 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - } ) - # Dynamically add weights for text transformer blocks - for i in range(len(model.text["blocks"])): - prefix = f"text_model.transformer.h.{i}" - blk = model.text["blocks"][i] - weight_map.update( - { - f"{prefix}.ln.weight": blk["ln"].weight, - f"{prefix}.ln.bias": blk["ln"].bias, - f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight, - f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias, - f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight, - f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias, - f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight, - f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, - f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight, - f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, - } - ) - - # Copy data from loaded tensors to model parameters - for key, tensor_target in weight_map.items(): - source_tensor = get_tensor(key) # get_tensor is expected to provide fp16 tensors if model is fp16 - # because the lambda passed to _load_weights will do .to(fp16) - tensor_target.data.copy_(source_tensor) - - # Special handling for transposed weights - coord_features_weight = get_tensor("region_model.coordinate_features.weight") - region.coord_features.data.copy_(coord_features_weight.T) - - size_features_weight = get_tensor("region_model.size_features.weight") - region.size_features.data.copy_(size_features_weight.T) + if not is_quantized: + for i in range(len(model.text["blocks"])): + prefix = f"text_model.transformer.h.{i}" + blk = model.text["blocks"][i] + weight_map.update( + { + f"{prefix}.ln.weight": blk["ln"].weight, + f"{prefix}.ln.bias": blk["ln"].bias, + f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight, + f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias, + f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight, + f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias, + f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight, + f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, + f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight, + f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, + } + ) + else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight) + for i in range(len(model.text["blocks"])): + prefix = f"text_model.transformer.h.{i}" + blk = model.text["blocks"][i] + weight_map.update( + { + f"{prefix}.ln.qweight": blk["ln"].weight, + f"{prefix}.ln.bias": blk["ln"].bias, + f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight, + f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias, + f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight, + f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias, + f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight, + f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, + f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight, + f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, + } + ) + + + for key, tensor in weight_map.items(): + tensor.data.copy_(get_tensor(key)) + + region.coord_features.data.copy_( + get_tensor("region_model.coordinate_features.weight").T + ) + region.size_features.data.copy_(get_tensor("region_model.size_features.weight").T) def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: - """Load weights from a safetensors file into the model instance, - with support for quantized models.""" - device = next(model.parameters()).device - print(f"Loading .safetensors file to CPU first (target device: {device})...") - + """Load weights from a safetensors file into a MoondreamModel instance.""" with safetensors_open(weights_file) as get_tensor: all_keys = get_tensor.keys() - - # Detect if the model is quantized by inspecting tensor keys + is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in all_keys) - if is_quantized: - print("Quantized model detected from safetensors file keys.") - else: - print("Non-quantized model detected from safetensors file keys.") - - if hasattr(model, 'text') and model.text is None: - # Determine the dtype for the text model based on quantization status - text_dtype = torch.int8 if is_quantized else torch.float16 - print(f"Building text model with dtype: {text_dtype} for safetensors loading.") - # Assuming build_text_model can handle the specified dtype (e.g., int8 for quantized) - model.text = build_text_model(model.config.text, dtype=text_dtype) - model.text.to(device) # Move the newly built text model to the target device - - if hasattr(model, 'setup_caches_flag') and model.setup_caches_flag: - model._setup_caches() - print("Caches set up for text model.") - - elif hasattr(model, 'text') and model.text is not None: - # Text model already exists, check if its dtype needs adjustment - current_text_dtype = next(model.text.parameters()).dtype - print(f"Text model already exists. Current text model dtype: {current_text_dtype}") - if not is_quantized and current_text_dtype != torch.float16: - # If loading non-quantized weights and existing text model is not fp16, convert it - print(f"Converting existing text model to float16 for non-quantized safetensors.") - model.text.to(dtype=torch.float16) - - - # --- Weight loading logic --- - if is_quantized: - print("Loading state_dict for quantized safetensors model.") - - tensors_processed = {} - for k in all_keys: - cleaned_key = k - if cleaned_key.startswith("model."): - cleaned_key = cleaned_key[len("model."):] + text_dtype = torch.int8 if is_quantized else torch.float16 + model.text = build_text_model( + TextConfig, text_dtype + ) + if model.setup_caches_flag: + model._setup_caches() - if "._orig_mod" in cleaned_key: # A more careful replacement might be needed - cleaned_key = cleaned_key.replace("._orig_mod", "") - - tensor = get_tensor(k).to(device) # Move tensor to target device - tensors_processed[cleaned_key] = tensor - - model.load_state_dict(tensors_processed, strict=False) + if ( + "vision.blocks.0.attn.proj.bias" in all_keys + or "model.vision.blocks.0.attn.proj.bias" in all_keys + ): + with safetensors_open(weights_file) as get_tensor: + tensors = { + k.replace("model.", ""): get_tensor(k) for k in all_keys + } + model.load_state_dict(tensors, strict=False) else: - - is_direct_load_style = False - - if any(key_name in all_keys for key_name in ["vision.blocks.0.attn.proj.bias", "model.vision.blocks.0.attn.proj.bias"]): - is_direct_load_style = True - - if is_direct_load_style: - print("Using load_state_dict for non-quantized safetensors (direct state_dict loading path).") - model.to(dtype=torch.float16, device=device) - - tensors_processed = {} - for k in all_keys: - cleaned_key = k - if cleaned_key.startswith("model."): - cleaned_key = cleaned_key[len("model."):] - - tensor = get_tensor(k).to(dtype=torch.float16, device=device) - tensors_processed[cleaned_key] = tensor - model.load_state_dict(tensors_processed, strict=False) - else: - print("Using _load_weights for non-quantized safetensors (custom mapping path).") - - name_map = {key.replace("._orig_mod", ""): key for key in all_keys} - - def get_tensor_for_load_weights(name: str) -> torch.Tensor: - source_key = name_map.get(name) - if source_key is None: - # This case should ideally not happen if _load_weights uses correct keys its a fallback - raise KeyError(f"Key {name} not found in safetensor after mapping. Available mapped keys: {list(name_map.keys())}") - return get_tensor(source_key).to(dtype=torch.float16, device=device) - - _load_weights(get_tensor_for_load_weights, model) - - model.to(device) - print("✓ Successfully loaded weights from safetensors file!") + # Wrap the get_tensor function to handle key normalization + name_map = {k.replace("._orig_mod", ""): k for k in all_keys} + _load_weights( + lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model, is_quantized + ) def load_weights_from_pt(weights_file: str, model: nn.Module) -> None: - """Load weights from a PyTorch (.pt) file into the model instance.""" - device = next(model.parameters()).device - print(f"Loading .pt file to CPU first to conserve GPU memory (target device: {device})...") - - state_dict_on_cpu = torch.load(weights_file, map_location='cpu', weights_only=True) - - is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in state_dict_on_cpu.keys()) - if is_quantized: - print("Quantized model detected from .pt file keys.") - - if hasattr(model, 'text') and model.text is None: - text_dtype = torch.int8 if is_quantized else torch.float16 - print(f"Building text model with dtype: {text_dtype} for .pt loading.") - model.text = build_text_model(model.config.text, dtype=text_dtype) - model.text.to(device) - - if hasattr(model, 'setup_caches_flag') and model.setup_caches_flag: - model._setup_caches() - print("Caches set up for text model.") - elif hasattr(model, 'text') and model.text is not None: - current_text_dtype = next(model.text.parameters()).dtype - print(f"Text model already exists. Current text model dtype: {current_text_dtype}") - if not is_quantized and current_text_dtype != torch.float16: - print(f"Converting existing text model to float16.") - model.text.to(dtype=torch.float16) - - if not is_quantized: - print( - "Model is not quantized. Loading weights from PyTorch file using _load_weights. This may take a while, please be patient." - ) - - processed_tensors = { + """Load weights from a PyTorch file into a MoondreamModel instance.""" + device = str(torch.empty(0).device) + tensors = torch.load(weights_file, map_location='cpu', weights_only=True) + is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in tensors.keys()) + + text_dtype = torch.int8 if is_quantized else torch.float16 + model.text = build_text_model( + TextConfig, text_dtype + ) + if model.setup_caches_flag: + model._setup_caches() + + if "vision.blocks.0.attn.proj.bias" in tensors.keys(): + model.load_state_dict(tensors, strict=False) + else: + tensors = { k.replace("._orig_mod", ""): v.to(dtype=torch.float16) - for k, v in state_dict_on_cpu.items() + for k, v in tensors.items() } - _load_weights(lambda x: processed_tensors[x], model) - del processed_tensors # Clean up intermediate dictionary - - else: # Quantized path - print("Loading state_dict (from CPU) into model (on device) for quantized model...") - - model.load_state_dict(state_dict_on_cpu, strict=False) + _load_weights(lambda x: tensors[x], model, is_quantized) - del state_dict_on_cpu - gc.collect() - print("Cleaned up original CPU state_dict from .pt loading.") - -def load_weights_into_model(weights_file: str, model: nn.Module) -> nn.Module: - """ - Main function to load weights into a model. - Determines file type and calls the appropriate loading function. +def load_weights_into_model(weights_file: str, model: nn.Module) -> None: """ - device = next(iter(model.parameters())).device - print(f"Starting weight loading process for model on {device}...") + Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance. - if weights_file.endswith('.pt'): - load_weights_from_pt(weights_file, model) - elif weights_file.endswith('.safetensors'): - print("Loading .safetensors file...") + Args: + weights_file: Path to weights file (either .safetensors or .pt) + model: MoondreamModel instance to load weights into + """ + if weights_file.endswith(".safetensors"): load_weights_from_safetensors(weights_file, model) else: - print(f"Unsupported weights file format: {weights_file}. Please use .pt or .safetensors.") - return model - - print("✓✓ Overall weight loading process complete!") - return model + load_weights_from_pt(weights_file, model) + # Make all parameters contiguous + for param in model.parameters(): + param.data = param.data.contiguous() \ No newline at end of file From 1c7f9e9ee3f8fa16589d455ad5b16613c0232345 Mon Sep 17 00:00:00 2001 From: snowclipsed Date: Mon, 19 May 2025 18:37:54 -0700 Subject: [PATCH 3/8] clean sample.py and text.py --- moondream/torch/sample.py | 12 +++++++----- moondream/torch/text.py | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index ab4026f2..40440f2d 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -26,9 +26,9 @@ args = parser.parse_args() if torch.cuda.is_available(): - torch.set_default_device("cuda") + device = "cuda" elif torch.backends.mps.is_available(): - torch.set_default_device("mps") + device = "mps" # Load model. if args.config is not None: @@ -46,10 +46,11 @@ if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found at {image_path}") image = Image.open(image_path) + model = model.to(device) if not args.benchmark: - model.compile() + # model.compile() encoded_image = model.encode_image(image) # Short caption @@ -103,7 +104,7 @@ # Detect gaze model.detect_gaze(encoded_image, (0.5, 0.5)) - else: + elif model.device.type != "mps": torch._dynamo.reset() model.compile() @@ -149,4 +150,5 @@ print("\nQuery Speed (tokens/sec):") print(f" Mean: {sum(query_speeds)/len(query_speeds):.2f}") print(f" Min: {min(query_speeds):.2f}") - print(f" Max: {max(query_speeds):.2f}") \ No newline at end of file + else: + raise ValueError("To run benchmarks, make sure you are on a CUDA device") \ No newline at end of file diff --git a/moondream/torch/text.py b/moondream/torch/text.py index 4fb31021..aadf5665 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -54,7 +54,6 @@ def attn( ) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = w.proj(out) - # print("out", out[:5, :5, :5]) return out @@ -162,17 +161,19 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module): return logits -def build_text_model(config: TextConfig, dtype: torch.dtype, group_size: int = 128) -> nn.Module: +def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads)) operator_cache = None cache_dir = None + group_size = None layernorm_dtype = torch.float16 if dtype == torch.int8: print("INITIALIZING QUANTIZED MODEL") operator_cache = OperatorCache() cache_dir = config.cache_dir + group_size = config.group_size def create_linear(in_features, out_features, dtype=dtype): From dc0251cd9e63641966adec553c66d52806439c03 Mon Sep 17 00:00:00 2001 From: snowclipsed Date: Mon, 19 May 2025 19:05:59 -0700 Subject: [PATCH 4/8] allow setting layernorm and linear dtypes independently and dynamically --- moondream/torch/text.py | 7 +++---- moondream/torch/weights.py | 26 +++++++++++++++++++------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/moondream/torch/text.py b/moondream/torch/text.py index aadf5665..1787aa32 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -161,14 +161,13 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module): return logits -def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: +def build_text_model(config: TextConfig, linear_dtype: torch.dtype = torch.float16, layernorm_dtype:torch.dtype = torch.float16) -> nn.Module: # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads)) operator_cache = None cache_dir = None group_size = None - layernorm_dtype = torch.float16 - if dtype == torch.int8: + if linear_dtype == torch.int8: print("INITIALIZING QUANTIZED MODEL") operator_cache = OperatorCache() @@ -176,7 +175,7 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: group_size = config.group_size - def create_linear(in_features, out_features, dtype=dtype): + def create_linear(in_features, out_features, dtype=linear_dtype): # factory function for creating Linear layers so we dont have to pass everything again and again return Linear( in_features=in_features, diff --git a/moondream/torch/weights.py b/moondream/torch/weights.py index d32a3ec6..b565d6db 100644 --- a/moondream/torch/weights.py +++ b/moondream/torch/weights.py @@ -148,9 +148,16 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in all_keys) - text_dtype = torch.int8 if is_quantized else torch.float16 + + if "text_model.transformer.h.0.ln.weight" in all_keys: + layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype + else: + layernorm_dtype = torch.float16 + + linear_dtype = torch.int8 if is_quantized else torch.float16 + model.text = build_text_model( - TextConfig, text_dtype + TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype ) if model.setup_caches_flag: model._setup_caches() @@ -174,18 +181,23 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: def load_weights_from_pt(weights_file: str, model: nn.Module) -> None: """Load weights from a PyTorch file into a MoondreamModel instance.""" - device = str(torch.empty(0).device) tensors = torch.load(weights_file, map_location='cpu', weights_only=True) - is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in tensors.keys()) + all_keys = tensors.keys() + is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in all_keys) + + if "text.blocks.0.ln.weight" in all_keys: + layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype + else: + layernorm_dtype = torch.float16 - text_dtype = torch.int8 if is_quantized else torch.float16 + linear_dtype = torch.int8 if is_quantized else torch.float16 model.text = build_text_model( - TextConfig, text_dtype + TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype ) if model.setup_caches_flag: model._setup_caches() - if "vision.blocks.0.attn.proj.bias" in tensors.keys(): + if "vision.blocks.0.attn.proj.bias" in all_keys: model.load_state_dict(tensors, strict=False) else: tensors = { From 6b3caeb4ff69ef51b86cd8b6bf6bc374e788a3ea Mon Sep 17 00:00:00 2001 From: snowclipsed Date: Mon, 19 May 2025 19:58:52 -0700 Subject: [PATCH 5/8] remove unnecessary debug statement --- moondream/torch/text.py | 1 - 1 file changed, 1 deletion(-) diff --git a/moondream/torch/text.py b/moondream/torch/text.py index 1787aa32..c966043d 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -169,7 +169,6 @@ def build_text_model(config: TextConfig, linear_dtype: torch.dtype = torch.float group_size = None if linear_dtype == torch.int8: - print("INITIALIZING QUANTIZED MODEL") operator_cache = OperatorCache() cache_dir = config.cache_dir group_size = config.group_size From 2eb349bebc4af6d61c8ef43e1160a5514a1a22d1 Mon Sep 17 00:00:00 2001 From: snowclipsed Date: Mon, 19 May 2025 22:15:41 -0700 Subject: [PATCH 6/8] run through black formatter --- moondream/torch/config.py | 2 +- moondream/torch/layers.py | 26 +++++++++++++++++-------- moondream/torch/moondream.py | 12 +++++++----- moondream/torch/sample.py | 9 +++++---- moondream/torch/text.py | 31 +++++++++++++++--------------- moondream/torch/weights.py | 37 +++++++++++++++++++++--------------- 6 files changed, 69 insertions(+), 48 deletions(-) diff --git a/moondream/torch/config.py b/moondream/torch/config.py index 01570010..e7d18ce7 100644 --- a/moondream/torch/config.py +++ b/moondream/torch/config.py @@ -85,4 +85,4 @@ def to_dict(self): "vision": self.vision.__dict__, "region": self.region.__dict__, "tokenizer": self.tokenizer.__dict__, - } \ No newline at end of file + } diff --git a/moondream/torch/layers.py b/moondream/torch/layers.py index 9dfdaf56..6949c7ca 100644 --- a/moondream/torch/layers.py +++ b/moondream/torch/layers.py @@ -18,16 +18,26 @@ class LinearWeights: weight: torch.Tensor bias: torch.Tensor + class Linear(nn.Module): """ Linear layer with support for bitblas quantization. If dtype is torch.int8, it uses bitblas for quantization. Otherwise, it uses a standard nn.Linear layer. """ - def __init__(self, in_features:int, out_features:int, bias: bool = True, - dtype:torch.dtype=None, operator_cache:OperatorCache=None, cache_dir:str=None, group_size:int=128): + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + operator_cache: OperatorCache = None, + cache_dir: str = None, + group_size: int = 128, + ): super().__init__() - + if dtype == torch.int8: self.linear = bitblas.Linear( in_features=in_features, @@ -51,11 +61,12 @@ def __init__(self, in_features:int, out_features:int, bias: bool = True, in_features=in_features, out_features=out_features, bias=bias, - dtype=torch.float16 - ) + dtype=torch.float16, + ) + def forward(self, x): return self.linear(x) - + @property def weight(self) -> torch.Tensor: try: @@ -63,7 +74,6 @@ def weight(self) -> torch.Tensor: except AttributeError: return self.linear.qweight - @property def bias(self) -> torch.Tensor: return self.linear.bias @@ -115,4 +125,4 @@ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = linear(out, w.proj) - return out \ No newline at end of file + return out diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index 91d07811..1190642a 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -133,7 +133,7 @@ def _setup_caches(self): """Setup KV caches for the text model""" if self.text is None: return # Can't set up caches without text model - + c = self.config.text for b in self.text.blocks: b.kv_cache = KVCache( @@ -167,7 +167,9 @@ def _decode_one_tok( def compile(self): # TODO: vision_projection is not being compiled - self._vis_enc = torch.compile(self._vis_enc, fullgraph=False, mode="reduce-overhead") + self._vis_enc = torch.compile( + self._vis_enc, fullgraph=False, mode="reduce-overhead" + ) # self._prefill = torch.compile(self._prefill) # self._decode_one_tok = torch.compile(self._decode_one_tok) @@ -213,7 +215,7 @@ def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage: mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :] pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long) self._prefill(inputs_embeds, mask, pos_ids) - + return EncodedImage( pos=inputs_embeds.size(1), caches=[ @@ -237,8 +239,8 @@ def _apply_top_p(self, probs: torch.Tensor, top_p: float): def _prefill_prompt( self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float - ): - + ): + with torch.inference_mode(): prompt_emb = text_encoder(prompt_tokens, self.text) torch._dynamo.mark_dynamic(prompt_emb, 1) diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index 40440f2d..cd789bd7 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -7,7 +7,8 @@ from tqdm import tqdm import logging import bitblas -bitblas.logger.setLevel('FATAL') + +bitblas.logger.setLevel("INFO") from .weights import load_weights_into_model from .moondream import MoondreamModel, MoondreamConfig @@ -37,7 +38,7 @@ config = MoondreamConfig.from_dict(config) else: config = MoondreamConfig() - + model = MoondreamModel(config) load_weights_into_model(args.model, model) @@ -150,5 +151,5 @@ print("\nQuery Speed (tokens/sec):") print(f" Mean: {sum(query_speeds)/len(query_speeds):.2f}") print(f" Min: {min(query_speeds):.2f}") - else: - raise ValueError("To run benchmarks, make sure you are on a CUDA device") \ No newline at end of file + else: + raise ValueError("To run benchmarks, make sure you are on a CUDA device") diff --git a/moondream/torch/text.py b/moondream/torch/text.py index c966043d..ee97a1ba 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -161,7 +161,13 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module): return logits -def build_text_model(config: TextConfig, linear_dtype: torch.dtype = torch.float16, layernorm_dtype:torch.dtype = torch.float16) -> nn.Module: # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm +def build_text_model( + config: TextConfig, + linear_dtype: torch.dtype = torch.float16, + layernorm_dtype: torch.dtype = torch.float16, +) -> ( + nn.Module +): # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads)) operator_cache = None @@ -173,9 +179,8 @@ def build_text_model(config: TextConfig, linear_dtype: torch.dtype = torch.float cache_dir = config.cache_dir group_size = config.group_size - def create_linear(in_features, out_features, dtype=linear_dtype): - # factory function for creating Linear layers so we dont have to pass everything again and again + # factory function for creating Linear layers so we dont have to pass everything again and again return Linear( in_features=in_features, out_features=out_features, @@ -184,7 +189,6 @@ def create_linear(in_features, out_features, dtype=linear_dtype): cache_dir=cache_dir, group_size=group_size, ) - text = nn.ModuleDict( { @@ -196,16 +200,13 @@ def create_linear(in_features, out_features, dtype=linear_dtype): "attn": nn.ModuleDict( { "qkv": create_linear(config.dim, qkv_dim), - "proj": create_linear( - config.dim, config.dim) + "proj": create_linear(config.dim, config.dim), } ), "mlp": nn.ModuleDict( { - "fc1": create_linear( - config.dim, config.ff_dim), - "fc2": create_linear( - config.ff_dim, config.dim) + "fc1": create_linear(config.dim, config.ff_dim), + "fc2": create_linear(config.ff_dim, config.dim), } ), } @@ -214,16 +215,16 @@ def create_linear(in_features, out_features, dtype=linear_dtype): ] ), "post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype), - "lm_head": nn.Linear( - config.dim, config.vocab_size, dtype=layernorm_dtype - ), + "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=layernorm_dtype), } ) - text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)) + text.wte = nn.Parameter( + torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype) + ) text.register_buffer( "freqs_cis", precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context), persistent=False, ) - return text \ No newline at end of file + return text diff --git a/moondream/torch/weights.py b/moondream/torch/weights.py index b565d6db..5f21565c 100644 --- a/moondream/torch/weights.py +++ b/moondream/torch/weights.py @@ -8,6 +8,7 @@ from .text import build_text_model from .config import TextConfig + @contextmanager def safetensors_open(safetensors_file: str): """ @@ -29,7 +30,11 @@ def get_keys() -> List[str]: yield get_tensor -def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module, is_quantized:bool=False) -> None: +def _load_weights( + get_tensor: Callable[[str], torch.Tensor], + model: nn.Module, + is_quantized: bool = False, +) -> None: """Internal function to load weights using a tensor getter function.""" model = model.to(dtype=torch.float16) @@ -93,8 +98,7 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module, i } ) - - if not is_quantized: + if not is_quantized: for i in range(len(model.text["blocks"])): prefix = f"text_model.transformer.h.{i}" blk = model.text["blocks"][i] @@ -112,7 +116,7 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module, i f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, } ) - else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight) + else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight) for i in range(len(model.text["blocks"])): prefix = f"text_model.transformer.h.{i}" blk = model.text["blocks"][i] @@ -131,7 +135,6 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module, i } ) - for key, tensor in weight_map.items(): tensor.data.copy_(get_tensor(key)) @@ -145,10 +148,12 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: """Load weights from a safetensors file into a MoondreamModel instance.""" with safetensors_open(weights_file) as get_tensor: all_keys = get_tensor.keys() - - is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in all_keys) - + is_quantized = any( + ".qweight" in key or "_quantized" in key or "quant." in key + for key in all_keys + ) + if "text_model.transformer.h.0.ln.weight" in all_keys: layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype else: @@ -167,23 +172,25 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: or "model.vision.blocks.0.attn.proj.bias" in all_keys ): with safetensors_open(weights_file) as get_tensor: - tensors = { - k.replace("model.", ""): get_tensor(k) for k in all_keys - } + tensors = {k.replace("model.", ""): get_tensor(k) for k in all_keys} model.load_state_dict(tensors, strict=False) else: # Wrap the get_tensor function to handle key normalization name_map = {k.replace("._orig_mod", ""): k for k in all_keys} _load_weights( - lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model, is_quantized + lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), + model, + is_quantized, ) def load_weights_from_pt(weights_file: str, model: nn.Module) -> None: """Load weights from a PyTorch file into a MoondreamModel instance.""" - tensors = torch.load(weights_file, map_location='cpu', weights_only=True) + tensors = torch.load(weights_file, map_location="cpu", weights_only=True) all_keys = tensors.keys() - is_quantized = any('.qweight' in key or '_quantized' in key or 'quant.' in key for key in all_keys) + is_quantized = any( + ".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys + ) if "text.blocks.0.ln.weight" in all_keys: layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype @@ -222,4 +229,4 @@ def load_weights_into_model(weights_file: str, model: nn.Module) -> None: # Make all parameters contiguous for param in model.parameters(): - param.data = param.data.contiguous() \ No newline at end of file + param.data = param.data.contiguous() From 78d2b88e9a0f47b8359a4e02c5a3ce2abe3bb137 Mon Sep 17 00:00:00 2001 From: snowclipsed Date: Mon, 19 May 2025 22:32:48 -0700 Subject: [PATCH 7/8] changed logger level and max tokens default --- moondream/torch/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index cd789bd7..8c5189af 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -8,7 +8,7 @@ import logging import bitblas -bitblas.logger.setLevel("INFO") +bitblas.logger.setLevel("FATAL") from .weights import load_weights_into_model from .moondream import MoondreamModel, MoondreamConfig @@ -21,7 +21,7 @@ parser.add_argument("--prompt", "-p", type=str, required=True) parser.add_argument("--model", "-m", type=str, required=True) parser.add_argument("--config", "-c", type=str, default=None) - parser.add_argument("--max-tokens", "-t", type=int, default=100) + parser.add_argument("--max-tokens", "-t", type=int, default=200) parser.add_argument("--sampler", "-s", type=str, default="greedy") parser.add_argument("--benchmark", "-b", action="store_true") args = parser.parse_args() From bd9740d94723f463e5112701b851160bda45bb6a Mon Sep 17 00:00:00 2001 From: EthanReid Date: Tue, 20 May 2025 09:55:42 +0000 Subject: [PATCH 8/8] Fixed fp16 naming --- moondream/torch/weights.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/moondream/torch/weights.py b/moondream/torch/weights.py index 5f21565c..e1193443 100644 --- a/moondream/torch/weights.py +++ b/moondream/torch/weights.py @@ -1,6 +1,7 @@ import safetensors import torch import torch.nn as nn +import re from contextlib import contextmanager from typing import Callable, List @@ -9,6 +10,18 @@ from .config import TextConfig +# Our custom linear has an module named linear, so we add linear to the name +def add_linear_to_key(k: str) -> str: + k = k.replace("model.", "") + if k.startswith("text.") and ".linear." not in k: + k = re.sub( + r"(attn\.(?:qkv|proj)|mlp\.fc[12])\.(weight|bias)$", + r"\1.linear.\2", + k, + ) + return k + + @contextmanager def safetensors_open(safetensors_file: str): """ @@ -172,7 +185,7 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: or "model.vision.blocks.0.attn.proj.bias" in all_keys ): with safetensors_open(weights_file) as get_tensor: - tensors = {k.replace("model.", ""): get_tensor(k) for k in all_keys} + tensors = {add_linear_to_key(k): get_tensor(k) for k in all_keys} model.load_state_dict(tensors, strict=False) else: # Wrap the get_tensor function to handle key normalization @@ -204,7 +217,11 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None: if model.setup_caches_flag: model._setup_caches() - if "vision.blocks.0.attn.proj.bias" in all_keys: + if ( + "vision.blocks.0.attn.proj.bias" in all_keys + or "model.vision.blocks.0.attn.proj.bias" in all_keys + ): + tensors = {add_linear_to_key(k): v for k, v in tensors.items()} model.load_state_dict(tensors, strict=False) else: tensors = {