Skip to content

Commit 28dc3d3

Browse files
Merge pull request #37 from patrickfleith/feature/33-implement-ollama-provider
Feature/33 implement ollama provider
2 parents eff0d6d + 0cf79ca commit 28dc3d3

4 files changed

Lines changed: 118 additions & 4 deletions

File tree

README.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,22 @@ Currently we support the following dataset types:
1616

1717
- ✅ Text Classification
1818
- ✅ Raw Text Generation
19-
- ✅ Instruction Dataset
20-
- ✅ UltraChat method
19+
- ✅ Instruction Dataset (UltraChat-like)
2120
- [ ] Preference Dataset
2221
- 📋 More coming soon!
2322

24-
⭐️ Star me if this is something you like!
23+
⭐️ Star me if this is something you like! 🌟
24+
25+
26+
## Supported LLM Providers
27+
28+
Currently we support the following LLM providers:
29+
30+
- ✔︎ OpenAI
31+
- ✔︎ Anthropic
32+
- ✔︎ Google
33+
- ✔︎ Ollama
34+
- ✔︎ HF Endpoints (buggy!)
2535

2636
## Key Features
2737

datafast/examples/test_llm_providers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def main():
128128

129129
# Test OpenAI (GPT-4)
130130
test_provider("openai", "gpt-4o-mini")
131+
132+
# Test Ollama (local LLM)
133+
test_provider("ollama", "gemma3:4b")
131134

132135

133136
if __name__ == "__main__":
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from datafast.llms import create_provider
2+
from pydantic import BaseModel, Field
3+
from typing import Optional
4+
import sys
5+
6+
"""
7+
A simple test script for the OllamaProvider.
8+
This script requires Ollama to be installed and running locally.
9+
10+
You can install Ollama from https://ollama.com/ and then run:
11+
ollama pull gemma3:4b
12+
"""
13+
14+
15+
class SimpleResponse(BaseModel):
16+
"""A simple response model with minimal fields to test the OllamaProvider."""
17+
answer: str = Field(..., description="The answer to the question")
18+
reasoning: str = Field(..., description="The reasoning behind the answer")
19+
20+
21+
def test_ollama(model_id: str = "gemma3:4b"):
22+
"""Test the OllamaProvider with a simple query."""
23+
print(f"\n{'=' * 50}")
24+
print(f"Testing Ollama provider with model {model_id}")
25+
print("=" * 50)
26+
27+
try:
28+
# Create the provider
29+
provider = create_provider("ollama", model_id)
30+
31+
# Simple test prompt
32+
prompt = "What is the capital of France? Provide a short answer and brief reasoning."
33+
34+
print(f"Sending prompt: {prompt}")
35+
print("Waiting for response (this might take a bit)...")
36+
37+
# Generate response
38+
response = provider.generate(prompt, SimpleResponse)
39+
40+
print("\nResponse received:")
41+
print(f"Answer: {response.answer}")
42+
print(f"Reasoning: {response.reasoning}")
43+
print("\nTest successful!")
44+
45+
except Exception as e:
46+
print(f"Error testing Ollama provider: {str(e)}")
47+
return False
48+
49+
return True
50+
51+
52+
if __name__ == "__main__":
53+
# Get model from command line arguments if provided
54+
model = sys.argv[1] if len(sys.argv) > 1 else "gemma3:4b"
55+
test_ollama(model)

datafast/llms.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,58 @@ def _generate_impl(self, prompt: str | list[dict[str, str]], response_format: ty
203203
)
204204

205205

206+
class OllamaProvider(LLMProvider):
207+
"""Ollama provider for structured text generation."""
208+
209+
# No API key needed for local Ollama
210+
DEFAULT_MODEL = "llama3:latest"
211+
212+
@property
213+
def name(self) -> str:
214+
return "ollama"
215+
216+
def _get_api_key(self) -> str:
217+
"""Override _get_api_key since Ollama doesn't need an API key"""
218+
return "not_needed" # Return a dummy value
219+
220+
def _initialize_client(self):
221+
try:
222+
import ollama
223+
return ollama
224+
except ImportError as e:
225+
raise ImportError(f"Ollama package not installed. Install it with 'pip install ollama': {str(e)}")
226+
except Exception as e:
227+
raise ValueError(f"Error initializing Ollama client: {str(e)}")
228+
229+
def _generate_impl(
230+
self, prompt: str | list[dict[str, str]], response_format: type[BaseModel]
231+
) -> BaseModel:
232+
# Convert prompt to messages format if it's a string
233+
messages = get_messages(prompt) if isinstance(prompt, str) else prompt
234+
235+
# Get schema for the response format
236+
schema = response_format.model_json_schema()
237+
238+
# Call the Ollama chat API
239+
response = self.client.chat(
240+
messages=messages,
241+
model=self.model_id,
242+
format=schema,
243+
)
244+
245+
# Parse the response content and validate against the Pydantic model
246+
# Unlike other providers that use instructor and return the parsed model directly,
247+
# we need to manually parse the JSON response here
248+
return response_format.model_validate_json(response.message.content)
249+
250+
206251
def create_provider(
207252
provider: str, model_id: str | None = None, **kwargs
208253
) -> LLMProvider:
209254
"""Create an LLM provider for structured text generation.
210255
211256
Args:
212-
provider: Provider name ('anthropic', 'google', or 'openai')
257+
provider: Provider name ('anthropic', 'google', 'openai', 'ollama')
213258
model_id: Optional model identifier. If not provided, uses provider's default
214259
**kwargs: Additional provider-specific arguments
215260
@@ -220,6 +265,7 @@ def create_provider(
220265
"anthropic": AnthropicProvider,
221266
"google": GoogleProvider,
222267
"openai": OpenAIProvider,
268+
"ollama": OllamaProvider,
223269
}
224270

225271
provider_class = provider_map.get(provider.lower())

0 commit comments

Comments
 (0)