Skip to content

Commit 73d284a

Browse files
authored
model : add LFM2-ColBert-350M (ggml-org#18607)
* model : add LFM2-ColBert-350M * llama_model_n_embd_out() - returns `hparams.n_embd_out` if set and fallbacks to `hparams.n_embd`
1 parent df17a4c commit 73d284a

16 files changed

Lines changed: 118 additions & 60 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9956,6 +9956,27 @@ def _is_audio_tensor(self, name: str):
99569956
return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])
99579957

99589958

9959+
@ModelBase.register("Lfm2Model")
9960+
class LFM2ColBertModel(LFM2Model):
9961+
model_arch = gguf.MODEL_ARCH.LFM2
9962+
dense_tensor_name = "dense_2"
9963+
9964+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
9965+
if not name.startswith(self.dense_tensor_name):
9966+
name = "model." + name
9967+
9968+
return super().modify_tensors(data_torch, name, bid)
9969+
9970+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
9971+
# dense tensor is stored in a separate safetensors file
9972+
from safetensors.torch import load_file
9973+
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
9974+
assert tensors_file.is_file()
9975+
tensor = load_file(tensors_file)["linear.weight"]
9976+
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
9977+
yield f"{self.dense_tensor_name}.weight", tensor.clone()
9978+
9979+
99599980
@ModelBase.register("Lfm2MoeForCausalLM")
99609981
class LFM2MoeModel(TextModel):
99619982
model_arch = gguf.MODEL_ARCH.LFM2MOE

examples/embedding/embedding.cpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3333
}
3434
}
3535

36-
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
36+
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
3737
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
3838

3939
// clear previous kv_cache values (irrelevant for embeddings)
@@ -65,8 +65,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
6565
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
6666
}
6767

68-
float * out = output + embd_pos * n_embd;
69-
common_embd_normalize(embd, out, n_embd, embd_norm);
68+
float * out = output + embd_pos * n_embd_out;
69+
common_embd_normalize(embd, out, n_embd_out, embd_norm);
7070
}
7171
}
7272

@@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
252252
}
253253

254254
// allocate output
255-
const int n_embd = llama_model_n_embd(model);
256-
std::vector<float> embeddings(n_embd_count * n_embd, 0);
255+
const int n_embd_out = llama_model_n_embd_out(model);
256+
std::vector<float> embeddings(n_embd_count * n_embd_out, 0);
257257
float * emb = embeddings.data();
258258

