Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions internal/errorutil/assertion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package errorutil

import (
"errors"
"testing"
)

// AssertTestError asserts the presence or absence of an error based on expectations.
// It can also check for a specific expected error.
//
// Parameters:
// - t: the testing instance (marked as Helper for accurate stack traces)
// - err: the actual error to validate
// - wantError: boolean indicating if an error is expected
// - wantSpecificErr: the specific expected error to match against (can be nil if not checking for a specific error)
// - funcName: descriptive name of the function being tested (for error messages)
//
// Example:
//
// err := myService.Update(ctx, req)
// testutil.AssertTestError(t, err, true, session.ErrSessionExpired, "Update()")
func AssertTestError(t *testing.T, err error, wantError bool, wantSpecificErr error, funcName string) {
t.Helper()

if !wantError {
if err != nil {
t.Fatalf("%s unexpected error: %v", funcName, err)
}
return
}

if err == nil {
if wantSpecificErr != nil {
t.Fatalf("%s expected error %v but got nil", funcName, wantSpecificErr)
} else {
t.Fatalf("%s expected an error but got nil", funcName)
}
return
}

if wantSpecificErr != nil && !errors.Is(err, wantSpecificErr) {
t.Fatalf("%s error = %v, want %v", funcName, err, wantSpecificErr)
}
}
26 changes: 17 additions & 9 deletions session/database/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ func (s *databaseService) Get(ctx context.Context, req *session.GetRequest) (*se
}).
First(&foundSession).Error
if err != nil {
// For any error including ErrRecordNotFound, return it as a system error.
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("%w with id %q", session.ErrSessionNotFound, sessionID)
}

// For any other database error, return it as a system error.
return nil, fmt.Errorf("database error while fetching session: %w", err)
}

Expand Down Expand Up @@ -312,6 +316,10 @@ func (s *databaseService) Delete(ctx context.Context, req *session.DeleteRequest
return fmt.Errorf("database error during session deletion: %w", result.Error)
}

if result.RowsAffected == 0 {
return fmt.Errorf("%w with id %q", session.ErrSessionNotFound, sessionID)
}

