Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ dependencies = [
]

[project.optional-dependencies]
transformers = [
"transformers>=4.50.0",
]
vllm = [
"vllm>=0.14.0",
]
Expand Down
3 changes: 2 additions & 1 deletion src/gimkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from importlib.metadata import PackageNotFoundError, version

from gimkit.guides import guide
from gimkit.models import from_openai, from_vllm, from_vllm_offline
from gimkit.models import from_openai, from_transformers, from_vllm, from_vllm_offline


try:
Expand All @@ -12,6 +12,7 @@

__all__ = [
"from_openai",
"from_transformers",
"from_vllm",
"from_vllm_offline",
"guide",
Expand Down
3 changes: 2 additions & 1 deletion src/gimkit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .openai import from_openai
from .transformers import from_transformers
from .vllm import from_vllm
from .vllm_offline import from_vllm_offline


__all__ = ["from_openai", "from_vllm", "from_vllm_offline"]
__all__ = ["from_openai", "from_transformers", "from_vllm", "from_vllm_offline"]
46 changes: 46 additions & 0 deletions src/gimkit/models/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Adapted from https://github.com/dottxt-ai/outlines/blob/main/outlines/models/transformers.py


from typing import TYPE_CHECKING, Any, Literal

from outlines.generator import Generator
from outlines.models.transformers import Transformers as OutlinesTransformers

from gimkit.contexts import Query, Result
from gimkit.log import get_logger
from gimkit.models.utils import get_outlines_model_input, get_outlines_output_type, infill_responses
from gimkit.schemas import ContextInput


logger = get_logger(__name__)

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer


class Transformers(OutlinesTransformers):
def __call__(
self,
model_input: ContextInput | Query,
output_type: Literal["cfg", "json"] | None = "cfg",
backend: str | None = None,
use_gim_prompt: bool = False,
**inference_kwargs: Any,
) -> Result | list[Result]:
outlines_model_input = get_outlines_model_input(model_input, output_type, use_gim_prompt)
outlines_output_type = get_outlines_output_type(model_input, output_type)
generator = Generator(self, outlines_output_type, backend)
raw_responses = generator(outlines_model_input, **inference_kwargs)
logger.debug(f"Raw responses of {self}: {raw_responses}")
return infill_responses(
model_input,
raw_responses, # type: ignore[arg-type]
json_responses=(output_type == "json"),
)


def from_transformers(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
) -> Transformers:
return Transformers(model, tokenizer)
71 changes: 71 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest.mock import MagicMock, patch

import pytest

from outlines.models.transformers import Transformers as OutlinesTransformers

from gimkit.contexts import Result
from gimkit.models.transformers import Transformers as GIMTransformers
from gimkit.models.transformers import from_transformers
from gimkit.schemas import MaskedTag


def _make_model_and_tokenizer():
from transformers import PreTrainedModel, PreTrainedTokenizerBase

mock_model = MagicMock(spec=PreTrainedModel)
mock_model.device = "cpu"
mock_model.config = MagicMock()
mock_model.config.is_encoder_decoder = False

mock_tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
mock_tokenizer.eos_token_id = 0
mock_tokenizer.eos_token = "</s>"
mock_tokenizer.pad_token_id = 0
mock_tokenizer.pad_token = "</s>"
mock_tokenizer.all_special_tokens = []
mock_tokenizer.get_vocab.return_value = {}
mock_tokenizer.chat_template = None

return mock_model, mock_tokenizer


def test_from_transformers():
mock_model, mock_tokenizer = _make_model_and_tokenizer()
model = from_transformers(mock_model, mock_tokenizer)
assert type(model) is GIMTransformers
assert type(model) is not OutlinesTransformers
assert model.model is mock_model


def test_transformers_call():
mock_model, mock_tokenizer = _make_model_and_tokenizer()
model = from_transformers(mock_model, mock_tokenizer)

with patch("gimkit.models.transformers.Generator") as mock_generator:
generator_instance = MagicMock()
generator_instance.return_value = '<|MASKED id="m_0"|>hi<|/MASKED|>'
mock_generator.return_value = generator_instance

returned = model(MaskedTag())
assert isinstance(returned, Result)
assert returned.tags[0].content == "hi"


def test_transformers_call_invalid_response():
mock_model, mock_tokenizer = _make_model_and_tokenizer()
model = from_transformers(mock_model, mock_tokenizer)

with patch("gimkit.models.transformers.Generator") as mock_generator:
generator_instance = MagicMock()
generator_instance.return_value = set()
mock_generator.return_value = generator_instance
with pytest.raises(TypeError, match="Expected responses to be str or list of str, got"):
model(MaskedTag())

with patch("gimkit.models.transformers.Generator") as mock_generator:
generator_instance = MagicMock()
generator_instance.return_value = []
mock_generator.return_value = generator_instance
with pytest.raises(ValueError, match="Response list is empty"):
model(MaskedTag())
6 changes: 5 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.