@@ -407,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
407407 return res;
408408}
409409
410+ void llm_graph_input_attn_k::set_input (const llama_ubatch * ubatch) {
411+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
412+
413+ mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
414+ }
415+
416+ bool llm_graph_input_attn_k::can_reuse (const llm_graph_params & params) {
417+ const auto * mctx = static_cast <const llama_kv_cache_context *>(params.mctx );
418+
419+ this ->mctx = mctx;
420+
421+ bool res = true ;
422+
423+ res &= self_k_idxs->ne [0 ] == params.ubatch .n_tokens ;
424+
425+ res &= self_kq_mask->ne [0 ] == mctx->get_n_kv ();
426+ res &= self_kq_mask->ne [1 ] == params.ubatch .n_tokens ;
427+
428+ return res;
429+ }
430+
410431void llm_graph_input_attn_kv_iswa::set_input (const llama_ubatch * ubatch) {
411432 mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
412433 mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
@@ -1596,11 +1617,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
15961617 v = ggml_transpose (ctx0, v);
15971618 }
15981619
1599- // TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
1600- if (v_mla) {
1601- v = ggml_view_4d (ctx0, k, v->ne [0 ], v->ne [1 ], v->ne [2 ], v->ne [3 ], k->nb [1 ], k->nb [2 ], k->nb [3 ], 0 );
1602- }
1603-
16041620 // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
16051621 if (k->type == GGML_TYPE_F32) {
16061622 k = ggml_cast (ctx0, k, GGML_TYPE_F16);
@@ -1823,9 +1839,11 @@ ggml_tensor * llm_graph_context::build_attn(
18231839 ggml_tensor * v_cur,
18241840 ggml_tensor * kq_b,
18251841 ggml_tensor * sinks,
1826- ggml_tensor * v_mla,
1842+ ggml_tensor * v_mla, // TODO: remove
18271843 float kq_scale,
18281844 int il) const {
1845+ GGML_ASSERT (v_mla == nullptr );
1846+
18291847 // these nodes are added to the graph together so that they are not reordered
18301848 // by doing so, the number of splits in the graph is reduced
18311849 // expand k later to enable rope fusion which directly writes into k-v cache
@@ -1868,6 +1886,93 @@ ggml_tensor * llm_graph_context::build_attn(
18681886 return cur;
18691887}
18701888
1889+ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl (
1890+ ggml_context * ctx0,
1891+ const llama_ubatch & ubatch,
1892+ const llama_hparams & hparams,
1893+ const llama_cparams & cparams,
1894+ const llama_kv_cache_context * mctx_cur) {
1895+
1896+ auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
1897+
1898+ {
1899+ GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_iswa for SWA" );
1900+
1901+ const auto n_kv = mctx_cur->get_n_kv ();
1902+ const auto n_tokens = ubatch.n_tokens ;
1903+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
1904+
1905+ inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1906+
1907+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1 , n_stream);
1908+ ggml_set_input (inp->self_kq_mask );
1909+
1910+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1911+ }
1912+
1913+ return inp;
1914+ }
1915+
1916+ llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k () const {
1917+ const auto * mctx_cur = static_cast <const llama_kv_cache_context *>(mctx);
1918+
1919+ auto inp = build_attn_inp_k_impl (ctx0, ubatch, hparams, cparams, mctx_cur);
1920+
1921+ return (llm_graph_input_attn_k *) res->add_input (std::move (inp));
1922+ }
1923+
1924+ ggml_tensor * llm_graph_context::build_attn (
1925+ llm_graph_input_attn_k * inp,
1926+ ggml_tensor * wo,
1927+ ggml_tensor * wo_b,
1928+ ggml_tensor * q_cur,
1929+ ggml_tensor * k_cur,
1930+ ggml_tensor * v_cur,
1931+ ggml_tensor * kq_b,
1932+ ggml_tensor * sinks,
1933+ ggml_tensor * v_mla,
1934+ float kq_scale,
1935+ int il) const {
1936+ // these nodes are added to the graph together so that they are not reordered
1937+ // by doing so, the number of splits in the graph is reduced
1938+ // expand k later to enable rope fusion which directly writes into k-v cache
1939+ ggml_build_forward_expand (gf, q_cur);
1940+ ggml_build_forward_expand (gf, v_cur);
1941+ ggml_build_forward_expand (gf, k_cur);
1942+
1943+ const auto * mctx_cur = inp->mctx ;
1944+
1945+ // store to KV cache
1946+ {
1947+ const auto & k_idxs = inp->get_k_idxs ();
1948+
1949+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1950+ }
1951+
1952+ const auto & kq_mask = inp->get_kq_mask ();
1953+
1954+ ggml_tensor * q = q_cur;
1955+ ggml_tensor * k = mctx_cur->get_k (ctx0, il);
1956+ ggml_tensor * v = ggml_view_4d (ctx0, k, v_cur->ne [0 ], k->ne [1 ], k->ne [2 ], k->ne [3 ], k->nb [1 ], k->nb [2 ], k->nb [3 ], 0 );
1957+
1958+ ggml_tensor * cur = build_attn_mha (q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1959+ cb (cur, " kqv_out" , il);
1960+
1961+ if (wo) {
1962+ cur = build_lora_mm (wo, cur);
1963+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1964+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1965+ ggml_mul_mat_set_prec (cur, GGML_PREC_F32);
1966+ }
1967+ }
1968+
1969+ if (wo_b) {
1970+ cur = ggml_add (ctx0, cur, wo_b);
1971+ }
1972+
1973+ return cur;
1974+ }
1975+
18711976ggml_tensor * llm_graph_context::build_attn (
18721977 llm_graph_input_attn_kv_iswa * inp,
18731978 ggml_tensor * wo,
0 commit comments