diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 3e17e5d..0e0826a 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -3,6 +3,7 @@ #include "arg.h" #include "llama.h" #include "log.h" +#include "json-schema-to-grammar.h" #include "nlohmann/json.hpp" #include "server.hpp" @@ -431,7 +432,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo if (!ctx_server->load_model(params)) { llama_backend_free(); - ; env->ThrowNew(c_llama_error, "could not load model from given file path"); return; } @@ -442,7 +442,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo LOG_INF("%s: model loaded\n", __func__); const auto model_meta = ctx_server->model_meta(); - + if (!params.speculative.model.empty() || !params.speculative.hf_repo.empty()) { SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str()); auto params_dft = params; @@ -493,7 +493,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, common_chat_templates_source(ctx_server->chat_templates.get()), common_chat_format_example(ctx_server->chat_templates.get(), ctx_server->params_base.use_jinja).c_str()); - + // print sample chat example to make it clear which template is used // LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, @@ -543,9 +543,9 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv try { const auto & prompt = data.at("prompt"); - + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true); - + tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { @@ -600,7 +600,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - + if (result->is_error()) { std::string response = result->to_json()["message"].get(); @@ -609,9 +609,9 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE return nullptr; } const auto out_res = result->to_json(); - - + + std::string response = out_res["content"].get(); if (result->is_stop()) { @@ -652,11 +652,11 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, "model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))"); return nullptr; } - + const std::string prompt = parse_jstring(env, jprompt); - + SRV_INF("Calling embedding '%s'\n", prompt.c_str()); const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true); @@ -716,7 +716,7 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, // Extract only the first row const std::vector& first_row = embedding[0]; // Reference to avoid copying - + // Create a new float array in JNI jfloatArray j_embedding = env->NewFloatArray(embedding_cols); if (j_embedding == nullptr) { @@ -819,3 +819,11 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jc } } } + +JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *env, jclass clazz, jstring j_schema) +{ + const std::string c_schema = parse_jstring(env, j_schema); + nlohmann::ordered_json c_schema_json = nlohmann::ordered_json::parse(c_schema); + const std::string c_grammar = json_schema_to_grammar(c_schema_json); + return parse_jbytes(env, c_grammar); +} \ No newline at end of file diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index fcc0148..a97463e 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -7,7 +7,6 @@ #ifdef __cplusplus extern "C" { #endif - /* * Class: de_kherud_llama_LlamaModel * Method: embed @@ -79,15 +78,22 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel */ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete (JNIEnv *, jobject); - - + /* * Class: de_kherud_llama_LlamaModel * Method: releaseTask - * Signature: ()V + * Signature: (I)V */ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask - (JNIEnv *, jobject, jint); + (JNIEnv *, jobject, jint); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: jsonSchemaToGrammarBytes + * Signature: (Ljava/lang/String;)[B + */ +JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes + (JNIEnv *, jclass, jstring); #ifdef __cplusplus } diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index 43bf077..7749b32 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -132,4 +132,9 @@ public void close() { private native void releaseTask(int taskId); + private static native byte[] jsonSchemaToGrammarBytes(String schema); + + public static String jsonSchemaToGrammar(String schema) { + return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8); + } } diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index 39b4e0d..f2e931b 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -271,4 +271,29 @@ private LogMessage(LogLevel level, String text) { this.text = text; } } + + @Test + public void testJsonSchemaToGrammar() { + String schema = "{\n" + + " \"properties\": {\n" + + " \"a\": {\"type\": \"string\"},\n" + + " \"b\": {\"type\": \"string\"},\n" + + " \"c\": {\"type\": \"string\"}\n" + + " },\n" + + " \"additionalProperties\": false\n" + + "}"; + + String expectedGrammar = "a-kv ::= \"\\\"a\\\"\" space \":\" space string\n" + + "a-rest ::= ( \",\" space b-kv )? b-rest\n" + + "b-kv ::= \"\\\"b\\\"\" space \":\" space string\n" + + "b-rest ::= ( \",\" space c-kv )?\n" + + "c-kv ::= \"\\\"c\\\"\" space \":\" space string\n" + + "char ::= [^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})\n" + + "root ::= \"{\" space (a-kv a-rest | b-kv b-rest | c-kv )? \"}\" space\n" + + "space ::= | \" \" | \"\\n\"{1,2} [ \\t]{0,20}\n" + + "string ::= \"\\\"\" char* \"\\\"\" space\n"; + + String actualGrammar = LlamaModel.jsonSchemaToGrammar(schema); + Assert.assertEquals(expectedGrammar, actualGrammar); + } }