From 9dd8763e33d90926b28edba6f9215db4d1774af2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 12:46:32 +0200 Subject: [PATCH 01/10] feat: add image input support for vision models - Add --image/-i flag to include image files in prompts - Support common formats: PNG, JPEG, GIF, WebP (max 5MB, max 10 images) - OpenAI: use content arrays with text and image_url parts - Ollama: use native Images field for vision models like LLaVA - Error gracefully for non-vision APIs (Anthropic, Google, Cohere) - Validate file existence, format, and size limits - Works with any OpenAI-compatible endpoint in config Authored-By: claude, @anuramat, @gy-mate --- config.go | 2 + images.go | 116 +++++++++++++++++++++++++++++++++++ internal/anthropic/format.go | 7 +++ internal/cohere/format.go | 7 +++ internal/google/format.go | 7 +++ internal/ollama/format.go | 16 +++++ internal/openai/format.go | 25 +++++++- internal/openai/openai.go | 10 ++- internal/proto/proto.go | 8 +++ main.go | 1 + stream.go | 10 +++ 11 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 images.go diff --git a/config.go b/config.go index 9f0abfe2..9902ff9b 100644 --- a/config.go +++ b/config.go @@ -75,6 +75,7 @@ var help = map[string]string{ "mcp-list": "List all available MCP servers", "mcp-list-tools": "List all available tools from enabled MCP servers", "mcp-timeout": "Timeout for MCP server calls, defaults to 15 seconds", + "image": "Include image file(s) in the prompt (vision models only)", } // Model represents the LLM model used in the API call. @@ -185,6 +186,7 @@ type Config struct { Delete []string DeleteOlderThan time.Duration User string + Images []string MCPServers map[string]MCPServerConfig `yaml:"mcp-servers"` MCPList bool diff --git a/images.go b/images.go new file mode 100644 index 00000000..1654503d --- /dev/null +++ b/images.go @@ -0,0 +1,116 @@ +package main + +import ( + "encoding/base64" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/charmbracelet/mods/internal/proto" +) + +// supportedImageFormats maps file extensions to MIME types +var supportedImageFormats = map[string]string{ + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", +} + +// maxImageSize sets the maximum allowed image file size (5MB) +const maxImageSize = 5 * 1024 * 1024 + +// processImageFiles reads and validates image files from paths +func processImageFiles(imagePaths []string) ([]proto.ImageContent, error) { + if len(imagePaths) == 0 { + return nil, nil + } + + if len(imagePaths) > 10 { + return nil, fmt.Errorf("too many images: maximum 10 images allowed, got %d", len(imagePaths)) + } + + var images []proto.ImageContent + for _, path := range imagePaths { + image, err := readImageFile(path) + if err != nil { + return nil, fmt.Errorf("error processing image %s: %w", path, err) + } + images = append(images, *image) + } + + return images, nil +} + +// readImageFile reads and validates a single image file +func readImageFile(path string) (*proto.ImageContent, error) { + // Check if file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + return nil, fmt.Errorf("image file does not exist: %s", path) + } + + // Get file info for size check + fileInfo, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("could not get file info: %w", err) + } + + if fileInfo.Size() > maxImageSize { + return nil, fmt.Errorf("image file too large: %s (%.2f MB > 5 MB)", + path, float64(fileInfo.Size())/(1024*1024)) + } + + // Detect MIME type from extension + ext := strings.ToLower(filepath.Ext(path)) + mimeType, supported := supportedImageFormats[ext] + if !supported { + return nil, fmt.Errorf("unsupported image format: %s (supported: %s)", + ext, getSupportedFormats()) + } + + // Read file data + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("could not open file: %w", err) + } + defer file.Close() + + data, err := io.ReadAll(file) + if err != nil { + return nil, fmt.Errorf("could not read file: %w", err) + } + + return &proto.ImageContent{ + Data: data, + MimeType: mimeType, + Filename: filepath.Base(path), + }, nil +} + +// getSupportedFormats returns a comma-separated list of supported formats +func getSupportedFormats() string { + var formats []string + for ext := range supportedImageFormats { + formats = append(formats, ext) + } + return strings.Join(formats, ", ") +} + +// createDataURI creates a data URI for base64 encoded image data +func createDataURI(image proto.ImageContent) string { + base64Data := base64.StdEncoding.EncodeToString(image.Data) + return fmt.Sprintf("data:%s;base64,%s", image.MimeType, base64Data) +} + +// hasImages checks if any message contains images +func hasImages(messages []proto.Message) bool { + for _, msg := range messages { + if len(msg.Images) > 0 { + return true + } + } + return false +} diff --git a/internal/anthropic/format.go b/internal/anthropic/format.go index cc806d44..113b6c81 100644 --- a/internal/anthropic/format.go +++ b/internal/anthropic/format.go @@ -28,6 +28,13 @@ func fromMCPTools(mcps map[string][]mcp.Tool) []anthropic.ToolUnionParam { } func fromProtoMessages(input []proto.Message) (system []anthropic.TextBlockParam, messages []anthropic.MessageParam) { + // Check for images and error if present (not supported yet) + for _, msg := range input { + if len(msg.Images) > 0 { + panic("image input is not supported for Anthropic API yet - use OpenAI API for vision capabilities") + } + } + for _, msg := range input { switch msg.Role { case proto.RoleSystem: diff --git a/internal/cohere/format.go b/internal/cohere/format.go index dd1a76b3..95d8e0c0 100644 --- a/internal/cohere/format.go +++ b/internal/cohere/format.go @@ -6,6 +6,13 @@ import ( ) func fromProtoMessages(input []proto.Message) (history []*cohere.Message, message string) { + // Check for images and error if present (not supported yet) + for _, msg := range input { + if len(msg.Images) > 0 { + panic("image input is not supported for Cohere API - use OpenAI API for vision capabilities") + } + } + var messages []*cohere.Message //nolint:prealloc for _, msg := range input { messages = append(messages, &cohere.Message{ diff --git a/internal/google/format.go b/internal/google/format.go index db3656d8..ab72c5ec 100644 --- a/internal/google/format.go +++ b/internal/google/format.go @@ -3,6 +3,13 @@ package google import "github.com/charmbracelet/mods/internal/proto" func fromProtoMessages(input []proto.Message) []Content { + // Check for images and error if present (not supported yet) + for _, msg := range input { + if len(msg.Images) > 0 { + panic("image input is not supported for Google API yet - use OpenAI API for vision capabilities") + } + } + result := make([]Content, 0, len(input)) for _, in := range input { switch in.Role { diff --git a/internal/ollama/format.go b/internal/ollama/format.go index 140848b4..dbd81ec0 100644 --- a/internal/ollama/format.go +++ b/internal/ollama/format.go @@ -42,6 +42,12 @@ func fromProtoMessage(input proto.Message) api.Message { Content: input.Content, Role: input.Role, } + + // Handle images + for _, img := range input.Images { + m.Images = append(m.Images, api.ImageData(img.Data)) + } + for _, call := range input.ToolCalls { var args api.ToolCallFunctionArguments _ = json.Unmarshal(call.Function.Arguments, &args) @@ -62,6 +68,16 @@ func toProtoMessage(in api.Message) proto.Message { Role: in.Role, Content: in.Content, } + + // Handle images + for _, imgData := range in.Images { + msg.Images = append(msg.Images, proto.ImageContent{ + Data: []byte(imgData), + MimeType: "image/jpeg", // Ollama doesn't provide MIME type, assume JPEG + Filename: "", // Ollama doesn't provide filename + }) + } + for _, call := range in.ToolCalls { msg.ToolCalls = append(msg.ToolCalls, proto.ToolCall{ ID: strconv.Itoa(call.Function.Index), diff --git a/internal/openai/format.go b/internal/openai/format.go index d56cfc33..42409125 100644 --- a/internal/openai/format.go +++ b/internal/openai/format.go @@ -1,6 +1,7 @@ package openai import ( + "encoding/base64" "fmt" "github.com/charmbracelet/mods/internal/proto" @@ -46,7 +47,29 @@ func fromProtoMessages(input []proto.Message) []openai.ChatCompletionMessagePara break } case proto.RoleUser: - messages = append(messages, openai.UserMessage(msg.Content)) + if len(msg.Images) > 0 { + // Create content array with text and images + var content []openai.ChatCompletionContentPartUnionParam + + // Add text content if present + if msg.Content != "" { + content = append(content, openai.TextContentPart(msg.Content)) + } + + // Add image content + for _, img := range msg.Images { + base64Data := base64.StdEncoding.EncodeToString(img.Data) + dataURI := fmt.Sprintf("data:%s;base64,%s", img.MimeType, base64Data) + + content = append(content, openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{ + URL: dataURI, + })) + } + + messages = append(messages, openai.UserMessage(content)) + } else { + messages = append(messages, openai.UserMessage(msg.Content)) + } case proto.RoleAssistant: m := openai.AssistantMessage(msg.Content) for _, tool := range msg.ToolCalls { diff --git a/internal/openai/openai.go b/internal/openai/openai.go index 47c498af..06bd5003 100644 --- a/internal/openai/openai.go +++ b/internal/openai/openai.go @@ -15,7 +15,13 @@ import ( "github.com/openai/openai-go/shared" ) -var _ stream.Client = &Client{} +var ( + _ stream.Client = &Client{} + apisWithJSONResponseFormat = map[string]bool{ + "openai": true, + "copilot": true, + } +) // Client is the openai client. type Client struct { @@ -86,7 +92,7 @@ func (c *Client) Request(ctx context.Context, request proto.Request) stream.Stre if request.MaxTokens != nil { body.MaxTokens = openai.Int(*request.MaxTokens) } - if request.API == "openai" && request.ResponseFormat != nil && *request.ResponseFormat == "json" { + if apisWithJSONResponseFormat[request.API] && request.ResponseFormat != nil && *request.ResponseFormat == "json" { body.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{ OfJSONObject: &shared.ResponseFormatJSONObjectParam{}, } diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 93a90be6..cdeeb437 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -42,10 +42,18 @@ func (c ToolCallStatus) String() string { return sb.String() } +// ImageContent represents an image attachment in a message. +type ImageContent struct { + Data []byte // base64 encoded image data + MimeType string // image/jpeg, image/png, etc. + Filename string // original filename for reference +} + // Message is a message in the conversation. type Message struct { Role string Content string + Images []ImageContent ToolCalls []ToolCall } diff --git a/main.go b/main.go index 77027448..74500b09 100644 --- a/main.go +++ b/main.go @@ -286,6 +286,7 @@ func initFlags() { flags.BoolVar(&config.MCPList, "mcp-list", false, stdoutStyles().FlagDesc.Render(help["mcp-list"])) flags.BoolVar(&config.MCPListTools, "mcp-list-tools", false, stdoutStyles().FlagDesc.Render(help["mcp-list-tools"])) flags.StringArrayVar(&config.MCPDisable, "mcp-disable", nil, stdoutStyles().FlagDesc.Render(help["mcp-disable"])) + flags.StringArrayVarP(&config.Images, "image", "i", nil, stdoutStyles().FlagDesc.Render(help["image"])) flags.Lookup("prompt").NoOptDefVal = "-1" flags.SortFlags = false diff --git a/stream.go b/stream.go index 5b92a319..8b107b2e 100644 --- a/stream.go +++ b/stream.go @@ -61,9 +61,19 @@ func (m *Mods) setupStreamContext(content string, mod Model) error { } } + // Process image files if provided + images, err := processImageFiles(cfg.Images) + if err != nil { + return modsError{ + err: err, + reason: "Could not process image files", + } + } + m.messages = append(m.messages, proto.Message{ Role: proto.RoleUser, Content: content, + Images: images, }) return nil From e7dd8720af3ce57c4aed0f8477e853dca1eaaa27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 20:30:04 +0200 Subject: [PATCH 02/10] fix: reformat comments --- images.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/images.go b/images.go index 1654503d..aacabf02 100644 --- a/images.go +++ b/images.go @@ -11,7 +11,7 @@ import ( "github.com/charmbracelet/mods/internal/proto" ) -// supportedImageFormats maps file extensions to MIME types +// Map file extensions to MIME types. var supportedImageFormats = map[string]string{ ".jpg": "image/jpeg", ".jpeg": "image/jpeg", @@ -20,10 +20,10 @@ var supportedImageFormats = map[string]string{ ".webp": "image/webp", } -// maxImageSize sets the maximum allowed image file size (5MB) +// Set the maximum allowed image file size (5 MB). const maxImageSize = 5 * 1024 * 1024 -// processImageFiles reads and validates image files from paths +// Read and validate image files from paths. func processImageFiles(imagePaths []string) ([]proto.ImageContent, error) { if len(imagePaths) == 0 { return nil, nil @@ -45,7 +45,7 @@ func processImageFiles(imagePaths []string) ([]proto.ImageContent, error) { return images, nil } -// readImageFile reads and validates a single image file +// Read and validate a single image file. func readImageFile(path string) (*proto.ImageContent, error) { // Check if file exists if _, err := os.Stat(path); os.IsNotExist(err) { @@ -90,7 +90,7 @@ func readImageFile(path string) (*proto.ImageContent, error) { }, nil } -// getSupportedFormats returns a comma-separated list of supported formats +// Return a comma-separated list of supported formats. func getSupportedFormats() string { var formats []string for ext := range supportedImageFormats { From 5e71687dcc83d19844a6b52f9993d7695c2045c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 20:31:47 +0200 Subject: [PATCH 03/10] fix: remove unused functions --- images.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/images.go b/images.go index aacabf02..61de04b5 100644 --- a/images.go +++ b/images.go @@ -98,19 +98,3 @@ func getSupportedFormats() string { } return strings.Join(formats, ", ") } - -// createDataURI creates a data URI for base64 encoded image data -func createDataURI(image proto.ImageContent) string { - base64Data := base64.StdEncoding.EncodeToString(image.Data) - return fmt.Sprintf("data:%s;base64,%s", image.MimeType, base64Data) -} - -// hasImages checks if any message contains images -func hasImages(messages []proto.Message) bool { - for _, msg := range messages { - if len(msg.Images) > 0 { - return true - } - } - return false -} From 6a541f7891bde963550eec3d21a30633f6812637 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 20:43:41 +0200 Subject: [PATCH 04/10] fix: remove unused import --- images.go | 1 - 1 file changed, 1 deletion(-) diff --git a/images.go b/images.go index 61de04b5..82bf3b83 100644 --- a/images.go +++ b/images.go @@ -1,7 +1,6 @@ package main import ( - "encoding/base64" "fmt" "io" "os" From 208d16b82b2719e70ee9bd0a6c1cef4b5c1d9837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 21:11:59 +0200 Subject: [PATCH 05/10] fix: catch file closure exception --- images.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/images.go b/images.go index 82bf3b83..f42a6767 100644 --- a/images.go +++ b/images.go @@ -75,7 +75,11 @@ func readImageFile(path string) (*proto.ImageContent, error) { if err != nil { return nil, fmt.Errorf("could not open file: %w", err) } - defer file.Close() + defer func() { + if cerr := file.Close(); cerr != nil { + fmt.Errorf("could not close file: %v", cerr) + } + }() data, err := io.ReadAll(file) if err != nil { From 3407f0bd5677fc46184a96ef4e8e6fe5cdab5545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 21:13:01 +0200 Subject: [PATCH 06/10] fix: pre-allocate slices --- images.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/images.go b/images.go index f42a6767..a2022dc5 100644 --- a/images.go +++ b/images.go @@ -32,7 +32,7 @@ func processImageFiles(imagePaths []string) ([]proto.ImageContent, error) { return nil, fmt.Errorf("too many images: maximum 10 images allowed, got %d", len(imagePaths)) } - var images []proto.ImageContent + var images = make([]proto.ImageContent, 0, len(imagePaths)) for _, path := range imagePaths { image, err := readImageFile(path) if err != nil { @@ -95,7 +95,7 @@ func readImageFile(path string) (*proto.ImageContent, error) { // Return a comma-separated list of supported formats. func getSupportedFormats() string { - var formats []string + var formats = make([]string, 0, len(supportedImageFormats)) for ext := range supportedImageFormats { formats = append(formats, ext) } From 5157f51ecfb4cf560ab4b98917595bb31c7735de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 21:37:08 +0200 Subject: [PATCH 07/10] fix: change error printing function --- images.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/images.go b/images.go index a2022dc5..647b7347 100644 --- a/images.go +++ b/images.go @@ -77,7 +77,7 @@ func readImageFile(path string) (*proto.ImageContent, error) { } defer func() { if cerr := file.Close(); cerr != nil { - fmt.Errorf("could not close file: %v", cerr) + fmt.Fprintf(os.Stderr, "could not close file: %v\n", cerr) } }() From b66ac64498454bc9d6997fb8e57451d079f7fa2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 21:38:45 +0200 Subject: [PATCH 08/10] fix: change variable declaration --- images.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/images.go b/images.go index 647b7347..7b31f1f5 100644 --- a/images.go +++ b/images.go @@ -32,7 +32,7 @@ func processImageFiles(imagePaths []string) ([]proto.ImageContent, error) { return nil, fmt.Errorf("too many images: maximum 10 images allowed, got %d", len(imagePaths)) } - var images = make([]proto.ImageContent, 0, len(imagePaths)) + images := make([]proto.ImageContent, 0, len(imagePaths)) for _, path := range imagePaths { image, err := readImageFile(path) if err != nil { @@ -95,7 +95,7 @@ func readImageFile(path string) (*proto.ImageContent, error) { // Return a comma-separated list of supported formats. func getSupportedFormats() string { - var formats = make([]string, 0, len(supportedImageFormats)) + formats := make([]string, 0, len(supportedImageFormats)) for ext := range supportedImageFormats { formats = append(formats, ext) } From b5c13377ad59e45554f0b0f47202d2c7cd21a1bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 21:48:34 +0200 Subject: [PATCH 09/10] fix: replace space indentation with tabs --- images.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/images.go b/images.go index 7b31f1f5..3618de2c 100644 --- a/images.go +++ b/images.go @@ -76,10 +76,10 @@ func readImageFile(path string) (*proto.ImageContent, error) { return nil, fmt.Errorf("could not open file: %w", err) } defer func() { - if cerr := file.Close(); cerr != nil { - fmt.Fprintf(os.Stderr, "could not close file: %v\n", cerr) - } - }() + if cerr := file.Close(); cerr != nil { + fmt.Fprintf(os.Stderr, "could not close file: %v\n", cerr) + } + }() data, err := io.ReadAll(file) if err != nil { From a1c16d344d50bd605ac6c3a169451414507f121a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1t=C3=A9=20Gy=C3=B6ngy=C3=B6si?= Date: Mon, 25 Aug 2025 21:57:43 +0200 Subject: [PATCH 10/10] fix: align variable declarations --- internal/openai/openai.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/openai/openai.go b/internal/openai/openai.go index 06bd5003..0e7579c3 100644 --- a/internal/openai/openai.go +++ b/internal/openai/openai.go @@ -16,8 +16,8 @@ import ( ) var ( - _ stream.Client = &Client{} - apisWithJSONResponseFormat = map[string]bool{ + _ stream.Client = &Client{} + apisWithJSONResponseFormat = map[string]bool{ "openai": true, "copilot": true, }