diff --git a/pyproject.toml b/pyproject.toml index f32cb17..0339690 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ dependencies = [ ] [project.optional-dependencies] +transformers = [ + "transformers>=4.50.0", +] vllm = [ "vllm>=0.14.0", ] diff --git a/src/gimkit/__init__.py b/src/gimkit/__init__.py index 37e53fe..90097ce 100644 --- a/src/gimkit/__init__.py +++ b/src/gimkit/__init__.py @@ -1,7 +1,7 @@ from importlib.metadata import PackageNotFoundError, version from gimkit.guides import guide -from gimkit.models import from_openai, from_vllm, from_vllm_offline +from gimkit.models import from_openai, from_transformers, from_vllm, from_vllm_offline try: @@ -12,6 +12,7 @@ __all__ = [ "from_openai", + "from_transformers", "from_vllm", "from_vllm_offline", "guide", diff --git a/src/gimkit/models/__init__.py b/src/gimkit/models/__init__.py index 26d35ce..f620910 100644 --- a/src/gimkit/models/__init__.py +++ b/src/gimkit/models/__init__.py @@ -1,6 +1,7 @@ from .openai import from_openai +from .transformers import from_transformers from .vllm import from_vllm from .vllm_offline import from_vllm_offline -__all__ = ["from_openai", "from_vllm", "from_vllm_offline"] +__all__ = ["from_openai", "from_transformers", "from_vllm", "from_vllm_offline"] diff --git a/src/gimkit/models/transformers.py b/src/gimkit/models/transformers.py new file mode 100644 index 0000000..7af0325 --- /dev/null +++ b/src/gimkit/models/transformers.py @@ -0,0 +1,46 @@ +# Adapted from https://github.com/dottxt-ai/outlines/blob/main/outlines/models/transformers.py + + +from typing import TYPE_CHECKING, Any, Literal + +from outlines.generator import Generator +from outlines.models.transformers import Transformers as OutlinesTransformers + +from gimkit.contexts import Query, Result +from gimkit.log import get_logger +from gimkit.models.utils import get_outlines_model_input, get_outlines_output_type, infill_responses +from gimkit.schemas import ContextInput + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +class Transformers(OutlinesTransformers): + def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + **inference_kwargs: Any, + ) -> Result | list[Result]: + outlines_model_input = get_outlines_model_input(model_input, output_type, use_gim_prompt) + outlines_output_type = get_outlines_output_type(model_input, output_type) + generator = Generator(self, outlines_output_type, backend) + raw_responses = generator(outlines_model_input, **inference_kwargs) + logger.debug(f"Raw responses of {self}: {raw_responses}") + return infill_responses( + model_input, + raw_responses, # type: ignore[arg-type] + json_responses=(output_type == "json"), + ) + + +def from_transformers( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", +) -> Transformers: + return Transformers(model, tokenizer) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py new file mode 100644 index 0000000..cc59840 --- /dev/null +++ b/tests/models/test_transformers.py @@ -0,0 +1,71 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from outlines.models.transformers import Transformers as OutlinesTransformers + +from gimkit.contexts import Result +from gimkit.models.transformers import Transformers as GIMTransformers +from gimkit.models.transformers import from_transformers +from gimkit.schemas import MaskedTag + + +def _make_model_and_tokenizer(): + from transformers import PreTrainedModel, PreTrainedTokenizerBase + + mock_model = MagicMock(spec=PreTrainedModel) + mock_model.device = "cpu" + mock_model.config = MagicMock() + mock_model.config.is_encoder_decoder = False + + mock_tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + mock_tokenizer.eos_token_id = 0 + mock_tokenizer.eos_token = "" + mock_tokenizer.pad_token_id = 0 + mock_tokenizer.pad_token = "" + mock_tokenizer.all_special_tokens = [] + mock_tokenizer.get_vocab.return_value = {} + mock_tokenizer.chat_template = None + + return mock_model, mock_tokenizer + + +def test_from_transformers(): + mock_model, mock_tokenizer = _make_model_and_tokenizer() + model = from_transformers(mock_model, mock_tokenizer) + assert type(model) is GIMTransformers + assert type(model) is not OutlinesTransformers + assert model.model is mock_model + + +def test_transformers_call(): + mock_model, mock_tokenizer = _make_model_and_tokenizer() + model = from_transformers(mock_model, mock_tokenizer) + + with patch("gimkit.models.transformers.Generator") as mock_generator: + generator_instance = MagicMock() + generator_instance.return_value = '<|MASKED id="m_0"|>hi<|/MASKED|>' + mock_generator.return_value = generator_instance + + returned = model(MaskedTag()) + assert isinstance(returned, Result) + assert returned.tags[0].content == "hi" + + +def test_transformers_call_invalid_response(): + mock_model, mock_tokenizer = _make_model_and_tokenizer() + model = from_transformers(mock_model, mock_tokenizer) + + with patch("gimkit.models.transformers.Generator") as mock_generator: + generator_instance = MagicMock() + generator_instance.return_value = set() + mock_generator.return_value = generator_instance + with pytest.raises(TypeError, match="Expected responses to be str or list of str, got"): + model(MaskedTag()) + + with patch("gimkit.models.transformers.Generator") as mock_generator: + generator_instance = MagicMock() + generator_instance.return_value = [] + mock_generator.return_value = generator_instance + with pytest.raises(ValueError, match="Response list is empty"): + model(MaskedTag()) diff --git a/uv.lock b/uv.lock index e30cf5d..237ff43 100644 --- a/uv.lock +++ b/uv.lock @@ -1485,6 +1485,9 @@ dependencies = [ ] [package.optional-dependencies] +transformers = [ + { name = "transformers" }, +] vllm = [ { name = "vllm" }, ] @@ -1505,9 +1508,10 @@ requires-dist = [ { name = "json-repair", specifier = ">=0.55.1" }, { name = "llguidance", specifier = ">=1.3.0" }, { name = "outlines", extras = ["openai"], specifier = ">=1.2.9" }, + { name = "transformers", marker = "extra == 'transformers'", specifier = ">=4.50.0" }, { name = "vllm", marker = "extra == 'vllm'", specifier = ">=0.14.0" }, ] -provides-extras = ["vllm"] +provides-extras = ["transformers", "vllm"] [package.metadata.requires-dev] dev = [