diff --git a/client/client.go b/client/client.go index aa5ad2e..5220d67 100644 --- a/client/client.go +++ b/client/client.go @@ -3,13 +3,13 @@ package client import ( _context "context" "encoding/json" + "errors" "fmt" "math" _nethttp "net/http" "time" "github.com/sourcegraph/conc/pool" - "golang.org/x/sync/errgroup" fgaSdk "github.com/openfga/go-sdk" "github.com/openfga/go-sdk/credentials" @@ -1786,18 +1786,14 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface if request.GetBody() != nil { for i := 0; i < len(request.GetBody().Writes); i += writeChunkSize { end := int(math.Min(float64(i+writeChunkSize), float64(len(request.GetBody().Writes)))) - writeChunks = append(writeChunks, (request.GetBody().Writes)[i:end]) } } - writeGroup, ctx := errgroup.WithContext(request.GetContext()) - - writeGroup.SetLimit(int(maxParallelReqs)) - writeResponses := make([]ClientWriteResponse, len(writeChunks)) - for index, writeBody := range writeChunks { - index, writeBody := index, writeBody - writeGroup.Go(func() error { + writePool := pool.NewWithResults[*ClientWriteResponse]().WithContext(request.GetContext()).WithMaxGoroutines(int(maxParallelReqs)) + for _, writeBody := range writeChunks { + writeBody := writeBody + writePool.Go(func(ctx _context.Context) (*ClientWriteResponse, error) { singleResponse, err := client.WriteExecute(&SdkClientWriteRequest{ ctx: ctx, Client: client, @@ -1811,19 +1807,24 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface Conflict: options.Conflict, }, }) - - if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok { - return err + var authErr fgaSdk.FgaApiAuthenticationError + // If an authentication error was returned, we want to return it immediately + if errors.As(err, &authErr) { + return nil, err } - writeResponses[index] = *singleResponse + // Handle nil response - create zero value if singleResponse is nil + if singleResponse == nil { + return &ClientWriteResponse{ + Writes: []ClientWriteRequestWriteResponse{}, + Deletes: []ClientWriteRequestDeleteResponse{}, + }, nil + } - return nil + return singleResponse, nil }) } - - err = writeGroup.Wait() - // If an error was returned then it will be an authentication error so we want to return + writeResponses, err := writePool.Wait() if err != nil { return &response, err } @@ -1838,12 +1839,10 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface } } - deleteGroup, ctx := errgroup.WithContext(request.GetContext()) - deleteGroup.SetLimit(int(maxParallelReqs)) - deleteResponses := make([]ClientWriteResponse, len(deleteChunks)) - for index, deleteBody := range deleteChunks { - index, deleteBody := index, deleteBody - deleteGroup.Go(func() error { + deletePool := pool.NewWithResults[*ClientWriteResponse]().WithContext(request.GetContext()).WithMaxGoroutines(int(maxParallelReqs)) + for _, deleteBody := range deleteChunks { + deleteBody := deleteBody + deletePool.Go(func(ctx _context.Context) (*ClientWriteResponse, error) { singleResponse, err := client.WriteExecute(&SdkClientWriteRequest{ ctx: ctx, Client: client, @@ -1858,28 +1857,40 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface }, }) - if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok { - return err + var authErr fgaSdk.FgaApiAuthenticationError + // If an authentication error was returned, we want to return it immediately + if errors.As(err, &authErr) { + return nil, err } - deleteResponses[index] = *singleResponse + // Handle nil response - create zero value if singleResponse is nil + if singleResponse == nil { + return &ClientWriteResponse{ + Writes: []ClientWriteRequestWriteResponse{}, + Deletes: []ClientWriteRequestDeleteResponse{}, + }, nil + } - return nil + return singleResponse, nil }) } - err = deleteGroup.Wait() + deleteResponses, err := deletePool.Wait() + // If authencication error was returned, we want to return it immediately if err != nil { - // If an error was returned then it will be an authentication error so we want to return return &response, err } for _, writeResponse := range writeResponses { - response.Writes = append(response.Writes, writeResponse.Writes...) + if writeResponse != nil { + response.Writes = append(response.Writes, writeResponse.Writes...) + } } for _, deleteResponse := range deleteResponses { - response.Deletes = append(response.Deletes, deleteResponse.Deletes...) + if deleteResponse != nil { + response.Deletes = append(response.Deletes, deleteResponse.Deletes...) + } } return &response, nil @@ -2238,7 +2249,7 @@ func (request *SdkClientBatchCheckClientRequest) GetOptions() *ClientBatchCheckC } func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheckClientRequestInterface) (*ClientBatchCheckClientResponse, error) { - group, ctx := errgroup.WithContext(request.GetContext()) + ctx := request.GetContext() requestOptions := RequestOptions{} maxParallelReqs := int(DEFAULT_MAX_METHOD_PARALLEL_REQS) if request.GetOptions() != nil { @@ -2248,7 +2259,6 @@ func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheck } } - group.SetLimit(maxParallelReqs) var numOfChecks = len(*request.GetBody()) response := make(ClientBatchCheckClientResponse, numOfChecks) authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride()) @@ -2272,9 +2282,15 @@ func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheck checkOptions.Consistency = request.GetOptions().Consistency } + type batchCheckResult struct { + Index int + Response ClientBatchCheckClientSingleResponse + } + + checkPool := pool.NewWithResults[*batchCheckResult]().WithContext(ctx).WithMaxGoroutines(maxParallelReqs) for index, checkBody := range *request.GetBody() { index, checkBody := index, checkBody - group.Go(func() error { + checkPool.Go(func(ctx _context.Context) (*batchCheckResult, error) { singleResponse, err := client.CheckExecute(&SdkClientCheckRequest{ ctx: ctx, Client: client, @@ -2282,24 +2298,39 @@ func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheck options: checkOptions, }) - if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok { - return err + var authErr fgaSdk.FgaApiAuthenticationError + // If the error is an authentication error, propagate it so the batch fails fast. + // Non-authentication errors are captured in the per-request response below. + if errors.As(err, &authErr) { + return nil, err } - response[index] = ClientBatchCheckClientSingleResponse{ - Request: checkBody, - ClientCheckResponse: *singleResponse, - Error: err, + // Handle nil response + var checkResponse ClientCheckResponse + if singleResponse != nil { + checkResponse = *singleResponse } - return nil + return &batchCheckResult{ + Index: index, + Response: ClientBatchCheckClientSingleResponse{ + Request: checkBody, + ClientCheckResponse: checkResponse, + Error: err, + }, + }, nil }) } - if err := group.Wait(); err != nil { + results, err := checkPool.Wait() + if err != nil { return nil, err } + for _, result := range results { + response[result.Index] = result.Response + } + return &response, nil } diff --git a/client/client_test.go b/client/client_test.go index b4d937f..7d9245a 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -3,6 +3,7 @@ package client_test import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -1764,7 +1765,8 @@ func TestOpenFgaClient(t *testing.T) { t.Fatalf("Expect error with invalid auth but there is none") } - if _, ok := err.(openfga.FgaApiAuthenticationError); !ok { + var authErr openfga.FgaApiAuthenticationError + if !errors.As(err, &authErr) { t.Fatalf("Expected an api auth error") } @@ -1782,7 +1784,7 @@ func TestOpenFgaClient(t *testing.T) { t.Fatalf("Expect error with invalid auth but there is none") } - if _, ok := err.(openfga.FgaApiAuthenticationError); !ok { + if !errors.As(err, &authErr) { t.Fatalf("Expected an api auth error") } })