return nil // Returning nil commits the transaction
})
}
Expand Down Expand Up @@ -348,24 +356,24 @@ func (s *databaseService) AppendEvent(ctx context.Context, curSession session.Se

// applyEvent fetches the session, validates it, applies state changes from an
// event, and saves the event atomically.
func (s *databaseService) applyEvent(ctx context.Context, session *localSession, event *session.Event) error {
func (s *databaseService) applyEvent(ctx context.Context, localSession *localSession, event *session.Event) error {
// Wrap database operations in a single transaction.
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Fetch the session object from storage.
var storageSess storageSession
err := tx.Where(&storageSession{AppName: session.AppName(), UserID: session.UserID(), ID: session.ID()}).
err := tx.Where(&storageSession{AppName: localSession.AppName(), UserID: localSession.UserID(), ID: localSession.ID()}).
First(&storageSess).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("session not found, cannot apply event")
return fmt.Errorf("%w with id %q", session.ErrSessionNotFound, localSession.ID())
}
return fmt.Errorf("failed to get session: %w", err)
}

// Ensure the session object is not stale.
// We use UnixNano() for microsecond-level precision, matching the Python code.
storageUpdateTime := storageSess.UpdateTime.UnixNano()
sessionUpdateTime := session.updatedAt.UnixNano()
sessionUpdateTime := localSession.updatedAt.UnixNano()
if storageUpdateTime > sessionUpdateTime {
return fmt.Errorf(
"stale session error: last update time from request (%s) is older than in database (%s)",
Expand All @@ -375,11 +383,11 @@ func (s *databaseService) applyEvent(ctx context.Context, session *localSession,
}

// Fetch App and User states.
storageApp, err := fetchStorageAppState(tx, session.AppName())
storageApp, err := fetchStorageAppState(tx, localSession.AppName())
if err != nil {
return err
}
storageUser, err := fetchStorageUserState(tx, session.AppName(), session.UserID())
storageUser, err := fetchStorageUserState(tx, localSession.AppName(), localSession.UserID())
if err != nil {
return err
}
Expand All @@ -406,7 +414,7 @@ func (s *databaseService) applyEvent(ctx context.Context, session *localSession,
}

// Create the new event record in the database.
storageEv, err := createStorageEvent(session, event)
storageEv, err := createStorageEvent(localSession, event)
if err != nil {
return fmt.Errorf("failed to map event to storage model: %w", err)
}
Expand All @@ -420,7 +428,7 @@ func (s *databaseService) applyEvent(ctx context.Context, session *localSession,
return fmt.Errorf("failed to save session state: %w", err)
}

session.updatedAt = storageSess.UpdateTime
localSession.updatedAt = storageSess.UpdateTime

return nil // Returning nil commits the transaction.
})
Expand Down
67 changes: 41 additions & 26 deletions session/database/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"gorm.io/driver/sqlite"
"gorm.io/gorm"

"google.golang.org/adk/internal/errorutil"
"google.golang.org/adk/model"
"google.golang.org/adk/session"
)
Expand Down Expand Up @@ -80,10 +81,8 @@ func Test_databaseService_Create(t *testing.T) {
s := tt.setup(t)

got, err := s.Create(t.Context(), tt.req)
if (err != nil) != tt.wantErr {
t.Fatalf("databaseService.Create() error = %v, wantErr %v", err, tt.wantErr)
return
}

errorutil.AssertTestError(t, err, tt.wantErr, nil, "databaseService.Create()")

if err != nil {
return
Expand Down Expand Up @@ -119,10 +118,11 @@ func Test_databaseService_Create(t *testing.T) {

func Test_databaseService_Delete(t *testing.T) {
tests := []struct {
name string
req *session.DeleteRequest
setup func(t *testing.T) *databaseService
wantErr bool
name string
req *session.DeleteRequest
setup func(t *testing.T) *databaseService
wantErr bool
wantNotFoundErr bool
}{
{
name: "delete ok",
Expand All @@ -134,20 +134,28 @@ func Test_databaseService_Delete(t *testing.T) {
},
},
{
name: "no error when not found",
name: "error when session not found",
setup: serviceDbWithData,
req: &session.DeleteRequest{
AppName: "appTest",
UserID: "user1",
SessionID: "session1",
},
wantErr: true,
wantNotFoundErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := tt.setup(t)
if err := s.Delete(t.Context(), tt.req); (err != nil) != tt.wantErr {
t.Errorf("databaseService.Delete() error = %v, wantErr %v", err, tt.wantErr)
err := s.Delete(t.Context(), tt.req)
var wantSpecificErr error
if tt.wantNotFoundErr {
wantSpecificErr = session.ErrSessionNotFound
}
errorutil.AssertTestError(t, err, tt.wantErr, wantSpecificErr, "databaseService.Delete()")
if err != nil {
return
}
})
}
Expand Down Expand Up @@ -220,12 +228,13 @@ func Test_databaseService_Get(t *testing.T) {
}

tests := []struct {
name string
req *session.GetRequest
setup func(t *testing.T) *databaseService
wantResponse *session.GetResponse
wantEvents []*session.Event
wantErr bool
name string
req *session.GetRequest
setup func(t *testing.T) *databaseService
wantResponse *session.GetResponse
wantEvents []*session.Event
wantErr bool
wantNotFoundErr bool
}{
{
name: "ok",
Expand Down Expand Up @@ -255,7 +264,8 @@ func Test_databaseService_Get(t *testing.T) {
UserID: "user1",
SessionID: "session1",
},
wantErr: true,
wantErr: true,
wantNotFoundErr: true,
},
{
name: "get session respects user id",
Expand Down Expand Up @@ -338,11 +348,12 @@ func Test_databaseService_Get(t *testing.T) {
s := tt.setup(t)

got, err := s.Get(t.Context(), tt.req)
if (err != nil) != tt.wantErr {
t.Fatalf("databaseService.Get() error = %v, wantErr %v", err, tt.wantErr)
return
}

var wantSpecificErr error
if tt.wantNotFoundErr {
wantSpecificErr = session.ErrSessionNotFound
}
errorutil.AssertTestError(t, err, tt.wantErr, wantSpecificErr, "databaseService.Get()")
if err != nil {
return
}
Expand Down Expand Up @@ -482,6 +493,7 @@ func Test_databaseService_AppendEvent(t *testing.T) {
wantStoredSession *localSession // State of the session after Get
wantEventCount int // Expected event count in storage
wantErr bool
wantNotFoundErr bool
}{
{
name: "append event to the session and overwrite in storage",
Expand Down Expand Up @@ -567,7 +579,8 @@ func Test_databaseService_AppendEvent(t *testing.T) {
Partial: false,
},
},
wantErr: true,
wantErr: true,
wantNotFoundErr: true,
},
{
name: "append event with bytes content",
Expand Down Expand Up @@ -725,10 +738,12 @@ func Test_databaseService_AppendEvent(t *testing.T) {

tt.session.updatedAt = time.Now() // set updatedAt value to pass stale validation
err := s.AppendEvent(ctx, tt.session, tt.event)
if (err != nil) != tt.wantErr {
t.Errorf("databaseService.AppendEvent() error = %v, wantErr %v", err, tt.wantErr)
}

var wantSpecificErr error
if tt.wantNotFoundErr {
wantSpecificErr = session.ErrSessionNotFound
}
errorutil.AssertTestError(t, err, tt.wantErr, wantSpecificErr, "databaseService.AppendEvent()")
if err != nil {
return
}
Expand Down
24 changes: 24 additions & 0 deletions session/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package session

import "errors"

var (
// ErrSessionNotFound is returned when a session is not found.
ErrSessionNotFound = errors.New("session not found")
// ErrStateKeyNotExist is the error thrown when key does not exist.
ErrStateKeyNotExist = errors.New("state key does not exist")
)
12 changes: 9 additions & 3 deletions session/inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (s *inMemoryService) Get(ctx context.Context, req *GetRequest) (*GetRespons

res, ok := s.sessions.Get(id.Encode())
if !ok {
return nil, fmt.Errorf("session %+v not found", req.SessionID)
return nil, fmt.Errorf("%w with id %q", ErrSessionNotFound, req.SessionID)
}

copiedSession := copySessionWithoutStateAndEvents(res)
Expand Down Expand Up @@ -191,7 +191,13 @@ func (s *inMemoryService) Delete(ctx context.Context, req *DeleteRequest) error
sessionID: sessionID,
}

s.sessions.Delete(id.Encode())
encodedKey := id.Encode()
_, ok := s.sessions.Get(encodedKey)
if !ok {
return fmt.Errorf("%w with id %q", ErrSessionNotFound, req.SessionID)
}

s.sessions.Delete(encodedKey)
return nil
}

Expand All @@ -216,7 +222,7 @@ func (s *inMemoryService) AppendEvent(ctx context.Context, curSession Session, e

stored_session, ok := s.sessions.Get(sess.id.Encode())
if !ok {
return fmt.Errorf("session not found, cannot apply event")
return fmt.Errorf("%w with id %q", ErrSessionNotFound, sess.ID())
}

// update the in-memory session
Expand Down
Loading