diff --git a/config.go b/config.go index 9f0abfe2..9902ff9b 100644 --- a/config.go +++ b/config.go @@ -75,6 +75,7 @@ var help = map[string]string{ "mcp-list": "List all available MCP servers", "mcp-list-tools": "List all available tools from enabled MCP servers", "mcp-timeout": "Timeout for MCP server calls, defaults to 15 seconds", + "image": "Include image file(s) in the prompt (vision models only)", } // Model represents the LLM model used in the API call. @@ -185,6 +186,7 @@ type Config struct { Delete []string DeleteOlderThan time.Duration User string + Images []string MCPServers map[string]MCPServerConfig `yaml:"mcp-servers"` MCPList bool diff --git a/images.go b/images.go new file mode 100644 index 00000000..3618de2c --- /dev/null +++ b/images.go @@ -0,0 +1,103 @@ +package main + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/charmbracelet/mods/internal/proto" +) + +// Map file extensions to MIME types. +var supportedImageFormats = map[string]string{ + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", +} + +// Set the maximum allowed image file size (5 MB). +const maxImageSize = 5 * 1024 * 1024 + +// Read and validate image files from paths. +func processImageFiles(imagePaths []string) ([]proto.ImageContent, error) { + if len(imagePaths) == 0 { + return nil, nil + } + + if len(imagePaths) > 10 { + return nil, fmt.Errorf("too many images: maximum 10 images allowed, got %d", len(imagePaths)) + } + + images := make([]proto.ImageContent, 0, len(imagePaths)) + for _, path := range imagePaths { + image, err := readImageFile(path) + if err != nil { + return nil, fmt.Errorf("error processing image %s: %w", path, err) + } + images = append(images, *image) + } + + return images, nil +} + +// Read and validate a single image file. +func readImageFile(path string) (*proto.ImageContent, error) { + // Check if file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + return nil, fmt.Errorf("image file does not exist: %s", path) + } + + // Get file info for size check + fileInfo, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("could not get file info: %w", err) + } + + if fileInfo.Size() > maxImageSize { + return nil, fmt.Errorf("image file too large: %s (%.2f MB > 5 MB)", + path, float64(fileInfo.Size())/(1024*1024)) + } + + // Detect MIME type from extension + ext := strings.ToLower(filepath.Ext(path)) + mimeType, supported := supportedImageFormats[ext] + if !supported { + return nil, fmt.Errorf("unsupported image format: %s (supported: %s)", + ext, getSupportedFormats()) + } + + // Read file data + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("could not open file: %w", err) + } + defer func() { + if cerr := file.Close(); cerr != nil { + fmt.Fprintf(os.Stderr, "could not close file: %v\n", cerr) + } + }() + + data, err := io.ReadAll(file) + if err != nil { + return nil, fmt.Errorf("could not read file: %w", err) + } + + return &proto.ImageContent{ + Data: data, + MimeType: mimeType, + Filename: filepath.Base(path), + }, nil +} + +// Return a comma-separated list of supported formats. +func getSupportedFormats() string { + formats := make([]string, 0, len(supportedImageFormats)) + for ext := range supportedImageFormats { + formats = append(formats, ext) + } + return strings.Join(formats, ", ") +} diff --git a/internal/anthropic/format.go b/internal/anthropic/format.go index cc806d44..113b6c81 100644 --- a/internal/anthropic/format.go +++ b/internal/anthropic/format.go @@ -28,6 +28,13 @@ func fromMCPTools(mcps map[string][]mcp.Tool) []anthropic.ToolUnionParam { } func fromProtoMessages(input []proto.Message) (system []anthropic.TextBlockParam, messages []anthropic.MessageParam) { + // Check for images and error if present (not supported yet) + for _, msg := range input { + if len(msg.Images) > 0 { + panic("image input is not supported for Anthropic API yet - use OpenAI API for vision capabilities") + } + } + for _, msg := range input { switch msg.Role { case proto.RoleSystem: diff --git a/internal/cohere/format.go b/internal/cohere/format.go index dd1a76b3..95d8e0c0 100644 --- a/internal/cohere/format.go +++ b/internal/cohere/format.go @@ -6,6 +6,13 @@ import ( ) func fromProtoMessages(input []proto.Message) (history []*cohere.Message, message string) { + // Check for images and error if present (not supported yet) + for _, msg := range input { + if len(msg.Images) > 0 { + panic("image input is not supported for Cohere API - use OpenAI API for vision capabilities") + } + } + var messages []*cohere.Message //nolint:prealloc for _, msg := range input { messages = append(messages, &cohere.Message{ diff --git a/internal/google/format.go b/internal/google/format.go index db3656d8..ab72c5ec 100644 --- a/internal/google/format.go +++ b/internal/google/format.go @@ -3,6 +3,13 @@ package google import "github.com/charmbracelet/mods/internal/proto" func fromProtoMessages(input []proto.Message) []Content { + // Check for images and error if present (not supported yet) + for _, msg := range input { + if len(msg.Images) > 0 { + panic("image input is not supported for Google API yet - use OpenAI API for vision capabilities") + } + } + result := make([]Content, 0, len(input)) for _, in := range input { switch in.Role { diff --git a/internal/ollama/format.go b/internal/ollama/format.go index 140848b4..dbd81ec0 100644 --- a/internal/ollama/format.go +++ b/internal/ollama/format.go @@ -42,6 +42,12 @@ func fromProtoMessage(input proto.Message) api.Message { Content: input.Content, Role: input.Role, } + + // Handle images + for _, img := range input.Images { + m.Images = append(m.Images, api.ImageData(img.Data)) + } + for _, call := range input.ToolCalls { var args api.ToolCallFunctionArguments _ = json.Unmarshal(call.Function.Arguments, &args) @@ -62,6 +68,16 @@ func toProtoMessage(in api.Message) proto.Message { Role: in.Role, Content: in.Content, } + + // Handle images + for _, imgData := range in.Images { + msg.Images = append(msg.Images, proto.ImageContent{ + Data: []byte(imgData), + MimeType: "image/jpeg", // Ollama doesn't provide MIME type, assume JPEG + Filename: "", // Ollama doesn't provide filename + }) + } + for _, call := range in.ToolCalls { msg.ToolCalls = append(msg.ToolCalls, proto.ToolCall{ ID: strconv.Itoa(call.Function.Index), diff --git a/internal/openai/format.go b/internal/openai/format.go index d56cfc33..42409125 100644 --- a/internal/openai/format.go +++ b/internal/openai/format.go @@ -1,6 +1,7 @@ package openai import ( + "encoding/base64" "fmt" "github.com/charmbracelet/mods/internal/proto" @@ -46,7 +47,29 @@ func fromProtoMessages(input []proto.Message) []openai.ChatCompletionMessagePara break } case proto.RoleUser: - messages = append(messages, openai.UserMessage(msg.Content)) + if len(msg.Images) > 0 { + // Create content array with text and images + var content []openai.ChatCompletionContentPartUnionParam + + // Add text content if present + if msg.Content != "" { + content = append(content, openai.TextContentPart(msg.Content)) + } + + // Add image content + for _, img := range msg.Images { + base64Data := base64.StdEncoding.EncodeToString(img.Data) + dataURI := fmt.Sprintf("data:%s;base64,%s", img.MimeType, base64Data) + + content = append(content, openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{ + URL: dataURI, + })) + } + + messages = append(messages, openai.UserMessage(content)) + } else { + messages = append(messages, openai.UserMessage(msg.Content)) + } case proto.RoleAssistant: m := openai.AssistantMessage(msg.Content) for _, tool := range msg.ToolCalls { diff --git a/internal/openai/openai.go b/internal/openai/openai.go index 47c498af..0e7579c3 100644 --- a/internal/openai/openai.go +++ b/internal/openai/openai.go @@ -15,7 +15,13 @@ import ( "github.com/openai/openai-go/shared" ) -var _ stream.Client = &Client{} +var ( + _ stream.Client = &Client{} + apisWithJSONResponseFormat = map[string]bool{ + "openai": true, + "copilot": true, + } +) // Client is the openai client. type Client struct { @@ -86,7 +92,7 @@ func (c *Client) Request(ctx context.Context, request proto.Request) stream.Stre if request.MaxTokens != nil { body.MaxTokens = openai.Int(*request.MaxTokens) } - if request.API == "openai" && request.ResponseFormat != nil && *request.ResponseFormat == "json" { + if apisWithJSONResponseFormat[request.API] && request.ResponseFormat != nil && *request.ResponseFormat == "json" { body.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{ OfJSONObject: &shared.ResponseFormatJSONObjectParam{}, } diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 93a90be6..cdeeb437 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -42,10 +42,18 @@ func (c ToolCallStatus) String() string { return sb.String() } +// ImageContent represents an image attachment in a message. +type ImageContent struct { + Data []byte // base64 encoded image data + MimeType string // image/jpeg, image/png, etc. + Filename string // original filename for reference +} + // Message is a message in the conversation. type Message struct { Role string Content string + Images []ImageContent ToolCalls []ToolCall } diff --git a/main.go b/main.go index 77027448..74500b09 100644 --- a/main.go +++ b/main.go @@ -286,6 +286,7 @@ func initFlags() { flags.BoolVar(&config.MCPList, "mcp-list", false, stdoutStyles().FlagDesc.Render(help["mcp-list"])) flags.BoolVar(&config.MCPListTools, "mcp-list-tools", false, stdoutStyles().FlagDesc.Render(help["mcp-list-tools"])) flags.StringArrayVar(&config.MCPDisable, "mcp-disable", nil, stdoutStyles().FlagDesc.Render(help["mcp-disable"])) + flags.StringArrayVarP(&config.Images, "image", "i", nil, stdoutStyles().FlagDesc.Render(help["image"])) flags.Lookup("prompt").NoOptDefVal = "-1" flags.SortFlags = false diff --git a/stream.go b/stream.go index 5b92a319..8b107b2e 100644 --- a/stream.go +++ b/stream.go @@ -61,9 +61,19 @@ func (m *Mods) setupStreamContext(content string, mod Model) error { } } + // Process image files if provided + images, err := processImageFiles(cfg.Images) + if err != nil { + return modsError{ + err: err, + reason: "Could not process image files", + } + } + m.messages = append(m.messages, proto.Message{ Role: proto.RoleUser, Content: content, + Images: images, }) return nil