Skip to content

Commit d9c6ce4

Browse files
authored
kv-cache : support V-less cache (ggml-org#19067)
* kv-cache : support V-less cache * cuda : better check for V_is_K_view * cuda : improve V_is_K_view check * graph : add comments * hparams : refactor
1 parent 70d8608 commit d9c6ce4

11 files changed

Lines changed: 246 additions & 53 deletions

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ void launch_fattn(
782782
const ggml_tensor * K = dst->src[1];
783783
const ggml_tensor * V = dst->src[2];
784784

785-
const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
785+
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);
786786

787787
const ggml_tensor * mask = dst->src[3];
788788
const ggml_tensor * sinks = dst->src[4];

ggml/src/ggml-cuda/fattn.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
247247
}
248248
}
249249

250-
const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
250+
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);
251251

252252
const int cc = ggml_cuda_info().devices[device].cc;
253253

src/llama-context.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
793793
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
794794
}
795795

796-
const uint32_t n_embd_out = model.hparams.get_n_embd_out();
796+
const uint32_t n_embd_out = model.hparams.n_embd_out();
797797
return embd + j*n_embd_out;
798798
} catch (const std::exception & err) {
799799
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
@@ -1279,7 +1279,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
12791279
{
12801280
// extract token embeddings
12811281
GGML_ASSERT(embd != nullptr);
1282-
const uint32_t n_embd_out = hparams.get_n_embd_out();
1282+
const uint32_t n_embd_out = hparams.n_embd_out();
12831283

12841284
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
12851285
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
@@ -1688,7 +1688,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
16881688
{
16891689
// extract token embeddings
16901690
GGML_ASSERT(embd != nullptr);
1691-
const uint32_t n_embd_out = hparams.get_n_embd_out();
1691+
const uint32_t n_embd_out = hparams.n_embd_out();
16921692
float * embd_out = embd + n_outputs_prev*n_embd_out;
16931693

16941694
if (n_outputs) {
@@ -1821,7 +1821,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
18211821

18221822
const auto n_batch = cparams.n_batch;
18231823
const auto n_vocab = vocab.n_tokens();
1824-
const auto n_embd_out = hparams.get_n_embd_out();
1824+
const auto n_embd_out = hparams.n_embd_out();
18251825

18261826
bool has_logits = true;
18271827
bool has_embd = cparams.embeddings;

src/llama-graph.cpp

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
407407
return res;
408408
}
409409

410+
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
411+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
412+
413+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
414+
}
415+
416+
bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
417+
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
418+
419+
this->mctx = mctx;
420+
421+
bool res = true;
422+
423+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
424+
425+
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
426+
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
427+
428+
return res;
429+
}
430+
410431
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
411432
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
412433
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -1596,11 +1617,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
15961617
v = ggml_transpose(ctx0, v);
15971618
}
15981619

1599-
// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
1600-
if (v_mla) {
1601-
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
1602-
}
1603-
16041620
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
16051621
if (k->type == GGML_TYPE_F32) {
16061622
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
@@ -1823,9 +1839,11 @@ ggml_tensor * llm_graph_context::build_attn(
18231839
ggml_tensor * v_cur,
18241840
ggml_tensor * kq_b,
18251841
ggml_tensor * sinks,
1826-
ggml_tensor * v_mla,
1842+
ggml_tensor * v_mla, // TODO: remove
18271843
float kq_scale,
18281844
int il) const {
1845+
GGML_ASSERT(v_mla == nullptr);
1846+
18291847
// these nodes are added to the graph together so that they are not reordered
18301848
// by doing so, the number of splits in the graph is reduced
18311849
// expand k later to enable rope fusion which directly writes into k-v cache
@@ -1868,6 +1886,93 @@ ggml_tensor * llm_graph_context::build_attn(
18681886
return cur;
18691887
}
18701888

1889+
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
1890+
ggml_context * ctx0,
1891+
const llama_ubatch & ubatch,
1892+
const llama_hparams & hparams,
1893+
const llama_cparams & cparams,
1894+
const llama_kv_cache_context * mctx_cur) {
1895+
1896+
auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
1897+
1898+
{
1899+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1900+
1901+
const auto n_kv = mctx_cur->get_n_kv();
1902+
const auto n_tokens = ubatch.n_tokens;
1903+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1904+
1905+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1906+
1907+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1908+
ggml_set_input(inp->self_kq_mask);
1909+
1910+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1911+
}
1912+
1913+
return inp;
1914+
}
1915+
1916+
llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
1917+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1918+
1919+
auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1920+
1921+
return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
1922+
}
1923+
1924+
ggml_tensor * llm_graph_context::build_attn(
1925+
llm_graph_input_attn_k * inp,
1926+
ggml_tensor * wo,
1927+
ggml_tensor * wo_b,
1928+
ggml_tensor * q_cur,
1929+
ggml_tensor * k_cur,
1930+
ggml_tensor * v_cur,
1931+
ggml_tensor * kq_b,
1932+
ggml_tensor * sinks,
1933+
ggml_tensor * v_mla,
1934+
float kq_scale,
1935+
int il) const {
1936+
// these nodes are added to the graph together so that they are not reordered
1937+
// by doing so, the number of splits in the graph is reduced
1938+
// expand k later to enable rope fusion which directly writes into k-v cache
1939+
ggml_build_forward_expand(gf, q_cur);
1940+
ggml_build_forward_expand(gf, v_cur);
1941+
ggml_build_forward_expand(gf, k_cur);
1942+
1943+
const auto * mctx_cur = inp->mctx;
1944+
1945+
// store to KV cache
1946+
{
1947+
const auto & k_idxs = inp->get_k_idxs();
1948+
1949+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1950+
}
1951+
1952+
const auto & kq_mask = inp->get_kq_mask();
1953+
1954+
ggml_tensor * q = q_cur;
1955+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1956+
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
1957+
1958+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1959+
cb(cur, "kqv_out", il);
1960+
1961+
if (wo) {
1962+
cur = build_lora_mm(wo, cur);
1963+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1964+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1965+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1966+
}
1967+
}
1968+
1969+
if (wo_b) {
1970+
cur = ggml_add(ctx0, cur, wo_b);
1971+
}
1972+
1973+
return cur;
1974+
}
1975+
18711976
ggml_tensor * llm_graph_context::build_attn(
18721977
llm_graph_input_attn_kv_iswa * inp,
18731978
ggml_tensor * wo,

src/llama-graph.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,39 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
317317
const llama_kv_cache_context * mctx;
318318
};
319319

