From 434d4491258b7d4e79d77ba912bc89a932db87df Mon Sep 17 00:00:00 2001 From: Chmouel Boudjnah Date: Wed, 28 Jan 2026 18:28:29 +0100 Subject: [PATCH] feat: Implement GraphQL batch fetching for .tekton files This commit introduces the use of GitHub's GraphQL API for fetching multiple `.tekton` directory files simultaneously. Previously, each file was fetched individually using REST API calls. The new implementation leverages GraphQL batching to retrieve all necessary YAML files in a single request, significantly reducing the number of API calls and improving performance. This change also necessitates new test cases to cover the GraphQL client functionality and ensures compatibility with GitHub Enterprise instances. Signed-off-by: Chmouel Boudjnah --- pkg/provider/github/github.go | 69 +++++-- pkg/provider/github/github_test.go | 243 ++++++++++++++++++++++-- pkg/provider/github/graphql.go | 276 ++++++++++++++++++++++++++++ pkg/provider/github/graphql_test.go | 227 +++++++++++++++++++++++ pkg/test/github/github.go | 129 +++++++++++++ 5 files changed, 917 insertions(+), 27 deletions(-) create mode 100644 pkg/provider/github/graphql.go create mode 100644 pkg/provider/github/graphql_test.go diff --git a/pkg/provider/github/github.go b/pkg/provider/github/github.go index f482536dff..2928393b07 100644 --- a/pkg/provider/github/github.go +++ b/pkg/provider/github/github.go @@ -338,8 +338,17 @@ func (v *Provider) GetTektonDir(ctx context.Context, runevent *info.Event, path, // default set provenance from the SHA revision := runevent.SHA if provenance == "default_branch" { - revision = runevent.DefaultBranch v.Logger.Infof("Using PipelineRun definition from default_branch: %s", runevent.DefaultBranch) + branch, _, err := wrapAPI(v, "get_default_branch", func() (*github.Branch, *github.Response, error) { + return v.Client().Repositories.GetBranch(ctx, runevent.Organization, runevent.Repository, runevent.DefaultBranch, 1) + }) + if err != nil { + return "", err + } + revision = branch.GetCommit().GetSHA() + if revision == "" { + return "", fmt.Errorf("default_branch %s did not resolve to a commit SHA", runevent.DefaultBranch) + } } else { prInfo := "" if runevent.TriggerTarget == triggertype.PullRequest { @@ -379,7 +388,7 @@ func (v *Provider) GetTektonDir(ctx context.Context, runevent *info.Event, path, if err != nil { return "", err } - return v.concatAllYamlFiles(ctx, tektonDirObjects.Entries, runevent) + return v.concatAllYamlFiles(ctx, tektonDirObjects.Entries, runevent, path, revision) } // GetCommitInfo get info (url and title) on a commit in runevent, this needs to @@ -473,25 +482,53 @@ func (v *Provider) GetFileInsideRepo(ctx context.Context, runevent *info.Event, } // concatAllYamlFiles concat all yaml files from a directory as one big multi document yaml string. -func (v *Provider) concatAllYamlFiles(ctx context.Context, objects []*github.TreeEntry, runevent *info.Event) (string, error) { - var allTemplates string - +// tektonDirPath is the path to the .tekton directory (e.g., ".tekton") used to construct full paths. +func (v *Provider) concatAllYamlFiles(ctx context.Context, objects []*github.TreeEntry, runevent *info.Event, tektonDirPath, ref string) (string, error) { + // Collect all YAML file paths and preserve order + // Tree entries have paths relative to the .tekton directory, so prepend tektonDirPath + var yamlFiles []string for _, value := range objects { if strings.HasSuffix(value.GetPath(), ".yaml") || strings.HasSuffix(value.GetPath(), ".yml") { - data, err := v.getObject(ctx, value.GetSHA(), runevent) - if err != nil { - return "", err - } - if err := provider.ValidateYaml(data, value.GetPath()); err != nil { - return "", err - } - if allTemplates != "" && !strings.HasPrefix(string(data), "---") { - allTemplates += "---" - } - allTemplates += "\n" + string(data) + "\n" + // Construct full path from repo root for GraphQL query + fullPath := tektonDirPath + "/" + value.GetPath() + yamlFiles = append(yamlFiles, fullPath) } } + + if len(yamlFiles) == 0 { + return "", nil + } + + client, err := newGraphQLClient(v) + if err != nil { + return "", fmt.Errorf("failed to create GraphQL client: %w", err) + } + graphQLResults, err := client.fetchFiles(ctx, runevent.Organization, runevent.Repository, ref, yamlFiles) + if err != nil { + return "", fmt.Errorf("failed to fetch .tekton files via GraphQL: %w", err) + } + + var allTemplates string + + for _, path := range yamlFiles { + content, ok := graphQLResults[path] + if !ok { + return "", fmt.Errorf("file %s not found in GraphQL response", path) + } + + // it used to be like that (stripped prefix) before we moved to GraphQL so + // let's keep it that way. + relativePath := strings.TrimPrefix(path, tektonDirPath+"/") + if err := provider.ValidateYaml(content, relativePath); err != nil { + return "", err + } + if allTemplates != "" && !strings.HasPrefix(string(content), "---") { + allTemplates += "---" + } + allTemplates += "\n" + string(content) + "\n" + } + return allTemplates, nil } diff --git a/pkg/provider/github/github_test.go b/pkg/provider/github/github_test.go index 511194f90b..fec26d7421 100644 --- a/pkg/provider/github/github_test.go +++ b/pkg/provider/github/github_test.go @@ -253,8 +253,8 @@ func TestGetTektonDir(t *testing.T) { filterMessageSnippet: "Using PipelineRun definition from source pull_request tekton/cat#0", // 1. Get Repo root objects // 2. Get Tekton Dir objects - // 3/4. Get object content for each object (pipelinerun.yaml, pipeline.yaml) - expectedGHApiCalls: 4, + // 3. GraphQL batch fetch for 2 files (replaces 2 REST calls) + expectedGHApiCalls: 3, }, { name: "test no subtree on push", @@ -269,8 +269,8 @@ func TestGetTektonDir(t *testing.T) { filterMessageSnippet: "Using PipelineRun definition from source push", // 1. Get Repo root objects // 2. Get Tekton Dir objects - // 3/4. Get object content for each object (pipelinerun.yaml, pipeline.yaml) - expectedGHApiCalls: 4, + // 3. GraphQL batch fetch for 2 files (replaces 2 REST calls) + expectedGHApiCalls: 3, }, { name: "test provenance default_branch ", @@ -283,9 +283,10 @@ func TestGetTektonDir(t *testing.T) { treepath: "testdata/tree/defaultbranch", provenance: "default_branch", filterMessageSnippet: "Using PipelineRun definition from default_branch: main", - // 1. Get Repo root objects - // 2. Get Tekton Dir objects - // 3/4. Get object content for each object (pipelinerun.yaml, pipeline.yaml) + // 1. Resolve default branch to a commit SHA + // 2. Get Repo root objects + // 3. Get Tekton Dir objects + // 4. GraphQL batch fetch for 2 files expectedGHApiCalls: 4, }, { @@ -371,11 +372,21 @@ func TestGetTektonDir(t *testing.T) { } }() + shaDir := fmt.Sprintf("%x", sha256.Sum256([]byte(tt.treepath))) + tt.event.SHA = shaDir if tt.provenance == "default_branch" { - tt.event.SHA = tt.event.DefaultBranch - } else { - shaDir := fmt.Sprintf("%x", sha256.Sum256([]byte(tt.treepath))) - tt.event.SHA = shaDir + mux.HandleFunc(fmt.Sprintf("/repos/%s/%s/branches/%s", + tt.event.Organization, tt.event.Repository, tt.event.DefaultBranch), + func(rw http.ResponseWriter, _ *http.Request) { + branch := &github.Branch{ + Name: github.Ptr(tt.event.DefaultBranch), + Commit: &github.RepositoryCommit{ + SHA: github.Ptr(shaDir), + }, + } + b, _ := json.Marshal(branch) + fmt.Fprint(rw, string(b)) + }) } ghtesthelper.SetupGitTree(t, mux, tt.treepath, tt.event, false) @@ -403,6 +414,216 @@ func TestGetTektonDir(t *testing.T) { } } +func TestGetTektonDir_GraphQLBatchFetch(t *testing.T) { + // Test that GraphQL is used for batch fetching multiple files + metricsutils.ResetMetrics() + observer, exporter := zapobserver.New(zap.DebugLevel) + fakelogger := zap.New(observer).Sugar() + ctx, _ := rtesting.SetupFakeContext(t) + fakeclient, mux, _, teardown := ghtesthelper.SetupGH() + defer teardown() + gvcs := Provider{ + ghClient: fakeclient, + providerName: "github", + Logger: fakelogger, + } + + event := &info.Event{ + Organization: "tekton", + Repository: "cat", + SHA: "123", + TriggerTarget: triggertype.PullRequest, + } + shaDir := fmt.Sprintf("%x", sha256.Sum256([]byte("testdata/tree/simple"))) + event.SHA = shaDir + ghtesthelper.SetupGitTree(t, mux, "testdata/tree/simple", event, false) + + got, err := gvcs.GetTektonDir(ctx, event, ".tekton", "") + assert.NilError(t, err) + assert.Assert(t, strings.Contains(got, "PipelineRun"), "expected PipelineRun in output, got %s", got) + + // Verify GraphQL was used (check logs) + graphQLLogs := exporter.FilterMessageSnippet("GraphQL batch fetch") + assert.Assert(t, graphQLLogs.Len() > 0, "expected GraphQL batch fetch log message") + + // Verify reduced API calls: 1 root tree + 1 tekton tree + 1 GraphQL = 3 (instead of 4) + metricstest.CheckCountData( + t, + "pipelines_as_code_git_provider_api_request_count", + map[string]string{"provider": "github"}, + 3, + ) +} + +func TestGetTektonDir_GraphQLError(t *testing.T) { + // Test that GraphQL errors are properly returned + metricsutils.ResetMetrics() + observer, _ := zapobserver.New(zap.DebugLevel) + fakelogger := zap.New(observer).Sugar() + ctx, _ := rtesting.SetupFakeContext(t) + fakeclient, mux, _, teardown := ghtesthelper.SetupGH() + defer teardown() + + // Register error handler on this mux (SetupGitTree is not called in this test, + // so this is the only /api/graphql handler) + mux.HandleFunc("/api/graphql", func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "GraphQL endpoint not available", http.StatusNotFound) + }) + + gvcs := Provider{ + ghClient: fakeclient, + providerName: "github", + Logger: fakelogger, + } + + event := &info.Event{ + Organization: "tekton", + Repository: "cat", + SHA: "123", + TriggerTarget: triggertype.PullRequest, + } + shaDir := fmt.Sprintf("%x", sha256.Sum256([]byte("testdata/tree/simple"))) + event.SHA = shaDir + + // Setup tree endpoints manually (skip SetupGitTree to avoid GraphQL handler registration) + // Set up root tree + mux.HandleFunc(fmt.Sprintf("/repos/%v/%v/git/trees/%v", event.Organization, event.Repository, event.SHA), + func(rw http.ResponseWriter, _ *http.Request) { + tree := &github.Tree{ + SHA: &event.SHA, + Entries: []*github.TreeEntry{ + { + Path: github.Ptr(".tekton"), + Type: github.Ptr("tree"), + SHA: github.Ptr("tektondirsha"), + }, + }, + } + b, _ := json.Marshal(tree) + fmt.Fprint(rw, string(b)) + }) + + // Set up .tekton directory tree + tektonDirSha := "tektondirsha" + mux.HandleFunc(fmt.Sprintf("/repos/%v/%v/git/trees/%v", event.Organization, event.Repository, tektonDirSha), + func(rw http.ResponseWriter, _ *http.Request) { + tree := &github.Tree{ + SHA: &tektonDirSha, + Entries: []*github.TreeEntry{ + { + Path: github.Ptr("pipeline.yaml"), + Type: github.Ptr("blob"), + SHA: github.Ptr("pipelinesha"), + }, + { + Path: github.Ptr("pipelinerun.yaml"), + Type: github.Ptr("blob"), + SHA: github.Ptr("pipelinerunsha"), + }, + }, + } + b, _ := json.Marshal(tree) + fmt.Fprint(rw, string(b)) + }) + + _, err := gvcs.GetTektonDir(ctx, event, ".tekton", "") + assert.ErrorContains(t, err, "failed to fetch .tekton files via GraphQL") +} + +func TestGetTektonDir_DefaultBranchUsesResolvedSHAForGraphQL(t *testing.T) { + metricsutils.ResetMetrics() + observer, _ := zapobserver.New(zap.DebugLevel) + fakelogger := zap.New(observer).Sugar() + ctx, _ := rtesting.SetupFakeContext(t) + fakeclient, mux, _, teardown := ghtesthelper.SetupGH() + defer teardown() + + gvcs := Provider{ + ghClient: fakeclient, + providerName: "github", + Logger: fakelogger, + } + + event := &info.Event{ + Organization: "tekton", + Repository: "cat", + DefaultBranch: "main", + } + resolvedSHA := "resolved-default-branch-sha" + tektonDirSHA := "tektondirsha" + + mux.HandleFunc("/repos/tekton/cat/branches/main", func(rw http.ResponseWriter, _ *http.Request) { + branch := &github.Branch{ + Name: github.Ptr("main"), + Commit: &github.RepositoryCommit{ + SHA: github.Ptr(resolvedSHA), + }, + } + b, _ := json.Marshal(branch) + fmt.Fprint(rw, string(b)) + }) + mux.HandleFunc("/repos/tekton/cat/git/trees/"+resolvedSHA, func(rw http.ResponseWriter, _ *http.Request) { + tree := &github.Tree{ + SHA: github.Ptr(resolvedSHA), + Entries: []*github.TreeEntry{ + { + Path: github.Ptr(".tekton"), + Type: github.Ptr("tree"), + SHA: github.Ptr(tektonDirSHA), + }, + }, + } + b, _ := json.Marshal(tree) + fmt.Fprint(rw, string(b)) + }) + mux.HandleFunc("/repos/tekton/cat/git/trees/"+tektonDirSHA, func(rw http.ResponseWriter, _ *http.Request) { + tree := &github.Tree{ + SHA: github.Ptr(tektonDirSHA), + Entries: []*github.TreeEntry{ + { + Path: github.Ptr("pipeline.yaml"), + Type: github.Ptr("blob"), + SHA: github.Ptr("pipeline-sha"), + }, + { + Path: github.Ptr("pipelinerun.yaml"), + Type: github.Ptr("blob"), + SHA: github.Ptr("pipelinerun-sha"), + }, + }, + } + b, _ := json.Marshal(tree) + fmt.Fprint(rw, string(b)) + }) + mux.HandleFunc("/api/graphql", func(w http.ResponseWriter, r *http.Request) { + var graphQLReq struct { + Query string `json:"query"` + } + assert.NilError(t, json.NewDecoder(r.Body).Decode(&graphQLReq)) + assert.Assert(t, strings.Contains(graphQLReq.Query, resolvedSHA+":.tekton/pipeline.yaml"), graphQLReq.Query) + assert.Assert(t, !strings.Contains(graphQLReq.Query, `main:.tekton/pipeline.yaml`), graphQLReq.Query) + + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "file0": map[string]any{"text": "kind: Pipeline\nmetadata:\n name: pipeline\n"}, + "file1": map[string]any{"text": "kind: PipelineRun\nmetadata:\n name: run\n"}, + }, + }, + }) + }) + + got, err := gvcs.GetTektonDir(ctx, event, ".tekton", "default_branch") + assert.NilError(t, err) + assert.Assert(t, strings.Contains(got, "PipelineRun")) + metricstest.CheckCountData( + t, + "pipelines_as_code_git_provider_api_request_count", + map[string]string{"provider": "github"}, + 4, + ) +} + func TestGetFileInsideRepo(t *testing.T) { testGetTektonDir := []struct { name string diff --git a/pkg/provider/github/graphql.go b/pkg/provider/github/graphql.go new file mode 100644 index 0000000000..e163251cf4 --- /dev/null +++ b/pkg/provider/github/graphql.go @@ -0,0 +1,276 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/url" + "strings" + "time" + + "github.com/google/go-github/v81/github" + "github.com/openshift-pipelines/pipelines-as-code/pkg/apis/pipelinesascode/v1alpha1" + providerMetrics "github.com/openshift-pipelines/pipelines-as-code/pkg/provider/providermetrics" + "go.uber.org/zap" +) + +// graphQLClient handles GraphQL API requests for fetching file contents. +type graphQLClient struct { + httpClient *http.Client + ghClient *github.Client + endpoint string + logger *zap.SugaredLogger + provider *Provider + triggerEvent string + repo *v1alpha1.Repository +} + +// newGraphQLClient creates a new GraphQL client from a GitHub provider. +func newGraphQLClient(p *Provider) (*graphQLClient, error) { + httpClient := p.Client().Client() + if httpClient == nil { + return nil, fmt.Errorf("GitHub client HTTP client is nil") + } + + endpoint, err := buildGraphQLEndpoint(p) + if err != nil { + return nil, fmt.Errorf("failed to build GraphQL endpoint: %w", err) + } + + return &graphQLClient{ + httpClient: httpClient, + ghClient: p.Client(), + endpoint: endpoint, + logger: p.Logger, + provider: p, + triggerEvent: p.triggerEvent, + repo: p.repo, + }, nil +} + +// buildGraphQLEndpoint constructs the GraphQL API endpoint URL from the GitHub client's BaseURL. +func buildGraphQLEndpoint(p *Provider) (string, error) { + baseURL := p.Client().BaseURL.String() + baseURL = strings.TrimSuffix(baseURL, "/") + + // For GitHub.com, use standard GraphQL endpoint + // apiPublicURL has a trailing slash which TrimSuffix above removes, + // so compare directly with the slash-less form. + if baseURL == "https://api.github.com" { + return "https://api.github.com/graphql", nil + } + + // For GHE and test servers, construct GraphQL endpoint from the base URL + // BaseURL could be: + // - https://ghe.example.com/api/v3/ -> https://ghe.example.com/api/graphql + // - http://127.0.0.1:PORT/api/v3/ -> http://127.0.0.1:PORT/api/graphql + parsedURL, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("failed to parse BaseURL: %w", err) + } + + // Replace /api/v3 with /api/graphql in the path + path := parsedURL.Path + if strings.HasSuffix(path, "/api/v3") || strings.HasSuffix(path, "/api/v3/") { + path = strings.TrimSuffix(path, "/api/v3/") + path = strings.TrimSuffix(path, "/api/v3") + path += "/api/graphql" + } else { + // Fallback: just use the host with /api/graphql + path = "/api/graphql" + } + + parsedURL.Path = path + return parsedURL.String(), nil +} + +// buildGraphQLQuery constructs a GraphQL query string with aliases for batch fetching multiple files. +func buildGraphQLQuery(ref string, paths []string) string { + // Escape ref for GraphQL string (escape quotes and backslashes) + escapedRef := strings.ReplaceAll(ref, `\`, `\\`) + escapedRef = strings.ReplaceAll(escapedRef, `"`, `\"`) + + aliases := make([]string, 0, len(paths)) + for i, path := range paths { + // Escape path for GraphQL string (escape quotes and backslashes) + escapedPath := strings.ReplaceAll(path, `\`, `\\`) + escapedPath = strings.ReplaceAll(escapedPath, `"`, `\"`) + aliases = append(aliases, fmt.Sprintf(` file%d: object(expression: "%s:%s") { + ... on Blob { + text + } + }`, i, escapedRef, escapedPath)) + } + + query := fmt.Sprintf(`query($owner: String!, $name: String!) { + repository(owner: $owner, name: $name) { +%s + } +}`, strings.Join(aliases, "\n")) + + return query +} + +// graphQLResponse represents the structure of a GraphQL API response. +type graphQLResponse struct { + Data struct { + Repository map[string]struct { + Text *string `json:"text"` + } `json:"repository"` + } `json:"data"` + Errors []struct { + Message string `json:"message"` + } `json:"errors,omitempty"` +} + +type rateLimitHeaders struct { + limit string + remaining string + reset string +} + +func getRateLimitHeaders(header http.Header) rateLimitHeaders { + return rateLimitHeaders{ + limit: header.Get("X-RateLimit-Limit"), + remaining: header.Get("X-RateLimit-Remaining"), + reset: header.Get("X-RateLimit-Reset"), + } +} + +// fetchFiles fetches multiple file contents using GraphQL batch queries. +// Returns a map of path -> content. +func (c *graphQLClient) fetchFiles(ctx context.Context, owner, repo, ref string, paths []string) (map[string][]byte, error) { + if len(paths) == 0 { + return make(map[string][]byte), nil + } + + // Limit batch size to avoid query complexity issues + const maxBatchSize = 50 + result := make(map[string][]byte, len(paths)) + for start := 0; start < len(paths); start += maxBatchSize { + end := min(start+maxBatchSize, len(paths)) + batch := paths[start:end] + batchResult, err := c.fetchFilesBatch(ctx, owner, repo, ref, batch) + if err != nil { + return nil, err + } + maps.Copy(result, batchResult) + } + + return result, nil +} + +// fetchFilesBatch fetches multiple file contents in a single GraphQL query. +// Returns a map of path -> content. +func (c *graphQLClient) fetchFilesBatch(ctx context.Context, owner, repo, ref string, paths []string) (map[string][]byte, error) { + if len(paths) == 0 { + return make(map[string][]byte), nil + } + + query := buildGraphQLQuery(ref, paths) + variables := map[string]any{ + "owner": owner, + "name": repo, + } + + requestBody := map[string]any{ + "query": query, + "variables": variables, + } + + req, err := c.ghClient.NewRequest(http.MethodPost, c.endpoint, requestBody) + if err != nil { + return nil, fmt.Errorf("failed to create GraphQL request: %w", err) + } + req = req.WithContext(ctx) + + // Record metrics for GraphQL API call + if c.logger != nil { + providerMetrics.RecordAPIUsage(c.logger, c.provider.providerName, c.triggerEvent, c.repo) + } + + start := time.Now() + resp, err := c.httpClient.Do(req) + duration := time.Since(start) + + if err != nil { + if c.logger != nil { + c.logger.Debugw("GraphQL request failed", + "error", err.Error(), + "duration_ms", duration.Milliseconds(), + ) + } + return nil, fmt.Errorf("GraphQL request failed: %w", err) + } + defer resp.Body.Close() + + rateLimit := getRateLimitHeaders(resp.Header) + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read GraphQL response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + if c.logger != nil { + c.logger.Debugw("GraphQL request returned non-200 status", + "status_code", resp.StatusCode, + "response", string(body), + "rate_limit", rateLimit.limit, + "rate_limit_remaining", rateLimit.remaining, + "rate_limit_reset", rateLimit.reset, + ) + } + return nil, fmt.Errorf("GraphQL request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var graphQLResp graphQLResponse + if err := json.Unmarshal(body, &graphQLResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal GraphQL response: %w", err) + } + + // Check for GraphQL errors + if len(graphQLResp.Errors) > 0 { + errorMsgs := make([]string, len(graphQLResp.Errors)) + for i, e := range graphQLResp.Errors { + errorMsgs[i] = e.Message + } + if c.logger != nil { + c.logger.Debugw("GraphQL returned errors", + "errors", strings.Join(errorMsgs, "; "), + ) + } + return nil, fmt.Errorf("GraphQL errors: %s", strings.Join(errorMsgs, "; ")) + } + + // Extract file contents from response + result := make(map[string][]byte, len(paths)) + for i, path := range paths { + alias := fmt.Sprintf("file%d", i) + blobData, ok := graphQLResp.Data.Repository[alias] + if !ok { + return nil, fmt.Errorf("file %s (alias %s) not found in GraphQL response", path, alias) + } + + if blobData.Text == nil { + return nil, fmt.Errorf("file %s returned null content (may be binary)", path) + } + + result[path] = []byte(*blobData.Text) + } + + if c.logger != nil { + c.logger.Debugw("GraphQL batch fetch completed", + "files_requested", len(paths), + "duration_ms", duration.Milliseconds(), + "rate_limit", rateLimit.limit, + "rate_limit_remaining", rateLimit.remaining, + "rate_limit_reset", rateLimit.reset, + ) + } + + return result, nil +} diff --git a/pkg/provider/github/graphql_test.go b/pkg/provider/github/graphql_test.go new file mode 100644 index 0000000000..46a4a1d935 --- /dev/null +++ b/pkg/provider/github/graphql_test.go @@ -0,0 +1,227 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/google/go-github/v81/github" + "github.com/openshift-pipelines/pipelines-as-code/pkg/apis/pipelinesascode/v1alpha1" + "go.uber.org/zap" + zapobserver "go.uber.org/zap/zaptest/observer" + "gotest.tools/v3/assert" + "gotest.tools/v3/assert/cmp" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +/* ---------------- helpers ---------------- */ + +func newTestProvider(baseURL string) (*Provider, *zapobserver.ObservedLogs) { + client := github.NewClient(nil) + parsed, _ := url.Parse(baseURL) + client.BaseURL = parsed + + core, observedLogs := zapobserver.New(zap.DebugLevel) + logger := zap.New(core).Sugar() + + return &Provider{ + ghClient: client, + Logger: logger, + providerName: "github", + triggerEvent: "push", + repo: &v1alpha1.Repository{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-ns", + Name: "test-repo", + }, + }, + }, observedLogs +} + +func newTestGraphQLClient(t *testing.T, baseURL string) (*graphQLClient, *zapobserver.ObservedLogs) { + t.Helper() + provider, observedLogs := newTestProvider(baseURL) + c, err := newGraphQLClient(provider) + assert.NilError(t, err) + return c, observedLogs +} + +func withServer(t *testing.T, h http.Handler) *httptest.Server { + t.Helper() + s := httptest.NewServer(h) + t.Cleanup(s.Close) + return s +} + +func graphqlOK(repo map[string]any) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-RateLimit-Limit", "5000") + w.Header().Set("X-RateLimit-Remaining", "4999") + w.Header().Set("X-RateLimit-Reset", "1735689600") + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{"repository": repo}, + }) + } +} + +func graphqlStatus(code int) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-RateLimit-Limit", "5000") + w.Header().Set("X-RateLimit-Remaining", "4998") + w.Header().Set("X-RateLimit-Reset", "1735689601") + w.WriteHeader(code) + } +} + +func mockRepo(n int) map[string]any { + repo := make(map[string]any) + for i := range n { + repo[fmt.Sprintf("file%d", i)] = map[string]any{ + "text": fmt.Sprintf("content-%d", i), + } + } + return repo +} + +/* ---------------- tests ---------------- */ + +func TestBuildGraphQLEndpoint(t *testing.T) { + cases := []struct { + name string + base string + want string + }{ + {"public", "https://api.github.com", "https://api.github.com/graphql"}, + {"public slash", "https://api.github.com/", "https://api.github.com/graphql"}, + {"ghe v3", "https://ghe/x/api/v3", "https://ghe/x/api/graphql"}, + {"ghe v3 slash", "https://ghe/x/api/v3/", "https://ghe/x/api/graphql"}, + {"ghe root", "https://ghe", "https://ghe/api/graphql"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(nil) + parsed, _ := url.Parse(tc.base) + client.BaseURL = parsed + + got, err := buildGraphQLEndpoint(&Provider{ghClient: client}) + assert.NilError(t, err) + assert.Check(t, cmp.Equal(tc.want, got)) + }) + } +} + +func TestBuildGraphQLQuery(t *testing.T) { + cases := []struct { + name string + ref string + paths []string + want []string + }{ + { + name: "two files", + ref: "main", + paths: []string{"a", "b"}, + want: []string{"query(", "repository(", "file0:", "file1:"}, + }, + { + name: "no files", + ref: "main", + paths: nil, + want: []string{"query(", "repository("}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + q := buildGraphQLQuery(tc.ref, tc.paths) + for _, s := range tc.want { + assert.Check(t, strings.Contains(q, s)) + } + }) + } +} + +func TestFetchFilesBatch(t *testing.T) { + cases := []struct { + name string + paths []string + handler http.HandlerFunc + want int + wantErr bool + }{ + { + name: "single batch", + paths: []string{"a", "b"}, + handler: graphqlOK(mockRepo(2)), + want: 2, + }, + { + name: "http error", + paths: []string{"a"}, + handler: graphqlStatus(http.StatusNotFound), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/api/graphql", tc.handler) + + srv := withServer(t, mux) + c, observedLogs := newTestGraphQLClient(t, srv.URL+"/api/v3/") + + res, err := c.fetchFilesBatch(context.Background(), "o", "r", "main", tc.paths) + if tc.wantErr { + assert.Assert(t, err != nil) + entries := observedLogs.FilterMessage("GraphQL request returned non-200 status").All() + assert.Check(t, cmp.Len(entries, 1)) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit"], "5000")) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit_remaining"], "4998")) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit_reset"], "1735689601")) + return + } + + assert.NilError(t, err) + assert.Check(t, cmp.Len(res, tc.want)) + entries := observedLogs.FilterMessage("GraphQL batch fetch completed").All() + assert.Check(t, cmp.Len(entries, 1)) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["files_requested"], int64(len(tc.paths)))) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit"], "5000")) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit_remaining"], "4999")) + assert.Check(t, cmp.Equal(entries[0].ContextMap()["rate_limit_reset"], "1735689600")) + _, exists := entries[0].ContextMap()["files_fetched"] + assert.Check(t, !exists) + }) + } +} + +func TestFetchFilesBatchPreservesGitHubHeaders(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/api/graphql", func(w http.ResponseWriter, r *http.Request) { + assert.Assert(t, r.Header.Get("User-Agent") != "") + assert.Assert(t, r.Header.Get("X-GitHub-Api-Version") != "") + assert.Check(t, cmp.Equal("application/json", r.Header.Get("Content-Type"))) + + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "file0": map[string]any{"text": "content-0"}, + }, + }, + }) + }) + + srv := withServer(t, mux) + c, _ := newTestGraphQLClient(t, srv.URL+"/api/v3/") + + res, err := c.fetchFilesBatch(context.Background(), "o", "r", "main", []string{"a"}) + assert.NilError(t, err) + assert.Check(t, cmp.Len(res, 1)) +} diff --git a/pkg/test/github/github.go b/pkg/test/github/github.go index fb2fcd493f..a638981714 100644 --- a/pkg/test/github/github.go +++ b/pkg/test/github/github.go @@ -34,6 +34,11 @@ func SetupGH() (client *github.Client, mux *http.ServeMux, serverURL string, tea // when there's a non-empty base URL path. So, use that. See issue #752. apiHandler := http.NewServeMux() apiHandler.Handle(githubBaseURLPath+"/", http.StripPrefix(githubBaseURLPath, mux)) + // GraphQL endpoint is at /api/graphql (not under /api/v3) + apiHandler.HandleFunc("/api/graphql", func(w http.ResponseWriter, r *http.Request) { + // Forward to mux for GraphQL handling + mux.ServeHTTP(w, r) + }) apiHandler.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { fmt.Fprintln(os.Stderr, "FAIL: Client.BaseURL path prefix is not preserved in the request URL:") fmt.Fprintln(os.Stderr) @@ -57,6 +62,12 @@ func SetupGH() (client *github.Client, mux *http.ServeMux, serverURL string, tea return client, mux, server.URL, server.Close } +// graphQLFileMapType is used to store files for GraphQL handler lookup. +type graphQLFileMapType map[string]struct { + sha, name string + isdir bool +} + // SetupGitTree Take a dir and fake a full GitTree GitHub api calls reply recursively over a muxer. func SetupGitTree(t *testing.T, mux *http.ServeMux, dir string, event *info.Event, recursive bool) { type file struct { @@ -64,6 +75,7 @@ func SetupGitTree(t *testing.T, mux *http.ServeMux, dir string, event *info.Even isdir bool } files := []file{} + if recursive { err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { sha := fmt.Sprintf("%x", sha256.Sum256([]byte(path))) @@ -142,4 +154,121 @@ func SetupGitTree(t *testing.T, mux *http.ServeMux, dir string, event *info.Even assert.NilError(t, err) fmt.Fprint(rw, string(b)) }) + + // Setup GraphQL endpoint handler for batch file fetching (only once per mux) + // Only register GraphQL handler once (at the root level, when recursive=false) + if !recursive { + // Walk the entire directory tree to collect all files for the GraphQL handler + allFiles := make(graphQLFileMapType) + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err == nil && !info.IsDir() && path != dir { + relPath := strings.TrimPrefix(path, dir+"/") + allFiles[relPath] = struct { + sha, name string + isdir bool + }{ + sha: fmt.Sprintf("%x", sha256.Sum256([]byte(path))), + name: path, + isdir: false, + } + } + return nil + }) + assert.NilError(t, err) + + // Register handler once with all collected files (only if we have files) + if len(allFiles) > 0 { + mux.HandleFunc("/api/graphql", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var graphQLReq struct { + Query string `json:"query"` + Variables map[string]any `json:"variables"` + } + if err := json.NewDecoder(r.Body).Decode(&graphQLReq); err != nil { + http.Error(w, fmt.Sprintf("Invalid GraphQL request: %v", err), http.StatusBadRequest) + return + } + + // Build response with file contents + repositoryData := make(map[string]any) + + // Parse query to extract aliases and paths + queryLines := strings.SplitSeq(graphQLReq.Query, "\n") + for line := range queryLines { + line = strings.TrimSpace(line) + if strings.Contains(line, ": object(expression:") && strings.Contains(line, "file") { + // Extract alias (e.g., "file0") + aliasEnd := strings.Index(line, ":") + if aliasEnd <= 0 { + continue + } + alias := strings.TrimSpace(line[:aliasEnd]) + + // Extract expression value between quotes: "ref:path" + exprStart := strings.Index(line, `expression: "`) + if exprStart < 0 { + continue + } + exprStart += len(`expression: "`) + exprEnd := strings.Index(line[exprStart:], `"`) + if exprEnd < 0 { + continue + } + expr := line[exprStart : exprStart+exprEnd] + // Unescape the expression (handle \" and \\) + expr = strings.ReplaceAll(expr, `\"`, `"`) + expr = strings.ReplaceAll(expr, `\\`, `\`) + // Split "ref:path" and take path + parts := strings.SplitN(expr, ":", 2) + if len(parts) != 2 { + continue + } + path := parts[1] + + // Look up file by path in the file map + var foundFile struct { + sha, name string + isdir bool + } + var found bool + if f, ok := allFiles[path]; ok { + foundFile = f + found = true + } else { + // Try to find by matching the end of the path or other variations + for k, f := range allFiles { + if strings.HasSuffix(k, "/"+path) || k == path { + foundFile = f + found = true + break + } + } + } + + if found { + content, err := os.ReadFile(foundFile.name) + if err == nil { + repositoryData[alias] = map[string]any{ + "text": string(content), + } + } + } + } + } + + responseData := map[string]any{ + "data": map[string]any{ + "repository": repositoryData, + }, + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(responseData) + }) + } + } }