Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 73 additions & 42 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package client
import (
_context "context"
"encoding/json"
"errors"
"fmt"
"math"
_nethttp "net/http"
"time"

"github.com/sourcegraph/conc/pool"
"golang.org/x/sync/errgroup"

fgaSdk "github.com/openfga/go-sdk"
"github.com/openfga/go-sdk/credentials"
Expand Down Expand Up @@ -1786,18 +1786,14 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
if request.GetBody() != nil {
for i := 0; i < len(request.GetBody().Writes); i += writeChunkSize {
end := int(math.Min(float64(i+writeChunkSize), float64(len(request.GetBody().Writes))))

writeChunks = append(writeChunks, (request.GetBody().Writes)[i:end])
}
}

writeGroup, ctx := errgroup.WithContext(request.GetContext())

writeGroup.SetLimit(int(maxParallelReqs))
writeResponses := make([]ClientWriteResponse, len(writeChunks))
for index, writeBody := range writeChunks {
index, writeBody := index, writeBody
writeGroup.Go(func() error {
writePool := pool.NewWithResults[*ClientWriteResponse]().WithContext(request.GetContext()).WithMaxGoroutines(int(maxParallelReqs))
for _, writeBody := range writeChunks {
writeBody := writeBody
writePool.Go(func(ctx _context.Context) (*ClientWriteResponse, error) {
singleResponse, err := client.WriteExecute(&SdkClientWriteRequest{
ctx: ctx,
Client: client,
Expand All @@ -1811,19 +1807,24 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
Conflict: options.Conflict,
},
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err
var authErr fgaSdk.FgaApiAuthenticationError
// If an authentication error was returned, we want to return it immediately
if errors.As(err, &authErr) {
return nil, err
}

writeResponses[index] = *singleResponse
// Handle nil response - create zero value if singleResponse is nil
if singleResponse == nil {
return &ClientWriteResponse{
Writes: []ClientWriteRequestWriteResponse{},
Deletes: []ClientWriteRequestDeleteResponse{},
}, nil
}

return nil
return singleResponse, nil
})
}

err = writeGroup.Wait()
// If an error was returned then it will be an authentication error so we want to return
writeResponses, err := writePool.Wait()
if err != nil {
return &response, err
}
Expand All @@ -1838,12 +1839,10 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
}
}

deleteGroup, ctx := errgroup.WithContext(request.GetContext())
deleteGroup.SetLimit(int(maxParallelReqs))
deleteResponses := make([]ClientWriteResponse, len(deleteChunks))
for index, deleteBody := range deleteChunks {
index, deleteBody := index, deleteBody
deleteGroup.Go(func() error {
deletePool := pool.NewWithResults[*ClientWriteResponse]().WithContext(request.GetContext()).WithMaxGoroutines(int(maxParallelReqs))
for _, deleteBody := range deleteChunks {
deleteBody := deleteBody
deletePool.Go(func(ctx _context.Context) (*ClientWriteResponse, error) {
singleResponse, err := client.WriteExecute(&SdkClientWriteRequest{
ctx: ctx,
Client: client,
Expand All @@ -1858,28 +1857,40 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
},
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err
var authErr fgaSdk.FgaApiAuthenticationError
// If an authentication error was returned, we want to return it immediately
if errors.As(err, &authErr) {
return nil, err
}

deleteResponses[index] = *singleResponse
// Handle nil response - create zero value if singleResponse is nil
if singleResponse == nil {
return &ClientWriteResponse{
Writes: []ClientWriteRequestWriteResponse{},
Deletes: []ClientWriteRequestDeleteResponse{},
}, nil
}

return nil
return singleResponse, nil
})
}

err = deleteGroup.Wait()
deleteResponses, err := deletePool.Wait()
// If authencication error was returned, we want to return it immediately
if err != nil {
// If an error was returned then it will be an authentication error so we want to return
return &response, err
}

for _, writeResponse := range writeResponses {
response.Writes = append(response.Writes, writeResponse.Writes...)
if writeResponse != nil {
response.Writes = append(response.Writes, writeResponse.Writes...)
}
}

for _, deleteResponse := range deleteResponses {
response.Deletes = append(response.Deletes, deleteResponse.Deletes...)
if deleteResponse != nil {
response.Deletes = append(response.Deletes, deleteResponse.Deletes...)
}
}

return &response, nil
Expand Down Expand Up @@ -2238,7 +2249,7 @@ func (request *SdkClientBatchCheckClientRequest) GetOptions() *ClientBatchCheckC
}

func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheckClientRequestInterface) (*ClientBatchCheckClientResponse, error) {
group, ctx := errgroup.WithContext(request.GetContext())
ctx := request.GetContext()
requestOptions := RequestOptions{}
maxParallelReqs := int(DEFAULT_MAX_METHOD_PARALLEL_REQS)
if request.GetOptions() != nil {
Expand All @@ -2248,7 +2259,6 @@ func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheck
}
}

group.SetLimit(maxParallelReqs)
var numOfChecks = len(*request.GetBody())
response := make(ClientBatchCheckClientResponse, numOfChecks)
authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride())
Expand All @@ -2272,34 +2282,55 @@ func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheck
checkOptions.Consistency = request.GetOptions().Consistency
}

type batchCheckResult struct {
Index int
Response ClientBatchCheckClientSingleResponse
}

checkPool := pool.NewWithResults[*batchCheckResult]().WithContext(ctx).WithMaxGoroutines(maxParallelReqs)
for index, checkBody := range *request.GetBody() {
index, checkBody := index, checkBody
group.Go(func() error {
checkPool.Go(func(ctx _context.Context) (*batchCheckResult, error) {
singleResponse, err := client.CheckExecute(&SdkClientCheckRequest{
ctx: ctx,
Client: client,
body: &checkBody,
options: checkOptions,
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err
var authErr fgaSdk.FgaApiAuthenticationError
// If the error is an authentication error, propagate it so the batch fails fast.
// Non-authentication errors are captured in the per-request response below.
if errors.As(err, &authErr) {
return nil, err
}

response[index] = ClientBatchCheckClientSingleResponse{
Request: checkBody,
ClientCheckResponse: *singleResponse,
Error: err,
// Handle nil response
var checkResponse ClientCheckResponse
if singleResponse != nil {
checkResponse = *singleResponse
}

return nil
return &batchCheckResult{
Index: index,
Response: ClientBatchCheckClientSingleResponse{
Request: checkBody,
ClientCheckResponse: checkResponse,
Error: err,
},
}, nil
})
}

if err := group.Wait(); err != nil {
results, err := checkPool.Wait()
if err != nil {
return nil, err
}

for _, result := range results {
response[result.Index] = result.Response
}

return &response, nil
}

Expand Down
6 changes: 4 additions & 2 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -1764,7 +1765,8 @@ func TestOpenFgaClient(t *testing.T) {
t.Fatalf("Expect error with invalid auth but there is none")
}

if _, ok := err.(openfga.FgaApiAuthenticationError); !ok {
var authErr openfga.FgaApiAuthenticationError
if !errors.As(err, &authErr) {
t.Fatalf("Expected an api auth error")
}

Expand All @@ -1782,7 +1784,7 @@ func TestOpenFgaClient(t *testing.T) {
t.Fatalf("Expect error with invalid auth but there is none")
}

if _, ok := err.(openfga.FgaApiAuthenticationError); !ok {
if !errors.As(err, &authErr) {
t.Fatalf("Expected an api auth error")
}
})
Expand Down
Loading