320+
// V-less input for the KV cache
321+
// ref: https://github.com/ggml-org/llama.cpp/pull/19067
322+
class llm_graph_input_attn_k : public llm_graph_input_i {
323+
public:
324+
llm_graph_input_attn_k(
325+
const llama_hparams & hparams,
326+
const llama_cparams & cparams,
327+
const llama_kv_cache_context * mctx) :
328+
hparams(hparams),
329+
cparams(cparams),
330+
mctx(mctx) {
331+
}
332+
~llm_graph_input_attn_k() = default;
333+
334+
void set_input(const llama_ubatch * ubatch) override;
335+
336+
bool can_reuse(const llm_graph_params & params) override;
337+
338+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
339+
340+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
341+
342+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
343+
344+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
345+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
346+
347+
const llama_hparams hparams;
348+
const llama_cparams cparams;
349+
350+
const llama_kv_cache_context * mctx;
351+
};
352+
320353
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
321354
public:
322355
llm_graph_input_attn_kv_iswa(
@@ -833,6 +866,21 @@ struct llm_graph_context {
833866
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
834867
ggml_tensor * kq_b,
835868
ggml_tensor * sinks, // [n_head_q]
869+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
870+
float kq_scale,
871+
int il) const;
872+
873+
llm_graph_input_attn_k * build_attn_inp_k() const;
874+
875+
ggml_tensor * build_attn(
876+
llm_graph_input_attn_k * inp,
877+
ggml_tensor * wo,
878+
ggml_tensor * wo_b,
879+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
880+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
881+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
882+
ggml_tensor * kq_b,
883+
ggml_tensor * sinks, // [n_head_q]
836884
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
837885
float kq_scale,
838886
int il) const;

src/llama-hparams.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ uint32_t llama_hparams::n_embd_inp() const {
7272
return n_embd_inp;
7373
}
7474

75-
uint32_t llama_hparams::get_n_embd_out() const {
76-
return n_embd_out > 0 ? n_embd_out : n_embd;
75+
uint32_t llama_hparams::n_embd_out() const {
76+
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
7777
}
7878

7979
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
@@ -175,6 +175,21 @@ bool llama_hparams::is_swa(uint32_t il) const {
175175
GGML_ABORT("fatal error");
176176
}
177177

178+
bool llama_hparams::is_mla() const {
179+
assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) ||
180+
(n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0));
181+
182+
return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0;
183+
}
184+
185+
uint32_t llama_hparams::n_embd_head_k_mla() const {
186+
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
187+
}
188+
189+
uint32_t llama_hparams::n_embd_head_v_mla() const {
190+
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
191+
}
192+
178193
bool llama_hparams::has_kv(uint32_t il) const {
179194
if (n_layer_kv_from_start >= 0) {
180195
if (il < (uint32_t) n_layer_kv_from_start) {

src/llama-hparams.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ struct llama_hparams {
5353
uint32_t n_rel_attn_bkts = 0;
5454

5555
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
56-
uint32_t n_embd_head_k_mla = 0;
57-
uint32_t n_embd_head_v_mla = 0;
56+
uint32_t n_embd_head_k_mla_impl = 0;
57+
uint32_t n_embd_head_v_mla_impl = 0;
5858

5959
// for WavTokenizer
6060
struct llama_hparams_posnet posnet;
@@ -164,7 +164,7 @@ struct llama_hparams {
164164
uint32_t n_cls_out = 1;
165165

166166
// output embedding dimension (0 = use n_embd)
167-
uint32_t n_embd_out = 0;
167+
uint32_t n_embd_out_impl = 0;
168168

169169
// llama4 smallthinker
170170
uint32_t n_moe_layer_step = 0;
@@ -239,7 +239,7 @@ struct llama_hparams {
239239
uint32_t n_embd_inp() const;
240240

241241
// dimension of output embeddings
242-
uint32_t get_n_embd_out() const;
242+
uint32_t n_embd_out() const;
243243

244244
// dimension of key embeddings across all k-v heads
245245
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
@@ -269,6 +269,12 @@ struct llama_hparams {
269269

270270
bool is_swa(uint32_t il) const;
271271

272+
// note: currently only support if either all or none of the layers are MLA
273+
bool is_mla() const;
274+
275+
uint32_t n_embd_head_k_mla() const;
276+
uint32_t n_embd_head_v_mla() const;
277+
272278
bool has_kv(uint32_t il) const;
273279

274280
// number of layers for which has_kv() returns true

0 commit comments

Comments
 (0)