diff --git a/chat.go b/chat.go index 0aa018715..ebfd9fdd2 100644 --- a/chat.go +++ b/chat.go @@ -1,6 +1,7 @@ package openai import ( + "bytes" "context" "encoding/json" "errors" @@ -86,12 +87,35 @@ type ChatMessagePartType string const ( ChatMessagePartTypeText ChatMessagePartType = "text" ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" + ChatMessagePartTypeAudio ChatMessagePartType = "input_audio" + ChatMessagePartTypeVideo ChatMessagePartType = "video" + ChatMessagePartTypeVideoURL ChatMessagePartType = "video_url" ) +/* reference: + * https://bailian.console.aliyun.com/ + * ?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 + * https://help.aliyun.com/zh/model-studio/qwen-omni#423736d367a7x + */ +type InputAudio struct { + Data string `json:"data"` + Format string `json:"format"` +} + +type CacheControl struct { + Type string `json:"type"` // must be "ephemeral" +} + type ChatMessagePart struct { - Type ChatMessagePartType `json:"type,omitempty"` - Text string `json:"text,omitempty"` - ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + Audio *InputAudio `json:"input_audio,omitempty"` // required when Type is "input_audio" + VideoURL *ChatMessageImageURL `json:"video_url,omitempty"` // required when Type is "video_url" + Video []string `json:"video,omitempty"` // required when Type is "video", array of image URLs + MinPixels int `json:"min_pixels,omitempty"` + MaxPixels int `json:"max_pixels,omitempty"` + *CacheControl `json:"cache_control,omitempty"` } type ChatCompletionMessage struct { @@ -333,6 +357,34 @@ type ChatCompletionRequest struct { SafetyIdentifier string `json:"safety_identifier,omitempty"` // Embedded struct for non-OpenAI extensions ChatCompletionRequestExtensions + // non-OpenAI extensions + Extensions map[string]interface{} `json:"-"` +} + +type customChatCompletionRequest ChatCompletionRequest + +const TrailingLen = 2 // length of "}\n" +func (r *ChatCompletionRequest) MarshalJSON() ([]byte, error) { + if len(r.Extensions) == 0 { + return json.Marshal((*customChatCompletionRequest)(r)) + } + buf := bytes.NewBuffer(nil) + encoder := json.NewEncoder(buf) + if err := encoder.Encode((*customChatCompletionRequest)(r)); err != nil { + return nil, err + } + // remove the trailing "}\n" + buf.Truncate(buf.Len() - TrailingLen) + // record the current position + pos := buf.Len() + // append extensions + if err := encoder.Encode(r.Extensions); err != nil { + return nil, err + } + data := buf.Bytes() + // change the leading '{' of extensions to ',' + data[pos] = ',' + return data, nil } type StreamOptions struct { diff --git a/chat_stream.go b/chat_stream.go index 80d16cc63..fd3bb6b95 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -5,6 +5,12 @@ import ( "net/http" ) +type OutputAudio struct { + Transcript string `json:"transcript"` // streamed text content + Data string `json:"data"` // base64-encoded audio data + ExpiresAt int `json:"expires_at"` // the timestamp when the request was created +} + type ChatCompletionStreamChoiceDelta struct { Content string `json:"content,omitempty"` Role string `json:"role,omitempty"` @@ -17,6 +23,8 @@ type ChatCompletionStreamChoiceDelta struct { // the doc from deepseek: // - https://api-docs.deepseek.com/api/create-chat-completion#responses ReasoningContent string `json:"reasoning_content,omitempty"` + // Audio is used for audio responses, if supported by the model, such as "qwen-omni". + Audio *OutputAudio `json:"audio,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { @@ -95,7 +103,7 @@ func (c *Client) CreateChatCompletionStream( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(request), + withBody(&request), ) if err != nil { return nil, err diff --git a/chat_stream_test.go b/chat_stream_test.go index eabb0f3a2..e39bb6077 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1021,3 +1021,183 @@ func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) } return true } + +func TestOutputAudio(t *testing.T) { + audio := openai.OutputAudio{ + Transcript: "Hello, world!", + Data: "base64encodedaudiodata", + ExpiresAt: 1234567890, + } + + data, err := json.Marshal(audio) + if err != nil { + t.Errorf("Failed to marshal OutputAudio: %v", err) + return + } + + var result openai.OutputAudio + if err = json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal OutputAudio: %v", err) + return + } + + if result.Transcript != audio.Transcript { + t.Errorf("Expected transcript %s, got %s", audio.Transcript, result.Transcript) + } + if result.Data != audio.Data { + t.Errorf("Expected data %s, got %s", audio.Data, result.Data) + } + if result.ExpiresAt != audio.ExpiresAt { + t.Errorf("Expected expires_at %d, got %d", audio.ExpiresAt, result.ExpiresAt) + } +} + +// verifyAudioContent checks if the audio content matches between expected and actual +func verifyAudioContent(t *testing.T, expected, actual *openai.OutputAudio) { + if actual.Transcript != expected.Transcript { + t.Errorf("Expected audio transcript %s, got %s", expected.Transcript, actual.Transcript) + } + if actual.Data != expected.Data { + t.Errorf("Expected audio data %s, got %s", expected.Data, actual.Data) + } + if actual.ExpiresAt != expected.ExpiresAt { + t.Errorf("Expected audio expires_at %d, got %d", expected.ExpiresAt, actual.ExpiresAt) + } +} + +// verifyAudioInDelta verifies the audio field in ChatCompletionStreamChoiceDelta +func verifyAudioInDelta(t *testing.T, expected, actual openai.ChatCompletionStreamChoiceDelta) { + if expected.Audio != nil { + if actual.Audio == nil { + t.Error("Expected audio to be present, but it's nil") + return + } + verifyAudioContent(t, expected.Audio, actual.Audio) + } else if actual.Audio != nil { + t.Error("Expected audio to be nil, but it's present") + } +} + +// testDeltaSerialization tests JSON marshaling and unmarshaling of a delta +func testDeltaSerialization(t *testing.T, delta openai.ChatCompletionStreamChoiceDelta) openai.ChatCompletionStreamChoiceDelta { + // Test JSON marshaling + data, err := json.Marshal(delta) + if err != nil { + t.Errorf("Failed to marshal ChatCompletionStreamChoiceDelta: %v", err) + return openai.ChatCompletionStreamChoiceDelta{} + } + + // Test JSON unmarshaling + var result openai.ChatCompletionStreamChoiceDelta + if err = json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal ChatCompletionStreamChoiceDelta: %v", err) + return openai.ChatCompletionStreamChoiceDelta{} + } + + return result +} + +func TestChatCompletionStreamChoiceDelta_Audio(t *testing.T) { + tests := []struct { + name string + delta openai.ChatCompletionStreamChoiceDelta + }{ + { + name: "with audio", + delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + Audio: &openai.OutputAudio{ + Transcript: "Hello, world!", + Data: "base64encodedaudiodata", + ExpiresAt: 1234567890, + }, + }, + }, + { + name: "without audio", + delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := testDeltaSerialization(t, tt.delta) + + // Verify the content is preserved + if result.Content != tt.delta.Content { + t.Errorf("Expected content %s, got %s", tt.delta.Content, result.Content) + } + + // Verify audio is preserved when present + verifyAudioInDelta(t, tt.delta, result) + }) + } +} + +func TestCreateChatCompletionStreamWithAudio(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses with audio + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + data := `{"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: message\n")...) + data = `{"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"audio":{"transcript":"Hello, world!","data":"base64encodedaudiodata","expires_at":1234567890}},"finish_reason":null}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + _, _ = w.Write(dataBytes) + }) + + ctx := context.Background() + req := openai.ChatCompletionRequest{ + Model: "qwen-omni", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } + + stream, err := client.CreateChatCompletionStream(ctx, req) + if err != nil { + t.Fatalf("CreateChatCompletionStream error: %v", err) + } + defer stream.Close() + + hasAudio := false + for { + var resp openai.ChatCompletionStreamResponse + resp, err = stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("Stream error: %v", err) + } + + if len(resp.Choices) > 0 && resp.Choices[0].Delta.Audio != nil { + hasAudio = true + if resp.Choices[0].Delta.Audio.Transcript != "Hello, world!" { + t.Errorf("Expected transcript 'Hello, world!', got %s", resp.Choices[0].Delta.Audio.Transcript) + } + if resp.Choices[0].Delta.Audio.Data != "base64encodedaudiodata" { + t.Errorf("Expected audio data 'base64encodedaudiodata', got %s", resp.Choices[0].Delta.Audio.Data) + } + } + } + + if !hasAudio { + t.Error("Expected to receive audio in stream response") + } +} diff --git a/chat_test.go b/chat_test.go index 236cff736..90219c21e 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1205,3 +1205,295 @@ func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) { }) } } + +func TestChatCompletionRequest_MarshalJSON(t *testing.T) { + tests := []struct { + name string + request openai.ChatCompletionRequest + want string + wantErr bool + }{ + { + name: "without extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + }, + wantErr: false, + }, + { + name: "with extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{ + "custom_field": "custom_value", + "number": 42, + }, + }, + wantErr: false, + }, + { + name: "with empty extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(&tt.request) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + + var result map[string]interface{} + if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { + t.Errorf("Failed to unmarshal result: %v", unmarshalErr) + return + } + + validateChatCompletionRequestResult(t, tt.request, result) + }) + } +} + +func validateChatCompletionRequestResult(t *testing.T, request openai.ChatCompletionRequest, result map[string]interface{}) { + // Check that model is present + if result["model"] != request.Model { + t.Errorf("Expected model %s, got %v", request.Model, result["model"]) + } + + // Check extensions are merged properly when present + if len(request.Extensions) > 0 { + validateExtensions(t, request.Extensions, result) + } +} + +func validateExtensions(t *testing.T, extensions map[string]interface{}, result map[string]interface{}) { + for key, value := range extensions { + // Convert both to string for comparison to handle type differences + resultStr := fmt.Sprintf("%v", result[key]) + valueStr := fmt.Sprintf("%v", value) + if resultStr != valueStr { + t.Errorf("Expected extension %s = %v (%s), got %v (%s)", key, value, valueStr, result[key], resultStr) + } + } +} + +func TestChatMessagePart_NewFields(t *testing.T) { + tests := []struct { + name string + part openai.ChatMessagePart + }{ + { + name: "with audio part", + part: openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeAudio, + Video: []string{"https://example.com/frame1.jpg", "https://example.com/frame2.jpg"}, + MinPixels: 100, + MaxPixels: 1000, + }, + }, + { + name: "with video URL part", + part: openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeVideoURL, + VideoURL: &openai.ChatMessageImageURL{ + URL: "https://example.com/video.mp4", + }, + }, + }, + { + name: "with video part", + part: openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeVideo, + Video: []string{"https://example.com/frame1.jpg", "https://example.com/frame2.jpg"}, + MinPixels: 100, + MaxPixels: 1000, + }, + }, + { + name: "with cache control", + part: openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeText, + Text: "Hello", + CacheControl: &openai.CacheControl{ + Type: "ephemeral", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test JSON marshaling + data, err := json.Marshal(tt.part) + if err != nil { + t.Errorf("Failed to marshal ChatMessagePart: %v", err) + return + } + + // Test JSON unmarshaling + var result openai.ChatMessagePart + if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { + t.Errorf("Failed to unmarshal ChatMessagePart: %v", unmarshalErr) + return + } + + // Verify the type is preserved + if result.Type != tt.part.Type { + t.Errorf("Expected type %s, got %s", tt.part.Type, result.Type) + } + }) + } +} + +func TestInputAudio(t *testing.T) { + audio := openai.InputAudio{ + Data: "base64encodedaudiodata", + Format: "wav", + } + + data, err := json.Marshal(audio) + if err != nil { + t.Errorf("Failed to marshal InputAudio: %v", err) + return + } + + var result openai.InputAudio + if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { + t.Errorf("Failed to unmarshal InputAudio: %v", unmarshalErr) + return + } + + if result.Data != audio.Data { + t.Errorf("Expected data %s, got %s", audio.Data, result.Data) + } + if result.Format != audio.Format { + t.Errorf("Expected format %s, got %s", audio.Format, result.Format) + } +} + +func TestCacheControl(t *testing.T) { + cacheControl := openai.CacheControl{ + Type: "ephemeral", + } + + data, err := json.Marshal(cacheControl) + if err != nil { + t.Errorf("Failed to marshal CacheControl: %v", err) + return + } + + var result openai.CacheControl + if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { + t.Errorf("Failed to unmarshal CacheControl: %v", unmarshalErr) + return + } + + if result.Type != cacheControl.Type { + t.Errorf("Expected type %s, got %s", cacheControl.Type, result.Type) + } +} + +func TestChatCompletionRequest_MarshalJSON_EdgeCases(t *testing.T) { + // Test with complex extension values that might cause encoding issues + tests := []struct { + name string + request openai.ChatCompletionRequest + expectErr bool + }{ + { + name: "with nil value in extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{ + "nil_value": nil, + }, + }, + expectErr: false, + }, + { + name: "with complex nested structure in extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{ + "nested": map[string]interface{}{ + "inner": map[string]interface{}{ + "value": "deep", + }, + }, + }, + }, + expectErr: false, + }, + { + name: "with function in extensions (should cause error)", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{ + "function": func() {}, // functions cannot be JSON encoded + }, + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := json.Marshal(&tt.request) + if (err != nil) != tt.expectErr { + t.Errorf("MarshalJSON() error = %v, expectErr %v", err, tt.expectErr) + } + }) + } +} + +func TestChatCompletionRequest_MarshalJSON_InvalidRequest(t *testing.T) { + // Test with invalid request data that might cause encoding errors + // Create a request with invalid channel in messages (this should cause an encoding error) + invalidMsg := make(chan int) + request := openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello", + // We can't directly add a channel to the message struct, + // so we'll test with extensions instead + }, + }, + Extensions: map[string]interface{}{ + "invalid_channel": invalidMsg, // channels cannot be JSON encoded + }, + } + + _, err := json.Marshal(&request) + if err == nil { + t.Error("Expected marshal to fail with invalid channel in extensions") + } +}