259259
// break into batches
@@ -267,8 +267,8 @@ int main(int argc, char ** argv) {
267267

268268
// encode if at capacity
269269
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
270-
float * out = emb + e * n_embd;
271-
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
270+
float * out = emb + e * n_embd_out;
271+
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
272272
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
273273
s = 0;
274274
common_batch_clear(batch);
@@ -280,28 +280,28 @@ int main(int argc, char ** argv) {
280280
}
281281

282282
// final batch
283-
float * out = emb + e * n_embd;
284-
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
283+
float * out = emb + e * n_embd_out;
284+
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
285285

286286
if (params.embd_out.empty()) {
287287
LOG("\n");
288288

289289
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
290290
for (int j = 0; j < n_embd_count; j++) {
291291
LOG("embedding %d: ", j);
292-
for (int i = 0; i < std::min(3, n_embd); i++) {
292+
for (int i = 0; i < std::min(3, n_embd_out); i++) {
293293
if (params.embd_normalize == 0) {
294-
LOG("%6.0f ", emb[j * n_embd + i]);
294+
LOG("%6.0f ", emb[j * n_embd_out + i]);
295295
} else {
296-
LOG("%9.6f ", emb[j * n_embd + i]);
296+
LOG("%9.6f ", emb[j * n_embd_out + i]);
297297
}
298298
}
299299
LOG(" ... ");
300-
for (int i = n_embd - 3; i < n_embd; i++) {
300+
for (int i = n_embd_out - 3; i < n_embd_out; i++) {
301301
if (params.embd_normalize == 0) {
302-
LOG("%6.0f ", emb[j * n_embd + i]);
302+
LOG("%6.0f ", emb[j * n_embd_out + i]);
303303
} else {
304-
LOG("%9.6f ", emb[j * n_embd + i]);
304+
LOG("%9.6f ", emb[j * n_embd_out + i]);
305305
}
306306
}
307307
LOG("\n");
@@ -320,21 +320,21 @@ int main(int argc, char ** argv) {
320320
for (uint32_t i = 0; i < n_cls_out; i++) {
321321
// NOTE: if you change this log - update the tests in ci/run.sh
322322
if (n_cls_out == 1) {
323-
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
323+
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]);
324324
} else {
325-
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
325+
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str());
326326
}
327327
}
328328
}
329329
} else {
330330
// print the first part of the embeddings or for a single prompt, the full embedding
331331
for (int j = 0; j < n_prompts; j++) {
332332
LOG("embedding %d: ", j);
333-
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
333+
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
334334
if (params.embd_normalize == 0) {
335-
LOG("%6.0f ", emb[j * n_embd + i]);
335+
LOG("%6.0f ", emb[j * n_embd_out + i]);
336336
} else {
337-
LOG("%9.6f ", emb[j * n_embd + i]);
337+
LOG("%9.6f ", emb[j * n_embd_out + i]);
338338
}
339339
}
340340
LOG("\n");
@@ -350,7 +350,7 @@ int main(int argc, char ** argv) {
350350
LOG("\n");
351351
for (int i = 0; i < n_prompts; i++) {
352352
for (int j = 0; j < n_prompts; j++) {
353-
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
353+
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
354354
LOG("%6.2f ", sim);
355355
}
356356
LOG("%1.10s", prompts[i].c_str());
@@ -368,9 +368,9 @@ int main(int argc, char ** argv) {
368368
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
369369
LOG("[");
370370
for (int i = 0;;) { // at least one iteration (n_embd > 0)
371-
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
371+
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
372372
i++;
373-
if (i < n_embd) LOG(","); else break;
373+
if (i < n_embd_out) LOG(","); else break;
374374
}
375375
LOG(notArray ? "]\n }" : "]");
376376
j++;
@@ -383,7 +383,7 @@ int main(int argc, char ** argv) {
383383
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
384384
LOG(" [");
385385
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
386-
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
386+
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
387387
LOG("%6.2f", sim);
388388
j++;
389389
if (j < n_embd_count) LOG(", "); else break;
@@ -397,7 +397,7 @@ int main(int argc, char ** argv) {
397397

398398
if (notArray) LOG("\n}\n");
399399
} else if (params.embd_out == "raw") {
400-
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
400+
print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize);
401401
}
402402

403403
LOG("\n");

examples/model-conversion/logits.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ int main(int argc, char ** argv) {
161161
std::vector<float> embd_out;
162162

163163
if (embedding_mode) {
164-
const int n_embd = llama_model_n_embd(model);
164+
const int n_embd_out = llama_model_n_embd_out(model);
165165
const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
166-
const int n_embeddings = n_embd * n_embd_count;
166+
const int n_embeddings = n_embd_out * n_embd_count;
167167
float * embeddings;
168168
type = "-embeddings";
169169

@@ -177,24 +177,24 @@ int main(int argc, char ** argv) {
177177
embeddings = llama_get_embeddings(ctx);
178178
}
179179

180-
printf("Embedding dimension: %d\n", n_embd);
180+
printf("Embedding dimension: %d\n", n_embd_out);
181181
printf("\n");
182182

183183
// Print embeddings in the specified format
184184
for (int j = 0; j < n_embd_count; j++) {
185185
printf("embedding %d: ", j);
186186

187187
// Print first 3 values
188-
for (int i = 0; i < 3 && i < n_embd; i++) {
189-
printf("%9.6f ", embeddings[j * n_embd + i]);
188+
for (int i = 0; i < 3 && i < n_embd_out; i++) {
189+
printf("%9.6f ", embeddings[j * n_embd_out + i]);
190190
}
191191

192192
printf(" ... ");
193193

194194
// Print last 3 values
195-
for (int i = n_embd - 3; i < n_embd; i++) {
195+
for (int i = n_embd_out - 3; i < n_embd_out; i++) {
196196
if (i >= 0) {
197-
printf("%9.6f ", embeddings[j * n_embd + i]);
197+
printf("%9.6f ", embeddings[j * n_embd_out + i]);
198198
}
199199
}
200200

examples/retrieval/retrieval.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ int main(int argc, char ** argv) {
217217
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
218218

219219
// allocate output
220-
const int n_embd = llama_model_n_embd(model);
221-
std::vector<float> embeddings(n_chunks * n_embd, 0);
220+
const int n_embd_out = llama_model_n_embd_out(model);
221+
std::vector<float> embeddings(n_chunks * n_embd_out, 0);
222222
float * emb = embeddings.data();
223223

224224
// break into batches
@@ -232,8 +232,8 @@ int main(int argc, char ** argv) {
232232

233233
// encode if at capacity
234234
if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) {
235-
float * out = emb + p * n_embd;
236-
batch_process(ctx, batch, out, s, n_embd);
235+
float * out = emb + p * n_embd_out;
236+
batch_process(ctx, batch, out, s, n_embd_out);
237237
common_batch_clear(batch);
238238
p += s;
239239
s = 0;
@@ -245,12 +245,12 @@ int main(int argc, char ** argv) {
245245
}
246246

247247
// final batch
248-
float * out = emb + p * n_embd;
249-
batch_process(ctx, batch, out, s, n_embd);
248+
float * out = emb + p * n_embd_out;
249+
batch_process(ctx, batch, out, s, n_embd_out);
250250

251251
// save embeddings to chunks
252252
for (int i = 0; i < n_chunks; i++) {
253-
chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd);
253+
chunks[i].embedding = std::vector<float>(emb + i * n_embd_out, emb + (i + 1) * n_embd_out);
254254
// clear tokens as they are no longer needed
255255
chunks[i].tokens.clear();
256256
}
@@ -266,16 +266,16 @@ int main(int argc, char ** argv) {
266266

267267
batch_add_seq(query_batch, query_tokens, 0);
268268

269-
std::vector<float> query_emb(n_embd, 0);
270-
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
269+
std::vector<float> query_emb(n_embd_out, 0);
270+
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd_out);
271271

272272
common_batch_clear(query_batch);
273273

274274
// compute cosine similarities
275275
{
276276
std::vector<std::pair<int, float>> similarities;
277277
for (int i = 0; i < n_chunks; i++) {
278-
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
278+
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd_out);
279279
similarities.push_back(std::make_pair(i, sim));
280280
}
281281

gguf-py/gguf/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class LLM:
104104
VOCAB_SIZE = "{arch}.vocab_size"
105105
CONTEXT_LENGTH = "{arch}.context_length"
106106
EMBEDDING_LENGTH = "{arch}.embedding_length"
107+
EMBEDDING_LENGTH_OUT = "{arch}.embedding_length_out"
107108
FEATURES_LENGTH = "{arch}.features_length"
108109
BLOCK_COUNT = "{arch}.block_count"
109110
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
@@ -3038,6 +3039,7 @@ class MODEL_TENSOR(IntEnum):
30383039
MODEL_TENSOR.ATTN_V,
30393040
MODEL_TENSOR.ATTN_OUT,
30403041
MODEL_TENSOR.OUTPUT,
3042+
MODEL_TENSOR.DENSE_2_OUT, # LFM2-ColBert-350M
30413043
],
30423044
MODEL_ARCH.LFM2MOE: [
30433045
MODEL_TENSOR.TOKEN_EMBD,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,9 @@ def add_context_length(self, length: int) -> None:
681681
def add_embedding_length(self, length: int) -> None:
682682
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
683683

684+
def add_embedding_length_out(self, length: int) -> None:
685+
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH_OUT.format(arch=self.arch), length)
686+
684687
def add_features_length(self, length: int) -> None:
685688
self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
686689

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ extern "C" {
535535
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
536536
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
537537
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
538+
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
538539
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
539540
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
540541
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);

src/llama-arch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
152152
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
153153
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
154154
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
155+
{ LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" },
155156
{ LLM_KV_FEATURES_LENGTH, "%s.features_length" },
156157
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
157158
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
@@ -2075,6 +2076,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
20752076
LLM_TENSOR_TOKEN_EMBD,
20762077
LLM_TENSOR_OUTPUT_NORM_LFM2,
20772078
LLM_TENSOR_OUTPUT,
2079+
LLM_TENSOR_DENSE_2_OUT,
20782080
};
20792081
case LLM_ARCH_LFM2MOE:
20802082
return {

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ enum llm_kv {
156156
LLM_KV_VOCAB_SIZE,
157157
LLM_KV_CONTEXT_LENGTH,
158158
LLM_KV_EMBEDDING_LENGTH,
159+
LLM_KV_EMBEDDING_LENGTH_OUT,
159160
LLM_KV_FEATURES_LENGTH,
160161
LLM_KV_BLOCK_COUNT,
161162
LLM_KV_LEADING_DENSE_BLOCK_COUNT,

src/llama-context.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
758758
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
759759
}
760760

761-
return embd + j*model.hparams.n_embd;
761+
const uint32_t n_embd_out = model.hparams.get_n_embd_out();
762+
return embd + j*n_embd_out;
762763
} catch (const std::exception & err) {
763764
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
764765
#ifndef NDEBUG
@@ -1194,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
11941195
{
11951196
// extract token embeddings
11961197
GGML_ASSERT(embd != nullptr);
1198+
const uint32_t n_embd_out = hparams.get_n_embd_out();
11971199

1198-
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
1199-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1200+
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
1201+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
12001202
} break;
12011203
case LLAMA_POOLING_TYPE_MEAN:
12021204
case LLAMA_POOLING_TYPE_CLS:
@@ -1600,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
16001602
{
16011603
// extract token embeddings
16021604
GGML_ASSERT(embd != nullptr);
1603-
float * embd_out = embd + n_outputs_prev*n_embd;
1605+
const uint32_t n_embd_out = hparams.get_n_embd_out();
1606+
float * embd_out = embd + n_outputs_prev*n_embd_out;
16041607

16051608
if (n_outputs) {
16061609
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1607-
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
1608-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
1610+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
1611+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
16091612
}
16101613
} break;
16111614
case LLAMA_POOLING_TYPE_MEAN:
@@ -1730,9 +1733,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
17301733

17311734
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
17321735

1733-
const auto n_batch = cparams.n_batch;
1734-
const auto n_vocab = vocab.n_tokens();
1735-
const auto n_embd = hparams.n_embd;
1736+
const auto n_batch = cparams.n_batch;
1737+
const auto n_vocab = vocab.n_tokens();
1738+
const auto n_embd_out = hparams.get_n_embd_out();
17361739

17371740
bool has_logits = true;
17381741
bool has_embd = cparams.embeddings;
@@ -1773,7 +1776,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
17731776

17741777
// Allocate CPU logits buffer only if needed by sequences in this batch
17751778
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
1776-
embd_size = has_embd ? n_embd*n_outputs_max : 0;
1779+
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
17771780

17781781
// TODO: avoid this branching by working with the worst-case
17791782
if (!has_sampling) {

0 commit comments

Comments
 (0)