@@ -33,7 +33,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3333 }
3434}
3535
36- static void batch_decode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd , int embd_norm) {
36+ static void batch_decode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out , int embd_norm) {
3737 const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
3838
3939 // clear previous kv_cache values (irrelevant for embeddings)
@@ -65,8 +65,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
6565 GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
6666 }
6767
68- float * out = output + embd_pos * n_embd ;
69- common_embd_normalize (embd, out, n_embd , embd_norm);
68+ float * out = output + embd_pos * n_embd_out ;
69+ common_embd_normalize (embd, out, n_embd_out , embd_norm);
7070 }
7171}
7272
@@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
252252 }
253253
254254 // allocate output
255- const int n_embd = llama_model_n_embd (model);
256- std::vector<float > embeddings (n_embd_count * n_embd , 0 );
255+ const int n_embd_out = llama_model_n_embd_out (model);
256+ std::vector<float > embeddings (n_embd_count * n_embd_out , 0 );
257257 float * emb = embeddings.data ();
258258
259259 // break into batches
@@ -267,8 +267,8 @@ int main(int argc, char ** argv) {
267267
268268 // encode if at capacity
269269 if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
270- float * out = emb + e * n_embd ;
271- batch_decode (ctx, batch, out, s, n_embd , params.embd_normalize );
270+ float * out = emb + e * n_embd_out ;
271+ batch_decode (ctx, batch, out, s, n_embd_out , params.embd_normalize );
272272 e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
273273 s = 0 ;
274274 common_batch_clear (batch);
@@ -280,28 +280,28 @@ int main(int argc, char ** argv) {
280280 }
281281
282282 // final batch
283- float * out = emb + e * n_embd ;
284- batch_decode (ctx, batch, out, s, n_embd , params.embd_normalize );
283+ float * out = emb + e * n_embd_out ;
284+ batch_decode (ctx, batch, out, s, n_embd_out , params.embd_normalize );
285285
286286 if (params.embd_out .empty ()) {
287287 LOG (" \n " );
288288
289289 if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
290290 for (int j = 0 ; j < n_embd_count; j++) {
291291 LOG (" embedding %d: " , j);
292- for (int i = 0 ; i < std::min (3 , n_embd ); i++) {
292+ for (int i = 0 ; i < std::min (3 , n_embd_out ); i++) {
293293 if (params.embd_normalize == 0 ) {
294- LOG (" %6.0f " , emb[j * n_embd + i]);
294+ LOG (" %6.0f " , emb[j * n_embd_out + i]);
295295 } else {
296- LOG (" %9.6f " , emb[j * n_embd + i]);
296+ LOG (" %9.6f " , emb[j * n_embd_out + i]);
297297 }
298298 }
299299 LOG (" ... " );
300- for (int i = n_embd - 3 ; i < n_embd ; i++) {
300+ for (int i = n_embd_out - 3 ; i < n_embd_out ; i++) {
301301 if (params.embd_normalize == 0 ) {
302- LOG (" %6.0f " , emb[j * n_embd + i]);
302+ LOG (" %6.0f " , emb[j * n_embd_out + i]);
303303 } else {
304- LOG (" %9.6f " , emb[j * n_embd + i]);
304+ LOG (" %9.6f " , emb[j * n_embd_out + i]);
305305 }
306306 }
307307 LOG (" \n " );
@@ -320,21 +320,21 @@ int main(int argc, char ** argv) {
320320 for (uint32_t i = 0 ; i < n_cls_out; i++) {
321321 // NOTE: if you change this log - update the tests in ci/run.sh
322322 if (n_cls_out == 1 ) {
323- LOG (" rerank score %d: %8.3f\n " , j, emb[j * n_embd ]);
323+ LOG (" rerank score %d: %8.3f\n " , j, emb[j * n_embd_out ]);
324324 } else {
325- LOG (" rerank score %d: %8.3f [%s]\n " , j, emb[j * n_embd + i], cls_out_labels[i].c_str ());
325+ LOG (" rerank score %d: %8.3f [%s]\n " , j, emb[j * n_embd_out + i], cls_out_labels[i].c_str ());
326326 }
327327 }
328328 }
329329 } else {
330330 // print the first part of the embeddings or for a single prompt, the full embedding
331331 for (int j = 0 ; j < n_prompts; j++) {
332332 LOG (" embedding %d: " , j);
333- for (int i = 0 ; i < (n_prompts > 1 ? std::min (16 , n_embd ) : n_embd ); i++) {
333+ for (int i = 0 ; i < (n_prompts > 1 ? std::min (16 , n_embd_out ) : n_embd_out ); i++) {
334334 if (params.embd_normalize == 0 ) {
335- LOG (" %6.0f " , emb[j * n_embd + i]);
335+ LOG (" %6.0f " , emb[j * n_embd_out + i]);
336336 } else {
337- LOG (" %9.6f " , emb[j * n_embd + i]);
337+ LOG (" %9.6f " , emb[j * n_embd_out + i]);
338338 }
339339 }
340340 LOG (" \n " );
@@ -350,7 +350,7 @@ int main(int argc, char ** argv) {
350350 LOG (" \n " );
351351 for (int i = 0 ; i < n_prompts; i++) {
352352 for (int j = 0 ; j < n_prompts; j++) {
353- float sim = common_embd_similarity_cos (emb + i * n_embd , emb + j * n_embd, n_embd );
353+ float sim = common_embd_similarity_cos (emb + i * n_embd_out , emb + j * n_embd_out, n_embd_out );
354354 LOG (" %6.2f " , sim);
355355 }
356356 LOG (" %1.10s" , prompts[i].c_str ());
@@ -368,9 +368,9 @@ int main(int argc, char ** argv) {
368368 if (notArray) LOG (" {\n \" object\" : \" embedding\" ,\n \" index\" : %d,\n \" embedding\" : " ,j);
369369 LOG (" [" );
370370 for (int i = 0 ;;) { // at least one iteration (n_embd > 0)
371- LOG (params.embd_normalize == 0 ? " %1.0f" : " %1.7f" , emb[j * n_embd + i]);
371+ LOG (params.embd_normalize == 0 ? " %1.0f" : " %1.7f" , emb[j * n_embd_out + i]);
372372 i++;
373- if (i < n_embd ) LOG (" ," ); else break ;
373+ if (i < n_embd_out ) LOG (" ," ); else break ;
374374 }
375375 LOG (notArray ? " ]\n }" : " ]" );
376376 j++;
@@ -383,7 +383,7 @@ int main(int argc, char ** argv) {
383383 for (int i = 0 ;;) { // at least two iteration (n_embd_count > 1)
384384 LOG (" [" );
385385 for (int j = 0 ;;) { // at least two iteration (n_embd_count > 1)
386- float sim = common_embd_similarity_cos (emb + i * n_embd , emb + j * n_embd, n_embd );
386+ float sim = common_embd_similarity_cos (emb + i * n_embd_out , emb + j * n_embd_out, n_embd_out );
387387 LOG (" %6.2f" , sim);
388388 j++;
389389 if (j < n_embd_count) LOG (" , " ); else break ;
@@ -397,7 +397,7 @@ int main(int argc, char ** argv) {
397397
398398 if (notArray) LOG (" \n }\n " );
399399 } else if (params.embd_out == " raw" ) {
400- print_raw_embeddings (emb, n_embd_count, n_embd , model, pooling_type, params.embd_normalize );
400+ print_raw_embeddings (emb, n_embd_count, n_embd_out , model, pooling_type, params.embd_normalize );
401401 }
402402
403403 LOG (" \n " );
0 commit comments