From 33695525c0c3663f14e0459b1a965151b20e53bf Mon Sep 17 00:00:00 2001 From: zhangjk Date: Mon, 13 Oct 2025 16:33:10 +0800 Subject: [PATCH 1/9] Add support for audio and video input --- chat.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++--- chat_stream.go | 12 +++++++++-- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/chat.go b/chat.go index 0aa018715..c49a7d961 100644 --- a/chat.go +++ b/chat.go @@ -1,6 +1,7 @@ package openai import ( + "bytes" "context" "encoding/json" "errors" @@ -86,12 +87,34 @@ type ChatMessagePartType string const ( ChatMessagePartTypeText ChatMessagePartType = "text" ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" + ChatMessagePartTypeAudio ChatMessagePartType = "input_audio" + ChatMessagePartTypeVideo ChatMessagePartType = "video" + ChatMessagePartTypeVideoURL ChatMessagePartType = "video_url" ) +/* reference: + * https://bailian.console.aliyun.com/?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 + * https://help.aliyun.com/zh/model-studio/qwen-omni#423736d367a7x + */ +type InputAudio struct { + Data string `json:"data"` + Format string `json:"format"` +} + +type CacheControl struct { + Type string `json:"type"` // must be "ephemeral" +} + type ChatMessagePart struct { - Type ChatMessagePartType `json:"type,omitempty"` - Text string `json:"text,omitempty"` - ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + Audio *InputAudio `json:"input_audio,omitempty"` // required when Type is "input_audio" + VideoURL *ChatMessageImageURL `json:"video_url,omitempty"` // required when Type is "video_url" + Video []string `json:"video,omitempty"` // required when Type is "video", array of image URLs + MinPixels int `json:"min_pixels,omitempty"` + MaxPixels int `json:"max_pixels,omitempty"` + *CacheControl `json:"cache_control,omitempty"` } type ChatCompletionMessage struct { @@ -333,6 +356,33 @@ type ChatCompletionRequest struct { SafetyIdentifier string `json:"safety_identifier,omitempty"` // Embedded struct for non-OpenAI extensions ChatCompletionRequestExtensions + // non-OpenAI extensions + Extensions map[string]interface{} `json:"-"` +} + +type customChatCompletionRequest ChatCompletionRequest + +func (r *ChatCompletionRequest) MarshalJSON() ([]byte, error) { + if len(r.Extensions) == 0 { + return json.Marshal((*customChatCompletionRequest)(r)) + } + buf := bytes.NewBuffer(nil) + encoder := json.NewEncoder(buf) + if err := encoder.Encode((*customChatCompletionRequest)(r)); err != nil { + return nil, err + } + // remove the trailing "}\n" + buf.Truncate(buf.Len() - 2) + // record the current position + pos := buf.Len() + // append extensions + if err := encoder.Encode(r.Extensions); err != nil { + return nil, err + } + data := buf.Bytes() + // change the leading '{' of extensions to ',' + data[pos] = ',' + return data, nil } type StreamOptions struct { diff --git a/chat_stream.go b/chat_stream.go index 80d16cc63..a224b13e6 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -5,6 +5,13 @@ import ( "net/http" ) +// reference: https://bailian.console.aliyun.com/?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 +type OutputAudio struct { + Transcript string `json:"transcript"` // streamed text content + Data string `json:"data"` // base64-encoded audio data + ExpiresAt int `json:"expires_at"` // the timestamp when the request was created +} + type ChatCompletionStreamChoiceDelta struct { Content string `json:"content,omitempty"` Role string `json:"role,omitempty"` @@ -16,7 +23,8 @@ type ChatCompletionStreamChoiceDelta struct { // which is not in the official documentation. // the doc from deepseek: // - https://api-docs.deepseek.com/api/create-chat-completion#responses - ReasoningContent string `json:"reasoning_content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Audio *OutputAudio `json:"audio,omitempty"` // Audio is used for audio responses, if supported by the model, such as "qwen-omni". } type ChatCompletionStreamChoiceLogprobs struct { @@ -95,7 +103,7 @@ func (c *Client) CreateChatCompletionStream( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(request), + withBody(&request), ) if err != nil { return nil, err From cd28a00070ebad275a2e240fa254d85999dcada1 Mon Sep 17 00:00:00 2001 From: zhangjk Date: Wed, 15 Oct 2025 11:04:27 +0800 Subject: [PATCH 2/9] fix PR error --- chat.go | 6 ++++-- chat_stream.go | 9 ++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/chat.go b/chat.go index c49a7d961..ebfd9fdd2 100644 --- a/chat.go +++ b/chat.go @@ -93,7 +93,8 @@ const ( ) /* reference: - * https://bailian.console.aliyun.com/?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 + * https://bailian.console.aliyun.com/ + * ?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 * https://help.aliyun.com/zh/model-studio/qwen-omni#423736d367a7x */ type InputAudio struct { @@ -362,6 +363,7 @@ type ChatCompletionRequest struct { type customChatCompletionRequest ChatCompletionRequest +const TrailingLen = 2 // length of "}\n" func (r *ChatCompletionRequest) MarshalJSON() ([]byte, error) { if len(r.Extensions) == 0 { return json.Marshal((*customChatCompletionRequest)(r)) @@ -372,7 +374,7 @@ func (r *ChatCompletionRequest) MarshalJSON() ([]byte, error) { return nil, err } // remove the trailing "}\n" - buf.Truncate(buf.Len() - 2) + buf.Truncate(buf.Len() - TrailingLen) // record the current position pos := buf.Len() // append extensions diff --git a/chat_stream.go b/chat_stream.go index a224b13e6..fde874af2 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -5,7 +5,9 @@ import ( "net/http" ) -// reference: https://bailian.console.aliyun.com/?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 +/* reference: https://bailian.console.aliyun.com/ + * ?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 + */ type OutputAudio struct { Transcript string `json:"transcript"` // streamed text content Data string `json:"data"` // base64-encoded audio data @@ -23,8 +25,9 @@ type ChatCompletionStreamChoiceDelta struct { // which is not in the official documentation. // the doc from deepseek: // - https://api-docs.deepseek.com/api/create-chat-completion#responses - ReasoningContent string `json:"reasoning_content,omitempty"` - Audio *OutputAudio `json:"audio,omitempty"` // Audio is used for audio responses, if supported by the model, such as "qwen-omni". + ReasoningContent string `json:"reasoning_content,omitempty"` + // Audio is used for audio responses, if supported by the model, such as "qwen-omni". + Audio *OutputAudio `json:"audio,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { From c58a2bdb8bb5a692f4b455b28b09027b657db7dc Mon Sep 17 00:00:00 2001 From: zhangjk Date: Wed, 15 Oct 2025 11:17:14 +0800 Subject: [PATCH 3/9] fix PR error again --- chat_stream.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index fde874af2..fd3bb6b95 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -5,9 +5,6 @@ import ( "net/http" ) -/* reference: https://bailian.console.aliyun.com/ - * ?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.191e7b08wdOQzD&tab=api#/api/?type=model&url=2712576 - */ type OutputAudio struct { Transcript string `json:"transcript"` // streamed text content Data string `json:"data"` // base64-encoded audio data From e180158886fa1f61413114aa1930c0b15d0b39ce Mon Sep 17 00:00:00 2001 From: zhangjk Date: Wed, 15 Oct 2025 13:40:06 +0800 Subject: [PATCH 4/9] add unit test cases for new code --- chat_stream_test.go | 162 ++++++++++++++++++++++++++++++++++++ chat_test.go | 195 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 357 insertions(+) diff --git a/chat_stream_test.go b/chat_stream_test.go index eabb0f3a2..571ee420b 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1021,3 +1021,165 @@ func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) } return true } + +func TestOutputAudio(t *testing.T) { + audio := openai.OutputAudio{ + Transcript: "Hello, world!", + Data: "base64encodedaudiodata", + ExpiresAt: 1234567890, + } + + data, err := json.Marshal(audio) + if err != nil { + t.Errorf("Failed to marshal OutputAudio: %v", err) + return + } + + var result openai.OutputAudio + if err := json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal OutputAudio: %v", err) + return + } + + if result.Transcript != audio.Transcript { + t.Errorf("Expected transcript %s, got %s", audio.Transcript, result.Transcript) + } + if result.Data != audio.Data { + t.Errorf("Expected data %s, got %s", audio.Data, result.Data) + } + if result.ExpiresAt != audio.ExpiresAt { + t.Errorf("Expected expires_at %d, got %d", audio.ExpiresAt, result.ExpiresAt) + } +} + +func TestChatCompletionStreamChoiceDelta_Audio(t *testing.T) { + tests := []struct { + name string + delta openai.ChatCompletionStreamChoiceDelta + }{ + { + name: "with audio", + delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + Audio: &openai.OutputAudio{ + Transcript: "Hello, world!", + Data: "base64encodedaudiodata", + ExpiresAt: 1234567890, + }, + }, + }, + { + name: "without audio", + delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test JSON marshaling + data, err := json.Marshal(tt.delta) + if err != nil { + t.Errorf("Failed to marshal ChatCompletionStreamChoiceDelta: %v", err) + return + } + + // Test JSON unmarshaling + var result openai.ChatCompletionStreamChoiceDelta + if err := json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal ChatCompletionStreamChoiceDelta: %v", err) + return + } + + // Verify the content is preserved + if result.Content != tt.delta.Content { + t.Errorf("Expected content %s, got %s", tt.delta.Content, result.Content) + } + + // Verify audio is preserved when present + if tt.delta.Audio != nil { + if result.Audio == nil { + t.Error("Expected audio to be present, but it's nil") + return + } + if result.Audio.Transcript != tt.delta.Audio.Transcript { + t.Errorf("Expected audio transcript %s, got %s", tt.delta.Audio.Transcript, result.Audio.Transcript) + } + if result.Audio.Data != tt.delta.Audio.Data { + t.Errorf("Expected audio data %s, got %s", tt.delta.Audio.Data, result.Audio.Data) + } + if result.Audio.ExpiresAt != tt.delta.Audio.ExpiresAt { + t.Errorf("Expected audio expires_at %d, got %d", tt.delta.Audio.ExpiresAt, result.Audio.ExpiresAt) + } + } else if result.Audio != nil { + t.Error("Expected audio to be nil, but it's present") + } + }) + } +} + +func TestCreateChatCompletionStreamWithAudio(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses with audio + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + data := `{"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: message\n")...) + data = `{"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"audio":{"transcript":"Hello, world!","data":"base64encodedaudiodata","expires_at":1234567890}},"finish_reason":null}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + _, _ = w.Write(dataBytes) + }) + + ctx := context.Background() + req := openai.ChatCompletionRequest{ + Model: "qwen-omni", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } + + stream, err := client.CreateChatCompletionStream(ctx, req) + if err != nil { + t.Fatalf("CreateChatCompletionStream error: %v", err) + } + defer stream.Close() + + hasAudio := false + for { + resp, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("Stream error: %v", err) + } + + if len(resp.Choices) > 0 && resp.Choices[0].Delta.Audio != nil { + hasAudio = true + if resp.Choices[0].Delta.Audio.Transcript != "Hello, world!" { + t.Errorf("Expected transcript 'Hello, world!', got %s", resp.Choices[0].Delta.Audio.Transcript) + } + if resp.Choices[0].Delta.Audio.Data != "base64encodedaudiodata" { + t.Errorf("Expected audio data 'base64encodedaudiodata', got %s", resp.Choices[0].Delta.Audio.Data) + } + } + } + + if !hasAudio { + t.Error("Expected to receive audio in stream response") + } +} diff --git a/chat_test.go b/chat_test.go index 236cff736..8de7c9646 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1205,3 +1205,198 @@ func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) { }) } } + +func TestChatCompletionRequest_MarshalJSON(t *testing.T) { + tests := []struct { + name string + request openai.ChatCompletionRequest + want string + wantErr bool + }{ + { + name: "without extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + }, + wantErr: false, + }, + { + name: "with extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{ + "custom_field": "custom_value", + "number": 42, + }, + }, + wantErr: false, + }, + { + name: "with empty extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(&tt.request) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + return + } + + // Check that model is present + if result["model"] != tt.request.Model { + t.Errorf("Expected model %s, got %v", tt.request.Model, result["model"]) + } + + // Check extensions are merged properly when present + if len(tt.request.Extensions) > 0 { + for key, value := range tt.request.Extensions { + // Convert both to string for comparison to handle type differences + resultStr := fmt.Sprintf("%v", result[key]) + valueStr := fmt.Sprintf("%v", value) + if resultStr != valueStr { + t.Errorf("Expected extension %s = %v (%s), got %v (%s)", key, value, valueStr, result[key], resultStr) + } + } + } + } + }) + } +} + +func TestChatMessagePart_NewFields(t *testing.T) { + tests := []struct { + name string + part openai.ChatMessagePart + }{ + { + name: "with audio part", + part: openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeAudio, + Video: []string{"https://example.com/frame1.jpg", "https://example.com/frame2.jpg"}, + MinPixels: 100, + MaxPixels: 1000, + }, + }, + { + name: "with video URL part", + part: openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeVideoURL, + VideoURL: &openai.ChatMessageImageURL{ + URL: "https://example.com/video.mp4", + }, + }, + }, + { + name: "with video part", + part: openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeVideo, + Video: []string{"https://example.com/frame1.jpg", "https://example.com/frame2.jpg"}, + MinPixels: 100, + MaxPixels: 1000, + }, + }, + { + name: "with cache control", + part: openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeText, + Text: "Hello", + CacheControl: &openai.CacheControl{ + Type: "ephemeral", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test JSON marshaling + data, err := json.Marshal(tt.part) + if err != nil { + t.Errorf("Failed to marshal ChatMessagePart: %v", err) + return + } + + // Test JSON unmarshaling + var result openai.ChatMessagePart + if err := json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal ChatMessagePart: %v", err) + return + } + + // Verify the type is preserved + if result.Type != tt.part.Type { + t.Errorf("Expected type %s, got %s", tt.part.Type, result.Type) + } + }) + } +} + +func TestInputAudio(t *testing.T) { + audio := openai.InputAudio{ + Data: "base64encodedaudiodata", + Format: "wav", + } + + data, err := json.Marshal(audio) + if err != nil { + t.Errorf("Failed to marshal InputAudio: %v", err) + return + } + + var result openai.InputAudio + if err := json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal InputAudio: %v", err) + return + } + + if result.Data != audio.Data { + t.Errorf("Expected data %s, got %s", audio.Data, result.Data) + } + if result.Format != audio.Format { + t.Errorf("Expected format %s, got %s", audio.Format, result.Format) + } +} + +func TestCacheControl(t *testing.T) { + cacheControl := openai.CacheControl{ + Type: "ephemeral", + } + + data, err := json.Marshal(cacheControl) + if err != nil { + t.Errorf("Failed to marshal CacheControl: %v", err) + return + } + + var result openai.CacheControl + if err := json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal CacheControl: %v", err) + return + } + + if result.Type != cacheControl.Type { + t.Errorf("Expected type %s, got %s", cacheControl.Type, result.Type) + } +} From 175504a9159aa9e08d5f505970f32f38e0dc4a97 Mon Sep 17 00:00:00 2001 From: zhangjk Date: Wed, 15 Oct 2025 14:26:28 +0800 Subject: [PATCH 5/9] fix PR error --- chat_stream_test.go | 88 +++++++++++++++++++++++++++------------------ chat_test.go | 87 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 35 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 571ee420b..e39bb6077 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1036,7 +1036,7 @@ func TestOutputAudio(t *testing.T) { } var result openai.OutputAudio - if err := json.Unmarshal(data, &result); err != nil { + if err = json.Unmarshal(data, &result); err != nil { t.Errorf("Failed to unmarshal OutputAudio: %v", err) return } @@ -1052,6 +1052,51 @@ func TestOutputAudio(t *testing.T) { } } +// verifyAudioContent checks if the audio content matches between expected and actual +func verifyAudioContent(t *testing.T, expected, actual *openai.OutputAudio) { + if actual.Transcript != expected.Transcript { + t.Errorf("Expected audio transcript %s, got %s", expected.Transcript, actual.Transcript) + } + if actual.Data != expected.Data { + t.Errorf("Expected audio data %s, got %s", expected.Data, actual.Data) + } + if actual.ExpiresAt != expected.ExpiresAt { + t.Errorf("Expected audio expires_at %d, got %d", expected.ExpiresAt, actual.ExpiresAt) + } +} + +// verifyAudioInDelta verifies the audio field in ChatCompletionStreamChoiceDelta +func verifyAudioInDelta(t *testing.T, expected, actual openai.ChatCompletionStreamChoiceDelta) { + if expected.Audio != nil { + if actual.Audio == nil { + t.Error("Expected audio to be present, but it's nil") + return + } + verifyAudioContent(t, expected.Audio, actual.Audio) + } else if actual.Audio != nil { + t.Error("Expected audio to be nil, but it's present") + } +} + +// testDeltaSerialization tests JSON marshaling and unmarshaling of a delta +func testDeltaSerialization(t *testing.T, delta openai.ChatCompletionStreamChoiceDelta) openai.ChatCompletionStreamChoiceDelta { + // Test JSON marshaling + data, err := json.Marshal(delta) + if err != nil { + t.Errorf("Failed to marshal ChatCompletionStreamChoiceDelta: %v", err) + return openai.ChatCompletionStreamChoiceDelta{} + } + + // Test JSON unmarshaling + var result openai.ChatCompletionStreamChoiceDelta + if err = json.Unmarshal(data, &result); err != nil { + t.Errorf("Failed to unmarshal ChatCompletionStreamChoiceDelta: %v", err) + return openai.ChatCompletionStreamChoiceDelta{} + } + + return result +} + func TestChatCompletionStreamChoiceDelta_Audio(t *testing.T) { tests := []struct { name string @@ -1078,19 +1123,7 @@ func TestChatCompletionStreamChoiceDelta_Audio(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Test JSON marshaling - data, err := json.Marshal(tt.delta) - if err != nil { - t.Errorf("Failed to marshal ChatCompletionStreamChoiceDelta: %v", err) - return - } - - // Test JSON unmarshaling - var result openai.ChatCompletionStreamChoiceDelta - if err := json.Unmarshal(data, &result); err != nil { - t.Errorf("Failed to unmarshal ChatCompletionStreamChoiceDelta: %v", err) - return - } + result := testDeltaSerialization(t, tt.delta) // Verify the content is preserved if result.Content != tt.delta.Content { @@ -1098,23 +1131,7 @@ func TestChatCompletionStreamChoiceDelta_Audio(t *testing.T) { } // Verify audio is preserved when present - if tt.delta.Audio != nil { - if result.Audio == nil { - t.Error("Expected audio to be present, but it's nil") - return - } - if result.Audio.Transcript != tt.delta.Audio.Transcript { - t.Errorf("Expected audio transcript %s, got %s", tt.delta.Audio.Transcript, result.Audio.Transcript) - } - if result.Audio.Data != tt.delta.Audio.Data { - t.Errorf("Expected audio data %s, got %s", tt.delta.Audio.Data, result.Audio.Data) - } - if result.Audio.ExpiresAt != tt.delta.Audio.ExpiresAt { - t.Errorf("Expected audio expires_at %d, got %d", tt.delta.Audio.ExpiresAt, result.Audio.ExpiresAt) - } - } else if result.Audio != nil { - t.Error("Expected audio to be nil, but it's present") - } + verifyAudioInDelta(t, tt.delta, result) }) } } @@ -1122,7 +1139,7 @@ func TestChatCompletionStreamChoiceDelta_Audio(t *testing.T) { func TestCreateChatCompletionStreamWithAudio(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") @@ -1131,11 +1148,11 @@ func TestCreateChatCompletionStreamWithAudio(t *testing.T) { dataBytes = append(dataBytes, []byte("event: message\n")...) data := `{"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - + dataBytes = append(dataBytes, []byte("event: message\n")...) data = `{"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"audio":{"transcript":"Hello, world!","data":"base64encodedaudiodata","expires_at":1234567890}},"finish_reason":null}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) _, _ = w.Write(dataBytes) }) @@ -1160,7 +1177,8 @@ func TestCreateChatCompletionStreamWithAudio(t *testing.T) { hasAudio := false for { - resp, err := stream.Recv() + var resp openai.ChatCompletionStreamResponse + resp, err = stream.Recv() if errors.Is(err, io.EOF) { break } diff --git a/chat_test.go b/chat_test.go index 8de7c9646..d00c7031c 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1400,3 +1400,90 @@ func TestCacheControl(t *testing.T) { t.Errorf("Expected type %s, got %s", cacheControl.Type, result.Type) } } + +func TestChatCompletionRequest_MarshalJSON_EdgeCases(t *testing.T) { + // Test with complex extension values that might cause encoding issues + tests := []struct { + name string + request openai.ChatCompletionRequest + expectErr bool + }{ + { + name: "with nil value in extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{ + "nil_value": nil, + }, + }, + expectErr: false, + }, + { + name: "with complex nested structure in extensions", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{ + "nested": map[string]interface{}{ + "inner": map[string]interface{}{ + "value": "deep", + }, + }, + }, + }, + expectErr: false, + }, + { + name: "with function in extensions (should cause error)", + request: openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "Hello"}, + }, + Extensions: map[string]interface{}{ + "function": func() {}, // functions cannot be JSON encoded + }, + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := json.Marshal(&tt.request) + if (err != nil) != tt.expectErr { + t.Errorf("MarshalJSON() error = %v, expectErr %v", err, tt.expectErr) + } + }) + } +} + +func TestChatCompletionRequest_MarshalJSON_InvalidRequest(t *testing.T) { + // Test with invalid request data that might cause encoding errors + // Create a request with invalid channel in messages (this should cause an encoding error) + invalidMsg := make(chan int) + request := openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello", + // We can't directly add a channel to the message struct, + // so we'll test with extensions instead + }, + }, + Extensions: map[string]interface{}{ + "invalid_channel": invalidMsg, // channels cannot be JSON encoded + }, + } + + _, err := json.Marshal(&request) + if err == nil { + t.Error("Expected marshal to fail with invalid channel in extensions") + } +} From 781c0cbace5bbdf95342e90917673d3127db4b94 Mon Sep 17 00:00:00 2001 From: zhangjk Date: Wed, 15 Oct 2025 16:51:00 +0800 Subject: [PATCH 6/9] =?UTF-8?q?fix:=20=E8=A7=A3=E5=86=B3=20golangci-lint?= =?UTF-8?q?=20=E6=8A=A5=E5=91=8A=E7=9A=84=E4=BB=A3=E7=A0=81=E8=B4=A8?= =?UTF-8?q?=E9=87=8F=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 TestChatCompletionRequest_MarshalJSON 函数的认知复杂度问题 - 为 chat_stream_test.go 中的注释添加句号 - 修复 chat_test.go 中的导入格式问题 - 修复 chat_test.go 中的变量阴影问题 - 修复 chat_stream_test.go 中的长行问题 - 修复 chat_test.go 中的嵌套 if 块问题 --- chat_stream_test.go | 20 +++++++++++----- chat_test.go | 58 +++++++++++++++++++++++---------------------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index e39bb6077..c5b2e3de2 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1052,7 +1052,7 @@ func TestOutputAudio(t *testing.T) { } } -// verifyAudioContent checks if the audio content matches between expected and actual +// verifyAudioContent checks if the audio content matches between expected and actual. func verifyAudioContent(t *testing.T, expected, actual *openai.OutputAudio) { if actual.Transcript != expected.Transcript { t.Errorf("Expected audio transcript %s, got %s", expected.Transcript, actual.Transcript) @@ -1065,7 +1065,7 @@ func verifyAudioContent(t *testing.T, expected, actual *openai.OutputAudio) { } } -// verifyAudioInDelta verifies the audio field in ChatCompletionStreamChoiceDelta +// verifyAudioInDelta verifies the audio field in ChatCompletionStreamChoiceDelta. func verifyAudioInDelta(t *testing.T, expected, actual openai.ChatCompletionStreamChoiceDelta) { if expected.Audio != nil { if actual.Audio == nil { @@ -1078,8 +1078,11 @@ func verifyAudioInDelta(t *testing.T, expected, actual openai.ChatCompletionStre } } -// testDeltaSerialization tests JSON marshaling and unmarshaling of a delta -func testDeltaSerialization(t *testing.T, delta openai.ChatCompletionStreamChoiceDelta) openai.ChatCompletionStreamChoiceDelta { +// testDeltaSerialization tests JSON marshaling and unmarshaling of a delta. +func testDeltaSerialization( + t *testing.T, + delta openai.ChatCompletionStreamChoiceDelta, +) openai.ChatCompletionStreamChoiceDelta { // Test JSON marshaling data, err := json.Marshal(delta) if err != nil { @@ -1146,11 +1149,16 @@ func TestCreateChatCompletionStreamWithAudio(t *testing.T) { // Send test responses with audio dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) - data := `{"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}` + data := `{"id":"1","object":"chat.completion.chunk","created":1729585728,` + + `"model":"qwen-omni","choices":[{"index":0,"delta":{"content":"Hello"},` + + `"finish_reason":null}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) - data = `{"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"audio":{"transcript":"Hello, world!","data":"base64encodedaudiodata","expires_at":1234567890}},"finish_reason":null}]}` + data = `{"id":"2","object":"chat.completion.chunk","created":1729585728,` + + `"model":"qwen-omni","choices":[{"index":0,"delta":{"audio":` + + `{"transcript":"Hello, world!","data":"base64encodedaudiodata",` + + `"expires_at":1234567890}},"finish_reason":null}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) diff --git a/chat_test.go b/chat_test.go index d00c7031c..538feeeb2 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1232,7 +1232,7 @@ func TestChatCompletionRequest_MarshalJSON(t *testing.T) { }, Extensions: map[string]interface{}{ "custom_field": "custom_value", - "number": 42, + "number": 42, }, }, wantErr: false, @@ -1257,27 +1257,29 @@ func TestChatCompletionRequest_MarshalJSON(t *testing.T) { t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } - if !tt.wantErr { - var result map[string]interface{} - if err := json.Unmarshal(data, &result); err != nil { - t.Errorf("Failed to unmarshal result: %v", err) - return - } - - // Check that model is present - if result["model"] != tt.request.Model { - t.Errorf("Expected model %s, got %v", tt.request.Model, result["model"]) - } - - // Check extensions are merged properly when present - if len(tt.request.Extensions) > 0 { - for key, value := range tt.request.Extensions { - // Convert both to string for comparison to handle type differences - resultStr := fmt.Sprintf("%v", result[key]) - valueStr := fmt.Sprintf("%v", value) - if resultStr != valueStr { - t.Errorf("Expected extension %s = %v (%s), got %v (%s)", key, value, valueStr, result[key], resultStr) - } + if tt.wantErr { + return + } + + var result map[string]interface{} + if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { + t.Errorf("Failed to unmarshal result: %v", unmarshalErr) + return + } + + // Check that model is present + if result["model"] != tt.request.Model { + t.Errorf("Expected model %s, got %v", tt.request.Model, result["model"]) + } + + // Check extensions are merged properly when present + if len(tt.request.Extensions) > 0 { + for key, value := range tt.request.Extensions { + // Convert both to string for comparison to handle type differences + resultStr := fmt.Sprintf("%v", result[key]) + valueStr := fmt.Sprintf("%v", value) + if resultStr != valueStr { + t.Errorf("Expected extension %s = %v (%s), got %v (%s)", key, value, valueStr, result[key], resultStr) } } } @@ -1340,8 +1342,8 @@ func TestChatMessagePart_NewFields(t *testing.T) { // Test JSON unmarshaling var result openai.ChatMessagePart - if err := json.Unmarshal(data, &result); err != nil { - t.Errorf("Failed to unmarshal ChatMessagePart: %v", err) + if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { + t.Errorf("Failed to unmarshal ChatMessagePart: %v", unmarshalErr) return } @@ -1366,8 +1368,8 @@ func TestInputAudio(t *testing.T) { } var result openai.InputAudio - if err := json.Unmarshal(data, &result); err != nil { - t.Errorf("Failed to unmarshal InputAudio: %v", err) + if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { + t.Errorf("Failed to unmarshal InputAudio: %v", unmarshalErr) return } @@ -1391,8 +1393,8 @@ func TestCacheControl(t *testing.T) { } var result openai.CacheControl - if err := json.Unmarshal(data, &result); err != nil { - t.Errorf("Failed to unmarshal CacheControl: %v", err) + if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { + t.Errorf("Failed to unmarshal CacheControl: %v", unmarshalErr) return } From e0c32444c43a41124519a9fc6dd3f143c9974c38 Mon Sep 17 00:00:00 2001 From: zhangjk Date: Wed, 15 Oct 2025 17:05:26 +0800 Subject: [PATCH 7/9] =?UTF-8?q?refactor:=20=E8=BF=9B=E4=B8=80=E6=AD=A5?= =?UTF-8?q?=E9=99=8D=E4=BD=8E=20TestChatCompletionRequest=5FMarshalJSON=20?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E7=9A=84=E8=AE=A4=E7=9F=A5=E5=A4=8D=E6=9D=82?= =?UTF-8?q?=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将验证逻辑提取到单独的函数 validateChatCompletionRequestResult - 将扩展验证逻辑提取到单独的函数 validateExtensions - 降低主函数的认知复杂度从 25 到 20 以下 --- chat_test.go | 50 +++++++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/chat_test.go b/chat_test.go index 538feeeb2..d61507099 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1232,7 +1232,7 @@ func TestChatCompletionRequest_MarshalJSON(t *testing.T) { }, Extensions: map[string]interface{}{ "custom_field": "custom_value", - "number": 42, + "number": 42, }, }, wantErr: false, @@ -1260,30 +1260,38 @@ func TestChatCompletionRequest_MarshalJSON(t *testing.T) { if tt.wantErr { return } - + var result map[string]interface{} if unmarshalErr := json.Unmarshal(data, &result); unmarshalErr != nil { t.Errorf("Failed to unmarshal result: %v", unmarshalErr) return } - - // Check that model is present - if result["model"] != tt.request.Model { - t.Errorf("Expected model %s, got %v", tt.request.Model, result["model"]) - } - - // Check extensions are merged properly when present - if len(tt.request.Extensions) > 0 { - for key, value := range tt.request.Extensions { - // Convert both to string for comparison to handle type differences - resultStr := fmt.Sprintf("%v", result[key]) - valueStr := fmt.Sprintf("%v", value) - if resultStr != valueStr { - t.Errorf("Expected extension %s = %v (%s), got %v (%s)", key, value, valueStr, result[key], resultStr) - } - } - } + + validateChatCompletionRequestResult(t, tt.request, result) }) + }) +} + +func validateChatCompletionRequestResult(t *testing.T, request openai.ChatCompletionRequest, result map[string]interface{}) { + // Check that model is present + if result["model"] != request.Model { + t.Errorf("Expected model %s, got %v", request.Model, result["model"]) + } + + // Check extensions are merged properly when present + if len(request.Extensions) > 0 { + validateExtensions(t, request.Extensions, result) + } +} + +func validateExtensions(t *testing.T, extensions map[string]interface{}, result map[string]interface{}) { + for key, value := range extensions { + // Convert both to string for comparison to handle type differences + resultStr := fmt.Sprintf("%v", result[key]) + valueStr := fmt.Sprintf("%v", value) + if resultStr != valueStr { + t.Errorf("Expected extension %s = %v (%s), got %v (%s)", key, value, valueStr, result[key], resultStr) + } } } @@ -1295,8 +1303,8 @@ func TestChatMessagePart_NewFields(t *testing.T) { { name: "with audio part", part: openai.ChatMessagePart{ - Type: openai.ChatMessagePartTypeAudio, - Video: []string{"https://example.com/frame1.jpg", "https://example.com/frame2.jpg"}, + Type: openai.ChatMessagePartTypeAudio, + Video: []string{"https://example.com/frame1.jpg", "https://example.com/frame2.jpg"}, MinPixels: 100, MaxPixels: 1000, }, From b40fbfade55d25e85bb94585fb6bfe726a77daa8 Mon Sep 17 00:00:00 2001 From: zhangjk Date: Wed, 15 Oct 2025 17:16:02 +0800 Subject: [PATCH 8/9] fix PR check error --- chat_stream_test.go | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index c5b2e3de2..e39bb6077 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1052,7 +1052,7 @@ func TestOutputAudio(t *testing.T) { } } -// verifyAudioContent checks if the audio content matches between expected and actual. +// verifyAudioContent checks if the audio content matches between expected and actual func verifyAudioContent(t *testing.T, expected, actual *openai.OutputAudio) { if actual.Transcript != expected.Transcript { t.Errorf("Expected audio transcript %s, got %s", expected.Transcript, actual.Transcript) @@ -1065,7 +1065,7 @@ func verifyAudioContent(t *testing.T, expected, actual *openai.OutputAudio) { } } -// verifyAudioInDelta verifies the audio field in ChatCompletionStreamChoiceDelta. +// verifyAudioInDelta verifies the audio field in ChatCompletionStreamChoiceDelta func verifyAudioInDelta(t *testing.T, expected, actual openai.ChatCompletionStreamChoiceDelta) { if expected.Audio != nil { if actual.Audio == nil { @@ -1078,11 +1078,8 @@ func verifyAudioInDelta(t *testing.T, expected, actual openai.ChatCompletionStre } } -// testDeltaSerialization tests JSON marshaling and unmarshaling of a delta. -func testDeltaSerialization( - t *testing.T, - delta openai.ChatCompletionStreamChoiceDelta, -) openai.ChatCompletionStreamChoiceDelta { +// testDeltaSerialization tests JSON marshaling and unmarshaling of a delta +func testDeltaSerialization(t *testing.T, delta openai.ChatCompletionStreamChoiceDelta) openai.ChatCompletionStreamChoiceDelta { // Test JSON marshaling data, err := json.Marshal(delta) if err != nil { @@ -1149,16 +1146,11 @@ func TestCreateChatCompletionStreamWithAudio(t *testing.T) { // Send test responses with audio dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) - data := `{"id":"1","object":"chat.completion.chunk","created":1729585728,` + - `"model":"qwen-omni","choices":[{"index":0,"delta":{"content":"Hello"},` + - `"finish_reason":null}]}` + data := `{"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) - data = `{"id":"2","object":"chat.completion.chunk","created":1729585728,` + - `"model":"qwen-omni","choices":[{"index":0,"delta":{"audio":` + - `{"transcript":"Hello, world!","data":"base64encodedaudiodata",` + - `"expires_at":1234567890}},"finish_reason":null}]}` + data = `{"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"qwen-omni","choices":[{"index":0,"delta":{"audio":{"transcript":"Hello, world!","data":"base64encodedaudiodata","expires_at":1234567890}},"finish_reason":null}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) From 2eb051bf80920c1a6e5daf417ececee1ffb2fa72 Mon Sep 17 00:00:00 2001 From: zhangjk Date: Wed, 15 Oct 2025 17:17:48 +0800 Subject: [PATCH 9/9] fix PR check error again --- chat_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat_test.go b/chat_test.go index d61507099..90219c21e 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1269,7 +1269,7 @@ func TestChatCompletionRequest_MarshalJSON(t *testing.T) { validateChatCompletionRequestResult(t, tt.request, result) }) - }) + } } func validateChatCompletionRequestResult(t *testing.T, request openai.ChatCompletionRequest, result map[string]interface{}) {