Skip to content

Commit 16e5334

Browse files
fix: consolidate llm client logic so all requests have consistent headers [IDE-1350] (#113)
1 parent d916785 commit 16e5334

File tree

10 files changed

+108
-66
lines changed

10 files changed

+108
-66
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ Implement the `config.Config` interface to configure the Snyk Code API client fr
8686

8787
Use the Code Scanner to trigger a scan for a Snyk Code workspace using the Bundle Manager.
8888

89-
The Code Scanner exposes a `UploadAndAnalyze` function, which can be used like this:
89+
The Code Scanner exposes two scanning functions: `UploadAndAnalyze` (which supports Code Consistent Ignores) and
90+
`UploadAndAnalyzeLegacy`. These functions may be used like this:
9091

9192
```go
9293
import (
@@ -102,7 +103,12 @@ codeScanner := codeClient.NewCodeScanner(
102103
codeClientHTTP.WithInstrumentor(instrumentor),
103104
codeClientHTTP.WithErrorReporter(errorReporter),
104105
)
105-
codeScanner.UploadAndAnalyze(context.Background(), requestId, target, channelForWalkingFiles, changedFiles)
106+
if useCodeConsistentIgnores() {
107+
codeScanner.UploadAndAnalyze(context.Background(), requestId, target, channelForWalkingFiles, changedFiles)
108+
} else {
109+
codeScanner.UploadAndAnalyzeLegacy(context.Background(), requestId, target, shardKey, files, changedFiles, statusChannel)
110+
}
111+
106112
```
107113

108114

http/http.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/rs/zerolog"
2727

28+
"github.com/snyk/code-client-go/internal/util/encoding"
2829
"github.com/snyk/code-client-go/observability"
2930
)
3031

@@ -100,6 +101,8 @@ var retryErrorCodes = map[int]bool{
100101
http.StatusInternalServerError: true,
101102
}
102103

