Skip to content

Commit de4e8a7

Browse files
committed
feat: Enhance content generation with structured messages and logging
Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent c9c92f6 commit de4e8a7

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

README.md

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,36 @@ Run `cargo add inference-gateway-sdk`.
2020
### Creating a Client
2121

2222
```rust
23-
use inference_gateway_rust_sdk::{InferenceGatewayClient, Provider};
23+
use inference_gateway_rust_sdk::{InferenceGatewayClient, Provider, Message};
24+
use log::{info, error};
2425
use std::error::Error;
2526

2627
fn main() -> Result<(), Box<dyn Error>> {
28+
env_logger::init();
29+
2730
let client = InferenceGatewayClient::new("http://localhost:8080");
2831

2932
// List available models
3033
let models = client.list_models()?;
3134
for provider_models in models {
32-
println!("Provider: {}", provider_models.provider);
35+
info!("Provider: {:?}", provider_models.provider);
3336
for model in provider_models.models {
34-
println!(" Model: {}", model.id);
37+
info!("Model: {:?}", model.id);
3538
}
3639
}
3740

38-
// Generate content
3941
let response = client.generate_content(
4042
Provider::Ollama,
4143
"llama2",
42-
"Tell me a joke"
44+
messages: vec![
45+
Message {
46+
role: "user".to_string(),
47+
content: "Tell me a joke".to_string(),
48+
}
49+
]
4350
)?;
44-
println!("Response: {}", response.response);
4551

52+
info!("Response: {:?}", response);
4653
Ok(())
4754
}
4855
```
@@ -52,11 +59,13 @@ fn main() -> Result<(), Box<dyn Error>> {
5259
To list available models, use the `list_models` method:
5360

5461
```rust
62+
use log::info;
63+
5564
let models = client.list_models()?;
5665
for provider_models in models {
57-
println!("Provider: {}", provider_models.provider);
66+
info!("Provider: {:?}", provider_models.provider);
5867
for model in provider_models.models {
59-
println!(" Model: {}", model.id);
68+
info!("Model: {:?}", model.id);
6069
}
6170
}
6271
```
@@ -66,13 +75,21 @@ for provider_models in models {
6675
To generate content using a model, use the `generate_content` method:
6776

6877
```rust
78+
use log::info;
79+
6980
let response = client.generate_content(
7081
Provider::Ollama,
7182
"llama2",
72-
"Tell me a joke"
83+
messages: vec![
84+
Message {
85+
role: "user".to_string(),
86+
content: "Tell me a joke".to_string(),
87+
}
88+
]
7389
)?;
74-
println!("Provider: {}", response.provider);
75-
println!("Response: {}", response.response);
90+
91+
info!("Provider: {:?}", response.provider);
92+
info!("Response: {:?}", response.response);
7693
```
7794

7895
### Health Check

src/lib.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,28 @@ impl fmt::Display for Provider {
4040
}
4141
}
4242

43+
#[derive(Debug, Serialize, Deserialize)]
44+
pub struct Message {
45+
pub role: String,
46+
pub content: String,
47+
}
48+
4349
#[derive(Debug, Serialize)]
4450
struct GenerateRequest {
4551
model: String,
46-
prompt: String,
52+
messages: Vec<Message>,
4753
}
4854

4955
#[derive(Debug, Deserialize)]
5056
pub struct GenerateResponse {
5157
pub provider: String,
52-
pub response: String,
58+
pub response: ResponseContent,
59+
}
60+
61+
#[derive(Debug, Deserialize)]
62+
pub struct ResponseContent {
63+
pub role: String,
64+
pub model: String,
5365
}
5466

5567
pub struct InferenceGatewayClient {
@@ -76,16 +88,15 @@ impl InferenceGatewayClient {
7688
&self,
7789
provider: Provider,
7890
model: &str,
79-
prompt: &str,
91+
messages: Vec<Message>,
8092
) -> Result<GenerateResponse, Box<dyn Error>> {
8193
let url = format!("{}/llms/{}/generate", self.base_url, provider);
8294
let request = GenerateRequest {
8395
model: model.to_string(),
84-
prompt: prompt.to_string(),
96+
messages,
8597
};
8698

8799
let response = self.client.post(&url).json(&request).send()?.json()?;
88-
89100
Ok(response)
90101
}
91102

@@ -126,16 +137,21 @@ mod tests {
126137
.mock("POST", "/llms/ollama/generate")
127138
.with_status(200)
128139
.with_header("content-type", "application/json")
129-
.with_body(r#"{"provider":"ollama","response":"Generated text"}"#)
140+
.with_body(r#"{"provider":"ollama","response":{"role":"assistant","model":"llama2"}}"#)
130141
.create();
131142

132143
let client = InferenceGatewayClient::new(&server.url());
144+
let messages = vec![Message {
145+
role: "user".to_string(),
146+
content: "Hello".to_string(),
147+
}];
133148
let response = client
134-
.generate_content(Provider::Ollama, "llama2", "Hello")
149+
.generate_content(Provider::Ollama, "llama2", messages)
135150
.unwrap();
136151

137152
assert_eq!(response.provider, "ollama");
138-
assert_eq!(response.response, "Generated text");
153+
assert_eq!(response.response.role, "assistant");
154+
assert_eq!(response.response.model, "llama2");
139155
mock.assert();
140156
}
141157

0 commit comments

Comments
 (0)