diff --git a/Libraries/Embedders/Bert.swift b/Libraries/Embedders/Bert.swift index aa9cdc5a..0bba6502 100644 --- a/Libraries/Embedders/Bert.swift +++ b/Libraries/Embedders/Bert.swift @@ -2,6 +2,7 @@ import MLX import MLXNN +import MLXLMCommon extension MLXArray { public static func arange(_ size: Int) -> MLXArray { @@ -196,6 +197,10 @@ public class BertModel: Module, EmbeddingModel { result[key] = item.value }.filter { key, _ in key != "embeddings.position_ids" } } + + public func sanitize(weights: [String : MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String : MLXArray] { + fatalError("Bert does not support quantization") + } } public class DistilBertModel: BertModel { diff --git a/Libraries/Embedders/Configuration.swift b/Libraries/Embedders/Configuration.swift index 358dcf82..473c7f07 100644 --- a/Libraries/Embedders/Configuration.swift +++ b/Libraries/Embedders/Configuration.swift @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. import Foundation +import MLXLLM public enum StringOrNumber: Codable, Equatable, Sendable { case string(String) @@ -69,6 +70,13 @@ private class ModelTypeRegistry: @unchecked Sendable { let model = NomicBertModel(configuration) return model }, + "gemma3_text": { + url in + let configuration = try JSONDecoder().decode( + Gemma3TextConfiguration.self, from: Data(contentsOf: url)) + let model = EmbeddingGemma(configuration) + return model + }, ] public func registerModelType( diff --git a/Libraries/Embedders/EmbeddingGemma.swift b/Libraries/Embedders/EmbeddingGemma.swift new file mode 100644 index 00000000..75235bc3 --- /dev/null +++ b/Libraries/Embedders/EmbeddingGemma.swift @@ -0,0 +1,104 @@ +import MLX +import MLXNN +import MLXLLM +import MLXLMCommon + +public class EmbeddingGemma: Module, EmbeddingModel { + @ModuleInfo private var model: Gemma3TextModel + @ModuleInfo private var dense: [Module] + + public let config: Gemma3TextConfiguration + public var vocabularySize: Int { config.vocabularySize } + + public init(_ config: Gemma3TextConfiguration) { + self.config = config + self.model = Gemma3TextModel(config) + self.dense = [ + Linear(768, 3072, bias: false), Linear(3072, 768, bias: false) + ] + } + + public func callAsFunction( + _ inputs: MLXArray, positionIds: MLXArray?, tokenTypeIds: MLXArray?, + attentionMask: MLXArray? + ) -> EmbeddingModelOutput { + var out = model.getHiddenStates(inputs, mask: nil, cache: nil) + + // mean pooling + let notPadding = inputs .!= 0 + let sum = (out * notPadding[.ellipsis, .newAxis]).sum(axis:1) + let nonMasked = notPadding.sum(axis: -1, keepDims: true) + out = sum / nonMasked + + for dense in self.dense { + if let dense = dense as? Linear { + out = dense(out) + } else if let dense = dense as? QuantizedLinear { + out = dense(out) + } + } + + // normalize + out = out.asType(Float32.self) + let norm = maximum(norm(out, ord:2.0, axis:-1, keepDims: true), MLXArray(1e-6)) + let pooledOutput = out / norm + + return EmbeddingModelOutput(hiddenStates: out, pooledOutput: pooledOutput) + } + + /// Get hidden states before the dense projection head + public func getHiddenStates(_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, cache: [KVCache]? = nil) -> MLXArray { + return model(inputs, mask: mask, cache: cache) + } + + + public func sanitize(weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization? = nil) + -> [String: MLXArray] + { + var processedWeights = model.sanitize(weights: weights, quantizationConfig: quantizationConfig) + + // 1. Add a model. prefix to all model. weights + processedWeights = Dictionary(uniqueKeysWithValues: processedWeights.map { key, value in + if key.hasPrefix("model.") || key.hasPrefix("lm_head.") { + return ("model.\(key)", value) + } else { + return (key, value) + } + }) + + // 2. Apply quantization to dense layers, if needed + let hasQuantizedDense = hasQuantizedWeights(layerPath: "dense.0", in: processedWeights) + if hasQuantizedDense { + let groupSize = quantizationConfig?.groupSize ?? 64 + let bits = quantizationConfig?.bits ?? 4 + + quantize(model: self) { path, module in + if hasQuantizedWeights(layerPath: path, in: processedWeights) { + return (groupSize, bits) + } + return nil + } + } + + return processedWeights.filter { key, _ in + !key.contains("self_attn.rotary_emb.inv_freq") + } + } + + public func sanitize(weights: [String : MLXArray]) -> [String : MLXArray] { + sanitize(weights: weights, quantizationConfig: nil) + } + + /// Check if a layer has quantized weights + private func hasQuantizedWeights(layerPath: String, in weights: [String: MLXArray]) -> Bool { + let scalesKey = "\(layerPath).scales" + let biasesKey = "\(layerPath).biases" + let weightKey = "\(layerPath).weight" + + let hasScales = weights[scalesKey] != nil + let hasBiases = weights[biasesKey] != nil + let hasWeight = weights[weightKey]?.dtype == .uint32 + + return hasScales && hasBiases && hasWeight + } +} diff --git a/Libraries/Embedders/EmbeddingModel.swift b/Libraries/Embedders/EmbeddingModel.swift index 3c4fbed7..ab10e0c9 100644 --- a/Libraries/Embedders/EmbeddingModel.swift +++ b/Libraries/Embedders/EmbeddingModel.swift @@ -4,6 +4,7 @@ import Foundation @preconcurrency import Hub import MLX import MLXNN +import MLXLMCommon import Tokenizers /// Container for models that guarantees single threaded access. @@ -87,8 +88,8 @@ extension Module { } public struct EmbeddingModelOutput { - let hiddenStates: MLXArray? - let pooledOutput: MLXArray? + public let hiddenStates: MLXArray? + public let pooledOutput: MLXArray? } public protocol EmbeddingModel: Module { @@ -99,6 +100,7 @@ public protocol EmbeddingModel: Module { ) -> EmbeddingModelOutput /// Optionally preprocess the weights and modify / remove values as needed. func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] + func sanitize(weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String: MLXArray] } extension EmbeddingModel { diff --git a/Libraries/Embedders/Load.swift b/Libraries/Embedders/Load.swift index 868cc6c4..ecab22d0 100644 --- a/Libraries/Embedders/Load.swift +++ b/Libraries/Embedders/Load.swift @@ -4,6 +4,7 @@ import Foundation @preconcurrency import Hub import MLX import MLXNN +import MLXLMCommon import Tokenizers struct EmbedderError: Error { @@ -60,6 +61,8 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel { let configurationURL = modelDirectory.appending(component: "config.json") let baseConfig = try JSONDecoder().decode( BaseConfiguration.self, from: Data(contentsOf: configurationURL)) + let commonBaseConfig = try JSONDecoder().decode( + MLXLMCommon.BaseConfiguration.self, from: Data(contentsOf: configurationURL)) let modelType = ModelType(rawValue: baseConfig.modelType) let model = try modelType.createModel(configuration: configurationURL) @@ -78,7 +81,7 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel { } // per-model cleanup - weights = model.sanitize(weights: weights) + weights = model.sanitize(weights: weights, quantizationConfig: commonBaseConfig.quantization) // quantize if needed if let perLayerQuantization = baseConfig.perLayerQuantization { diff --git a/Libraries/Embedders/Models.swift b/Libraries/Embedders/Models.swift index 3c47a83c..fbc2dfd6 100644 --- a/Libraries/Embedders/Models.swift +++ b/Libraries/Embedders/Models.swift @@ -108,6 +108,14 @@ extension ModelConfiguration { public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3") public static let mixedbread_large = ModelConfiguration( id: "mixedbread-ai/mxbai-embed-large-v1") + public static let embeddinggemma_300m = ModelConfiguration( + id: "mlx-community/embeddinggemma-300m-bf16") + public static let embeddinggemma_300m_8bit = ModelConfiguration( + id: "mlx-community/embeddinggemma-300m-8bit") + public static let embeddinggemma_300m_6bit = ModelConfiguration( + id: "mlx-community/embeddinggemma-300m-6bit") + public static let embeddinggemma_300m_4bit = ModelConfiguration( + id: "mlx-community/embeddinggemma-300m-4bit") private enum BootstrapState: Sendable { case idle @@ -138,6 +146,10 @@ extension ModelConfiguration { snowflake_lg, bge_m3, mixedbread_large, + embeddinggemma_300m, + embeddinggemma_300m_8bit, + embeddinggemma_300m_6bit, + embeddinggemma_300m_4bit, ]) bootstrapState = .bootstrapped diff --git a/Libraries/Embedders/NomicBert.swift b/Libraries/Embedders/NomicBert.swift index f98a69be..5107570d 100644 --- a/Libraries/Embedders/NomicBert.swift +++ b/Libraries/Embedders/NomicBert.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXNN +import MLXLMCommon class NomicEmbedding: Module { @@ -390,6 +391,10 @@ public class NomicBertModel: Module, EmbeddingModel { result[key] = item.value } } + + public func sanitize(weights: [String : MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String : MLXArray] { + fatalError("Nomic does not support quantization") + } } public struct NomicBertConfiguration: Decodable, Sendable { diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index eb953ec4..9d2f8b7b 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -14,22 +14,48 @@ import MLXLLM import MLXLMCommon import MLXNN +/// Create a bidirectional sliding window mask where tokens can attend to others within the sliding window distance +public func createBidirectionalSlidingWindowMask( + n: Int, + offset: Int, + windowSize: Int +) -> MLXArray { + let rinds = MLXArray(Int32(0) ..< Int32(offset + n)) + var linds = offset != 0 ? MLXArray(Int32(offset) ..< Int32(offset + n)) : rinds + linds = linds[0..., .newAxis] + let rindsBcast = rinds[.newAxis] + + // Create mask where abs(q_idx - kv_idx) < windowSize (bidirectional window) + let distance = abs(linds - rindsBcast) + let mask = distance .< windowSize + + return mask +} + +func simpleSDPA(queries: MLXArray, keys: MLXArray, values: MLXArray, mask: MLXArray, scale: Float) -> MLXArray { + var attn = matmul(queries, keys.transposed(0, 1, 3, 2)) + attn = attn - (1 - mask) * 1e6 + let weights = softmax(scale * attn, axis:-1) + return matmul(weights, values) +} + public struct Gemma3TextConfiguration: Codable { - let modelType: String - let hiddenSize: Int - let hiddenLayers: Int - let intermediateSize: Int - let attentionHeads: Int - let headDim: Int - let rmsNormEps: Float - let vocabularySize: Int - let kvHeads: Int - let ropeGlobalBaseFreq: Float - let ropeLocalBaseFreq: Float - let ropeTraditional: Bool - let queryPreAttnScalar: Float - let slidingWindow: Int - let slidingWindowPattern: Int + public let modelType: String + public let hiddenSize: Int + public let hiddenLayers: Int + public let intermediateSize: Int + public let attentionHeads: Int + public let headDim: Int + public let rmsNormEps: Float + public let vocabularySize: Int + public let kvHeads: Int + public let ropeGlobalBaseFreq: Float + public let ropeLocalBaseFreq: Float + public let ropeTraditional: Bool + public let queryPreAttnScalar: Float + public let slidingWindow: Int + public let slidingWindowPattern: Int + public let useBidirectionalAttention: Bool enum CodingKeys: String, CodingKey { case modelType = "model_type" @@ -47,12 +73,32 @@ public struct Gemma3TextConfiguration: Codable { case queryPreAttnScalar = "query_pre_attn_scalar" case slidingWindow = "sliding_window" case slidingWindowPattern = "sliding_window_pattern" + case useBidirectionalAttention = "use_bidirectional_attention" } enum VLMCodingKeys: String, CodingKey { case textConfig = "text_config" } + public init(modelType: String, hiddenSize: Int, hiddenLayers: Int, intermediateSize: Int, attentionHeads: Int, headDim: Int, rmsNormEps: Float, vocabularySize: Int, kvHeads: Int, ropeGlobalBaseFreq: Float, ropeLocalBaseFreq: Float, ropeTraditional: Bool, queryPreAttnScalar: Float, slidingWindow: Int, slidingWindowPattern: Int, useBidirectionalAttention: Bool, quantizationConfig: QuantizationConfig? = nil) { + self.modelType = modelType + self.hiddenSize = hiddenSize + self.hiddenLayers = hiddenLayers + self.intermediateSize = intermediateSize + self.attentionHeads = attentionHeads + self.headDim = headDim + self.rmsNormEps = rmsNormEps + self.vocabularySize = vocabularySize + self.kvHeads = kvHeads + self.ropeGlobalBaseFreq = ropeGlobalBaseFreq + self.ropeLocalBaseFreq = ropeLocalBaseFreq + self.ropeTraditional = ropeTraditional + self.queryPreAttnScalar = queryPreAttnScalar + self.slidingWindow = slidingWindow + self.slidingWindowPattern = slidingWindowPattern + self.useBidirectionalAttention = useBidirectionalAttention + } + public init(from decoder: Decoder) throws { let nestedContainer = try decoder.container(keyedBy: VLMCodingKeys.self) @@ -82,12 +128,29 @@ public struct Gemma3TextConfiguration: Codable { try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false queryPreAttnScalar = try container.decodeIfPresent(Float.self, forKey: .queryPreAttnScalar) ?? 256 - slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512 + useBidirectionalAttention = + try container.decodeIfPresent(Bool.self, forKey: .useBidirectionalAttention) ?? false + + let rawSlidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512 + // Apply sliding window adjustment for bidirectional attention (from patch: (sliding_window // 2) + 1) + slidingWindow = useBidirectionalAttention ? (rawSlidingWindow / 2) + 1 : rawSlidingWindow slidingWindowPattern = try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6 } } +// MARK: - Quantization Configuration + +public struct QuantizationConfig: Codable, Sendable { + public let groupSize: Int + public let bits: Int + + enum CodingKeys: String, CodingKey { + case groupSize = "group_size" + case bits + } +} + private class Attention: Module { let nHeads: Int let nKVHeads: Int @@ -98,6 +161,7 @@ private class Attention: Module { let isSliding: Bool let slidingWindow: Int let slidingWindowPattern: Int + let useBidirectionalAttention: Bool @ModuleInfo(key: "q_proj") var queryProj: Linear @ModuleInfo(key: "k_proj") var keyProj: Linear @@ -118,6 +182,7 @@ private class Attention: Module { self.layerIdx = layerIdx self.slidingWindow = config.slidingWindow self.slidingWindowPattern = config.slidingWindowPattern + self.useBidirectionalAttention = config.useBidirectionalAttention self.scale = pow(config.queryPreAttnScalar, -0.5) @@ -178,14 +243,17 @@ private class Attention: Module { } } - let output = attentionWithCacheUpdate( - queries: queries, - keys: keys, - values: values, - cache: cache, - scale: scale, - mask: finalMask - ) + let maskArr: MLXArray + if case .array(let maskArray) = finalMask { + maskArr = maskArray + } else { + fatalError("oh noes") + } + let output = simpleSDPA(queries: queries, + keys: keys, + values: values, + mask: maskArr, + scale: scale) .transposed(0, 2, 1, 3) .reshaped(B, L, -1) return outputProj(output) @@ -290,7 +358,7 @@ private class Gemma3Model: Module { { var h: MLXArray h = embedTokens(inputs) - let scale = MLXArray(sqrt(Float(config.hiddenSize)), dtype: .bfloat16) + let scale = MLXArray(sqrt(Float(config.hiddenSize)), dtype: .float32) h = h * scale.asType(h.dtype) var layerCache = cache if layerCache == nil { @@ -299,14 +367,36 @@ private class Gemma3Model: Module { // Create attention masks var fullMask: MLXFast.ScaledDotProductAttentionMaskMode = .none var slidingWindowMask: MLXFast.ScaledDotProductAttentionMaskMode = .none - if mask == nil { - let j = config.slidingWindowPattern - let globalLayerCache: [KVCache] - if j > 0 && j <= (layerCache?.count ?? 0), let globalCache = layerCache?[j - 1] { - globalLayerCache = [globalCache] - } else { - globalLayerCache = [] + let j = config.slidingWindowPattern + let globalLayerCache: [KVCache] + if j > 0 && j <= (layerCache?.count ?? 0), let globalCache = layerCache?[j - 1] { + globalLayerCache = [globalCache] + } else { + globalLayerCache = [] + } + + if config.useBidirectionalAttention { + // For bidirectional attention: full attention for global layers, bidirectional sliding window for others + var fullMaskArray = MLXArray.ones([h.dim(1), h.dim(1)], dtype: .bool) + if case .array(let maskArray) = mask { + fullMaskArray = fullMaskArray & maskArray } + fullMask = .array(fullMaskArray) + + let t = h.dim(1) + var offset = 0 + if let cache = layerCache?.compactMap({ $0 }).first { + offset = cache.offset + } + var slidingWindowMaskArray = createBidirectionalSlidingWindowMask( + n: t, offset: offset, windowSize: config.slidingWindow) + if case .array(let maskArray) = mask { + slidingWindowMaskArray = slidingWindowMaskArray & maskArray + } + slidingWindowMask = .array(slidingWindowMaskArray) + } else { + // Standard causal attention + // TODO: probably need to merge the custom mask in fullMask = createAttentionMask(h: h, cache: globalLayerCache) let allCaches = layerCache?.compactMap { $0 } ?? [] slidingWindowMask = createAttentionMask(h: h, cache: allCaches) @@ -315,9 +405,7 @@ private class Gemma3Model: Module { let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1) let localMask: MLXFast.ScaledDotProductAttentionMaskMode - if let mask { - localMask = mask - } else if isGlobal { + if isGlobal { localMask = fullMask } else { localMask = slidingWindowMask @@ -331,7 +419,7 @@ private class Gemma3Model: Module { public class Gemma3TextModel: Module, LLMModel { @ModuleInfo private var model: Gemma3Model - @ModuleInfo(key: "lm_head") var lmHead: Linear + @ModuleInfo(key: "lm_head") var lmHead: Module // Can be Linear or QuantizedLinear public let config: Gemma3TextConfiguration public var vocabularySize: Int { config.vocabularySize } @@ -343,32 +431,83 @@ public class Gemma3TextModel: Module, LLMModel { super.init() } - public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { - var out = model(inputs, mask: nil, cache: cache) - out = lmHead(out) + public func callAsFunction(_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, cache: [KVCache]? = nil) -> MLXArray { + var out = model(inputs, mask: mask, cache: cache) + + // Call the lmHead (works whether it's Linear or QuantizedLinear) + if let linear = lmHead as? Linear { + out = linear(out) + } else if let quantized = lmHead as? QuantizedLinear { + out = quantized(out) + } else { + fatalError("lmHead must be Linear or QuantizedLinear") + } + return out } + + /// Get hidden states before the language modeling head for embedding use cases + public func getHiddenStates(_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, cache: [KVCache]? = nil) -> MLXArray { + return model(inputs, mask: mask, cache: cache) + } - public func sanitize(weights: [String: MLXArray]) - -> [String: MLXArray] - { + public func sanitize( + weights: [String: MLXArray], + quantizationConfig: BaseConfiguration.Quantization? = nil + ) -> [String: MLXArray] { var processedWeights = weights - // VLM models converted using mlx_vlm.convert will still have - // the weights are under a language_model key + // 1. Handle VLM weight extraction first - VLM models converted using mlx_vlm.convert + // will still have the weights under a language_model key let unflattened = ModuleParameters.unflattened(weights) if let lm = unflattened["language_model"] { processedWeights = Dictionary(uniqueKeysWithValues: lm.flattened()) } + // 2. Handle weight sharing (works for both regular and quantized) + // Copy embedding weights to lm_head if lm_head weights don't exist (weight tying) if processedWeights["lm_head.weight"] == nil { - ["weight", "scales", "biases"].forEach { key in - if let embedWeight = processedWeights["model.embed_tokens.\(key)"] { - processedWeights["lm_head.\(key)"] = embedWeight + for suffix in ["weight", "scales", "biases"] { + let embedKey = "model.embed_tokens.\(suffix)" + let lmHeadKey = "lm_head.\(suffix)" + + if let embedWeight = processedWeights[embedKey] { + processedWeights[lmHeadKey] = embedWeight } } } - return processedWeights + + // 3. Apply quantization if needed + let hasQuantizedLmHead = hasQuantizedWeights(layerPath: "lm_head", in: processedWeights) + if hasQuantizedLmHead { + let groupSize = quantizationConfig?.groupSize ?? 64 + let bits = quantizationConfig?.bits ?? 4 + + quantize(model: self) { path, module in + if hasQuantizedWeights(layerPath: path, in: processedWeights) { + return (groupSize, bits) + } + return nil + } + } + + // Remove unused precomputed rotary freqs + return processedWeights.filter { key, _ in + !key.contains("self_attn.rotary_emb.inv_freq") + } + } + + /// Check if a layer has quantized weights + private func hasQuantizedWeights(layerPath: String, in weights: [String: MLXArray]) -> Bool { + let scalesKey = "\(layerPath).scales" + let biasesKey = "\(layerPath).biases" + let weightKey = "\(layerPath).weight" + + let hasScales = weights[scalesKey] != nil + let hasBiases = weights[biasesKey] != nil + let hasWeight = weights[weightKey]?.dtype == .uint32 + + return hasScales && hasBiases && hasWeight } public func newCache(parameters: GenerateParameters? = nil) -> [KVCache] { diff --git a/Package.swift b/Package.swift index 365d7814..29f28885 100644 --- a/Package.swift +++ b/Package.swift @@ -113,6 +113,7 @@ let package = Package( .target( name: "MLXEmbedders", dependencies: [ + "MLXLLM", .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXFast", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"),