104+
const NoRequestId = ""
105+
103106
func (s *httpClient) Do(req *http.Request) (*http.Response, error) {
104107
span := s.instrumentor.StartSpan(req.Context(), "http.Do")
105108
defer s.instrumentor.Finish(span)
@@ -160,14 +163,43 @@ func NewDefaultClientFactory() HTTPClientFactory {
160163
return clientFunc
161164
}
162165

163-
func AddDefaultHeaders(req *http.Request, requestId string, orgId string) {
166+
func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method string) {
164167
// if requestId is empty it will be enriched from the Gateway
165168
if len(requestId) > 0 {
166169
req.Header.Set("snyk-request-id", requestId)
167170
}
168171
if len(orgId) > 0 {
169172
req.Header.Set("snyk-org-name", orgId)
170173
}
174+
175+
// https://www.keycdn.com/blog/http-cache-headers
171176
req.Header.Set("Cache-Control", "private, max-age=0, no-cache")
172-
req.Header.Set("Content-Type", "application/json")
177+
178+
if mustBeEncoded(method) {
179+
req.Header.Set("Content-Type", "application/octet-stream")
180+
req.Header.Set("Content-Encoding", "gzip")
181+
} else {
182+
req.Header.Set("Content-Type", "application/json")
183+
}
184+
}
185+
186+
// EncodeIfNeeded returns a byte buffer for the requestBody. Depending on the request method, it may encode the buffer.
187+
// (See http.mustBeEncoded for the list of methods which require encoding the request body.)
188+
func EncodeIfNeeded(method string, requestBody []byte) (*bytes.Buffer, error) {
189+
b := new(bytes.Buffer)
190+
if mustBeEncoded(method) {
191+
enc := encoding.NewEncoder(b)
192+
_, err := enc.Write(requestBody)
193+
if err != nil {
194+
return nil, err
195+
}
196+
} else {
197+
b = bytes.NewBuffer(requestBody)
198+
}
199+
return b, nil
200+
}
201+
202+
// mustBeEncoded returns true if the request method requires the request body to be encoded.
203+
func mustBeEncoded(method string) bool {
204+
return method == http.MethodPost || method == http.MethodPut
173205
}

internal/analysis/analysis_legacy.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ import (
2828
"net/url"
2929
"strings"
3030

31-
"github.com/snyk/code-client-go/scan"
32-
3331
codeClientHTTP "github.com/snyk/code-client-go/http"
3432
"github.com/snyk/code-client-go/sarif"
33+
"github.com/snyk/code-client-go/scan"
3534
)
3635

3736
// Legacy analysis types and constants
@@ -76,15 +75,21 @@ type FailedError struct {
7675
func (e FailedError) Error() string { return e.Msg }
7776

7877
// Legacy analysis helper functions
79-
func (a *analysisOrchestrator) newRequestContext() requestContext {
78+
func (a *analysisOrchestrator) newRequestContext(ctx context.Context) requestContext {
8079
unknown := "unknown"
8180
orgId := unknown
8281
if a.config.Organization() != "" {
8382
orgId = a.config.Organization()
8483
}
8584

85+
initiator := unknown
86+
contextInitiator, ok := scan.ScanSourceFromContext(ctx)
87+
if ok {
88+
initiator = string(contextInitiator)
89+
}
90+
8691
return requestContext{
87-
Initiator: "IDE",
92+
Initiator: initiator,
8893
Flow: "language-server",
8994
Org: requestContextOrg{
9095
Name: unknown,
@@ -94,15 +99,15 @@ func (a *analysisOrchestrator) newRequestContext() requestContext {
9499
}
95100
}
96101

97-
func (a *analysisOrchestrator) createRequestBody(bundleHash, shardKey string, limitToFiles []string, severity int) ([]byte, error) {
102+
func (a *analysisOrchestrator) createRequestBody(ctx context.Context, bundleHash, shardKey string, limitToFiles []string, severity int) ([]byte, error) {
98103
request := Request{
99104
Key: RequestKey{
100105
Type: "file",
101106
Hash: bundleHash,
102107
LimitToFiles: limitToFiles,
103108
},
104109
Legacy: false,
105-
AnalysisContext: a.newRequestContext(),
110+
AnalysisContext: a.newRequestContext(ctx),
106111
}
107112
if len(shardKey) > 0 {
108113
request.Key.Shard = shardKey
@@ -144,7 +149,7 @@ func (a *analysisOrchestrator) RunLegacyTest(ctx context.Context, bundleHash str
144149
a.logger.Debug().Str("method", method).Str("bundleHash", bundleHash).Msg("API: Retrieving analysis for bundle")
145150
defer a.logger.Debug().Str("method", method).Str("bundleHash", bundleHash).Msg("API: Retrieving analysis done")
146151

147-
requestBody, err := a.createRequestBody(bundleHash, shardKey, limitToFiles, severity)
152+
requestBody, err := a.createRequestBody(ctx, bundleHash, shardKey, limitToFiles, severity)
148153
if err != nil {
149154
a.logger.Err(err).Str("method", method).Str("requestBody", string(requestBody)).Msg("error creating request body")
150155
return nil, scan.LegacyScanStatus{}, err
@@ -158,12 +163,13 @@ func (a *analysisOrchestrator) RunLegacyTest(ctx context.Context, bundleHash str
158163

159164
// Create HTTP request
160165
analysisUrl := baseUrl + "/analysis"
161-
req, err := http.NewRequestWithContext(span.Context(), http.MethodPost, analysisUrl, bytes.NewBuffer(requestBody))
166+
httpMethod := http.MethodPost
167+
req, err := http.NewRequestWithContext(span.Context(), httpMethod, analysisUrl, bytes.NewBuffer(requestBody))
162168
if err != nil {
163169
a.logger.Err(err).Str("method", method).Msg("error creating HTTP request")
164170
return nil, scan.LegacyScanStatus{}, err
165171
}
166-
codeClientHTTP.AddDefaultHeaders(req, span.GetTraceId(), a.config.Organization())
172+
codeClientHTTP.AddDefaultHeaders(req, span.GetTraceId(), a.config.Organization(), httpMethod)
167173

168174
// Make HTTP call
169175
resp, err := a.httpClient.Do(req)

internal/analysis/analysis_legacy_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ func TestAnalysis_CreateRequestBody(t *testing.T) {
530530

531531
// run method under test
532532
_, _, err := analysisOrchestrator.RunLegacyTest(
533-
t.Context(),
533+
scan.NewContextWithScanSource(t.Context(), scan.IDE),
534534
bundleHash,
535535
shardKey,
536536
limitToFiles,
@@ -568,7 +568,7 @@ func TestAnalysis_CreateRequestBody(t *testing.T) {
568568

569569
// Validate analysisContext
570570
analysisContext := request["analysisContext"].(map[string]interface{})
571-
assert.Equal(t, "IDE", analysisContext["initiator"])
571+
assert.Equal(t, string(scan.IDE), analysisContext["initiator"])
572572
assert.Equal(t, "language-server", analysisContext["flow"])
573573

574574
org := analysisContext["org"].(map[string]interface{})

internal/deepcode/client.go

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package deepcode
1818

1919
import (
20-
"bytes"
2120
"context"
2221
"encoding/json"
2322
"errors"
@@ -30,7 +29,6 @@ import (
3029

3130
"github.com/rs/zerolog"
3231
"github.com/snyk/code-client-go/config"
33-
"github.com/snyk/code-client-go/internal/util/encoding"
3432

3533
codeClientHTTP "github.com/snyk/code-client-go/http"
3634
"github.com/snyk/code-client-go/observability"
@@ -217,7 +215,7 @@ func (s *deepcodeClient) Request(
217215
return nil, err
218216
}
219217

220-
bodyBuffer, err := s.encodeIfNeeded(method, requestBody)
218+
bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(method, requestBody)
221219
if err != nil {
222220
return nil, err
223221
}
@@ -227,7 +225,7 @@ func (s *deepcodeClient) Request(
227225
return nil, err
228226
}
229227

230-
s.addHeaders(method, req)
228+
codeClientHTTP.AddDefaultHeaders(req, codeClientHTTP.NoRequestId, s.config.Organization(), method)
231229

232230
response, err := s.httpClient.Do(req)
233231
if err != nil {
@@ -255,41 +253,6 @@ func (s *deepcodeClient) Request(
255253
return responseBody, nil
256254
}
257255

258-
func (s *deepcodeClient) addHeaders(method string, req *http.Request) {
259-
// Setting a chosen org name for the request
260-
org := s.config.Organization()
261-
if org != "" {
262-
req.Header.Set("snyk-org-name", org)
263-
}
264-
// https://www.keycdn.com/blog/http-cache-headers
265-
req.Header.Set("Cache-Control", "private, max-age=0, no-cache")
266-
if s.mustBeEncoded(method) {
267-
req.Header.Set("Content-Type", "application/octet-stream")
268-
req.Header.Set("Content-Encoding", "gzip")
269-
} else {
270-
req.Header.Set("Content-Type", "application/json")
271-
}
272-
}
273-
274-
func (s *deepcodeClient) encodeIfNeeded(method string, requestBody []byte) (*bytes.Buffer, error) {
275-
b := new(bytes.Buffer)
276-
mustBeEncoded := s.mustBeEncoded(method)
277-
if mustBeEncoded {
278-
enc := encoding.NewEncoder(b)
279-
_, err := enc.Write(requestBody)
280-
if err != nil {
281-
return nil, err
282-
}
283-
} else {
284-
b = bytes.NewBuffer(requestBody)
285-
}
286-
return b, nil
287-
}
288-
289-
func (s *deepcodeClient) mustBeEncoded(method string) bool {
290-
return method == http.MethodPost || method == http.MethodPut
291-
}
292-
293256
func (s *deepcodeClient) checkResponseCode(r *http.Response) error {
294257
if r.StatusCode >= 200 && r.StatusCode <= 299 {
295258
return nil

llm/api_client.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package llm
22

33
import (
4-
"bytes"
54
"context"
65
"encoding/base64"
76
"encoding/json"
@@ -12,7 +11,7 @@ import (
1211
"net/url"
1312
"strings"
1413

15-
codeClientHTTP "github.com/snyk/code-client-go/http"
14+
http2 "github.com/snyk/code-client-go/http"
1615
)
1716

1817
var (
@@ -70,13 +69,20 @@ func (d *DeepCodeLLMBindingImpl) submitRequest(ctx context.Context, url *url.URL
7069
span := d.instrumentor.StartSpan(ctx, "code.SubmitRequest")
7170
defer span.Finish()
7271

73-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), bytes.NewBuffer(requestBody))
72+
// Encode the request body
73+
bodyBuffer, err := http2.EncodeIfNeeded(http.MethodPost, requestBody)
74+
if err != nil {
75+
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error encoding request body")
76+
return nil, err
77+
}
78+
79+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), bodyBuffer)
7480
if err != nil {
7581
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error creating request")
7682
return nil, err
7783
}
7884

79-
codeClientHTTP.AddDefaultHeaders(req, span.GetTraceId(), orgId)
85+
http2.AddDefaultHeaders(req, http2.NoRequestId, orgId, http.MethodPost)
8086

8187
resp, err := d.httpClientFunc().Do(req) //nolint:bodyclose // this seems to be a false positive
8288
if err != nil {

llm/api_client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import (
1010
"testing"
1111

1212
"github.com/rs/zerolog"
13+
http2 "github.com/snyk/code-client-go/http"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/require"
1516

16-
codeClientHTTP "github.com/snyk/code-client-go/http"
1717
"github.com/snyk/code-client-go/observability"
1818
)
1919

@@ -264,7 +264,7 @@ func testLogger(t *testing.T) *zerolog.Logger {
264264
func TestAddDefaultHeadersWithExistingHeaders(t *testing.T) {
265265
req := &http.Request{Header: http.Header{"Existing-Header": {"existing-value"}}}
266266

267-
codeClientHTTP.AddDefaultHeaders(req, "", "")
267+
http2.AddDefaultHeaders(req, http2.NoRequestId, "", http.MethodGet)
268268

269269
cacheControl := req.Header.Get("Cache-Control")
270270
contentType := req.Header.Get("Content-Type")

scan.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ type CodeScanner interface {
6666
changedFiles map[string]bool,
6767
) (*sarif.SarifResponse, string, error)
6868

69+
// UploadAndAnalyzeLegacy runs the legacy scanner (no consistent ignores)
70+
// ctx may include a scan.ScanSource value for use in the requestContext (see analysis_legacy.go)
6971
UploadAndAnalyzeLegacy(
7072
ctx context.Context,
7173
requestId string,

scan/resultmetadata.go

Lines changed: 0 additions & 6 deletions
This file was deleted.

scan/scan_metadata.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package scan
2+
3+
import "context"
4+
5+
type ResultMetaData struct {
6+
FindingsUrl string
7+
WebUiUrl string
8+
}
9+
10+
type ScanSource string
11+
12+
func (s ScanSource) String() string {
13+
return string(s)
14+
}
15+
16+
const (
17+
LLM ScanSource = "LLM"
18+
IDE ScanSource = "IDE"
19+
CLI ScanSource = "CLI"
20+
)
21+
22+
type scanSourceKeyType int
23+
24+
var scanSourceKey scanSourceKeyType
25+
26+
func NewContextWithScanSource(ctx context.Context, source ScanSource) context.Context {
27+
return context.WithValue(ctx, scanSourceKey, source)
28+
}
29+
30+
func ScanSourceFromContext(ctx context.Context) (ScanSource, bool) {
31+
s, ok := ctx.Value(scanSourceKey).(ScanSource)
32+
return s, ok
33+
}

0 commit comments

Comments
 (0)