Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ dependencies = [
]

[project.optional-dependencies]
llamacpp = [
"llama-cpp-python>=0.3.0",
]
vllm = [
"vllm>=0.14.0",
]
Expand Down
3 changes: 2 additions & 1 deletion src/gimkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from importlib.metadata import PackageNotFoundError, version

from gimkit.guides import guide
from gimkit.models import from_openai, from_vllm, from_vllm_offline
from gimkit.models import from_llamacpp, from_openai, from_vllm, from_vllm_offline


try:
Expand All @@ -11,6 +11,7 @@


__all__ = [
"from_llamacpp",
"from_openai",
"from_vllm",
"from_vllm_offline",
Expand Down
3 changes: 2 additions & 1 deletion src/gimkit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .llamacpp import from_llamacpp
from .openai import from_openai
from .vllm import from_vllm
from .vllm_offline import from_vllm_offline


__all__ = ["from_openai", "from_vllm", "from_vllm_offline"]
__all__ = ["from_llamacpp", "from_openai", "from_vllm", "from_vllm_offline"]
58 changes: 58 additions & 0 deletions src/gimkit/models/llamacpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Adapted from https://github.com/dottxt-ai/outlines/blob/main/outlines/models/llamacpp.py


from typing import TYPE_CHECKING, Any, Literal, cast

from outlines.generator import Generator
from outlines.models.llamacpp import LlamaCpp as OutlinesLlamaCpp

from gimkit.contexts import Query, Result
from gimkit.log import get_logger
from gimkit.models.utils import get_outlines_model_input, get_outlines_output_type, infill_responses
from gimkit.schemas import RESPONSE_SUFFIX, ContextInput


logger = get_logger(__name__)

if TYPE_CHECKING:
from llama_cpp import Llama


class LlamaCpp(OutlinesLlamaCpp):
def __call__(
self,
model_input: ContextInput | Query,
output_type: Literal["cfg", "json"] | None = "cfg",
backend: str | None = None,
use_gim_prompt: bool = False,
**inference_kwargs: Any,
) -> Result | list[Result]:
# Using `stop=RESPONSE_SUFFIX` is preferred for two reasons:
# 1. The model might not be trained well enough to generate EOS tokens immediately after RESPONSE_SUFFIX.
# 2. Even with CFG, inference engines may not guarantee termination when the CFG is satisfied.
inference_kwargs = self._ensure_response_suffix(inference_kwargs)

outlines_model_input = get_outlines_model_input(model_input, output_type, use_gim_prompt)
outlines_output_type = get_outlines_output_type(model_input, output_type)
generator = Generator(self, outlines_output_type, backend)
raw_responses = generator(outlines_model_input, **inference_kwargs)
logger.debug(f"Raw responses of {self}: {raw_responses}")
return infill_responses(
model_input,
cast("str | list[str]", raw_responses),
json_responses=(output_type == "json"),
)

def _ensure_response_suffix(self, inference_kwargs: dict[str, Any]) -> dict[str, Any]:
stop = inference_kwargs.get("stop")
if stop is None:
inference_kwargs["stop"] = [RESPONSE_SUFFIX]
elif isinstance(stop, list) and RESPONSE_SUFFIX not in stop:
inference_kwargs["stop"] = [*stop, RESPONSE_SUFFIX]
elif isinstance(stop, str) and stop != RESPONSE_SUFFIX:
inference_kwargs["stop"] = [stop, RESPONSE_SUFFIX]
return inference_kwargs


def from_llamacpp(model: "Llama") -> LlamaCpp:
return LlamaCpp(model)
89 changes: 89 additions & 0 deletions tests/models/test_llamacpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from unittest.mock import MagicMock, patch

import pytest

from outlines.models.llamacpp import LlamaCpp as OutlinesLlamaCpp

from gimkit.contexts import Result
from gimkit.models.llamacpp import LlamaCpp as GIMLlamaCpp
from gimkit.models.llamacpp import from_llamacpp
from gimkit.schemas import RESPONSE_SUFFIX, MaskedTag


@pytest.fixture(autouse=True)
def patch_tokenizer():
"""Patch LlamaCppTokenizer so tests run without llama-cpp-python installed."""
with patch("outlines.models.llamacpp.LlamaCppTokenizer"):
yield


def test_from_llamacpp():
mock_llama = MagicMock()
model = from_llamacpp(mock_llama)
assert type(model) is GIMLlamaCpp
assert type(model) is not OutlinesLlamaCpp
assert model.model is mock_llama


def test_llamacpp_call():
mock_llama = MagicMock()
model = from_llamacpp(mock_llama)

with patch("gimkit.models.llamacpp.Generator") as mock_generator:
generator_instance = MagicMock()
generator_instance.return_value = '<|MASKED id="m_0"|>hi<|/MASKED|>'
mock_generator.return_value = generator_instance

returned = model(MaskedTag())
assert isinstance(returned, Result)
assert returned.tags[0].content == "hi"

# Verify RESPONSE_SUFFIX is added to stop
call_kwargs = generator_instance.call_args[1]
assert RESPONSE_SUFFIX in call_kwargs["stop"]


def test_llamacpp_call_invalid_response():
mock_llama = MagicMock()
model = from_llamacpp(mock_llama)

with patch("gimkit.models.llamacpp.Generator") as mock_generator:
generator_instance = MagicMock()
generator_instance.return_value = set()
mock_generator.return_value = generator_instance
with pytest.raises(TypeError, match="Expected responses to be str or list of str, got"):
model(MaskedTag())

with patch("gimkit.models.llamacpp.Generator") as mock_generator:
generator_instance = MagicMock()
generator_instance.return_value = []
mock_generator.return_value = generator_instance
with pytest.raises(ValueError, match="Response list is empty"):
model(MaskedTag())


def test_ensure_response_suffix():
mock_llama = MagicMock()
model = from_llamacpp(mock_llama)

# No stop provided — should add RESPONSE_SUFFIX
kwargs = model._ensure_response_suffix({})
assert kwargs["stop"] == [RESPONSE_SUFFIX]

# stop is a list without RESPONSE_SUFFIX — should append it
kwargs = model._ensure_response_suffix({"stop": ["other"]})
assert RESPONSE_SUFFIX in kwargs["stop"]
assert "other" in kwargs["stop"]

# stop is a list already containing RESPONSE_SUFFIX — unchanged
kwargs = model._ensure_response_suffix({"stop": [RESPONSE_SUFFIX]})
assert kwargs["stop"] == [RESPONSE_SUFFIX]

# stop is a string different from RESPONSE_SUFFIX — should wrap both
kwargs = model._ensure_response_suffix({"stop": "other"})
assert RESPONSE_SUFFIX in kwargs["stop"]
assert "other" in kwargs["stop"]

# stop is already RESPONSE_SUFFIX string — unchanged
kwargs = model._ensure_response_suffix({"stop": RESPONSE_SUFFIX})
assert kwargs["stop"] == RESPONSE_SUFFIX
18 changes: 17 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.