diff --git a/pkg/provider/github/github.go b/pkg/provider/github/github.go index f482536dff..2928393b07 100644 --- a/pkg/provider/github/github.go +++ b/pkg/provider/github/github.go @@ -338,8 +338,17 @@ func (v *Provider) GetTektonDir(ctx context.Context, runevent *info.Event, path, // default set provenance from the SHA revision := runevent.SHA if provenance == "default_branch" { - revision = runevent.DefaultBranch v.Logger.Infof("Using PipelineRun definition from default_branch: %s", runevent.DefaultBranch) + branch, _, err := wrapAPI(v, "get_default_branch", func() (*github.Branch, *github.Response, error) { + return v.Client().Repositories.GetBranch(ctx, runevent.Organization, runevent.Repository, runevent.DefaultBranch, 1) + }) + if err != nil { + return "", err + } + revision = branch.GetCommit().GetSHA() + if revision == "" { + return "", fmt.Errorf("default_branch %s did not resolve to a commit SHA", runevent.DefaultBranch) + } } else { prInfo := "" if runevent.TriggerTarget == triggertype.PullRequest { @@ -379,7 +388,7 @@ func (v *Provider) GetTektonDir(ctx context.Context, runevent *info.Event, path, if err != nil { return "", err } - return v.concatAllYamlFiles(ctx, tektonDirObjects.Entries, runevent) + return v.concatAllYamlFiles(ctx, tektonDirObjects.Entries, runevent, path, revision) } // GetCommitInfo get info (url and title) on a commit in runevent, this needs to @@ -473,25 +482,53 @@ func (v *Provider) GetFileInsideRepo(ctx context.Context, runevent *info.Event, } // concatAllYamlFiles concat all yaml files from a directory as one big multi document yaml string. -func (v *Provider) concatAllYamlFiles(ctx context.Context, objects []*github.TreeEntry, runevent *info.Event) (string, error) { - var allTemplates string - +// tektonDirPath is the path to the .tekton directory (e.g., ".tekton") used to construct full paths. +func (v *Provider) concatAllYamlFiles(ctx context.Context, objects []*github.TreeEntry, runevent *info.Event, tektonDirPath, ref string) (string, error) { + // Collect all YAML file paths and preserve order + // Tree entries have paths relative to the .tekton directory, so prepend tektonDirPath + var yamlFiles []string for _, value := range objects { if strings.HasSuffix(value.GetPath(), ".yaml") || strings.HasSuffix(value.GetPath(), ".yml") { - data, err := v.getObject(ctx, value.GetSHA(), runevent) - if err != nil { - return "", err - } - if err := provider.ValidateYaml(data, value.GetPath()); err != nil { - return "", err - } - if allTemplates != "" && !strings.HasPrefix(string(data), "---") { - allTemplates += "---" - } - allTemplates += "\n" + string(data) + "\n" + // Construct full path from repo root for GraphQL query + fullPath := tektonDirPath + "/" + value.GetPath() + yamlFiles = append(yamlFiles, fullPath) } } + + if len(yamlFiles) == 0 { + return "", nil + } + + client, err := newGraphQLClient(v) + if err != nil { + return "", fmt.Errorf("failed to create GraphQL client: %w", err) + } + graphQLResults, err := client.fetchFiles(ctx, runevent.Organization, runevent.Repository, ref, yamlFiles) + if err != nil { + return "", fmt.Errorf("failed to fetch .tekton files via GraphQL: %w", err) + } + + var allTemplates string + + for _, path := range yamlFiles { + content, ok := graphQLResults[path] + if !ok { + return "", fmt.Errorf("file %s not found in GraphQL response", path) + } + + // it used to be like that (stripped prefix) before we moved to GraphQL so + // let's keep it that way. + relativePath := strings.TrimPrefix(path, tektonDirPath+"/") + if err := provider.ValidateYaml(content, relativePath); err != nil { + return "", err + } + if allTemplates != "" && !strings.HasPrefix(string(content), "---") { + allTemplates += "---" + } + allTemplates += "\n" + string(content) + "\n" + } + return allTemplates, nil } diff --git a/pkg/provider/github/github_test.go b/pkg/provider/github/github_test.go index 511194f90b..fec26d7421 100644 --- a/pkg/provider/github/github_test.go +++ b/pkg/provider/github/github_test.go @@ -253,8 +253,8 @@ func TestGetTektonDir(t *testing.T) { filterMessageSnippet: "Using PipelineRun definition from source pull_request tekton/cat#0", // 1. Get Repo root objects // 2. Get Tekton Dir objects - // 3/4. Get object content for each object (pipelinerun.yaml, pipeline.yaml) - expectedGHApiCalls: 4, + // 3. GraphQL batch fetch for 2 files (replaces 2 REST calls) + expectedGHApiCalls: 3, }, { name: "test no subtree on push", @@ -269,8 +269,8 @@ func TestGetTektonDir(t *testing.T) { filterMessageSnippet: "Using PipelineRun definition from source push", // 1. Get Repo root objects // 2. Get Tekton Dir objects - // 3/4. Get object content for each object (pipelinerun.yaml, pipeline.yaml) - expectedGHApiCalls: 4, + // 3. GraphQL batch fetch for 2 files (replaces 2 REST calls) + expectedGHApiCalls: 3, }, { name: "test provenance default_branch ", @@ -283,9 +283,10 @@ func TestGetTektonDir(t *testing.T) { treepath: "testdata/tree/defaultbranch", provenance: "default_branch", filterMessageSnippet: "Using PipelineRun definition from default_branch: main", - // 1. Get Repo root objects - // 2. Get Tekton Dir objects - // 3/4. Get object content for each object (pipelinerun.yaml, pipeline.yaml) + // 1. Resolve default branch to a commit SHA + // 2. Get Repo root objects + // 3. Get Tekton Dir objects + // 4. GraphQL batch fetch for 2 files expectedGHApiCalls: 4, }, { @@ -371,11 +372,21 @@ func TestGetTektonDir(t *testing.T) { } }() + shaDir := fmt.Sprintf("%x", sha256.Sum256([]byte(tt.treepath))) + tt.event.SHA = shaDir if tt.provenance == "default_branch" { - tt.event.SHA = tt.event.DefaultBranch - } else { - shaDir := fmt.Sprintf("%x", sha256.Sum256([]byte(tt.treepath))) - tt.event.SHA = shaDir + mux.HandleFunc(fmt.Sprintf("/repos/%s/%s/branches/%s", + tt.event.Organization, tt.event.Repository, tt.event.DefaultBranch), + func(rw http.ResponseWriter, _ *http.Request) { + branch := &github.Branch{ + Name: github.Ptr(tt.event.DefaultBranch), + Commit: &github.RepositoryCommit{ + SHA: github.Ptr(shaDir), + }, + } + b, _ := json.Marshal(branch) + fmt.Fprint(rw, string(b)) + }) } ghtesthelper.SetupGitTree(t, mux, tt.treepath, tt.event, false) @@ -403,6 +414,216 @@ func TestGetTektonDir(t *testing.T) { } } +func TestGetTektonDir_GraphQLBatchFetch(t *testing.T) { + // Test that GraphQL is used for batch fetching multiple files + metricsutils.ResetMetrics() + observer, exporter := zapobserver.New(zap.DebugLevel) + fakelogger := zap.New(observer).Sugar() + ctx, _ := rtesting.SetupFakeContext(t) + fakeclient, mux, _, teardown := ghtesthelper.SetupGH() + defer teardown() + gvcs := Provider{ + ghClient: fakeclient, + providerName: "github", + Logger: fakelogger, + } + + event := &info.Event{ + Organization: "tekton", + Repository: "cat", + SHA: "123", + TriggerTarget: triggertype.PullRequest, + } + shaDir := fmt.Sprintf("%x", sha256.Sum256([]byte("testdata/tree/simple"))) + event.SHA = shaDir + ghtesthelper.SetupGitTree(t, mux, "testdata/tree/simple", event, false) + + got, err := gvcs.GetTektonDir(ctx, event, ".tekton", "") + assert.NilError(t, err) + assert.Assert(t, strings.Contains(got, "PipelineRun"), "expected PipelineRun in output, got %s", got) + + // Verify GraphQL was used (check logs) + graphQLLogs := exporter.FilterMessageSnippet("GraphQL batch fetch") + assert.Assert(t, graphQLLogs.Len() > 0, "expected GraphQL batch fetch log message") + + // Verify reduced API calls: 1 root tree + 1 tekton tree + 1 GraphQL = 3 (instead of 4) + metricstest.CheckCountData( + t, + "pipelines_as_code_git_provider_api_request_count", + map[string]string{"provider": "github"}, + 3, + ) +} + +func TestGetTektonDir_GraphQLError(t *testing.T) { + // Test that GraphQL errors are properly returned + metricsutils.ResetMetrics() + observer, _ := zapobserver.New(zap.DebugLevel) + fakelogger := zap.New(observer).Sugar() + ctx, _ := rtesting.SetupFakeContext(t) + fakeclient, mux, _, teardown := ghtesthelper.SetupGH() + defer teardown() + + // Register error handler on this mux (SetupGitTree is not called in this test, + // so this is the only /api/graphql handler) + mux.HandleFunc("/api/graphql", func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "GraphQL endpoint not available", http.StatusNotFound) + }) + + gvcs := Provider{ + ghClient: fakeclient, + providerName: "github", + Logger: fakelogger, + } + + event := &info.Event{ + Organization: "tekton", + Repository: "cat", + SHA: "123", + TriggerTarget: triggertype.PullRequest, + } + shaDir := fmt.Sprintf("%x", sha256.Sum256([]byte("testdata/tree/simple"))) + event.SHA = shaDir + + // Setup tree endpoints manually (skip SetupGitTree to avoid GraphQL handler registration) + // Set up root tree + mux.HandleFunc(fmt.Sprintf("/repos/%v/%v/git/trees/%v", event.Organization, event.Repository, event.SHA), + func(rw http.ResponseWriter, _ *http.Request) { + tree := &github.Tree{ + SHA: &event.SHA, + Entries: []*github.TreeEntry{ + { + Path: github.Ptr(".tekton"), + Type: github.Ptr("tree"), + SHA: github.Ptr("tektondirsha"), + }, + }, + } + b, _ := json.Marshal(tree) + fmt.Fprint(rw, string(b)) + }) + + // Set up .tekton directory tree + tektonDirSha := "tektondirsha" + mux.HandleFunc(fmt.Sprintf("/repos/%v/%v/git/trees/%v", event.Organization, event.Repository, tektonDirSha), + func(rw http.ResponseWriter, _ *http.Request) { + tree := &github.Tree{ + SHA: &tektonDirSha, + Entries: []*github.TreeEntry{ + { + Path: github.Ptr("pipeline.yaml"), + Type: github.Ptr("blob"), + SHA: github.Ptr("pipelinesha"), + }, + { + Path: github.Ptr("pipelinerun.yaml"), + Type: github.Ptr("blob"), + SHA: github.Ptr("pipelinerunsha"), + }, + }, + } + b, _ := json.Marshal(tree) + fmt.Fprint(rw, string(b)) + }) + + _, err := gvcs.GetTektonDir(ctx, event, ".tekton", "") + assert.ErrorContains(t, err, "failed to fetch .tekton files via GraphQL") +} + +func TestGetTektonDir_DefaultBranchUsesResolvedSHAForGraphQL(t *testing.T) { + metricsutils.ResetMetrics() + observer, _ := zapobserver.New(zap.DebugLevel) + fakelogger := zap.New(observer).Sugar() + ctx, _ := rtesting.SetupFakeContext(t) + fakeclient, mux, _, teardown := ghtesthelper.SetupGH() + defer teardown() + + gvcs := Provider{ + ghClient: fakeclient, + providerName: "github", + Logger: fakelogger, + } + + event := &info.Event{ + Organization: "tekton", + Repository: "cat", + DefaultBranch: "main", + } + resolvedSHA := "resolved-default-branch-sha" + tektonDirSHA := "tektondirsha" + + mux.HandleFunc("/repos/tekton/cat/branches/main", func(rw http.ResponseWriter, _ *http.Request) { + branch := &github.Branch{ + Name: github.Ptr("main"), + Commit: &github.RepositoryCommit{ + SHA: github.Ptr(resolvedSHA), + }, + } + b, _ := json.Marshal(branch) + fmt.Fprint(rw, string(b)) + }) + mux.HandleFunc("/repos/tekton/cat/git/trees/"+resolvedSHA, func(rw http.ResponseWriter, _ *http.Request) { + tree := &github.Tree{ + SHA: github.Ptr(resolvedSHA), + Entries: []*github.TreeEntry{ + { + Path: github.Ptr(".tekton"), + Type: github.Ptr("tree"), + SHA: github.Ptr(tektonDirSHA), + }, + }, + } + b, _ := json.Marshal(tree) + fmt.Fprint(rw, string(b)) + }) + mux.HandleFunc("/repos/tekton/cat/git/trees/"+tektonDirSHA, func(rw http.ResponseWriter, _ *http.Request) { + tree := &github.Tree{ + SHA: github.Ptr(tektonDirSHA), + Entries: []*github.TreeEntry{ + { + Path: github.Ptr("pipeline.yaml"), + Type: github.Ptr("blob"), + SHA: github.Ptr("pipeline-sha"), + }, + { + Path: github.Ptr("pipelinerun.yaml"), + Type: github.Ptr("blob"), + SHA: github.Ptr("pipelinerun-sha"), + }, + }, + } + b, _ := json.Marshal(tree) + fmt.Fprint(rw, string(b)) + }) + mux.HandleFunc("/api/graphql", func(w http.ResponseWriter, r *http.Request) { + var graphQLReq struct { + Query string `json:"query"` + } + assert.NilError(t, json.NewDecoder(r.Body).Decode(&graphQLReq)) + assert.Assert(t, strings.Contains(graphQLReq.Query, resolvedSHA+":.tekton/pipeline.yaml"), graphQLReq.Query) + assert.Assert(t, !strings.Contains(graphQLReq.Query, `main:.tekton/pipeline.yaml`), graphQLReq.Query) + + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "file0": map[string]any{"text": "kind: Pipeline\nmetadata:\n name: pipeline\n"}, + "file1": map[string]any{"text": "kind: PipelineRun\nmetadata:\n name: run\n"}, + }, + }, + }) + }) + + got, err := gvcs.GetTektonDir(ctx, event, ".tekton", "default_branch") + assert.NilError(t, err) + assert.Assert(t, strings.Contains(got, "PipelineRun")) + metricstest.CheckCountData( + t, + "pipelines_as_code_git_provider_api_request_count", + map[string]string{"provider": "github"}, + 4, + ) +} + func TestGetFileInsideRepo(t *testing.T) { testGetTektonDir := []struct { name string diff --git a/pkg/provider/github/graphql.go b/pkg/provider/github/graphql.go new file mode 100644 index 0000000000..e163251cf4 --- /dev/null +++ b/pkg/provider/github/graphql.go @@ -0,0 +1,276 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/url" + "strings" + "time" + + "github.com/google/go-github/v81/github" + "github.com/openshift-pipelines/pipelines-as-code/pkg/apis/pipelinesascode/v1alpha1" + providerMetrics "github.com/openshift-pipelines/pipelines-as-code/pkg/provider/providermetrics" + "go.uber.org/zap" +) + +// graphQLClient handles GraphQL API requests for fetching file contents. +type graphQLClient struct { + httpClient *http.Client + ghClient *github.Client + endpoint string + logger *zap.SugaredLogger + provider *Provider + triggerEvent string + repo *v1alpha1.Repository +} + +// newGraphQLClient creates a new GraphQL client from a GitHub provider. +func newGraphQLClient(p *Provider) (*graphQLClient, error) { + httpClient := p.Client().Client() + if httpClient == nil { + return nil, fmt.Errorf("GitHub client HTTP client is nil") + } + + endpoint, err := buildGraphQLEndpoint(p) + if err != nil { + return nil, fmt.Errorf("failed to build GraphQL endpoint: %w", err) + } + + return &graphQLClient{ + httpClient: httpClient, + ghClient: p.Client(), + endpoint: endpoint, + logger: p.Logger, + provider: p, + triggerEvent: p.triggerEvent, + repo: p.repo, + }, nil +} + +// buildGraphQLEndpoint constructs the GraphQL API endpoint URL from the GitHub client's BaseURL. +func buildGraphQLEndpoint(p *Provider) (string, error) { + baseURL := p.Client().BaseURL.String() + baseURL = strings.TrimSuffix(baseURL, "/") + + // For GitHub.com, use standard GraphQL endpoint + // apiPublicURL has a trailing slash which TrimSuffix above removes, + // so compare directly with the slash-less form. + if baseURL == "https://api.github.com" { + return "https://api.github.com/graphql", nil + } + + // For GHE and test servers, construct GraphQL endpoint from the base URL + // BaseURL could be: + // - https://ghe.example.com/api/v3/ -> https://ghe.example.com/api/graphql + // - http://127.0.0.1:PORT/api/v3/ -> http://127.0.0.1:PORT/api/graphql + parsedURL, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("failed to parse BaseURL: %w", err) + } + + // Replace /api/v3 with /api/graphql in the path + path := parsedURL.Path + if strings.HasSuffix(path, "/api/v3") || strings.HasSuffix(path, "/api/v3/") { + path = strings.TrimSuffix(path, "/api/v3/") + path = strings.TrimSuffix(path, "/api/v3") + path += "/api/graphql" + } else { + // Fallback: just use the host with /api/graphql + path = "/api/graphql" + } + + parsedURL.Path = path + return parsedURL.String(), nil +} + +// buildGraphQLQuery constructs a GraphQL query string with aliases for batch fetching multiple files. +func buildGraphQLQuery(ref string, paths []string) string { + // Escape ref for GraphQL string (escape quotes and backslashes) + escapedRef := strings.ReplaceAll(ref, `\`, `\\`) + escapedRef = strings.ReplaceAll(escapedRef, `"`, `\"`) + + aliases := make([]string, 0, len(paths)) + for i, path := range paths { + // Escape path for GraphQL string (escape quotes and backslashes) + escapedPath := strings.ReplaceAll(path, `\`, `\\`) + escapedPath = strings.ReplaceAll(escapedPath, `"`, `\"`) + aliases = append(aliases, fmt.Sprintf(` file%d: object(expression: "%s:%s") { + ... on Blob { + text + } + }`, i, escapedRef, escapedPath)) + } + + query := fmt.Sprintf(`query($owner: String!, $name: String!) { + repository(owner: $owner, name: $name) { +%s + } +}`, strings.Join(aliases, "\n")) + + return query +} + +// graphQLResponse represents the structure of a GraphQL API response. +type graphQLResponse struct { + Data struct { + Repository map[string]struct { + Text *string `json:"text"` + } `json:"repository"` + } `json:"data"` + Errors []struct { + Message string `json:"message"` + } `json:"errors,omitempty"` +} + +type rateLimitHeaders struct { + limit string + remaining string + reset string +} + +func getRateLimitHeaders(header http.Header) rateLimitHeaders { + return rateLimitHeaders{ + limit: header.Get("X-RateLimit-Limit"), + remaining: header.Get("X-RateLimit-Remaining"), + reset: header.Get("X-RateLimit-Reset"), + } +} + +// fetchFiles fetches multiple file contents using GraphQL batch queries. +// Returns a map of path -> content. +func (c *graphQLClient) fetchFiles(ctx context.Context, owner, repo, ref string, paths []string) (map[string][]byte, error) { + if len(paths) == 0 { + return make(map[string][]byte), nil + } + + // Limit batch size to avoid query complexity issues + const maxBatchSize = 50 + result := make(map[string][]byte, len(paths)) + for start := 0; start < len(paths); start += maxBatchSize { + end := min(start+maxBatchSize, len(paths)) + batch := paths[start:end] + batchResult, err := c.fetchFilesBatch(ctx, owner, repo, ref, batch) + if err != nil { + return nil, err + } + maps.Copy(result, batchResult) + } + + return result, nil +} + +// fetchFilesBatch fetches multiple file contents in a single GraphQL query. +// Returns a map of path -> content. +func (c *graphQLClient) fetchFilesBatch(ctx context.Context, owner, repo, ref string, paths []string) (map[string][]byte, error) { + if len(paths) == 0 { + return make(map[string][]byte), nil + } + + query := buildGraphQLQuery(ref, paths) + variables := map[string]any{ + "owner": owner, + "name": repo, + } + + requestBody := map[string]any{ + "query": query, + "variables": variables, + } + + req, err := c.ghClient.NewRequest(http.MethodPost, c.endpoint, requestBody) + if err != nil { + return nil, fmt.Errorf("failed to create GraphQL request: %w", err) + } + req = req.WithContext(ctx) + + // Record metrics for GraphQL API call + if c.logger != nil { + providerMetrics.RecordAPIUsage(c.logger, c.provider.providerName, c.triggerEvent, c.repo) + } + + start := time.Now() + resp, err := c.httpClient.Do(req) + duration := time.Since(start) + + if err != nil { + if c.logger != nil { + c.logger.Debugw("GraphQL request failed", + "error", err.Error(), + "duration_ms", duration.Milliseconds(), + ) + } + return nil, fmt.Errorf("GraphQL request failed: %w", err) + } + defer resp.Body.Close() + + rateLimit := getRateLimitHeaders(resp.Header) + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read GraphQL response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + if c.logger != nil { + c.logger.Debugw("GraphQL request returned non-200 status", + "status_code", resp.StatusCode, + "response", string(body), + "rate_limit", rateLimit.limit, + "rate_limit_remaining", rateLimit.remaining, + "rate_limit_reset", rateLimit.reset, + ) + } + return nil, fmt.Errorf("GraphQL request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var graphQLResp graphQLResponse + if err := json.Unmarshal(body, &graphQLResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal GraphQL response: %w", err) + } + + // Check for GraphQL errors + if len(graphQLResp.Errors) > 0 { + errorMsgs := make([]string, len(graphQLResp.Errors)) + for i, e := range graphQLResp.Errors { + errorMsgs[i] = e.Message + } + if c.logger != nil { + c.logger.Debugw("GraphQL returned errors", + "errors", strings.Join(errorMsgs, "; "), + ) + } + return nil, fmt.Errorf("GraphQL errors: %s", strings.Join(errorMsgs, "; ")) + } + + // Extract file contents from response + result := make(map[string][]byte, len(paths)) + for i, path := range paths { + alias := fmt.Sprintf("file%d", i) + blobData, ok := graphQLResp.Data.Repository[alias] + if !ok { + return nil, fmt.Errorf("file %s (alias %s) not found in GraphQL response", path, alias) + } + + if blobData.Text == nil { + return nil, fmt.Errorf("file %s returned null content (may be binary)", path) + } + + result[path] = []byte(*blobData.Text) + } + + if c.logger != nil { + c.logger.Debugw("GraphQL batch fetch completed", + "files_requested", len(paths), + "duration_ms", duration.Milliseconds(), + "rate_limit", rateLimit.limit, + "rate_limit_remaining", rateLimit.remaining, + "rate_limit_reset", rateLimit.reset, + ) + } + + return result, nil +} diff --git a/pkg/provider/github/graphql_test.go b/pkg/provider/github/graphql_test.go new file mode 100644 index 0000000000..46a4a1d935 --- /dev/null +++ b/pkg/provider/github/graphql_test.go @@ -0,0 +1,227 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/google/go-github/v81/github" + "github.com/openshift-pipelines/pipelines-as-code/pkg/apis/pipelinesascode/v1alpha1" + "go.uber.org/zap" + zapobserver "go.uber.org/zap/zaptest/observer" + "gotest.tools/v3/assert" + "gotest.tools/v3/assert/cmp" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +/* ---------------- helpers ---------------- */ + +func newTestProvider(baseURL string) (*Provider, *zapobserver.ObservedLogs) { + client := github.NewClient(nil) + parsed, _ := url.Parse(baseURL) + client.BaseURL = parsed + + core, observedLogs := zapobserver.New(zap.DebugLevel) + logger := zap.New(core).Sugar() + + return &Provider{ + ghClient: client, + Logger: logger, + providerName: "github", + triggerEvent: "push", + repo: &v1alpha1.Repository{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-ns", + Name: "test-repo", + }, + }, + }, observedLogs +} + +func newTestGraphQLClient(t *testing.T, baseURL string) (*graphQLClient, *zapobserver.ObservedLogs) { + t.Helper() + provider, observedLogs := newTestProvider(baseURL) + c, err := newGraphQLClient(provider) + assert.NilError(t, err) + return c, observedLogs +} + +func withServer(t *testing.T, h http.Handler) *httptest.Server { + t.Helper() + s := httptest.NewServer(h) + t.Cleanup(s.Close) + return s +} + +func graphqlOK(repo map[string]any) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-RateLimit-Limit", "5000") + w.Header().Set("X-RateLimit-Remaining", "4999") + w.Header().Set("X-RateLimit-Reset", "1735689600") + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{"repository": repo}, + }) + } +} + +func graphqlStatus(code int) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-RateLimit-Limit", "5000") + w.Header().Set("X-RateLimit-Remaining", "4998") + w.Header().Set("X-RateLimit-Reset", "1735689601") + w.WriteHeader(code) + } +} + +func mockRepo(n int) map[string]any { + repo := make(map[string]any) + for i := range n { + repo[fmt.Sprintf("file%d", i)] = map[string]any{ + "text": fmt.Sprintf("content-%d", i), + } + } + return repo +} + +/* ---------------- tests ---------------- */ + +func TestBuildGraphQLEndpoint(t *testing.T) { + cases := []struct { + name string + base string + want string + }{ + {"public", "https://api.github.com", "https://api.github.com/graphql"}, + {"public slash", "https://api.github.com/", "https://api.github.com/graphql"}, + {"ghe v3", "https://ghe/x/api/v3", "https://ghe/x/api/graphql"}, + {"ghe v3 slash", "https://ghe/x/api/v3/", "https://ghe/x/api/graphql"}, + {"ghe root", "https://ghe", "https://ghe/api/graphql"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(nil) + parsed, _ := url.Parse(tc.base) + client.BaseURL = parsed + + got, err := buildGraphQLEndpoint(&Provider{ghClient: client}) + assert.NilError(t, err) + assert.Check(t, cmp.Equal(tc.want, got)) + }) + } +} + +func TestBuildGraphQLQuery(t *testing.T) { + cases := []struct { + name string + ref string + paths []string + want []string + }{ + { + name: "two files", + ref: "main", + paths: []string{"a", "b"}, + want: []string{"query(", "repository(", "file0:", "file1:"}, + }, + { + name: "no files", + ref: "main", + paths: nil, + want: []string{"query(", "repository("}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + q := buildGraphQLQuery(tc.ref, tc.paths) + for _, s := range tc.want { + assert.Check(t, strings.Contains(q, s)) + } + }) + } +} + +func TestFetchFilesBatch(t *testing.T) { + cases := []struct { + name string + paths []string + handler http.HandlerFunc + want int + wantErr bool + }{ + { + name: "single batch", + paths: []string{"a", "b"}, + handler: graphqlOK(mockRepo(2)), + want: 2, + }, + { + name: "http error", + paths: []string{"a"}, + handler: graphqlStatus(http.StatusNotFound), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/api/graphql", tc.handler) + + srv := withServer(t, mux) + c, observedLogs := newTestGraphQLClient(t, srv.URL+"/api/v3/") + + res, err := c.fetchFilesBatch(context.Background(), "o", "r", "main", tc.paths) + if tc.wantErr { + assert.Assert(t, err != nil) + entries := observedLogs.FilterMessage("GraphQL request returned non-200 status").All() + assert.Check(t, cmp.Len(entries, 1)) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit"], "5000")) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit_remaining"], "4998")) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit_reset"], "1735689601")) + return + } + + assert.NilError(t, err) + assert.Check(t, cmp.Len(res, tc.want)) + entries := observedLogs.FilterMessage("GraphQL batch fetch completed").All() + assert.Check(t, cmp.Len(entries, 1)) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["files_requested"], int64(len(tc.paths)))) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit"], "5000")) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit_remaining"], "4999")) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit_reset"], "1735689600")) + _, exists := entries[0].ContextMap()["files_fetched"] + assert.Check(t, !exists) + }) + } +} + +func TestFetchFilesBatchPreservesGitHubHeaders(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/api/graphql", func(w http.ResponseWriter, r *http.Request) { + assert.Assert(t, r.Header.Get("User-Agent") != "") + assert.Assert(t, r.Header.Get("X-GitHub-Api-Version") != "") + assert.Check(t, cmp.Equal("application/json", r.Header.Get("Content-Type"))) + + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "file0": map[string]any{"text": "content-0"}, + }, + }, + }) + }) + + srv := withServer(t, mux) + c, _ := newTestGraphQLClient(t, srv.URL+"/api/v3/") + + res, err := c.fetchFilesBatch(context.Background(), "o", "r", "main", []string{"a"}) + assert.NilError(t, err) + assert.Check(t, cmp.Len(res, 1)) +} diff --git a/pkg/test/github/github.go b/pkg/test/github/github.go index fb2fcd493f..a638981714 100644 --- a/pkg/test/github/github.go +++ b/pkg/test/github/github.go @@ -34,6 +34,11 @@ func SetupGH() (client *github.Client, mux *http.ServeMux, serverURL string, tea // when there's a non-empty base URL path. So, use that. See issue #752. apiHandler := http.NewServeMux() apiHandler.Handle(githubBaseURLPath+"/", http.StripPrefix(githubBaseURLPath, mux)) + // GraphQL endpoint is at /api/graphql (not under /api/v3) + apiHandler.HandleFunc("/api/graphql", func(w http.ResponseWriter, r *http.Request) { + // Forward to mux for GraphQL handling + mux.ServeHTTP(w, r) + }) apiHandler.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { fmt.Fprintln(os.Stderr, "FAIL: Client.BaseURL path prefix is not preserved in the request URL:") fmt.Fprintln(os.Stderr) @@ -57,6 +62,12 @@ func SetupGH() (client *github.Client, mux *http.ServeMux, serverURL string, tea return client, mux, server.URL, server.Close } +// graphQLFileMapType is used to store files for GraphQL handler lookup. +type graphQLFileMapType map[string]struct { + sha, name string + isdir bool +} + // SetupGitTree Take a dir and fake a full GitTree GitHub api calls reply recursively over a muxer. func SetupGitTree(t *testing.T, mux *http.ServeMux, dir string, event *info.Event, recursive bool) { type file struct { @@ -64,6 +75,7 @@ func SetupGitTree(t *testing.T, mux *http.ServeMux, dir string, event *info.Even isdir bool } files := []file{} + if recursive { err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { sha := fmt.Sprintf("%x", sha256.Sum256([]byte(path))) @@ -142,4 +154,121 @@ func SetupGitTree(t *testing.T, mux *http.ServeMux, dir string, event *info.Even assert.NilError(t, err) fmt.Fprint(rw, string(b)) }) + + // Setup GraphQL endpoint handler for batch file fetching (only once per mux) + // Only register GraphQL handler once (at the root level, when recursive=false) + if !recursive { + // Walk the entire directory tree to collect all files for the GraphQL handler + allFiles := make(graphQLFileMapType) + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err == nil && !info.IsDir() && path != dir { + relPath := strings.TrimPrefix(path, dir+"/") + allFiles[relPath] = struct { + sha, name string + isdir bool + }{ + sha: fmt.Sprintf("%x", sha256.Sum256([]byte(path))), + name: path, + isdir: false, + } + } + return nil + }) + assert.NilError(t, err) + + // Register handler once with all collected files (only if we have files) + if len(allFiles) > 0 { + mux.HandleFunc("/api/graphql", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var graphQLReq struct { + Query string `json:"query"` + Variables map[string]any `json:"variables"` + } + if err := json.NewDecoder(r.Body).Decode(&graphQLReq); err != nil { + http.Error(w, fmt.Sprintf("Invalid GraphQL request: %v", err), http.StatusBadRequest) + return + } + + // Build response with file contents + repositoryData := make(map[string]any) + + // Parse query to extract aliases and paths + queryLines := strings.SplitSeq(graphQLReq.Query, "\n") + for line := range queryLines { + line = strings.TrimSpace(line) + if strings.Contains(line, ": object(expression:") && strings.Contains(line, "file") { + // Extract alias (e.g., "file0") + aliasEnd := strings.Index(line, ":") + if aliasEnd <= 0 { + continue + } + alias := strings.TrimSpace(line[:aliasEnd]) + + // Extract expression value between quotes: "ref:path" + exprStart := strings.Index(line, `expression: "`) + if exprStart < 0 { + continue + } + exprStart += len(`expression: "`) + exprEnd := strings.Index(line[exprStart:], `"`) + if exprEnd < 0 { + continue + } + expr := line[exprStart : exprStart+exprEnd] + // Unescape the expression (handle \" and \\) + expr = strings.ReplaceAll(expr, `\"`, `"`) + expr = strings.ReplaceAll(expr, `\\`, `\`) + // Split "ref:path" and take path + parts := strings.SplitN(expr, ":", 2) + if len(parts) != 2 { + continue + } + path := parts[1] + + // Look up file by path in the file map + var foundFile struct { + sha, name string + isdir bool + } + var found bool + if f, ok := allFiles[path]; ok { + foundFile = f + found = true + } else { + // Try to find by matching the end of the path or other variations + for k, f := range allFiles { + if strings.HasSuffix(k, "/"+path) || k == path { + foundFile = f + found = true + break + } + } + } + + if found { + content, err := os.ReadFile(foundFile.name) + if err == nil { + repositoryData[alias] = map[string]any{ + "text": string(content), + } + } + } + } + } + + responseData := map[string]any{ + "data": map[string]any{ + "repository": repositoryData, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(responseData) + }) + } + } }