From 3952f9437419fe5970cb984ccb22c0ff53c64831 Mon Sep 17 00:00:00 2001 From: git-hulk Date: Fri, 19 Dec 2025 17:53:02 +0800 Subject: [PATCH] Implement RWLock in internalArtifacts to prevent data races --- internal/context/callback_context.go | 11 +++- internal/context/callback_context_test.go | 80 +++++++++++++++++++++++ internal/toolinternal/context.go | 11 +++- 3 files changed, 96 insertions(+), 6 deletions(-) create mode 100644 internal/context/callback_context_test.go diff --git a/internal/context/callback_context.go b/internal/context/callback_context.go index 0c2870fad..f800c027a 100644 --- a/internal/context/callback_context.go +++ b/internal/context/callback_context.go @@ -17,6 +17,7 @@ package context import ( "context" "iter" + "sync" "google.golang.org/genai" @@ -27,7 +28,8 @@ import ( type internalArtifacts struct { agent.Artifacts - eventActions *session.EventActions + eventActions *session.EventActions + mu sync.RWMutex } func (ia *internalArtifacts) Save(ctx context.Context, name string, data *genai.Part) (*artifact.SaveResponse, error) { @@ -36,11 +38,14 @@ func (ia *internalArtifacts) Save(ctx context.Context, name string, data *genai. return resp, err } if ia.eventActions != nil { + ia.mu.Lock() + defer ia.mu.Unlock() if ia.eventActions.ArtifactDelta == nil { ia.eventActions.ArtifactDelta = make(map[string]int64) } - // TODO: RWLock, check the version stored is newer in case multiple tools save the same file. - ia.eventActions.ArtifactDelta[name] = resp.Version + if current, ok := ia.eventActions.ArtifactDelta[name]; !ok || resp.Version > current { + ia.eventActions.ArtifactDelta[name] = resp.Version + } } return resp, nil } diff --git a/internal/context/callback_context_test.go b/internal/context/callback_context_test.go new file mode 100644 index 000000000..e0dde9455 --- /dev/null +++ b/internal/context/callback_context_test.go @@ -0,0 +1,80 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "context" + "testing" + + "google.golang.org/genai" + + "google.golang.org/adk/artifact" + "google.golang.org/adk/session" +) + +type fakeArtifacts struct { + version int64 +} + +func (f *fakeArtifacts) Save(ctx context.Context, name string, data *genai.Part) (*artifact.SaveResponse, error) { + return &artifact.SaveResponse{Version: f.version}, nil +} + +func (f *fakeArtifacts) List(ctx context.Context) (*artifact.ListResponse, error) { + return nil, nil +} + +func (f *fakeArtifacts) Load(ctx context.Context, name string) (*artifact.LoadResponse, error) { + return nil, nil +} + +func (f *fakeArtifacts) LoadVersion(ctx context.Context, name string, version int) (*artifact.LoadResponse, error) { + return nil, nil +} + +func TestInternalArtifactsSaveKeepsNewestVersion(t *testing.T) { + t.Parallel() + + actions := &session.EventActions{} + fake := &fakeArtifacts{version: 1} + ia := &internalArtifacts{ + Artifacts: fake, + eventActions: actions, + } + + ctx := context.Background() + if _, err := ia.Save(ctx, "file.txt", nil); err != nil { + t.Fatalf("Save returned error: %v", err) + } + if got := actions.ArtifactDelta["file.txt"]; got != 1 { + t.Fatalf("expected version 1 after first save, got %d", got) + } + + fake.version = 3 + if _, err := ia.Save(ctx, "file.txt", nil); err != nil { + t.Fatalf("Save returned error: %v", err) + } + if got := actions.ArtifactDelta["file.txt"]; got != 3 { + t.Fatalf("expected version 3 after newer save, got %d", got) + } + + fake.version = 2 + if _, err := ia.Save(ctx, "file.txt", nil); err != nil { + t.Fatalf("Save returned error: %v", err) + } + if got := actions.ArtifactDelta["file.txt"]; got != 3 { + t.Fatalf("expected version 3 after older save, got %d", got) + } +} diff --git a/internal/toolinternal/context.go b/internal/toolinternal/context.go index 9a3525d84..4ec512982 100644 --- a/internal/toolinternal/context.go +++ b/internal/toolinternal/context.go @@ -16,6 +16,7 @@ package toolinternal import ( "context" + "sync" "github.com/google/uuid" "google.golang.org/genai" @@ -30,7 +31,8 @@ import ( type internalArtifacts struct { agent.Artifacts - eventActions *session.EventActions + eventActions *session.EventActions + mu sync.RWMutex } func (ia *internalArtifacts) Save(ctx context.Context, name string, data *genai.Part) (*artifact.SaveResponse, error) { @@ -39,11 +41,14 @@ func (ia *internalArtifacts) Save(ctx context.Context, name string, data *genai. return resp, err } if ia.eventActions != nil { + ia.mu.Lock() + defer ia.mu.Unlock() if ia.eventActions.ArtifactDelta == nil { ia.eventActions.ArtifactDelta = make(map[string]int64) } - // TODO: RWLock, check the version stored is newer in case multiple tools save the same file. - ia.eventActions.ArtifactDelta[name] = resp.Version + if current, ok := ia.eventActions.ArtifactDelta[name]; !ok || resp.Version > current { + ia.eventActions.ArtifactDelta[name] = resp.Version + } } return resp, nil }