From 206e7b5abc91db63aefa7a878e69f284d44cb3fb Mon Sep 17 00:00:00 2001 From: Tanvir Ahmed Date: Mon, 30 Mar 2026 00:27:15 +0000 Subject: [PATCH 1/6] chore: Many updates and more refactors --- models/plugins_context.go | 12 + plugins/access-control/api.go | 61 +- .../handlers/permission_handlers.go | 147 ++ .../handlers/permission_handlers_test.go | 566 +++++++ .../access-control/handlers/role_handlers.go | 147 ++ .../handlers/role_handlers_test.go | 625 ++++++++ .../handlers/role_permission_handlers.go | 286 +--- .../handlers/role_permission_handlers_test.go | 1029 +++++-------- .../access-control/handlers/shared_helpers.go | 35 + .../handlers/user_access_handlers.go | 59 + .../handlers/user_access_handlers_test.go | 241 +++ .../handlers/user_roles_handlers.go | 86 +- .../handlers/user_roles_handlers_test.go | 790 +++++++--- plugins/access-control/hooks_test.go | 204 --- plugins/access-control/migrations.go | 236 +-- .../access-control/migrationset/migrations.go | 241 +++ plugins/access-control/plugin.go | 26 +- plugins/access-control/repositories/errors.go | 30 + .../access-control/repositories/interfaces.go | 19 +- .../repositories/permissions_repository.go | 84 ++ .../permissions_repository_test.go | 265 ++++ .../repository_test_helpers_test.go | 55 - .../role_permission_repository.go | 257 +--- .../role_permission_repository_test.go | 323 +++- .../repositories/roles_repository.go | 88 ++ .../repositories/roles_repository_test.go | 388 +++++ .../repositories/user_access_repository.go | 136 +- .../user_access_repository_test.go | 365 +++-- .../repositories/user_roles_repository.go | 189 +++ .../user_roles_repository_test.go | 400 +++++ plugins/access-control/routes.go | 50 +- .../services/permissions_service.go | 141 ++ .../services/permissions_service_test.go | 403 +++++ .../services/role_permission_service.go | 351 +---- .../services/role_permission_service_test.go | 140 +- .../access-control/services/roles_service.go | 156 ++ .../services/roles_service_test.go | 84 ++ .../services/user_access_service.go | 35 +- .../services/user_access_service_test.go | 112 +- .../services/user_roles_service.go | 98 ++ .../services/user_roles_service_test.go | 78 + plugins/access-control/tests/test_helpers.go | 200 ++- .../usecases/permissions_usecase.go | 36 + .../usecases/role_permission_mock_test.go | 126 -- .../usecases/role_permission_usecase.go | 64 +- .../usecases/role_permission_usecase_test.go | 1320 ----------------- .../access-control/usecases/roles_usecase.go | 36 + plugins/access-control/usecases/usecases.go | 120 +- .../usecases/user_access_usecase.go | 22 +- .../usecases/user_access_usecase_test.go | 42 - .../usecases/user_roles_usecase.go | 36 + 51 files changed, 6573 insertions(+), 4467 deletions(-) create mode 100644 models/plugins_context.go create mode 100644 plugins/access-control/handlers/permission_handlers.go create mode 100644 plugins/access-control/handlers/permission_handlers_test.go create mode 100644 plugins/access-control/handlers/role_handlers.go create mode 100644 plugins/access-control/handlers/role_handlers_test.go create mode 100644 plugins/access-control/handlers/shared_helpers.go create mode 100644 plugins/access-control/handlers/user_access_handlers.go create mode 100644 plugins/access-control/handlers/user_access_handlers_test.go delete mode 100644 plugins/access-control/hooks_test.go create mode 100644 plugins/access-control/migrationset/migrations.go create mode 100644 plugins/access-control/repositories/errors.go create mode 100644 plugins/access-control/repositories/permissions_repository.go create mode 100644 plugins/access-control/repositories/permissions_repository_test.go delete mode 100644 plugins/access-control/repositories/repository_test_helpers_test.go create mode 100644 plugins/access-control/repositories/roles_repository.go create mode 100644 plugins/access-control/repositories/roles_repository_test.go create mode 100644 plugins/access-control/repositories/user_roles_repository.go create mode 100644 plugins/access-control/repositories/user_roles_repository_test.go create mode 100644 plugins/access-control/services/permissions_service.go create mode 100644 plugins/access-control/services/permissions_service_test.go create mode 100644 plugins/access-control/services/roles_service.go create mode 100644 plugins/access-control/services/roles_service_test.go create mode 100644 plugins/access-control/services/user_roles_service.go create mode 100644 plugins/access-control/services/user_roles_service_test.go create mode 100644 plugins/access-control/usecases/permissions_usecase.go delete mode 100644 plugins/access-control/usecases/role_permission_mock_test.go delete mode 100644 plugins/access-control/usecases/role_permission_usecase_test.go create mode 100644 plugins/access-control/usecases/roles_usecase.go delete mode 100644 plugins/access-control/usecases/user_access_usecase_test.go create mode 100644 plugins/access-control/usecases/user_roles_usecase.go diff --git a/models/plugins_context.go b/models/plugins_context.go new file mode 100644 index 00000000..09cbd49e --- /dev/null +++ b/models/plugins_context.go @@ -0,0 +1,12 @@ +package models + +const ( + ContextAccessControlAssignRole ContextKey = "access_control.assign_role" +) + +// Access Control + +type AccessControlAssignRoleContext struct { + UserID string + RoleName string +} diff --git a/plugins/access-control/api.go b/plugins/access-control/api.go index aa024c8b..b12cb211 100644 --- a/plugins/access-control/api.go +++ b/plugins/access-control/api.go @@ -3,38 +3,19 @@ package accesscontrol import ( "context" - "github.com/Authula/authula/plugins/access-control/repositories" "github.com/Authula/authula/plugins/access-control/types" "github.com/Authula/authula/plugins/access-control/usecases" ) type API struct { - useCases *usecases.UseCases - rolePermissionRepo repositories.RolePermissionRepository - userAccessRepo repositories.UserAccessRepository + useCases *usecases.UseCases } -func NewAPI( - useCases *usecases.UseCases, - rolePermissionRepo repositories.RolePermissionRepository, - userAccessRepo repositories.UserAccessRepository, -) *API { - return &API{ - useCases: useCases, - rolePermissionRepo: rolePermissionRepo, - userAccessRepo: userAccessRepo, - } +func NewAPI(useCases *usecases.UseCases) *API { + return &API{useCases: useCases} } -func (a *API) RolePermissionRepository() repositories.RolePermissionRepository { - return a.rolePermissionRepo -} - -func (a *API) UserAccessRepository() repositories.UserAccessRepository { - return a.userAccessRepo -} - -// Roles and permissions +// Roles func (a *API) GetAllRoles(ctx context.Context) ([]types.Role, error) { return a.useCases.GetAllRoles(ctx) @@ -56,6 +37,8 @@ func (a *API) DeleteRole(ctx context.Context, roleID string) error { return a.useCases.DeleteRole(ctx, roleID) } +// Permissions + func (a *API) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { return a.useCases.CreatePermission(ctx, req) } @@ -64,6 +47,10 @@ func (a *API) GetAllPermissions(ctx context.Context) ([]types.Permission, error) return a.useCases.GetAllPermissions(ctx) } +func (a *API) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + return a.useCases.GetPermissionByID(ctx, permissionID) +} + func (a *API) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { return a.useCases.GetRolePermissions(ctx, roleID) } @@ -76,6 +63,8 @@ func (a *API) DeletePermission(ctx context.Context, permissionID string) error { return a.useCases.DeletePermission(ctx, permissionID) } +// Role permissions + func (a *API) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { return a.useCases.AddPermissionToRole(ctx, roleID, permissionID, grantedByUserID) } @@ -88,12 +77,16 @@ func (a *API) ReplaceRolePermissions(ctx context.Context, roleID string, permiss return a.useCases.ReplaceRolePermissions(ctx, roleID, permissionIDs, grantedByUserID) } -// User roles and permissions +// User roles func (a *API) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { return a.useCases.GetUserRoles(ctx, userID) } +func (a *API) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { + return a.useCases.GetUserWithRolesByID(ctx, userID) +} + func (a *API) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { return a.useCases.AssignRoleToUser(ctx, userID, req, assignedByUserID) } @@ -106,20 +99,20 @@ func (a *API) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []str return a.useCases.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) } -func (a *API) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - return a.useCases.GetUserEffectivePermissions(ctx, userID) -} +// User access -// User access and permissions +func (a *API) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { + return a.useCases.GetUserWithPermissionsByID(ctx, userID) +} -func (a *API) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - return a.useCases.HasPermissions(ctx, userID, requiredPermissions) +func (a *API) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { + return a.useCases.GetUserAuthorizationProfile(ctx, userID) } -func (a *API) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return a.useCases.GetUserWithRolesByID(ctx, userID) +func (a *API) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + return a.useCases.GetUserEffectivePermissions(ctx, userID) } -func (a *API) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - return a.useCases.GetUserWithPermissionsByID(ctx, userID) +func (a *API) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { + return a.useCases.HasPermissions(ctx, userID, requiredPermissions) } diff --git a/plugins/access-control/handlers/permission_handlers.go b/plugins/access-control/handlers/permission_handlers.go new file mode 100644 index 00000000..74c07845 --- /dev/null +++ b/plugins/access-control/handlers/permission_handlers.go @@ -0,0 +1,147 @@ +package handlers + +import ( + "net/http" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +type CreatePermissionHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewCreatePermissionHandler(useCase *usecases.PermissionsUseCase) *CreatePermissionHandler { + return &CreatePermissionHandler{useCase: useCase} +} + +func (h *CreatePermissionHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + var payload types.CreatePermissionRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + permission, err := h.useCase.CreatePermission(r.Context(), payload) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, &types.CreatePermissionResponse{ + Permission: permission, + }) + } +} + +type GetAllPermissionsHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewGetAllPermissionsHandler(useCase *usecases.PermissionsUseCase) *GetAllPermissionsHandler { + return &GetAllPermissionsHandler{useCase: useCase} +} + +func (h *GetAllPermissionsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + permissions, err := h.useCase.GetAllPermissions(r.Context()) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, permissions) + } +} + +type GetPermissionByIDHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewGetPermissionByIDHandler(useCase *usecases.PermissionsUseCase) *GetPermissionByIDHandler { + return &GetPermissionByIDHandler{useCase: useCase} +} + +func (h *GetPermissionByIDHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + permissionID := r.PathValue("permission_id") + + permission, err := h.useCase.GetPermissionByID(r.Context(), permissionID) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, permission) + } +} + +type UpdatePermissionHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewUpdatePermissionHandler(useCase *usecases.PermissionsUseCase) *UpdatePermissionHandler { + return &UpdatePermissionHandler{useCase: useCase} +} + +func (h *UpdatePermissionHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + permissionID := r.PathValue("permission_id") + + var payload types.UpdatePermissionRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + permission, err := h.useCase.UpdatePermission(r.Context(), permissionID, payload) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.UpdatePermissionResponse{ + Permission: permission, + }) + } +} + +type DeletePermissionHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewDeletePermissionHandler(useCase *usecases.PermissionsUseCase) *DeletePermissionHandler { + return &DeletePermissionHandler{useCase: useCase} +} + +func (h *DeletePermissionHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + permissionID := r.PathValue("permission_id") + + if err := h.useCase.DeletePermission(r.Context(), permissionID); err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.DeletePermissionResponse{ + Message: "permission deleted", + }) + } +} diff --git a/plugins/access-control/handlers/permission_handlers_test.go b/plugins/access-control/handlers/permission_handlers_test.go new file mode 100644 index 00000000..3dec403f --- /dev/null +++ b/plugins/access-control/handlers/permission_handlers_test.go @@ -0,0 +1,566 @@ +package handlers + +import ( + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +func TestGetAllPermissionsHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "read access" + + tests := []struct { + name string + setupMock func(*accesscontroltests.MockPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetAllPermissions", mock.Anything).Return(([]types.Permission)(nil), constants.ErrForbidden).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "forbidden"}, + }, + { + name: "success", + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetAllPermissions", mock.Anything).Return([]types.Permission{ + { + ID: "perm-1", + Key: "users.read", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: "perm-2", + Key: "users.write", + Description: nil, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: []types.Permission{ + { + ID: "perm-1", + Key: "users.read", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: "perm-2", + Key: "users.write", + Description: nil, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + handler := NewGetAllPermissionsHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/permissions", nil, nil) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[[]types.Permission](t, reqCtx) + assertPermissionsEqual(t, payload, tc.expectedBody.([]types.Permission)) + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestCreatePermissionHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "create access" + + tests := []struct { + name string + body []byte + setupMock func(*accesscontroltests.MockPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "service error", + body: internaltests.MarshalToJSON(t, types.CreatePermissionRequest{ + Key: "users.create", + Description: description, + IsSystem: false, + }), + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("CreatePermission", mock.Anything, mock.MatchedBy(func(permission *types.Permission) bool { + return permission != nil && permission.Key == "users.create" && permission.Description != nil && *permission.Description == *description && !permission.IsSystem && permission.ID != "" + })).Return(constants.ErrBadRequest).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "success", + body: internaltests.MarshalToJSON(t, types.CreatePermissionRequest{ + Key: "users.create", + Description: description, + IsSystem: false, + }), + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("CreatePermission", mock.Anything, mock.MatchedBy(func(permission *types.Permission) bool { + return permission != nil && permission.Key == "users.create" && permission.Description != nil && *permission.Description == *description && !permission.IsSystem && permission.ID != "" + })).Run(func(args mock.Arguments) { + permission := args.Get(1).(*types.Permission) + permission.ID = "perm-1" + permission.CreatedAt = fixedTime + permission.UpdatedAt = fixedTime + }).Return(nil).Once() + }, + expectedStatus: http.StatusCreated, + expectedBody: types.CreatePermissionResponse{ + Permission: &types.Permission{ + ID: "perm-1", + Key: "users.create", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + handler := NewCreatePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/permissions", tc.body, nil) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusCreated { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.CreatePermissionResponse](t, reqCtx) + assertCreatePermissionResponseEqual(t, payload, tc.expectedBody.(types.CreatePermissionResponse)) + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestGetPermissionByIDHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "read access" + + tests := []struct { + name string + permissionID string + setupMock func(*accesscontroltests.MockPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + permissionID: "perm-404", + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-404").Return((*types.Permission)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + permissionID: "perm-1", + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ + ID: "perm-1", + Key: "users.read", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: &types.Permission{ + ID: "perm-1", + Key: "users.read", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + handler := NewGetPermissionByIDHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/permissions/"+tc.permissionID, nil, nil) + req.SetPathValue("permission_id", tc.permissionID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.Permission](t, reqCtx) + assertPermissionEqual(t, payload, *tc.expectedBody.(*types.Permission)) + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestUpdatePermissionHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + updatedDescription := "updated read access" + updatedDescriptionPtr := &updatedDescription + existingDescription := new(string) + *existingDescription = "read access" + + tests := []struct { + name string + permissionID string + body []byte + setupMock func(*accesscontroltests.MockPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + permissionID: "perm-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "service error", + permissionID: "perm-1", + body: internaltests.MarshalToJSON(t, types.UpdatePermissionRequest{Description: updatedDescriptionPtr}), + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", Description: existingDescription, IsSystem: false}, nil).Once() + m.On("UpdatePermission", mock.Anything, "perm-1", mock.MatchedBy(func(description *string) bool { + return description != nil && *description == *updatedDescriptionPtr + })).Return(false, constants.ErrBadRequest).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "success", + permissionID: "perm-1", + body: internaltests.MarshalToJSON(t, types.UpdatePermissionRequest{Description: updatedDescriptionPtr}), + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", Description: existingDescription, IsSystem: false}, nil).Once() + m.On("UpdatePermission", mock.Anything, "perm-1", mock.MatchedBy(func(description *string) bool { + return description != nil && *description == *updatedDescriptionPtr + })).Return(true, nil).Once() + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ + ID: "perm-1", + Key: "users.read", + Description: updatedDescriptionPtr, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.UpdatePermissionResponse{ + Permission: &types.Permission{ + ID: "perm-1", + Key: "users.read", + Description: updatedDescriptionPtr, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + handler := NewUpdatePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/permissions/"+tc.permissionID, tc.body, nil) + req.SetPathValue("permission_id", tc.permissionID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.UpdatePermissionResponse](t, reqCtx) + assertUpdatePermissionResponseEqual(t, payload, tc.expectedBody.(types.UpdatePermissionResponse)) + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestDeletePermissionHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + permissionID string + setupMock func(*accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockUserAccessRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + permissionID: "perm-1", + setupMock: func(m *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: false}, nil).Once() + userAccessRepo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() + m.On("DeletePermission", mock.Anything, "perm-1").Return(false, errors.New("database error")).Once() + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: map[string]string{"message": "database error"}, + }, + { + name: "success", + permissionID: "perm-1", + setupMock: func(m *accesscontroltests.MockPermissionsRepository, u *accesscontroltests.MockUserAccessRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: false}, nil).Once() + u.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() + m.On("DeletePermission", mock.Anything, "perm-1").Return(true, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.DeletePermissionResponse{Message: "permission deleted"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo, userAccessRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + handler := NewDeletePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/permissions/"+tc.permissionID, nil, nil) + req.SetPathValue("permission_id", tc.permissionID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.DeletePermissionResponse](t, reqCtx) + assertDeletePermissionResponseEqual(t, payload, tc.expectedBody.(types.DeletePermissionResponse)) + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func newPermissionsUseCase(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) *usecases.PermissionsUseCase { + return usecases.NewPermissionsUseCase(services.NewPermissionsService(permissionsRepo, userAccessRepo)) +} + +func assertPermissionsEqual(t *testing.T, got []types.Permission, want []types.Permission) { + t.Helper() + + if len(got) != len(want) { + t.Fatalf("expected %d permissions, got %d", len(want), len(got)) + } + + for i := range want { + assertPermissionEqual(t, got[i], want[i]) + } +} + +func assertPermissionEqual(t *testing.T, got types.Permission, want types.Permission) { + t.Helper() + + if got.ID != want.ID { + t.Fatalf("expected id %q, got %q", want.ID, got.ID) + } + if got.Key != want.Key { + t.Fatalf("expected key %q, got %q", want.Key, got.Key) + } + if got.IsSystem != want.IsSystem { + t.Fatalf("expected is_system %v, got %v", want.IsSystem, got.IsSystem) + } + if !timesEqual(got.CreatedAt, want.CreatedAt) { + t.Fatalf("expected created_at %v, got %v", want.CreatedAt, got.CreatedAt) + } + if !timesEqual(got.UpdatedAt, want.UpdatedAt) { + t.Fatalf("expected updated_at %v, got %v", want.UpdatedAt, got.UpdatedAt) + } + if !stringsEqualPtr(got.Description, want.Description) { + t.Fatalf("expected description %#v, got %#v", want.Description, got.Description) + } +} + +func assertCreatePermissionResponseEqual(t *testing.T, got types.CreatePermissionResponse, want types.CreatePermissionResponse) { + t.Helper() + + if got.Permission == nil || want.Permission == nil { + if got.Permission != want.Permission { + t.Fatalf("expected permission %#v, got %#v", want.Permission, got.Permission) + } + return + } + + assertPermissionEqual(t, *got.Permission, *want.Permission) +} + +func assertUpdatePermissionResponseEqual(t *testing.T, got types.UpdatePermissionResponse, want types.UpdatePermissionResponse) { + t.Helper() + + if got.Permission == nil || want.Permission == nil { + if got.Permission != want.Permission { + t.Fatalf("expected permission %#v, got %#v", want.Permission, got.Permission) + } + return + } + + assertPermissionEqual(t, *got.Permission, *want.Permission) +} + +func assertDeletePermissionResponseEqual(t *testing.T, got types.DeletePermissionResponse, want types.DeletePermissionResponse) { + t.Helper() + + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } +} + +func timesEqual(left, right time.Time) bool { + return left.Equal(right) +} + +func stringsEqualPtr(left, right *string) bool { + if left == nil || right == nil { + return left == right + } + return *left == *right +} diff --git a/plugins/access-control/handlers/role_handlers.go b/plugins/access-control/handlers/role_handlers.go new file mode 100644 index 00000000..51e45fc6 --- /dev/null +++ b/plugins/access-control/handlers/role_handlers.go @@ -0,0 +1,147 @@ +package handlers + +import ( + "net/http" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +type CreateRoleHandler struct { + useCase *usecases.RolesUseCase +} + +func NewCreateRoleHandler(useCase *usecases.RolesUseCase) *CreateRoleHandler { + return &CreateRoleHandler{useCase: useCase} +} + +func (h *CreateRoleHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + var payload types.CreateRoleRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + role, err := h.useCase.CreateRole(r.Context(), payload) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, &types.CreateRoleResponse{ + Role: role, + }) + } +} + +type GetAllRolesHandler struct { + useCase *usecases.RolesUseCase +} + +func NewGetAllRolesHandler(useCase *usecases.RolesUseCase) *GetAllRolesHandler { + return &GetAllRolesHandler{useCase: useCase} +} + +func (h *GetAllRolesHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + roles, err := h.useCase.GetAllRoles(r.Context()) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, roles) + } +} + +type GetRoleByIDHandler struct { + useCase *usecases.RolesUseCase +} + +func NewGetRoleByIDHandler(useCase *usecases.RolesUseCase) *GetRoleByIDHandler { + return &GetRoleByIDHandler{useCase: useCase} +} + +func (h *GetRoleByIDHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + roleID := r.PathValue("role_id") + + roleDetails, err := h.useCase.GetRoleByID(r.Context(), roleID) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, roleDetails) + } +} + +type UpdateRoleHandler struct { + useCase *usecases.RolesUseCase +} + +func NewUpdateRoleHandler(useCase *usecases.RolesUseCase) *UpdateRoleHandler { + return &UpdateRoleHandler{useCase: useCase} +} + +func (h *UpdateRoleHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + roleID := r.PathValue("role_id") + + var payload types.UpdateRoleRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + role, err := h.useCase.UpdateRole(r.Context(), roleID, payload) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.UpdateRoleResponse{ + Role: role, + }) + } +} + +type DeleteRoleHandler struct { + useCase *usecases.RolesUseCase +} + +func NewDeleteRoleHandler(useCase *usecases.RolesUseCase) *DeleteRoleHandler { + return &DeleteRoleHandler{useCase: useCase} +} + +func (h *DeleteRoleHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + roleID := r.PathValue("role_id") + + if err := h.useCase.DeleteRole(r.Context(), roleID); err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.DeleteRoleResponse{ + Message: "deleted role", + }) + } +} diff --git a/plugins/access-control/handlers/role_handlers_test.go b/plugins/access-control/handlers/role_handlers_test.go new file mode 100644 index 00000000..2da254a2 --- /dev/null +++ b/plugins/access-control/handlers/role_handlers_test.go @@ -0,0 +1,625 @@ +package handlers + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +func TestCreateRoleHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform administrator" + + tests := []struct { + name string + body []byte + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "service error", + body: internaltests.MarshalToJSON(t, types.CreateRoleRequest{ + Name: "Administrator", + Description: description, + IsSystem: true, + }), + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("CreateRole", mock.Anything, mock.MatchedBy(func(role *types.Role) bool { + return role != nil && role.Name == "Administrator" && role.Description != nil && *role.Description == *description && role.IsSystem && role.ID != "" + })).Return(constants.ErrUnauthorized).Once() + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: map[string]string{"message": "unauthorized"}, + }, + { + name: "success", + body: internaltests.MarshalToJSON(t, types.CreateRoleRequest{ + Name: "Administrator", + Description: description, + IsSystem: true, + }), + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("CreateRole", mock.Anything, mock.MatchedBy(func(role *types.Role) bool { + return role != nil && role.Name == "Administrator" && role.Description != nil && *role.Description == *description && role.IsSystem && role.ID != "" + })).Run(func(args mock.Arguments) { + role := args.Get(1).(*types.Role) + role.ID = "role-1" + role.CreatedAt = fixedTime + role.UpdatedAt = fixedTime + }).Return(nil).Once() + }, + expectedStatus: http.StatusCreated, + expectedBody: types.CreateRoleResponse{ + Role: &types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + handler := NewCreateRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/roles", tc.body, nil) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusCreated { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.CreateRoleResponse](t, reqCtx) + assertCreateRoleResponseEqual(t, payload, tc.expectedBody.(types.CreateRoleResponse)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestGetAllRolesHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform administrator" + + tests := []struct { + name string + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetAllRoles", mock.Anything).Return(([]types.Role)(nil), constants.ErrForbidden).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "forbidden"}, + }, + { + name: "success", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetAllRoles", mock.Anything).Return([]types.Role{ + { + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: "role-2", + Name: "Editor", + Description: nil, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: []types.Role{ + { + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: "role-2", + Name: "Editor", + Description: nil, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + handler := NewGetAllRolesHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles", nil, nil) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[[]types.Role](t, reqCtx) + assertRolesEqual(t, payload, tc.expectedBody.([]types.Role)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestGetRoleByIDHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform administrator" + grantedAt := fixedTime + grantedByUserID := new(string) + *grantedByUserID = "user-1" + + tests := []struct { + name string + roleID string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + roleID: "role-404", + setupMock: func(m *accesscontroltests.MockRolesRepository, _ *accesscontroltests.MockRolePermissionsRepository) { + m.On("GetRoleByID", mock.Anything, "role-404").Return((*types.Role)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockRolesRepository, rp *accesscontroltests.MockRolePermissionsRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + rp.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{ + { + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedByUserID: grantedByUserID, + GrantedAt: &grantedAt, + }, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.RoleDetails{ + Role: types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + Permissions: []types.UserPermissionInfo{ + { + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedByUserID: grantedByUserID, + GrantedAt: &grantedAt, + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, rolePermissionsRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + handler := NewGetRoleByIDHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles/"+tc.roleID, nil, nil) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.RoleDetails](t, reqCtx) + assertRoleDetailsEqual(t, payload, tc.expectedBody.(types.RoleDetails)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestUpdateRoleHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "updated administrator" + updatedName := "Administrator" + + tests := []struct { + name string + body []byte + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "service error", + body: internaltests.MarshalToJSON(t, types.UpdateRoleRequest{ + Name: &updatedName, + Description: description, + }), + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", IsSystem: true}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "cannot update system role"}, + }, + { + name: "success", + body: internaltests.MarshalToJSON(t, types.UpdateRoleRequest{ + Name: &updatedName, + Description: description, + }), + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", Description: description, IsSystem: false}, nil).Once() + m.On("UpdateRole", mock.Anything, "role-1", mock.MatchedBy(func(name *string) bool { + return name != nil && *name == "Administrator" + }), mock.MatchedBy(func(desc *string) bool { + return desc != nil && *desc == *description + })).Return(true, nil).Once() + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.UpdateRoleResponse{ + Role: &types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + handler := NewUpdateRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/roles/role-1", tc.body, nil) + req.SetPathValue("role_id", "role-1") + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.UpdateRoleResponse](t, reqCtx) + assertUpdateRoleResponseEqual(t, payload, tc.expectedBody.(types.UpdateRoleResponse)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestDeleteRoleHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + roleID string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserAccessRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockRolesRepository, _ *accesscontroltests.MockUserAccessRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", IsSystem: true}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "cannot update system role"}, + }, + { + name: "success", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockRolesRepository, u *accesscontroltests.MockUserAccessRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", IsSystem: false}, nil).Once() + u.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(0, nil).Once() + m.On("DeleteRole", mock.Anything, "role-1").Return(true, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.DeleteRoleResponse{Message: "deleted role"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, userAccessRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + handler := NewDeleteRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/roles/"+tc.roleID, nil, nil) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.DeleteRoleResponse](t, reqCtx) + assertDeleteRoleResponseEqual(t, payload, tc.expectedBody.(types.DeleteRoleResponse)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func newRolesUseCase(rolesRepo *accesscontroltests.MockRolesRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) *usecases.RolesUseCase { + return usecases.NewRolesUseCase(services.NewRolesService(rolesRepo, rolePermissionsRepo, userAccessRepo)) +} + +func assertRolesEqual(t *testing.T, got []types.Role, want []types.Role) { + t.Helper() + + if len(got) != len(want) { + t.Fatalf("expected %d roles, got %d", len(want), len(got)) + } + + for i := range want { + assertRoleEqual(t, got[i], want[i]) + } +} + +func assertRoleEqual(t *testing.T, got types.Role, want types.Role) { + t.Helper() + + if got.ID != want.ID { + t.Fatalf("expected id %q, got %q", want.ID, got.ID) + } + if got.Name != want.Name { + t.Fatalf("expected name %q, got %q", want.Name, got.Name) + } + if got.IsSystem != want.IsSystem { + t.Fatalf("expected is_system %v, got %v", want.IsSystem, got.IsSystem) + } + if !timesEqual(got.CreatedAt, want.CreatedAt) { + t.Fatalf("expected created_at %v, got %v", want.CreatedAt, got.CreatedAt) + } + if !timesEqual(got.UpdatedAt, want.UpdatedAt) { + t.Fatalf("expected updated_at %v, got %v", want.UpdatedAt, got.UpdatedAt) + } + if !stringsEqualPtr(got.Description, want.Description) { + t.Fatalf("expected description %#v, got %#v", want.Description, got.Description) + } +} + +func assertRoleDetailsEqual(t *testing.T, got types.RoleDetails, want types.RoleDetails) { + t.Helper() + + assertRoleEqual(t, got.Role, want.Role) + if len(got.Permissions) != len(want.Permissions) { + t.Fatalf("expected %d permissions, got %d", len(want.Permissions), len(got.Permissions)) + } + for i := range want.Permissions { + gotPerm := got.Permissions[i] + wantPerm := want.Permissions[i] + if gotPerm.PermissionID != wantPerm.PermissionID { + t.Fatalf("expected permission id %q, got %q", wantPerm.PermissionID, gotPerm.PermissionID) + } + if gotPerm.PermissionKey != wantPerm.PermissionKey { + t.Fatalf("expected permission key %q, got %q", wantPerm.PermissionKey, gotPerm.PermissionKey) + } + if !stringsEqualPtr(gotPerm.PermissionDescription, wantPerm.PermissionDescription) { + t.Fatalf("expected permission description %#v, got %#v", wantPerm.PermissionDescription, gotPerm.PermissionDescription) + } + if !stringsEqualPtr(gotPerm.GrantedByUserID, wantPerm.GrantedByUserID) { + t.Fatalf("expected granted_by_user_id %#v, got %#v", wantPerm.GrantedByUserID, gotPerm.GrantedByUserID) + } + if !timesEqualPtr(gotPerm.GrantedAt, wantPerm.GrantedAt) { + t.Fatalf("expected granted_at %#v, got %#v", wantPerm.GrantedAt, gotPerm.GrantedAt) + } + } +} + +func assertCreateRoleResponseEqual(t *testing.T, got types.CreateRoleResponse, want types.CreateRoleResponse) { + t.Helper() + + if got.Role == nil || want.Role == nil { + if got.Role != want.Role { + t.Fatalf("expected role %#v, got %#v", want.Role, got.Role) + } + return + } + + assertRoleEqual(t, *got.Role, *want.Role) +} + +func assertUpdateRoleResponseEqual(t *testing.T, got types.UpdateRoleResponse, want types.UpdateRoleResponse) { + t.Helper() + + if got.Role == nil || want.Role == nil { + if got.Role != want.Role { + t.Fatalf("expected role %#v, got %#v", want.Role, got.Role) + } + return + } + + assertRoleEqual(t, *got.Role, *want.Role) +} + +func assertDeleteRoleResponseEqual(t *testing.T, got types.DeleteRoleResponse, want types.DeleteRoleResponse) { + t.Helper() + + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } +} + +func timesEqualPtr(left, right *time.Time) bool { + if left == nil || right == nil { + return left == right + } + return left.Equal(*right) +} diff --git a/plugins/access-control/handlers/role_permission_handlers.go b/plugins/access-control/handlers/role_permission_handlers.go index 8b8eb271..60c28db6 100644 --- a/plugins/access-control/handlers/role_permission_handlers.go +++ b/plugins/access-control/handlers/role_permission_handlers.go @@ -9,171 +9,11 @@ import ( "github.com/Authula/authula/plugins/access-control/usecases" ) -type CreateRoleHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewCreateRoleHandler(useCase usecases.RolePermissionUseCase) *CreateRoleHandler { - return &CreateRoleHandler{useCase: useCase} -} - -func (h *CreateRoleHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - - var payload types.CreateRoleRequest - if err := util.ParseJSON(r, &payload); err != nil { - reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) - reqCtx.Handled = true - return - } - - role, err := h.useCase.CreateRole(r.Context(), payload) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusCreated, &types.CreateRoleResponse{ - Role: role, - }) - } -} - -type GetAllRolesHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewGetAllRolesHandler(useCase usecases.RolePermissionUseCase) *GetAllRolesHandler { - return &GetAllRolesHandler{useCase: useCase} -} - -func (h *GetAllRolesHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - - roles, err := h.useCase.GetAllRoles(r.Context()) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, roles) - } -} - -type GetRoleByIDHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewGetRoleByIDHandler(useCase usecases.RolePermissionUseCase) *GetRoleByIDHandler { - return &GetRoleByIDHandler{useCase: useCase} -} - -func (h *GetRoleByIDHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - roleID := r.PathValue("role_id") - - roleDetails, err := h.useCase.GetRoleByID(r.Context(), roleID) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, roleDetails) - } -} - -type UpdateRoleHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewUpdateRoleHandler(useCase usecases.RolePermissionUseCase) *UpdateRoleHandler { - return &UpdateRoleHandler{useCase: useCase} -} - -func (h *UpdateRoleHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - roleID := r.PathValue("role_id") - - var payload types.UpdateRoleRequest - if err := util.ParseJSON(r, &payload); err != nil { - reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) - reqCtx.Handled = true - return - } - - role, err := h.useCase.UpdateRole(r.Context(), roleID, payload) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.UpdateRoleResponse{ - Role: role, - }) - } -} - -type DeleteRoleHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewDeleteRoleHandler(useCase usecases.RolePermissionUseCase) *DeleteRoleHandler { - return &DeleteRoleHandler{useCase: useCase} -} - -func (h *DeleteRoleHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - roleID := r.PathValue("role_id") - - if err := h.useCase.DeleteRole(r.Context(), roleID); err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.DeleteRoleResponse{ - Message: "deleted role", - }) - } -} - -type GetAllPermissionsHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewGetAllPermissionsHandler(useCase usecases.RolePermissionUseCase) *GetAllPermissionsHandler { - return &GetAllPermissionsHandler{useCase: useCase} -} - -func (h *GetAllPermissionsHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - - permissions, err := h.useCase.GetAllPermissions(r.Context()) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, permissions) - } -} - type GetRolePermissionsHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.RolePermissionsUseCase } -func NewGetRolePermissionsHandler(useCase usecases.RolePermissionUseCase) *GetRolePermissionsHandler { +func NewGetRolePermissionsHandler(useCase *usecases.RolePermissionsUseCase) *GetRolePermissionsHandler { return &GetRolePermissionsHandler{useCase: useCase} } @@ -193,101 +33,11 @@ func (h *GetRolePermissionsHandler) Handler() http.HandlerFunc { } } -type CreatePermissionHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewCreatePermissionHandler(useCase usecases.RolePermissionUseCase) *CreatePermissionHandler { - return &CreatePermissionHandler{useCase: useCase} -} - -func (h *CreatePermissionHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - - var payload types.CreatePermissionRequest - if err := util.ParseJSON(r, &payload); err != nil { - reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) - reqCtx.Handled = true - return - } - - permission, err := h.useCase.CreatePermission(r.Context(), payload) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusCreated, &types.CreatePermissionResponse{ - Permission: permission, - }) - } -} - -type UpdatePermissionHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewUpdatePermissionHandler(useCase usecases.RolePermissionUseCase) *UpdatePermissionHandler { - return &UpdatePermissionHandler{useCase: useCase} -} - -func (h *UpdatePermissionHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - permissionID := r.PathValue("permission_id") - - var payload types.UpdatePermissionRequest - if err := util.ParseJSON(r, &payload); err != nil { - reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) - reqCtx.Handled = true - return - } - - permission, err := h.useCase.UpdatePermission(r.Context(), permissionID, payload) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.UpdatePermissionResponse{ - Permission: permission, - }) - } -} - -type DeletePermissionHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewDeletePermissionHandler(useCase usecases.RolePermissionUseCase) *DeletePermissionHandler { - return &DeletePermissionHandler{useCase: useCase} -} - -func (h *DeletePermissionHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - permissionID := r.PathValue("permission_id") - - if err := h.useCase.DeletePermission(r.Context(), permissionID); err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.DeletePermissionResponse{ - Message: "permission deleted", - }) - } -} - type AddRolePermissionHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.RolePermissionsUseCase } -func NewAddRolePermissionHandler(useCase usecases.RolePermissionUseCase) *AddRolePermissionHandler { +func NewAddRolePermissionHandler(useCase *usecases.RolePermissionsUseCase) *AddRolePermissionHandler { return &AddRolePermissionHandler{useCase: useCase} } @@ -316,10 +66,10 @@ func (h *AddRolePermissionHandler) Handler() http.HandlerFunc { } type ReplaceRolePermissionsHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.RolePermissionsUseCase } -func NewReplaceRolePermissionsHandler(useCase usecases.RolePermissionUseCase) *ReplaceRolePermissionsHandler { +func NewReplaceRolePermissionsHandler(useCase *usecases.RolePermissionsUseCase) *ReplaceRolePermissionsHandler { return &ReplaceRolePermissionsHandler{useCase: useCase} } @@ -348,10 +98,10 @@ func (h *ReplaceRolePermissionsHandler) Handler() http.HandlerFunc { } type RemoveRolePermissionHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.RolePermissionsUseCase } -func NewRemoveRolePermissionHandler(useCase usecases.RolePermissionUseCase) *RemoveRolePermissionHandler { +func NewRemoveRolePermissionHandler(useCase *usecases.RolePermissionsUseCase) *RemoveRolePermissionHandler { return &RemoveRolePermissionHandler{useCase: useCase} } @@ -372,23 +122,3 @@ func (h *RemoveRolePermissionHandler) Handler() http.HandlerFunc { }) } } - -func rolePermissionActorUserID(reqCtx *models.RequestContext) *string { - if reqCtx == nil || reqCtx.UserID == nil || *reqCtx.UserID == "" { - return nil - } - return reqCtx.UserID -} - -func respondRolePermissionError(reqCtx *models.RequestContext, err error) { - if reqCtx == nil { - return - } - - reqCtx.SetJSONResponse(mapRolePermissionErrorStatus(err), map[string]any{"message": mapHttpErrorMessage(err)}) - reqCtx.Handled = true -} - -func mapRolePermissionErrorStatus(err error) int { - return mapHttpErrorStatus(err) -} diff --git a/plugins/access-control/handlers/role_permission_handlers_test.go b/plugins/access-control/handlers/role_permission_handlers_test.go index aaded8eb..10d6f1dd 100644 --- a/plugins/access-control/handlers/role_permission_handlers_test.go +++ b/plugins/access-control/handlers/role_permission_handlers_test.go @@ -1,660 +1,451 @@ package handlers import ( - "errors" "net/http" "testing" + "time" "github.com/stretchr/testify/mock" internaltests "github.com/Authula/authula/internal/tests" - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" - "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" ) -func TestCreateRoleHandler(t *testing.T) { +func TestGetRolePermissionsHandler(t *testing.T) { t.Parallel() - t.Run("invalid request body", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewCreateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles", []byte("{invalid"), nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("CreateRole", mock.Anything, mock.AnythingOfType("*types.Role")).Return(accesscontrolconstants.ErrConflict).Once() - handler := NewCreateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles", internaltests.MarshalToJSON(t, types.CreateRoleRequest{Name: "admin"}), nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusConflict, "conflict") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("CreateRole", mock.Anything, mock.AnythingOfType("*types.Role")).Return(nil).Once() - handler := NewCreateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles", internaltests.MarshalToJSON(t, types.CreateRoleRequest{Name: "admin"}), nil) - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusCreated { - t.Fatalf("expected status %d, got %d", http.StatusCreated, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.CreateRoleResponse](t, reqCtx) - if payload.Role == nil { - t.Fatal("expected role key, got nil") - } - repo.AssertExpectations(t) - }) + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "read access" + grantedByUserID := new(string) + *grantedByUserID = "user-1" + + tests := []struct { + name string + roleID string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank role id", + roleID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "use case error", + roleID: "role-404", + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, _ *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-404").Return((*types.Role)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + roleID: "role-1", + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + rolePermissionsRepo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{ + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedByUserID: grantedByUserID, + GrantedAt: &fixedTime, + }}, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: []types.UserPermissionInfo{{ + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedByUserID: grantedByUserID, + GrantedAt: &fixedTime, + }}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, rolePermissionsRepo) + } + + useCase := newRolePermissionsUseCase(rolesRepo, permissionsRepo, rolePermissionsRepo) + handler := NewGetRolePermissionsHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles/"+tc.roleID+"/permissions", nil, nil) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[[]types.UserPermissionInfo](t, reqCtx) + assertUserPermissionInfosEqual(t, payload, tc.expectedBody.([]types.UserPermissionInfo)) + + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } -func TestGetAllRolesHandler(t *testing.T) { +func TestAddRolePermissionHandler(t *testing.T) { t.Parallel() - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetAllRoles", mock.Anything).Return(([]types.Role)(nil), errors.New("internal error")).Once() - handler := NewGetAllRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles", nil, nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusInternalServerError, "internal error") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetAllRoles", mock.Anything).Return([]types.Role{{ID: "role-1", Name: "admin"}}, nil).Once() - handler := NewGetAllRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles", nil, nil) - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[[]types.Role](t, reqCtx) - if payload == nil { - t.Fatalf("expected roles key, got %v", payload) - } - repo.AssertExpectations(t) - }) + grantedByUserID := "user-1" + actorUserID := &grantedByUserID + + tests := []struct { + name string + roleID string + body []byte + userID *string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + roleID: "role-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "use case error", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-1"}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("AddRolePermission", mock.Anything, "role-1", "perm-1", (*string)(nil)).Return(constants.ErrUnauthorized).Once() + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: map[string]string{"message": "unauthorized"}, + }, + { + name: "success", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-1"}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("AddRolePermission", mock.Anything, "role-1", "perm-1", (*string)(nil)).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.AddRolePermissionResponse{Message: "permission assigned to role"}, + }, + { + name: "success with actor user id", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-2"}), + userID: actorUserID, + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-2").Return(&types.Permission{ID: "perm-2", Key: "users.write"}, nil).Once() + rolePermissionsRepo.On("AddRolePermission", mock.Anything, "role-1", "perm-2", actorUserID).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.AddRolePermissionResponse{Message: "permission assigned to role"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, permissionsRepo, rolePermissionsRepo) + } + + useCase := newRolePermissionsUseCase(rolesRepo, permissionsRepo, rolePermissionsRepo) + handler := NewAddRolePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/roles/"+tc.roleID+"/permissions", tc.body, tc.userID) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.AddRolePermissionResponse](t, reqCtx) + assertAddRolePermissionResponseEqual(t, payload, tc.expectedBody.(types.AddRolePermissionResponse)) + + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } -func TestGetRoleByIDHandler(t *testing.T) { +func TestReplaceRolePermissionsHandler(t *testing.T) { t.Parallel() - t.Run("not found", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return((*types.Role)(nil), nil).Once() - handler := NewGetRoleByIDHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/role-1", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() - handler := NewGetRoleByIDHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/role-1", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[*types.RoleDetails](t, reqCtx) - if payload == nil { - t.Fatalf("expected role details, got %v", payload) - } - repo.AssertExpectations(t) - }) + actorUserID := "user-1" + actorUserIDPtr := &actorUserID + + tests := []struct { + name string + roleID string + body []byte + userID *string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + roleID: "role-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "use case error", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-1", "perm-2"}}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-2").Return(&types.Permission{ID: "perm-2", Key: "users.write"}, nil).Once() + rolePermissionsRepo.On("ReplaceRolePermissions", mock.Anything, "role-1", []string{"perm-1", "perm-2"}, (*string)(nil)).Return(constants.ErrConflict).Once() + }, + expectedStatus: http.StatusConflict, + expectedBody: map[string]string{"message": "conflict"}, + }, + { + name: "success", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-1", "perm-2"}}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-2").Return(&types.Permission{ID: "perm-2", Key: "users.write"}, nil).Once() + rolePermissionsRepo.On("ReplaceRolePermissions", mock.Anything, "role-1", []string{"perm-1", "perm-2"}, (*string)(nil)).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.ReplaceRolePermissionResponse{Message: "role permissions replaced"}, + }, + { + name: "success with actor user id", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-3", "perm-4"}}), + userID: actorUserIDPtr, + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-3").Return(&types.Permission{ID: "perm-3", Key: "users.create"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-4").Return(&types.Permission{ID: "perm-4", Key: "users.update"}, nil).Once() + rolePermissionsRepo.On("ReplaceRolePermissions", mock.Anything, "role-1", []string{"perm-3", "perm-4"}, actorUserIDPtr).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.ReplaceRolePermissionResponse{Message: "role permissions replaced"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, permissionsRepo, rolePermissionsRepo) + } + + useCase := newRolePermissionsUseCase(rolesRepo, permissionsRepo, rolePermissionsRepo) + handler := NewReplaceRolePermissionsHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/roles/"+tc.roleID+"/permissions", tc.body, tc.userID) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.ReplaceRolePermissionResponse](t, reqCtx) + assertReplaceRolePermissionResponseEqual(t, payload, tc.expectedBody.(types.ReplaceRolePermissionResponse)) + + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } -func TestUpdateRoleHandler(t *testing.T) { +func TestRemoveRolePermissionHandler(t *testing.T) { t.Parallel() - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewUpdateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPatch, "/access-control/roles/role-1", []byte("{invalid"), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - name := "new-admin" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return((*types.Role)(nil), nil).Once() - handler := NewUpdateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPatch, "/access-control/roles/role-1", internaltests.MarshalToJSON(t, types.UpdateRoleRequest{Name: &name}), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("cannot update system role", func(t *testing.T) { - t.Parallel() - - name := "new-admin" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin", IsSystem: true}, nil).Once() - handler := NewUpdateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPatch, "/access-control/roles/role-1", internaltests.MarshalToJSON(t, types.UpdateRoleRequest{Name: &name}), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusForbidden, "cannot update system role") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - name := "new-admin" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("UpdateRole", mock.Anything, "role-1", &name, (*string)(nil)).Return(true, nil).Once() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: name}, nil).Once() - handler := NewUpdateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPatch, "/access-control/roles/role-1", internaltests.MarshalToJSON(t, types.UpdateRoleRequest{Name: &name}), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.UpdateRoleResponse](t, reqCtx) - if payload.Role == nil { - t.Fatalf("expected role key, got %v", payload) - } - repo.AssertExpectations(t) - }) + tests := []struct { + name string + roleID string + permissionID string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "use case error", + roleID: "role-1", + permissionID: "perm-1", + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("RemoveRolePermission", mock.Anything, "role-1", "perm-1").Return(constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + roleID: "role-1", + permissionID: "perm-1", + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("RemoveRolePermission", mock.Anything, "role-1", "perm-1").Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.RemoveRolePermissionResponse{Message: "permission removed from role"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, permissionsRepo, rolePermissionsRepo) + } + + useCase := newRolePermissionsUseCase(rolesRepo, permissionsRepo, rolePermissionsRepo) + handler := NewRemoveRolePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/roles/"+tc.roleID+"/permissions/"+tc.permissionID, nil, nil) + req.SetPathValue("role_id", tc.roleID) + req.SetPathValue("permission_id", tc.permissionID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.RemoveRolePermissionResponse](t, reqCtx) + assertRemoveRolePermissionResponseEqual(t, payload, tc.expectedBody.(types.RemoveRolePermissionResponse)) + + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } -func TestDeleteRoleHandler(t *testing.T) { - t.Parallel() - - t.Run("conflict when assigned", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(1, nil).Once() - handler := NewDeleteRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/roles/role-1", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusConflict, "conflict") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(0, nil).Once() - repo.On("DeleteRole", mock.Anything, "role-1").Return(true, nil).Once() - handler := NewDeleteRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/roles/role-1", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.DeleteRoleResponse](t, reqCtx) - if payload.Message != "deleted role" { - t.Fatalf("expected deleted role message, got %v", payload.Message) - } - repo.AssertExpectations(t) - }) +func newRolePermissionsUseCase(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) *usecases.RolePermissionsUseCase { + return usecases.NewRolePermissionsUseCase(services.NewRolePermissionsService(rolesRepo, permissionsRepo, rolePermissionsRepo)) } -func TestGetAllPermissionsHandler(t *testing.T) { - t.Parallel() - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetAllPermissions", mock.Anything).Return(([]types.Permission)(nil), accesscontrolconstants.ErrForbidden).Once() - handler := NewGetAllPermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/permissions", nil, nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusForbidden, "forbidden") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() +func assertUserPermissionInfosEqual(t *testing.T, got []types.UserPermissionInfo, want []types.UserPermissionInfo) { + t.Helper() - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetAllPermissions", mock.Anything).Return([]types.Permission{{ID: "perm-1", Key: "users.read"}}, nil).Once() - handler := NewGetAllPermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/permissions", nil, nil) + if len(got) != len(want) { + t.Fatalf("expected %d permissions, got %d", len(want), len(got)) + } - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[[]types.Permission](t, reqCtx) - if payload == nil { - t.Fatalf("expected permissions key, got %v", payload) - } - repo.AssertExpectations(t) - }) + for i := range want { + assertUserPermissionInfoEqual(t, got[i], want[i]) + } } -func TestGetRolePermissionsHandler(t *testing.T) { - t.Parallel() - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return((*types.Role)(nil), nil).Once() - handler := NewGetRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/role-1/permissions", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("invalid role id", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewGetRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/%20%20%20/permissions", nil, nil) - req.SetPathValue("role_id", " ") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "unprocessable entity") - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() - handler := NewGetRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/role-1/permissions", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[[]types.UserPermissionInfo](t, reqCtx) - if len(payload) != 1 || payload[0].PermissionID != "perm-1" { - t.Fatalf("expected role permissions payload, got %v", payload) - } - repo.AssertExpectations(t) - }) -} - -func TestCreatePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewCreatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/permissions", []byte("{invalid"), nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("CreatePermission", mock.Anything, mock.AnythingOfType("*types.Permission")).Return(accesscontrolconstants.ErrBadRequest).Once() - handler := NewCreatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/permissions", internaltests.MarshalToJSON(t, types.CreatePermissionRequest{Key: "users.read"}), nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusBadRequest, "bad request") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("CreatePermission", mock.Anything, mock.AnythingOfType("*types.Permission")).Return(nil).Once() - handler := NewCreatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/permissions", internaltests.MarshalToJSON(t, types.CreatePermissionRequest{Key: "users.read"}), nil) - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusCreated { - t.Fatalf("expected status %d, got %d", http.StatusCreated, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.CreatePermissionResponse](t, reqCtx) - if payload.Permission == nil { - t.Fatalf("expected permission key, got %v", payload) - } - repo.AssertExpectations(t) - }) +func assertUserPermissionInfoEqual(t *testing.T, got types.UserPermissionInfo, want types.UserPermissionInfo) { + t.Helper() + + if got.PermissionID != want.PermissionID { + t.Fatalf("expected permission id %q, got %q", want.PermissionID, got.PermissionID) + } + if got.PermissionKey != want.PermissionKey { + t.Fatalf("expected permission key %q, got %q", want.PermissionKey, got.PermissionKey) + } + if !stringsEqualPtr(got.PermissionDescription, want.PermissionDescription) { + t.Fatalf("expected permission description %#v, got %#v", want.PermissionDescription, got.PermissionDescription) + } + if !stringsEqualPtr(got.GrantedByUserID, want.GrantedByUserID) { + t.Fatalf("expected granted_by_user_id %#v, got %#v", want.GrantedByUserID, got.GrantedByUserID) + } + if !timesEqualPtr(got.GrantedAt, want.GrantedAt) { + t.Fatalf("expected granted_at %#v, got %#v", want.GrantedAt, got.GrantedAt) + } } -func TestUpdatePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewUpdatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/permissions/perm-1", []byte("{invalid"), nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - desc := "updated" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() - handler := NewUpdatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/permissions/perm-1", internaltests.MarshalToJSON(t, types.UpdatePermissionRequest{Description: &desc}), nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - desc := "updated" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - repo.On("UpdatePermission", mock.Anything, "perm-1", &desc).Return(true, nil).Once() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", Description: &desc}, nil).Once() - handler := NewUpdatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/permissions/perm-1", internaltests.MarshalToJSON(t, types.UpdatePermissionRequest{Description: &desc}), nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.UpdatePermissionResponse](t, reqCtx) - if payload.Permission == nil { - t.Fatalf("expected permission key, got %v", payload) - } - repo.AssertExpectations(t) - }) +func assertAddRolePermissionResponseEqual(t *testing.T, got types.AddRolePermissionResponse, want types.AddRolePermissionResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } } -func TestDeletePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("in use conflict", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - repo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(1, nil).Once() - handler := NewDeletePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/permissions/perm-1", nil, nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusConflict, "conflict") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "admin.read"}, nil).Once() - repo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() - repo.On("DeletePermission", mock.Anything, "perm-1").Return(true, nil).Once() - handler := NewDeletePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/permissions/perm-1", nil, nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[map[string]any](t, reqCtx) - if payload["message"] != "permission deleted" { - t.Fatalf("expected permission deleted message, got %v", payload["message"]) - } - repo.AssertExpectations(t) - }) +func assertReplaceRolePermissionResponseEqual(t *testing.T, got types.ReplaceRolePermissionResponse, want types.ReplaceRolePermissionResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } } -func TestAddRolePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("invalid request body", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewAddRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles/role-1/permissions", []byte("{invalid"), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return((*types.Role)(nil), nil).Once() - handler := NewAddRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles/role-1/permissions", internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-1"}), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - actorID := "actor-1" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - repo.On("AddRolePermission", mock.Anything, "role-1", "perm-1", &actorID).Return(nil).Once() - handler := NewAddRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles/role-1/permissions", internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-1"}), nil) - req.SetPathValue("role_id", "role-1") - reqCtx.UserID = &actorID - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[map[string]any](t, reqCtx) - if payload["message"] != "permission assigned to role" { - t.Fatalf("expected permission assigned to role message, got %v", payload["message"]) - } - repo.AssertExpectations(t) - }) -} - -func TestReplaceRolePermissionsHandler(t *testing.T) { - t.Parallel() - - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewReplaceRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/roles/role-1/permissions", []byte("{invalid"), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - actorID := "actor-1" - request := types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-1"}} - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("ReplaceRolePermissions", mock.Anything, "role-1", request.PermissionIDs, &actorID).Return(accesscontrolconstants.ErrForbidden).Once() - handler := NewReplaceRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/roles/role-1/permissions", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("role_id", "role-1") - reqCtx.UserID = &actorID - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusForbidden, "forbidden") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - request := types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-1", "perm-1", "perm-2"}} - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("ReplaceRolePermissions", mock.Anything, "role-1", []string{"perm-1", "perm-2"}, (*string)(nil)).Return(nil).Once() - handler := NewReplaceRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/roles/role-1/permissions", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[map[string]any](t, reqCtx) - if payload["message"] != "role permissions replaced" { - t.Fatalf("expected role permissions replaced message, got %v", payload["message"]) - } - repo.AssertExpectations(t) - }) -} - -func TestRemoveRolePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() - handler := NewRemoveRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/roles/role-1/permissions/perm-1", nil, nil) - req.SetPathValue("role_id", "role-1") - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "admin.read"}, nil).Once() - repo.On("RemoveRolePermission", mock.Anything, "role-1", "perm-1").Return(nil).Once() - handler := NewRemoveRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/roles/role-1/permissions/perm-1", nil, nil) - req.SetPathValue("role_id", "role-1") - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[map[string]any](t, reqCtx) - if payload["message"] != "permission removed from role" { - t.Fatalf("expected permission removed from role message, got %v", payload["message"]) - } - repo.AssertExpectations(t) - }) +func assertRemoveRolePermissionResponseEqual(t *testing.T, got types.RemoveRolePermissionResponse, want types.RemoveRolePermissionResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } } diff --git a/plugins/access-control/handlers/shared_helpers.go b/plugins/access-control/handlers/shared_helpers.go new file mode 100644 index 00000000..d4f9f79a --- /dev/null +++ b/plugins/access-control/handlers/shared_helpers.go @@ -0,0 +1,35 @@ +package handlers + +import "github.com/Authula/authula/models" + +func rolePermissionActorUserID(reqCtx *models.RequestContext) *string { + if reqCtx.UserID == nil || *reqCtx.UserID == "" { + return nil + } + return reqCtx.UserID +} + +func respondRolePermissionError(reqCtx *models.RequestContext, err error) { + reqCtx.SetJSONResponse(mapRolePermissionErrorStatus(err), map[string]any{"message": mapHttpErrorMessage(err)}) + reqCtx.Handled = true +} + +func mapRolePermissionErrorStatus(err error) int { + return mapHttpErrorStatus(err) +} + +func userActorUserID(reqCtx *models.RequestContext) *string { + if reqCtx.UserID == nil || *reqCtx.UserID == "" { + return nil + } + return reqCtx.UserID +} + +func respondUserHandlerError(reqCtx *models.RequestContext, err error) { + reqCtx.SetJSONResponse(mapUserHandlerErrorStatus(err), map[string]any{"message": mapHttpErrorMessage(err)}) + reqCtx.Handled = true +} + +func mapUserHandlerErrorStatus(err error) int { + return mapHttpErrorStatus(err) +} diff --git a/plugins/access-control/handlers/user_access_handlers.go b/plugins/access-control/handlers/user_access_handlers.go new file mode 100644 index 00000000..925694a0 --- /dev/null +++ b/plugins/access-control/handlers/user_access_handlers.go @@ -0,0 +1,59 @@ +package handlers + +import ( + "net/http" + + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +type GetUserEffectivePermissionsHandler struct { + useCase *usecases.UserAccessUseCase +} + +func NewGetUserEffectivePermissionsHandler(useCase *usecases.UserAccessUseCase) *GetUserEffectivePermissionsHandler { + return &GetUserEffectivePermissionsHandler{ + useCase: useCase, + } +} + +func (h *GetUserEffectivePermissionsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + userID := r.PathValue("user_id") + + permissions, err := h.useCase.GetUserEffectivePermissions(r.Context(), userID) + if err != nil { + respondUserHandlerError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.GetUserEffectivePermissionsResponse{Permissions: permissions}) + } +} + +type GetUserAuthorizationProfileHandler struct { + useCase *usecases.UserAccessUseCase +} + +func NewGetUserAuthorizationProfileHandler(useCase *usecases.UserAccessUseCase) *GetUserAuthorizationProfileHandler { + return &GetUserAuthorizationProfileHandler{useCase: useCase} +} + +func (h *GetUserAuthorizationProfileHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + userID := r.PathValue("user_id") + + profile, err := h.useCase.GetUserAuthorizationProfile(r.Context(), userID) + if err != nil { + respondUserHandlerError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, profile) + } +} diff --git a/plugins/access-control/handlers/user_access_handlers_test.go b/plugins/access-control/handlers/user_access_handlers_test.go new file mode 100644 index 00000000..4ba113ef --- /dev/null +++ b/plugins/access-control/handlers/user_access_handlers_test.go @@ -0,0 +1,241 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" + authmodels "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestGetUserEffectivePermissionsHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "read access" + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserAccessRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank user id", + userID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "use case error", + userID: "user-1", + setupMock: func(m *accesscontroltests.MockUserAccessRepository) { + m.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return(([]types.UserPermissionInfo)(nil), constants.ErrUnauthorized).Once() + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: map[string]string{"message": "unauthorized"}, + }, + { + name: "success", + userID: "user-1", + setupMock: func(m *accesscontroltests.MockUserAccessRepository) { + m.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return([]types.UserPermissionInfo{{ + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedAt: &fixedTime, + }}, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: &types.GetUserEffectivePermissionsResponse{Permissions: []types.UserPermissionInfo{{ + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedAt: &fixedTime, + }}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(userAccessRepo) + } + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + useCase := newUserAccessUseCase(userRolesRepo, userAccessRepo) + handler := NewGetUserEffectivePermissionsHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/permissions", nil, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.GetUserEffectivePermissionsResponse](t, reqCtx) + assertGetUserEffectivePermissionsResponseEqual(t, payload, *tc.expectedBody.(*types.GetUserEffectivePermissionsResponse)) + + userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func TestGetUserAuthorizationProfileHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + roleDescription := new(string) + *roleDescription = "profile role" + permissionDescription := new(string) + *permissionDescription = "profile access" + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserRolesRepository, *accesscontroltests.MockUserAccessRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank user id", + userID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "use case error", + userID: "user-404", + setupMock: func(userRolesRepo *accesscontroltests.MockUserRolesRepository, _ *accesscontroltests.MockUserAccessRepository) { + userRolesRepo.On("GetUserWithRolesByID", mock.Anything, "user-404").Return((*types.UserWithRoles)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + userID: "user-1", + setupMock: func(userRolesRepo *accesscontroltests.MockUserRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + userRolesRepo.On("GetUserWithRolesByID", mock.Anything, "user-1").Return(&types.UserWithRoles{ + User: authmodels.User{ + ID: "user-1", + Name: "Pat", + Email: "pat@example.com", + EmailVerified: true, + Metadata: json.RawMessage("null"), + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + Roles: []types.UserRoleInfo{{ + RoleID: "role-1", + RoleName: "Editor", + RoleDescription: roleDescription, + AssignedByUserID: nil, + AssignedAt: &fixedTime, + ExpiresAt: nil, + }}, + }, nil).Once() + userAccessRepo.On("GetUserWithPermissionsByID", mock.Anything, "user-1").Return(&types.UserWithPermissions{ + User: authmodels.User{ + ID: "user-1", + Name: "Pat", + Email: "pat@example.com", + EmailVerified: true, + Metadata: json.RawMessage("null"), + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + Permissions: []types.UserPermissionInfo{{ + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: permissionDescription, + GrantedAt: &fixedTime, + }}, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: &types.UserAuthorizationProfile{ + User: authmodels.User{ + ID: "user-1", + Name: "Pat", + Email: "pat@example.com", + EmailVerified: true, + Metadata: json.RawMessage("null"), + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + Roles: []types.UserRoleInfo{{ + RoleID: "role-1", + RoleName: "Editor", + RoleDescription: roleDescription, + AssignedByUserID: nil, + AssignedAt: &fixedTime, + ExpiresAt: nil, + }}, + Permissions: []types.UserPermissionInfo{{ + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: permissionDescription, + GrantedAt: &fixedTime, + }}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setupMock != nil { + tc.setupMock(userRolesRepo, userAccessRepo) + } + + useCase := newUserAccessUseCase(userRolesRepo, userAccessRepo) + handler := NewGetUserAuthorizationProfileHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/authorization-profile", nil, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + userRolesRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.UserAuthorizationProfile](t, reqCtx) + assertUserAuthorizationProfileEqual(t, payload, *tc.expectedBody.(*types.UserAuthorizationProfile)) + + userRolesRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} diff --git a/plugins/access-control/handlers/user_roles_handlers.go b/plugins/access-control/handlers/user_roles_handlers.go index 5b2ea59f..93514f35 100644 --- a/plugins/access-control/handlers/user_roles_handlers.go +++ b/plugins/access-control/handlers/user_roles_handlers.go @@ -10,10 +10,10 @@ import ( ) type GetUserRolesHandler struct { - useCase usecases.UserRolesUseCase + useCase *usecases.UserRolesUseCase } -func NewGetUserRolesHandler(useCase usecases.UserRolesUseCase) *GetUserRolesHandler { +func NewGetUserRolesHandler(useCase *usecases.UserRolesUseCase) *GetUserRolesHandler { return &GetUserRolesHandler{ useCase: useCase, } @@ -35,11 +35,35 @@ func (h *GetUserRolesHandler) Handler() http.HandlerFunc { } } +type GetUserWithRolesHandler struct { + useCase *usecases.UserRolesUseCase +} + +func NewGetUserWithRolesHandler(useCase *usecases.UserRolesUseCase) *GetUserWithRolesHandler { + return &GetUserWithRolesHandler{useCase: useCase} +} + +func (h *GetUserWithRolesHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + userID := r.PathValue("user_id") + + userWithRoles, err := h.useCase.GetUserWithRolesByID(r.Context(), userID) + if err != nil { + respondUserHandlerError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, userWithRoles) + } +} + type ReplaceUserRolesHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.UserRolesUseCase } -func NewReplaceUserRolesHandler(useCase usecases.RolePermissionUseCase) *ReplaceUserRolesHandler { +func NewReplaceUserRolesHandler(useCase *usecases.UserRolesUseCase) *ReplaceUserRolesHandler { return &ReplaceUserRolesHandler{ useCase: useCase, } @@ -68,10 +92,10 @@ func (h *ReplaceUserRolesHandler) Handler() http.HandlerFunc { } type AssignUserRoleHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.UserRolesUseCase } -func NewAssignUserRoleHandler(useCase usecases.RolePermissionUseCase) *AssignUserRoleHandler { +func NewAssignUserRoleHandler(useCase *usecases.UserRolesUseCase) *AssignUserRoleHandler { return &AssignUserRoleHandler{useCase: useCase} } @@ -98,10 +122,10 @@ func (h *AssignUserRoleHandler) Handler() http.HandlerFunc { } type RemoveUserRoleHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.UserRolesUseCase } -func NewRemoveUserRoleHandler(useCase usecases.RolePermissionUseCase) *RemoveUserRoleHandler { +func NewRemoveUserRoleHandler(useCase *usecases.UserRolesUseCase) *RemoveUserRoleHandler { return &RemoveUserRoleHandler{useCase: useCase} } @@ -120,49 +144,3 @@ func (h *RemoveUserRoleHandler) Handler() http.HandlerFunc { reqCtx.SetJSONResponse(http.StatusOK, &types.RemoveUserRoleResponse{Message: "role removed"}) } } - -type GetUserEffectivePermissionsHandler struct { - useCase usecases.UserRolesUseCase -} - -func NewGetUserEffectivePermissionsHandler(useCase usecases.UserRolesUseCase) *GetUserEffectivePermissionsHandler { - return &GetUserEffectivePermissionsHandler{ - useCase: useCase, - } -} - -func (h *GetUserEffectivePermissionsHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - userID := r.PathValue("user_id") - - permissions, err := h.useCase.GetUserEffectivePermissions(r.Context(), userID) - if err != nil { - respondUserHandlerError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.GetUserEffectivePermissionsResponse{Permissions: permissions}) - } -} - -func userActorUserID(reqCtx *models.RequestContext) *string { - if reqCtx == nil || reqCtx.UserID == nil || *reqCtx.UserID == "" { - return nil - } - return reqCtx.UserID -} - -func respondUserHandlerError(reqCtx *models.RequestContext, err error) { - if reqCtx == nil { - return - } - - reqCtx.SetJSONResponse(mapUserHandlerErrorStatus(err), map[string]any{"message": mapHttpErrorMessage(err)}) - reqCtx.Handled = true -} - -func mapUserHandlerErrorStatus(err error) int { - return mapHttpErrorStatus(err) -} diff --git a/plugins/access-control/handlers/user_roles_handlers_test.go b/plugins/access-control/handlers/user_roles_handlers_test.go index 20d12a5e..47fb7909 100644 --- a/plugins/access-control/handlers/user_roles_handlers_test.go +++ b/plugins/access-control/handlers/user_roles_handlers_test.go @@ -1,6 +1,7 @@ package handlers import ( + "encoding/json" "net/http" "testing" "time" @@ -8,279 +9,598 @@ import ( "github.com/stretchr/testify/mock" internaltests "github.com/Authula/authula/internal/tests" - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" - "github.com/Authula/authula/plugins/access-control/tests" + authmodels "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" ) func TestGetUserRolesHandler(t *testing.T) { t.Parallel() - t.Run("missing user id", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewUserRolesUseCaseFixture() - handler := NewGetUserRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users//roles", nil, nil) - req.SetPathValue("user_id", " ") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "unprocessable entity") - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, accessRepo := tests.NewUserRolesUseCaseFixture() - accessRepo.On("GetUserRoles", mock.Anything, "user-1").Return(([]types.UserRoleInfo)(nil), accesscontrolconstants.ErrNotFound).Once() - handler := NewGetUserRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users/user-1/roles", nil, nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - accessRepo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, accessRepo := tests.NewUserRolesUseCaseFixture() - expiresAt := time.Now().UTC().Add(time.Hour) - accessRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{{RoleID: "role-1", RoleName: "admin", ExpiresAt: &expiresAt}}, nil).Once() - handler := NewGetUserRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users/user-1/roles", nil, nil) - req.SetPathValue("user_id", "user-1") + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "editor role" + assignedByUserID := new(string) + *assignedByUserID = "user-2" + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank user id", + userID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "use case error", + userID: "user-404", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("GetUserRoles", mock.Anything, "user-404").Return(([]types.UserRoleInfo)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + userID: "user-1", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{{ + RoleID: "role-1", + RoleName: "Editor", + RoleDescription: description, + AssignedByUserID: assignedByUserID, + AssignedAt: &fixedTime, + ExpiresAt: &fixedTime, + }}, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: []types.UserRoleInfo{{ + RoleID: "role-1", + RoleName: "Editor", + RoleDescription: description, + AssignedByUserID: assignedByUserID, + AssignedAt: &fixedTime, + ExpiresAt: &fixedTime, + }}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + rolesRepo := &accesscontroltests.MockRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(userRolesRepo) + } + + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewGetUserRolesHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/roles", nil, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[[]types.UserRoleInfo](t, reqCtx) + assertUserRoleInfosEqual(t, payload, tc.expectedBody.([]types.UserRoleInfo)) + + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + }) + } +} - handler.Handler()(w, req) +func TestGetUserWithRolesHandler(t *testing.T) { + t.Parallel() - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[[]types.UserRoleInfo](t, reqCtx) - if len(payload) != 1 { - t.Fatalf("expected 1 role, got %d", len(payload)) - } - accessRepo.AssertExpectations(t) - }) + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform user" + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank user id", + userID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "use case error", + userID: "user-404", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("GetUserWithRolesByID", mock.Anything, "user-404").Return((*types.UserWithRoles)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + userID: "user-1", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("GetUserWithRolesByID", mock.Anything, "user-1").Return(&types.UserWithRoles{ + User: authmodels.User{ + ID: "user-1", + Name: "Pat", + Email: "pat@example.com", + EmailVerified: true, + Metadata: json.RawMessage("null"), + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + Roles: []types.UserRoleInfo{{ + RoleID: "role-1", + RoleName: "Editor", + RoleDescription: description, + AssignedByUserID: nil, + AssignedAt: &fixedTime, + ExpiresAt: nil, + }}, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: &types.UserWithRoles{ + User: authmodels.User{ + ID: "user-1", + Name: "Pat", + Email: "pat@example.com", + EmailVerified: true, + Metadata: json.RawMessage("null"), + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + Roles: []types.UserRoleInfo{{ + RoleID: "role-1", + RoleName: "Editor", + RoleDescription: description, + AssignedByUserID: nil, + AssignedAt: &fixedTime, + ExpiresAt: nil, + }}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + rolesRepo := &accesscontroltests.MockRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(userRolesRepo) + } + + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewGetUserWithRolesHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/roles/details", nil, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.UserWithRoles](t, reqCtx) + assertUserWithRolesEqual(t, payload, *tc.expectedBody.(*types.UserWithRoles)) + + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + }) + } } func TestReplaceUserRolesHandler(t *testing.T) { t.Parallel() - t.Run("invalid json", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewReplaceUserRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/users/user-1/roles", []byte("{invalid"), nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - request := types.ReplaceUserRolesRequest{RoleIDs: []string{"role-1"}} - actorID := "actor-1" - roleRepo.On("ReplaceUserRoles", mock.Anything, "user-1", request.RoleIDs, &actorID).Return(accesscontrolconstants.ErrForbidden).Once() - handler := NewReplaceUserRolesHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/users/user-1/roles", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("user_id", "user-1") - reqCtx.UserID = &actorID - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusForbidden, "forbidden") - roleRepo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - request := types.ReplaceUserRolesRequest{RoleIDs: []string{"role-1", "role-2"}} - roleRepo.On("ReplaceUserRoles", mock.Anything, "user-1", request.RoleIDs, (*string)(nil)).Return(nil).Once() - handler := NewReplaceUserRolesHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/users/user-1/roles", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.ReplaceUserRolesResponse](t, reqCtx) - if payload.Message != "user roles replaced" { - t.Fatalf("expected message user roles replaced, got %v", payload.Message) - } - roleRepo.AssertExpectations(t) - }) + actorUserID := "user-1" + actorUserIDPtr := &actorUserID + + tests := []struct { + name string + userID string + body []byte + userIDPtr *string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + userID: "user-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "blank user id", + userID: "", + body: internaltests.MarshalToJSON(t, types.ReplaceUserRolesRequest{RoleIDs: []string{"role-1"}}), + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "success", + userID: "user-1", + body: internaltests.MarshalToJSON(t, types.ReplaceUserRolesRequest{RoleIDs: []string{"role-1", "role-2"}}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-2").Return(&types.Role{ID: "role-2", Name: "Viewer"}, nil).Once() + userRolesRepo.On("ReplaceUserRoles", mock.Anything, "user-1", []string{"role-1", "role-2"}, (*string)(nil)).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.ReplaceUserRolesResponse{Message: "user roles replaced"}, + }, + { + name: "success with actor user id", + userID: "user-1", + userIDPtr: actorUserIDPtr, + body: internaltests.MarshalToJSON(t, types.ReplaceUserRolesRequest{RoleIDs: []string{"role-3", "role-4"}}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-3").Return(&types.Role{ID: "role-3", Name: "Reviewer"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-4").Return(&types.Role{ID: "role-4", Name: "Commenter"}, nil).Once() + userRolesRepo.On("ReplaceUserRoles", mock.Anything, "user-1", []string{"role-3", "role-4"}, actorUserIDPtr).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.ReplaceUserRolesResponse{Message: "user roles replaced"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, userRolesRepo) + } + + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewReplaceUserRolesHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/users/"+tc.userID+"/roles", tc.body, tc.userIDPtr) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.ReplaceUserRolesResponse](t, reqCtx) + assertReplaceUserRolesResponseEqual(t, payload, tc.expectedBody.(types.ReplaceUserRolesResponse)) + + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } } func TestAssignUserRoleHandler(t *testing.T) { t.Parallel() - t.Run("invalid json", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewAssignUserRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/users/user-1/roles", []byte("{invalid"), nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - request := types.AssignUserRoleRequest{RoleID: "role-1"} - actorID := "actor-1" - roleRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", &actorID, (*time.Time)(nil)).Return(accesscontrolconstants.ErrBadRequest).Once() - handler := NewAssignUserRoleHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/users/user-1/roles", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("user_id", "user-1") - reqCtx.UserID = &actorID - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusBadRequest, "bad request") - roleRepo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - request := types.AssignUserRoleRequest{RoleID: "role-1"} - roleRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(nil).Once() - handler := NewAssignUserRoleHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/users/user-1/roles", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.AssignUserRoleResponse](t, reqCtx) - if payload.Message != "role assigned" { - t.Fatalf("expected role assigned message, got %v", payload.Message) - } - roleRepo.AssertExpectations(t) - }) + futureTime := time.Date(2030, 3, 29, 13, 0, 0, 0, time.UTC) + actorUserID := "user-1" + actorUserIDPtr := &actorUserID + + tests := []struct { + name string + userID string + body []byte + userIDPtr *string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + userID: "user-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "blank user id", + userID: "", + body: internaltests.MarshalToJSON(t, types.AssignUserRoleRequest{RoleID: "role-1"}), + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "use case error", + userID: "user-1", + body: internaltests.MarshalToJSON(t, types.AssignUserRoleRequest{RoleID: "role-1"}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(constants.ErrUnauthorized).Once() + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: map[string]string{"message": "unauthorized"}, + }, + { + name: "success", + userID: "user-1", + body: internaltests.MarshalToJSON(t, types.AssignUserRoleRequest{RoleID: "role-1", ExpiresAt: &futureTime}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), &futureTime).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.AssignUserRoleResponse{Message: "role assigned"}, + }, + { + name: "success with actor user id", + userID: "user-1", + userIDPtr: actorUserIDPtr, + body: internaltests.MarshalToJSON(t, types.AssignUserRoleRequest{RoleID: "role-2"}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-2").Return(&types.Role{ID: "role-2", Name: "Reviewer"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-2", actorUserIDPtr, (*time.Time)(nil)).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.AssignUserRoleResponse{Message: "role assigned"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, userRolesRepo) + } + + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewAssignUserRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/users/"+tc.userID+"/roles", tc.body, tc.userIDPtr) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.AssignUserRoleResponse](t, reqCtx) + assertAssignUserRoleResponseEqual(t, payload, tc.expectedBody.(types.AssignUserRoleResponse)) + + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } } func TestRemoveUserRoleHandler(t *testing.T) { t.Parallel() - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - roleRepo.On("RemoveUserRole", mock.Anything, "user-1", "role-1").Return(accesscontrolconstants.ErrNotFound).Once() - handler := NewRemoveUserRoleHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/users/user-1/roles/role-1", nil, nil) - req.SetPathValue("user_id", "user-1") - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - roleRepo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - roleRepo.On("RemoveUserRole", mock.Anything, "user-1", "role-1").Return(nil).Once() - handler := NewRemoveUserRoleHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/users/user-1/roles/role-1", nil, nil) - req.SetPathValue("user_id", "user-1") - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.RemoveUserRoleResponse](t, reqCtx) - if payload.Message != "role removed" { - t.Fatalf("expected role removed message, got %v", payload.Message) - } - roleRepo.AssertExpectations(t) - }) + tests := []struct { + name string + userID string + roleID string + setupMock func(*accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank role id", + userID: "user-1", + roleID: "", + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "use case error", + userID: "user-1", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("RemoveUserRole", mock.Anything, "user-1", "role-1").Return(constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + userID: "user-1", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("RemoveUserRole", mock.Anything, "user-1", "role-1").Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.RemoveUserRoleResponse{Message: "role removed"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(userRolesRepo) + } + + rolesRepo := &accesscontroltests.MockRolesRepository{} + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewRemoveUserRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/users/"+tc.userID+"/roles/"+tc.roleID, nil, nil) + req.SetPathValue("user_id", tc.userID) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.RemoveUserRoleResponse](t, reqCtx) + assertRemoveUserRoleResponseEqual(t, payload, tc.expectedBody.(types.RemoveUserRoleResponse)) + + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + }) + } } -func TestGetUserEffectivePermissionsHandler(t *testing.T) { - t.Parallel() +func newUserRolesUseCase(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) *usecases.UserRolesUseCase { + return usecases.NewUserRolesUseCase(services.NewUserRolesService(userRolesRepo, rolesRepo)) +} - t.Run("missing user id", func(t *testing.T) { - t.Parallel() +func newUserAccessUseCase(userRolesRepo *accesscontroltests.MockUserRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) *usecases.UserAccessUseCase { + return usecases.NewUserAccessUseCase(services.NewUserAccessService(userRolesRepo, userAccessRepo)) +} - useCase, _ := tests.NewUserRolesUseCaseFixture() - handler := NewGetUserEffectivePermissionsHandler(useCase) +func assertUserRoleInfosEqual(t *testing.T, got []types.UserRoleInfo, want []types.UserRoleInfo) { + t.Helper() - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users//permissions", nil, nil) - req.SetPathValue("user_id", "") + if len(got) != len(want) { + t.Fatalf("expected %d roles, got %d", len(want), len(got)) + } - handler.Handler()(w, req) + for i := range want { + assertUserRoleInfoEqual(t, got[i], want[i]) + } +} - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "unprocessable entity") - }) +func assertUserRoleInfoEqual(t *testing.T, got types.UserRoleInfo, want types.UserRoleInfo) { + t.Helper() + + if got.RoleID != want.RoleID || got.RoleName != want.RoleName { + t.Fatalf("unexpected role info: %#v", got) + } + if !stringsEqualPtr(got.RoleDescription, want.RoleDescription) { + t.Fatalf("unexpected role description: %#v", got) + } + if !stringsEqualPtr(got.AssignedByUserID, want.AssignedByUserID) { + t.Fatalf("unexpected assigned_by_user_id: %#v", got) + } + if !timesEqualPtr(got.AssignedAt, want.AssignedAt) { + t.Fatalf("unexpected assigned_at: %#v", got) + } + if !timesEqualPtr(got.ExpiresAt, want.ExpiresAt) { + t.Fatalf("unexpected expires_at: %#v", got) + } +} - t.Run("error", func(t *testing.T) { - t.Parallel() +func assertUserWithRolesEqual(t *testing.T, got types.UserWithRoles, want types.UserWithRoles) { + t.Helper() - useCase, accessRepo := tests.NewUserRolesUseCaseFixture() - accessRepo.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return(([]types.UserPermissionInfo)(nil), accesscontrolconstants.ErrUnauthorized).Once() - handler := NewGetUserEffectivePermissionsHandler(useCase) + assertUserEqual(t, got.User, want.User) + assertUserRoleInfosEqual(t, got.Roles, want.Roles) +} - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users/user-1/permissions", nil, nil) - req.SetPathValue("user_id", "user-1") +func assertGetUserEffectivePermissionsResponseEqual(t *testing.T, got types.GetUserEffectivePermissionsResponse, want types.GetUserEffectivePermissionsResponse) { + t.Helper() - handler.Handler()(w, req) + assertUserPermissionInfosEqual(t, got.Permissions, want.Permissions) +} - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnauthorized, "unauthorized") - accessRepo.AssertExpectations(t) - }) +func assertUserAuthorizationProfileEqual(t *testing.T, got types.UserAuthorizationProfile, want types.UserAuthorizationProfile) { + t.Helper() - t.Run("success", func(t *testing.T) { - t.Parallel() + assertUserEqual(t, got.User, want.User) + assertUserRoleInfosEqual(t, got.Roles, want.Roles) + assertUserPermissionInfosEqual(t, got.Permissions, want.Permissions) +} - useCase, accessRepo := tests.NewUserRolesUseCaseFixture() - accessRepo.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "admin.read"}}, nil).Once() - handler := NewGetUserEffectivePermissionsHandler(useCase) +func assertUserEqual(t *testing.T, got authmodels.User, want authmodels.User) { + t.Helper() + + if got.ID != want.ID { + t.Fatalf("expected id %q, got %q", want.ID, got.ID) + } + if got.Name != want.Name { + t.Fatalf("expected name %q, got %q", want.Name, got.Name) + } + if got.Email != want.Email { + t.Fatalf("expected email %q, got %q", want.Email, got.Email) + } + if got.EmailVerified != want.EmailVerified { + t.Fatalf("expected email_verified %v, got %v", want.EmailVerified, got.EmailVerified) + } + if string(got.Metadata) != string(want.Metadata) { + t.Fatalf("expected metadata %s, got %s", string(want.Metadata), string(got.Metadata)) + } + if !timesEqual(got.CreatedAt, want.CreatedAt) { + t.Fatalf("expected created_at %v, got %v", want.CreatedAt, got.CreatedAt) + } + if !timesEqual(got.UpdatedAt, want.UpdatedAt) { + t.Fatalf("expected updated_at %v, got %v", want.UpdatedAt, got.UpdatedAt) + } +} - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users/user-1/permissions", nil, nil) - req.SetPathValue("user_id", "user-1") +func assertReplaceUserRolesResponseEqual(t *testing.T, got types.ReplaceUserRolesResponse, want types.ReplaceUserRolesResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } +} - handler.Handler()(w, req) +func assertAssignUserRoleResponseEqual(t *testing.T, got types.AssignUserRoleResponse, want types.AssignUserRoleResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } +} - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.GetUserEffectivePermissionsResponse](t, reqCtx) - if len(payload.Permissions) != 1 { - t.Fatalf("expected 1 permission, got %d", len(payload.Permissions)) - } - accessRepo.AssertExpectations(t) - }) +func assertRemoveUserRoleResponseEqual(t *testing.T, got types.RemoveUserRoleResponse, want types.RemoveUserRoleResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } } diff --git a/plugins/access-control/hooks_test.go b/plugins/access-control/hooks_test.go deleted file mode 100644 index e06c944e..00000000 --- a/plugins/access-control/hooks_test.go +++ /dev/null @@ -1,204 +0,0 @@ -package accesscontrol - -import ( - "errors" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - - "github.com/Authula/authula/models" - "github.com/Authula/authula/plugins/access-control/tests" - "github.com/Authula/authula/plugins/access-control/types" - "github.com/Authula/authula/plugins/access-control/usecases" -) - -func newHookTestPluginWithUserAccessRepoMock(t *testing.T) (*AccessControlPlugin, interface { - On(string, ...any) *mock.Call - AssertExpectations(mock.TestingT) bool -}) { - t.Helper() - - userAccessUseCase, userAccessRepo := tests.NewUserRolesUseCaseFixture() - useCases := usecases.NewAccessControlUseCases(usecases.RolePermissionUseCase{}, userAccessUseCase) - plugin := &AccessControlPlugin{Api: &API{useCases: useCases}} - - return plugin, userAccessRepo -} - -func TestRequireAccessControlHook(t *testing.T) { - t.Parallel() - - t.Run("unauthorized without user", func(t *testing.T) { - t.Parallel() - - plugin := &AccessControlPlugin{} - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{Request: req} - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.True(t, reqCtx.Handled, "expected request to be handled") - require.Equal(t, http.StatusUnauthorized, reqCtx.ResponseStatus) - }) - - t.Run("unauthorized with empty user id", func(t *testing.T) { - t.Parallel() - - plugin := &AccessControlPlugin{} - userID := "" - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.True(t, reqCtx.Handled, "expected request to be handled") - require.Equal(t, http.StatusUnauthorized, reqCtx.ResponseStatus) - }) - - t.Run("opt in skips when no permissions metadata", func(t *testing.T) { - t.Parallel() - - plugin := &AccessControlPlugin{} - userID := "user-1" - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{}, - }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.False(t, reqCtx.Handled, "expected request not to be handled when no permissions metadata is set") - }) - - t.Run("internal error when HasPermissions fails", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return(nil, errors.New("database error")).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{"admin"}, - }, - }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.True(t, reqCtx.Handled) - require.Equal(t, http.StatusInternalServerError, reqCtx.ResponseStatus) - userAccessRepo.AssertExpectations(t) - }) - - t.Run("forbidden when user lacks permissions", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "profiles.read"}}, nil).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{"users.read"}, - }, - }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.True(t, reqCtx.Handled, "expected request to be handled") - require.Equal(t, http.StatusForbidden, reqCtx.ResponseStatus) - userAccessRepo.AssertExpectations(t) - }) - - t.Run("allowed when user has permissions", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return([]types.UserPermissionInfo{{PermissionKey: "users.read"}}, nil).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{"users.read"}, - }, - }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.False(t, reqCtx.Handled, "expected request not to be handled when user has permissions") - userAccessRepo.AssertExpectations(t) - }) - - t.Run("multiple permissions allow when user has any one required permission", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return([]types.UserPermissionInfo{{PermissionKey: "users.read"}}, nil).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{"users.read", "users.write"}, - }, - }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.False(t, reqCtx.Handled, "expected request not to be handled when user has at least one required permission") - userAccessRepo.AssertExpectations(t) - }) - - t.Run("permissions metadata with whitespace trimming", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return([]types.UserPermissionInfo{{PermissionKey: "users.read"}}, nil).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{" users.read ", " ", ""}, - }, - }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.False(t, reqCtx.Handled, "expected request not to be handled when user has permissions (after trimming)") - userAccessRepo.AssertExpectations(t) - }) -} diff --git a/plugins/access-control/migrations.go b/plugins/access-control/migrations.go index ef6a0e9d..d2faacc0 100644 --- a/plugins/access-control/migrations.go +++ b/plugins/access-control/migrations.go @@ -1,242 +1,10 @@ package accesscontrol import ( - "context" - - "github.com/uptrace/bun" - "github.com/Authula/authula/migrations" + "github.com/Authula/authula/plugins/access-control/migrationset" ) func accessControlMigrationsForProvider(provider string) []migrations.Migration { - return migrations.ForProvider(provider, migrations.ProviderVariants{ - "sqlite": func() []migrations.Migration { return []migrations.Migration{accessControlSQLiteInitial()} }, - "postgres": func() []migrations.Migration { return []migrations.Migration{accessControlPostgresInitial()} }, - "mysql": func() []migrations.Migration { return []migrations.Migration{accessControlMySQLInitial()} }, - }) -} - -func accessControlSQLiteInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260309000000_access_control_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `PRAGMA foreign_keys = ON;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_roles ( - id TEXT PRIMARY KEY, - name VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system BOOLEAN NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP - );`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_permissions ( - id TEXT PRIMARY KEY, - key VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system BOOLEAN NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP - );`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( - role_id TEXT NOT NULL, - permission_id TEXT NOT NULL, - granted_by_user_id TEXT, - granted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (role_id, permission_id), - FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, - FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL - );`, - `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_role_id ON access_control_role_permissions(role_id);`, - `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_permission_id ON access_control_role_permissions(permission_id);`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_user_roles ( - user_id TEXT NOT NULL, - role_id TEXT NOT NULL, - assigned_by_user_id TEXT, - assigned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP, - PRIMARY KEY (user_id, role_id), - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL - );`, - `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_role_id ON access_control_user_roles(role_id);`, - `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_expires_at ON access_control_user_roles(expires_at);`, - // ------------------------------- - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP TABLE IF EXISTS access_control_user_roles;`, - `DROP TABLE IF EXISTS access_control_role_permissions;`, - `DROP TABLE IF EXISTS access_control_permissions;`, - `DROP TABLE IF EXISTS access_control_roles;`, - ) - }, - } -} - -func accessControlPostgresInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260309000000_access_control_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `CREATE OR REPLACE FUNCTION access_control_set_updated_at_fn() - RETURNS TRIGGER AS $$ - BEGIN - NEW.updated_at = NOW(); - RETURN NEW; - END; - - $$ LANGUAGE plpgsql;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_roles ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - name VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system BOOLEAN NOT NULL DEFAULT FALSE, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() - );`, - `DROP TRIGGER IF EXISTS update_access_control_roles_updated_at_trigger ON access_control_roles;`, - `CREATE TRIGGER update_access_control_roles_updated_at_trigger - BEFORE UPDATE ON access_control_roles - FOR EACH ROW - EXECUTE FUNCTION access_control_set_updated_at_fn();`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_permissions ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - key VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system BOOLEAN NOT NULL DEFAULT FALSE, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() - );`, - `DROP TRIGGER IF EXISTS update_access_control_permissions_updated_at_trigger ON access_control_permissions;`, - `CREATE TRIGGER update_access_control_permissions_updated_at_trigger - BEFORE UPDATE ON access_control_permissions - FOR EACH ROW - EXECUTE FUNCTION access_control_set_updated_at_fn();`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( - role_id UUID NOT NULL, - permission_id UUID NOT NULL, - granted_by_user_id UUID, - granted_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - PRIMARY KEY (role_id, permission_id), - CONSTRAINT fk_access_control_role_permissions_role FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_role_permissions_permission FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_role_permissions_granted_by FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL - );`, - `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_role_id ON access_control_role_permissions(role_id);`, - `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_permission_id ON access_control_role_permissions(permission_id);`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_user_roles ( - user_id UUID NOT NULL, - role_id UUID NOT NULL, - assigned_by_user_id UUID, - assigned_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - expires_at TIMESTAMP WITH TIME ZONE, - PRIMARY KEY (user_id, role_id), - CONSTRAINT fk_access_control_user_roles_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_user_roles_role FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_user_roles_assigned_by FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL - );`, - `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_role_id ON access_control_user_roles(role_id);`, - `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_expires_at ON access_control_user_roles(expires_at);`, - // ------------------------------- - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP TABLE IF EXISTS access_control_user_roles;`, - `DROP TABLE IF EXISTS access_control_role_permissions;`, - `DROP TRIGGER IF EXISTS update_access_control_roles_updated_at_trigger ON access_control_roles;`, - `DROP TRIGGER IF EXISTS update_access_control_permissions_updated_at_trigger ON access_control_permissions;`, - `DROP TABLE IF EXISTS access_control_permissions;`, - `DROP TABLE IF EXISTS access_control_roles;`, - `DROP FUNCTION IF EXISTS access_control_set_updated_at_fn();`, - ) - }, - } -} - -func accessControlMySQLInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260309000000_access_control_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `CREATE TABLE IF NOT EXISTS access_control_roles ( - id BINARY(16) NOT NULL PRIMARY KEY, - name VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system TINYINT(1) NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_permissions ( - id BINARY(16) NOT NULL PRIMARY KEY, - key VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system TINYINT(1) NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( - role_id BINARY(16) NOT NULL, - permission_id BINARY(16) NOT NULL, - granted_by_user_id BINARY(16) NULL, - granted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (role_id, permission_id), - CONSTRAINT fk_access_control_role_permissions_role_id FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_role_permissions_permission_id FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_role_permissions_granted_by_user_id FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL, - INDEX idx_access_control_role_permissions_role_id (role_id), - INDEX idx_access_control_role_permissions_permission_id (permission_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_user_roles ( - user_id BINARY(16) NOT NULL, - role_id BINARY(16) NOT NULL, - assigned_by_user_id BINARY(16) NULL, - assigned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP NULL, - PRIMARY KEY (user_id, role_id), - CONSTRAINT fk_access_control_user_roles_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_user_roles_role_id FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_user_roles_assigned_by_user_id FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL, - INDEX idx_access_control_user_roles_role_id (role_id), - INDEX idx_access_control_user_roles_expires_at (expires_at) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, - // ------------------------------- - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP TABLE IF EXISTS access_control_user_roles;`, - `DROP TABLE IF EXISTS access_control_role_permissions;`, - `DROP TABLE IF EXISTS access_control_permissions;`, - `DROP TABLE IF EXISTS access_control_roles;`, - ) - }, - } + return migrationset.ForProvider(provider) } diff --git a/plugins/access-control/migrationset/migrations.go b/plugins/access-control/migrationset/migrations.go new file mode 100644 index 00000000..02f1efdb --- /dev/null +++ b/plugins/access-control/migrationset/migrations.go @@ -0,0 +1,241 @@ +package migrationset + +import ( + "context" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/migrations" +) + +func ForProvider(provider string) []migrations.Migration { + return migrations.ForProvider(provider, migrations.ProviderVariants{ + "sqlite": func() []migrations.Migration { return []migrations.Migration{accessControlSQLiteInitial()} }, + "postgres": func() []migrations.Migration { return []migrations.Migration{accessControlPostgresInitial()} }, + "mysql": func() []migrations.Migration { return []migrations.Migration{accessControlMySQLInitial()} }, + }) +} + +func accessControlSQLiteInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_access_control_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `PRAGMA foreign_keys = ON;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_roles ( + id TEXT PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system BOOLEAN NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + );`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_permissions ( + id TEXT PRIMARY KEY, + key VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system BOOLEAN NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + );`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( + role_id TEXT NOT NULL, + permission_id TEXT NOT NULL, + granted_by_user_id TEXT, + granted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (role_id, permission_id), + FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, + FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_role_id ON access_control_role_permissions(role_id);`, + `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_permission_id ON access_control_role_permissions(permission_id);`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_user_roles ( + user_id TEXT NOT NULL, + role_id TEXT NOT NULL, + assigned_by_user_id TEXT, + assigned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP, + PRIMARY KEY (user_id, role_id), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_role_id ON access_control_user_roles(role_id);`, + `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_expires_at ON access_control_user_roles(expires_at);`, + // ------------------------------- + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS access_control_user_roles;`, + `DROP TABLE IF EXISTS access_control_role_permissions;`, + `DROP TABLE IF EXISTS access_control_permissions;`, + `DROP TABLE IF EXISTS access_control_roles;`, + ) + }, + } +} + +func accessControlPostgresInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_access_control_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE OR REPLACE FUNCTION access_control_set_updated_at_fn() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_roles ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + );`, + `DROP TRIGGER IF EXISTS update_access_control_roles_updated_at_trigger ON access_control_roles;`, + `CREATE TRIGGER update_access_control_roles_updated_at_trigger + BEFORE UPDATE ON access_control_roles + FOR EACH ROW + EXECUTE FUNCTION access_control_set_updated_at_fn();`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_permissions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + key VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + );`, + `DROP TRIGGER IF EXISTS update_access_control_permissions_updated_at_trigger ON access_control_permissions;`, + `CREATE TRIGGER update_access_control_permissions_updated_at_trigger + BEFORE UPDATE ON access_control_permissions + FOR EACH ROW + EXECUTE FUNCTION access_control_set_updated_at_fn();`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( + role_id UUID NOT NULL, + permission_id UUID NOT NULL, + granted_by_user_id UUID, + granted_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + PRIMARY KEY (role_id, permission_id), + CONSTRAINT fk_access_control_role_permissions_role FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_role_permissions_permission FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_role_permissions_granted_by FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_role_id ON access_control_role_permissions(role_id);`, + `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_permission_id ON access_control_role_permissions(permission_id);`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_user_roles ( + user_id UUID NOT NULL, + role_id UUID NOT NULL, + assigned_by_user_id UUID, + assigned_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + expires_at TIMESTAMP WITH TIME ZONE, + PRIMARY KEY (user_id, role_id), + CONSTRAINT fk_access_control_user_roles_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_user_roles_role FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_user_roles_assigned_by FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_role_id ON access_control_user_roles(role_id);`, + `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_expires_at ON access_control_user_roles(expires_at);`, + // ------------------------------- + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS access_control_user_roles;`, + `DROP TABLE IF EXISTS access_control_role_permissions;`, + `DROP TRIGGER IF EXISTS update_access_control_roles_updated_at_trigger ON access_control_roles;`, + `DROP TRIGGER IF EXISTS update_access_control_permissions_updated_at_trigger ON access_control_permissions;`, + `DROP TABLE IF EXISTS access_control_permissions;`, + `DROP TABLE IF EXISTS access_control_roles;`, + `DROP FUNCTION IF EXISTS access_control_set_updated_at_fn();`, + ) + }, + } +} + +func accessControlMySQLInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_access_control_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE TABLE IF NOT EXISTS access_control_roles ( + id BINARY(16) NOT NULL PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system TINYINT(1) NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_permissions ( + id BINARY(16) NOT NULL PRIMARY KEY, + key VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system TINYINT(1) NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( + role_id BINARY(16) NOT NULL, + permission_id BINARY(16) NOT NULL, + granted_by_user_id BINARY(16) NULL, + granted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (role_id, permission_id), + CONSTRAINT fk_access_control_role_permissions_role_id FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_role_permissions_permission_id FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_role_permissions_granted_by_user_id FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL, + INDEX idx_access_control_role_permissions_role_id (role_id), + INDEX idx_access_control_role_permissions_permission_id (permission_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_user_roles ( + user_id BINARY(16) NOT NULL, + role_id BINARY(16) NOT NULL, + assigned_by_user_id BINARY(16) NULL, + assigned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP NULL, + PRIMARY KEY (user_id, role_id), + CONSTRAINT fk_access_control_user_roles_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_user_roles_role_id FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_user_roles_assigned_by_user_id FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL, + INDEX idx_access_control_user_roles_role_id (role_id), + INDEX idx_access_control_user_roles_expires_at (expires_at) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + // ------------------------------- + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS access_control_user_roles;`, + `DROP TABLE IF EXISTS access_control_role_permissions;`, + `DROP TABLE IF EXISTS access_control_permissions;`, + `DROP TABLE IF EXISTS access_control_roles;`, + ) + }, + } +} diff --git a/plugins/access-control/plugin.go b/plugins/access-control/plugin.go index 25d48412..e15395f6 100644 --- a/plugins/access-control/plugin.go +++ b/plugins/access-control/plugin.go @@ -42,20 +42,26 @@ func (p *AccessControlPlugin) Init(ctx *models.PluginContext) error { return err } - rolePermissionRepo := repositories.NewBunRolePermissionRepository(ctx.DB) + rolesRepo := repositories.NewBunRolesRepository(ctx.DB) + permissionsRepo := repositories.NewBunPermissionsRepository(ctx.DB) + rolePermissionsRepo := repositories.NewBunRolePermissionsRepository(ctx.DB) + userRolesRepo := repositories.NewBunUserRolesRepository(ctx.DB) userAccessRepo := repositories.NewBunUserAccessRepository(ctx.DB) - rolePermissionService := services.NewRolePermissionService(rolePermissionRepo) - userAccessService := services.NewUserAccessService(userAccessRepo) + + rolesService := services.NewRolesService(rolesRepo, rolePermissionsRepo, userAccessRepo) + permissionsService := services.NewPermissionsService(permissionsRepo, userAccessRepo) + rolePermissionsService := services.NewRolePermissionsService(rolesRepo, permissionsRepo, rolePermissionsRepo) + userRolesService := services.NewUserRolesService(userRolesRepo, rolesRepo) + userAccessService := services.NewUserAccessService(userRolesRepo, userAccessRepo) useCases := usecases.NewAccessControlUseCases( - usecases.NewRolePermissionUseCase(rolePermissionService), - usecases.NewUserRolesUseCase(userAccessService), - ) - p.Api = NewAPI( - useCases, - rolePermissionRepo, - userAccessRepo, + usecases.NewRolesUseCase(rolesService), + usecases.NewPermissionsUseCase(permissionsService), + usecases.NewRolePermissionsUseCase(rolePermissionsService), + usecases.NewUserRolesUseCase(userRolesService), + usecases.NewUserAccessUseCase(userAccessService), ) + p.Api = NewAPI(useCases) return nil } diff --git a/plugins/access-control/repositories/errors.go b/plugins/access-control/repositories/errors.go new file mode 100644 index 00000000..39ea9b18 --- /dev/null +++ b/plugins/access-control/repositories/errors.go @@ -0,0 +1,30 @@ +package repositories + +import ( + "fmt" + "strings" + + "github.com/Authula/authula/plugins/access-control/constants" +) + +func wrapRepositoryError(action string, err error) error { + if err == nil { + return nil + } + + if isUniqueConstraintError(err) { + return constants.ErrConflict + } + + return fmt.Errorf("failed to %s: %w", action, err) +} + +func isUniqueConstraintError(err error) bool { + message := strings.ToLower(err.Error()) + + return strings.Contains(message, "unique constraint") || + strings.Contains(message, "unique violation") || + strings.Contains(message, "duplicate key value") || + strings.Contains(message, "duplicate entry") || + strings.Contains(message, "error 1062") +} diff --git a/plugins/access-control/repositories/interfaces.go b/plugins/access-control/repositories/interfaces.go index 7ed2387e..15f9a80e 100644 --- a/plugins/access-control/repositories/interfaces.go +++ b/plugins/access-control/repositories/interfaces.go @@ -7,31 +7,40 @@ import ( "github.com/Authula/authula/plugins/access-control/types" ) -type RolePermissionRepository interface { +type RolesRepository interface { CreateRole(ctx context.Context, role *types.Role) error GetAllRoles(ctx context.Context) ([]types.Role, error) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) DeleteRole(ctx context.Context, roleID string) (bool, error) +} + +type PermissionsRepository interface { GetAllPermissions(ctx context.Context) ([]types.Permission, error) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) CreatePermission(ctx context.Context, permission *types.Permission) error UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) DeletePermission(ctx context.Context, permissionID string) (bool, error) +} + +type RolePermissionsRepository interface { GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error +} + +type UserRolesRepository interface { + GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) + GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error RemoveUserRole(ctx context.Context, userID string, roleID string) error - CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) - CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) } type UserAccessRepository interface { - GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) + CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) + CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) - GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) } diff --git a/plugins/access-control/repositories/permissions_repository.go b/plugins/access-control/repositories/permissions_repository.go new file mode 100644 index 00000000..5680863d --- /dev/null +++ b/plugins/access-control/repositories/permissions_repository.go @@ -0,0 +1,84 @@ +package repositories + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/access-control/types" +) + +type BunPermissionsRepository struct { + db bun.IDB +} + +func NewBunPermissionsRepository(db bun.IDB) *BunPermissionsRepository { + return &BunPermissionsRepository{db: db} +} + +func (r *BunPermissionsRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { + _, err := r.db.NewInsert().Model(permission).Exec(ctx) + return wrapRepositoryError("create permission", err) +} + +func (r *BunPermissionsRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { + permissions := make([]types.Permission, 0) + err := r.db.NewSelect().Model(&permissions).Order("created_at ASC").Scan(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get permissions: %w", err) + } + return permissions, nil +} + +func (r *BunPermissionsRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + permission := new(types.Permission) + err := r.db.NewSelect().Model(permission).Where("id = ?", permissionID).Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get permission by id: %w", err) + } + + return permission, nil +} + +func (r *BunPermissionsRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { + query := r.db.NewUpdate(). + Model((*types.Permission)(nil)). + Set("updated_at = ?", time.Now().UTC()). + Where("id = ?", permissionID) + + if description != nil { + query = query.Set("description = ?", *description) + } + + result, err := query.Exec(ctx) + if err != nil { + return false, wrapRepositoryError("update permission", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to determine updated rows: %w", err) + } + + return affected > 0, nil +} + +func (r *BunPermissionsRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { + result, err := r.db.NewDelete().Model((*types.Permission)(nil)).Where("id = ?", permissionID).Exec(ctx) + if err != nil { + return false, fmt.Errorf("failed to delete permission: %w", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to determine deleted rows: %w", err) + } + + return affected > 0, nil +} diff --git a/plugins/access-control/repositories/permissions_repository_test.go b/plugins/access-control/repositories/permissions_repository_test.go new file mode 100644 index 00000000..95ece8c3 --- /dev/null +++ b/plugins/access-control/repositories/permissions_repository_test.go @@ -0,0 +1,265 @@ +package repositories + +import ( + "context" + "testing" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + plugintests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestBunPermissionsRepositoryCreatePermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + permission *types.Permission + wantErr error + wantID string + wantKey string + wantDescription *string + }{ + { + name: "success", + permission: &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}, + wantID: "p1", + wantKey: "users.read", + wantDescription: new("Read users"), + }, + { + name: "duplicate key returns conflict", + permission: &types.Permission{ID: "p2", Key: "users.read", Description: new("Duplicate"), IsSystem: false}, + wantErr: accesscontrolconstants.ErrConflict, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if tc.name == "duplicate key returns conflict" { + if err := repo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}); err != nil { + t.Fatalf("failed to seed permission: %v", err) + } + } + + err := repo.CreatePermission(ctx, tc.permission) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.wantErr != nil { + return + } + + stored, err := repo.GetPermissionByID(ctx, tc.wantID) + if err != nil { + t.Fatalf("failed to fetch stored permission: %v", err) + } + if stored == nil { + t.Fatal("expected stored permission, got nil") + } + if stored.ID != tc.wantID || stored.Key != tc.wantKey { + t.Fatalf("unexpected stored permission: %#v", stored) + } + if tc.wantDescription != nil { + if stored.Description == nil || *stored.Description != *tc.wantDescription { + t.Fatalf("expected description %q, got %#v", *tc.wantDescription, stored.Description) + } + } + if stored.CreatedAt.IsZero() || stored.UpdatedAt.IsZero() { + t.Fatal("expected timestamps to be populated") + } + }) + } +} + +func TestBunPermissionsRepositoryGetAllPermissions(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if err := repo.CreatePermission(ctx, &types.Permission{ID: "p2", Key: "users.write"}); err != nil { + t.Fatalf("failed to seed permission p2: %v", err) + } + if err := repo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "users.read"}); err != nil { + t.Fatalf("failed to seed permission p1: %v", err) + } + + permissions, err := repo.GetAllPermissions(ctx) + if err != nil { + t.Fatalf("failed to get permissions: %v", err) + } + if len(permissions) != 2 { + t.Fatalf("expected 2 permissions, got %d", len(permissions)) + } + if permissions[0].ID != "p2" || permissions[1].ID != "p1" { + t.Fatalf("expected permissions ordered by creation time, got %#v", permissions) + } +} + +func TestBunPermissionsRepositoryGetPermissionByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + permissionID string + seedPermission *types.Permission + wantNil bool + }{ + { + name: "not found", + permissionID: "missing", + wantNil: true, + }, + { + name: "success", + permissionID: "p1", + seedPermission: &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if tc.seedPermission != nil { + if err := repo.CreatePermission(ctx, tc.seedPermission); err != nil { + t.Fatalf("failed to seed permission: %v", err) + } + } + + permission, err := repo.GetPermissionByID(ctx, tc.permissionID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantNil { + if permission != nil { + t.Fatalf("expected nil permission, got %#v", permission) + } + return + } + if permission == nil || permission.ID != tc.permissionID { + t.Fatalf("unexpected permission: %#v", permission) + } + }) + } +} + +func TestBunPermissionsRepositoryUpdatePermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedPermission *types.Permission + permissionID string + description *string + wantUpdated bool + }{ + { + name: "missing permission", + permissionID: "missing", + description: new("updated"), + wantUpdated: false, + }, + { + name: "success", + seedPermission: &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}, + permissionID: "p1", + description: new("updated"), + wantUpdated: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if tc.seedPermission != nil { + if err := repo.CreatePermission(ctx, tc.seedPermission); err != nil { + t.Fatalf("failed to seed permission: %v", err) + } + } + + updated, err := repo.UpdatePermission(ctx, tc.permissionID, tc.description) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if updated != tc.wantUpdated { + t.Fatalf("expected updated=%v, got %v", tc.wantUpdated, updated) + } + + if tc.wantUpdated { + permission, err := repo.GetPermissionByID(ctx, tc.permissionID) + if err != nil { + t.Fatalf("failed to fetch permission: %v", err) + } + if permission == nil || permission.Description == nil || *permission.Description != *tc.description { + t.Fatalf("unexpected permission after update: %#v", permission) + } + } + }) + } +} + +func TestBunPermissionsRepositoryDeletePermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedPermission *types.Permission + permissionID string + wantDeleted bool + }{ + { + name: "missing permission", + permissionID: "missing", + wantDeleted: false, + }, + { + name: "success", + seedPermission: &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}, + permissionID: "p1", + wantDeleted: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if tc.seedPermission != nil { + if err := repo.CreatePermission(ctx, tc.seedPermission); err != nil { + t.Fatalf("failed to seed permission: %v", err) + } + } + + deleted, err := repo.DeletePermission(ctx, tc.permissionID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deleted != tc.wantDeleted { + t.Fatalf("expected deleted=%v, got %v", tc.wantDeleted, deleted) + } + }) + } +} diff --git a/plugins/access-control/repositories/repository_test_helpers_test.go b/plugins/access-control/repositories/repository_test_helpers_test.go deleted file mode 100644 index 08eb4d60..00000000 --- a/plugins/access-control/repositories/repository_test_helpers_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package repositories - -import ( - "context" - "database/sql" - "testing" - - _ "github.com/mattn/go-sqlite3" - "github.com/uptrace/bun" - "github.com/uptrace/bun/dialect/sqlitedialect" - - "github.com/Authula/authula/models" - "github.com/Authula/authula/plugins/access-control/types" -) - -func setupRepoDB(t *testing.T) *bun.DB { - t.Helper() - - sqldb, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatalf("failed to open sqlite: %v", err) - } - - db := bun.NewDB(sqldb, sqlitedialect.New()) - t.Cleanup(func() { - _ = db.Close() - }) - - ctx := context.Background() - - if _, err := db.NewCreateTable().Model((*models.User)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create users table: %v", err) - } - if _, err := db.NewCreateTable().Model((*types.Role)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create access control roles table: %v", err) - } - if _, err := db.NewCreateTable().Model((*types.Permission)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create access control permissions table: %v", err) - } - if _, err := db.NewCreateTable().Model((*types.RolePermission)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create access control role permissions table: %v", err) - } - if _, err := db.NewCreateTable().Model((*types.UserRole)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create access control user roles table: %v", err) - } - - if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name, email, email_verified, metadata) VALUES ('u1', 'User One', 'u1@example.com', 1, '{}')`); err != nil { - t.Fatalf("failed to seed user u1: %v", err) - } - if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name, email, email_verified, metadata) VALUES ('u2', 'User Two', 'u2@example.com', 1, '{}')`); err != nil { - t.Fatalf("failed to seed user u2: %v", err) - } - - return db -} diff --git a/plugins/access-control/repositories/role_permission_repository.go b/plugins/access-control/repositories/role_permission_repository.go index 7d3e61fe..4a36a87d 100644 --- a/plugins/access-control/repositories/role_permission_repository.go +++ b/plugins/access-control/repositories/role_permission_repository.go @@ -2,7 +2,6 @@ package repositories import ( "context" - "database/sql" "fmt" "time" @@ -12,172 +11,41 @@ import ( ) type BunRolePermissionRepository struct { - db bun.IDB + RolesRepository + PermissionsRepository + RolePermissionsRepository + UserRolesRepository + UserAccessRepository } func NewBunRolePermissionRepository(db bun.IDB) *BunRolePermissionRepository { - return &BunRolePermissionRepository{db: db} -} - -func (r *BunRolePermissionRepository) CreateRole(ctx context.Context, role *types.Role) error { - _, err := r.db.NewInsert().Model(role).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to create role: %w", err) - } - return nil -} - -func (r *BunRolePermissionRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { - roles := make([]types.Role, 0) - err := r.db.NewSelect().Model(&roles).Order("created_at ASC").Scan(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get roles: %w", err) - } - return roles, nil -} - -func (r *BunRolePermissionRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { - role := new(types.Role) - err := r.db.NewSelect().Model(role).Where("id = ?", roleID).Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get role by id: %w", err) - } - - return role, nil -} - -func (r *BunRolePermissionRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { - query := r.db.NewUpdate(). - Model((*types.Role)(nil)). - Set("updated_at = ?", time.Now().UTC()). - Where("id = ?", roleID) - - if name != nil { - query = query.Set("name = ?", *name) - } - - if description != nil { - query = query.Set("description = ?", *description) - } - - result, err := query.Exec(ctx) - if err != nil { - return false, fmt.Errorf("failed to update role: %w", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to determine updated rows: %w", err) - } - - return affected > 0, nil -} - -func (r *BunRolePermissionRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { - result, err := r.db.NewDelete().Model((*types.Role)(nil)).Where("id = ?", roleID).Exec(ctx) - if err != nil { - return false, fmt.Errorf("failed to delete role: %w", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to determine deleted rows: %w", err) - } - - return affected > 0, nil -} - -func (r *BunRolePermissionRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { - count, err := r.db.NewSelect(). - Model((*types.UserRole)(nil)). - Where("role_id = ?", roleID). - Count(ctx) - if err != nil { - return 0, fmt.Errorf("failed to count role user assignments: %w", err) - } - - return count, nil -} - -func (r *BunRolePermissionRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - permissions := make([]types.Permission, 0) - err := r.db.NewSelect().Model(&permissions).Order("created_at ASC").Scan(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get permissions: %w", err) - } - return permissions, nil -} - -func (r *BunRolePermissionRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { - permission := new(types.Permission) - err := r.db.NewSelect().Model(permission).Where("id = ?", permissionID).Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get permission by id: %w", err) + return &BunRolePermissionRepository{ + RolesRepository: NewBunRolesRepository(db), + PermissionsRepository: NewBunPermissionsRepository(db), + RolePermissionsRepository: NewBunRolePermissionsRepository(db), + UserRolesRepository: NewBunUserRolesRepository(db), + UserAccessRepository: NewBunUserAccessRepository(db), } - - return permission, nil } -func (r *BunRolePermissionRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { - _, err := r.db.NewInsert().Model(permission).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to create permission: %w", err) - } - return nil -} - -func (r *BunRolePermissionRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { - result, err := r.db.NewUpdate(). - Model((*types.Permission)(nil)). - Set("description = ?", *description). - Set("updated_at = ?", time.Now().UTC()). - Where("id = ?", permissionID). - Exec(ctx) - if err != nil { - return false, fmt.Errorf("failed to update permission: %w", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to determine updated rows: %w", err) - } - - return affected > 0, nil +type BunRolePermissionsRepository struct { + db bun.IDB } -func (r *BunRolePermissionRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { - result, err := r.db.NewDelete().Model((*types.Permission)(nil)).Where("id = ?", permissionID).Exec(ctx) - if err != nil { - return false, fmt.Errorf("failed to delete permission: %w", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to determine deleted rows: %w", err) - } - - return affected > 0, nil +func NewBunRolePermissionsRepository(db bun.IDB) *BunRolePermissionsRepository { + return &BunRolePermissionsRepository{db: db} } -func (r *BunRolePermissionRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { - count, err := r.db.NewSelect(). - Model((*types.RolePermission)(nil)). - Where("permission_id = ?", permissionID). - Count(ctx) - if err != nil { - return 0, fmt.Errorf("failed to count permission role assignments: %w", err) - } - - return count, nil +type rolePermissionRow struct { + PermissionID string `bun:"permission_id"` + PermissionKey string `bun:"permission_key"` + PermissionDescription *string `bun:"permission_description"` + GrantedByUserID *string `bun:"granted_by_user_id"` + GrantedAt *time.Time `bun:"granted_at"` } -func (r *BunRolePermissionRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { +func (r *BunRolePermissionsRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { + var scanned []rolePermissionRow rows := make([]types.UserPermissionInfo, 0) err := r.db.NewSelect(). @@ -190,15 +58,25 @@ func (r *BunRolePermissionRepository) GetRolePermissions(ctx context.Context, ro Join("JOIN access_control_permissions ap ON ap.id = arp.permission_id"). Where("arp.role_id = ?", roleID). OrderExpr("ap.key ASC"). - Scan(ctx, &rows) + Scan(ctx, &scanned) if err != nil { return nil, fmt.Errorf("failed to get role permissions: %w", err) } + for _, row := range scanned { + rows = append(rows, types.UserPermissionInfo{ + PermissionID: row.PermissionID, + PermissionKey: row.PermissionKey, + PermissionDescription: row.PermissionDescription, + GrantedByUserID: row.GrantedByUserID, + GrantedAt: row.GrantedAt, + }) + } + return rows, nil } -func (r *BunRolePermissionRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { +func (r *BunRolePermissionsRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewDelete().Model((*types.RolePermission)(nil)).Where("role_id = ?", roleID).Exec(ctx); err != nil { return fmt.Errorf("failed to clear role permissions: %w", err) @@ -213,7 +91,7 @@ func (r *BunRolePermissionRepository) ReplaceRolePermissions(ctx context.Context GrantedAt: now, } if _, err := tx.NewInsert().Model(rp).Exec(ctx); err != nil { - return fmt.Errorf("failed to insert role permission: %w", err) + return wrapRepositoryError("insert role permission", err) } } @@ -221,7 +99,7 @@ func (r *BunRolePermissionRepository) ReplaceRolePermissions(ctx context.Context }) } -func (r *BunRolePermissionRepository) AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { +func (r *BunRolePermissionsRepository) AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { rp := &types.RolePermission{ RoleID: roleID, PermissionID: permissionID, @@ -230,14 +108,10 @@ func (r *BunRolePermissionRepository) AddRolePermission(ctx context.Context, rol } _, err := r.db.NewInsert().Model(rp).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to add role permission: %w", err) - } - - return nil + return wrapRepositoryError("add role permission", err) } -func (r *BunRolePermissionRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { +func (r *BunRolePermissionsRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { _, err := r.db.NewDelete(). Model((*types.RolePermission)(nil)). Where("role_id = ?", roleID). @@ -249,56 +123,3 @@ func (r *BunRolePermissionRepository) RemoveRolePermission(ctx context.Context, return nil } - -func (r *BunRolePermissionRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - if _, err := tx.NewDelete().Model((*types.UserRole)(nil)).Where("user_id = ?", userID).Exec(ctx); err != nil { - return fmt.Errorf("failed to clear user roles: %w", err) - } - - now := time.Now().UTC() - for _, roleID := range roleIDs { - ur := &types.UserRole{ - UserID: userID, - RoleID: roleID, - AssignedByUserID: assignedByUserID, - AssignedAt: now, - } - if _, err := tx.NewInsert().Model(ur).Exec(ctx); err != nil { - return fmt.Errorf("failed to insert user role: %w", err) - } - } - - return nil - }) -} - -func (r *BunRolePermissionRepository) AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error { - ur := &types.UserRole{ - UserID: userID, - RoleID: roleID, - AssignedByUserID: assignedByUserID, - AssignedAt: time.Now().UTC(), - ExpiresAt: expiresAt, - } - - _, err := r.db.NewInsert().Model(ur).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to assign user role: %w", err) - } - - return nil -} - -func (r *BunRolePermissionRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { - _, err := r.db.NewDelete(). - Model((*types.UserRole)(nil)). - Where("user_id = ?", userID). - Where("role_id = ?", roleID). - Exec(ctx) - if err != nil { - return fmt.Errorf("failed to remove user role: %w", err) - } - - return nil -} diff --git a/plugins/access-control/repositories/role_permission_repository_test.go b/plugins/access-control/repositories/role_permission_repository_test.go index 7ad5393d..3e9e2ab8 100644 --- a/plugins/access-control/repositories/role_permission_repository_test.go +++ b/plugins/access-control/repositories/role_permission_repository_test.go @@ -3,88 +3,295 @@ package repositories import ( "context" "testing" - "time" - internaltests "github.com/Authula/authula/internal/tests" + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + plugintests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" ) -func TestBunRolePermissionRepositoryRoleCRUDAndAssignmentCount(t *testing.T) { - db := setupRepoDB(t) - repo := NewBunRolePermissionRepository(db) - ctx := context.Background() +func TestBunRolePermissionsRepositoryGetRolePermissions(t *testing.T) { + t.Parallel() - role := &types.Role{ID: "r1", Name: "admin"} - if err := repo.CreateRole(ctx, role); err != nil { - t.Fatalf("failed to create role: %v", err) + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, context.Context) + roleID string + wantIDs []string + wantKeys []string + wantNil bool + }{ + { + name: "empty result", + roleID: "role-missing", + wantNil: false, + wantIDs: []string{}, + wantKeys: []string{}, + }, + { + name: "success", + roleID: "role-1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-2", Key: "users.write", Description: new("Write users")}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read", Description: new("Read users")}); err != nil { + panic(err) + } + grantedBy := "u2" + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-2", &grantedBy); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", &grantedBy); err != nil { + panic(err) + } + }, + wantIDs: []string{"perm-1", "perm-2"}, + wantKeys: []string{"users.read", "users.write"}, + }, } - if err := repo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { - t.Fatalf("failed to assign user role: %v", err) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - count, err := repo.CountUserAssignmentsByRoleID(ctx, "r1") - if err != nil { - t.Fatalf("failed to count assignments: %v", err) - } - if count != 1 { - t.Fatalf("expected 1 assignment, got %d", count) - } + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + ctx := context.Background() - deleted, err := repo.DeleteRole(ctx, "r1") - if err != nil { - t.Fatalf("failed to delete role: %v", err) - } - if !deleted { - t.Fatal("expected role to be deleted") + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, ctx) + } + + permissions, err := rolePermissionsRepo.GetRolePermissions(ctx, tc.roleID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if permissions == nil { + t.Fatal("expected permissions slice, got nil") + } + if len(permissions) != len(tc.wantIDs) { + t.Fatalf("expected %d permissions, got %d", len(tc.wantIDs), len(permissions)) + } + for i := range tc.wantIDs { + if permissions[i].PermissionID != tc.wantIDs[i] || permissions[i].PermissionKey != tc.wantKeys[i] { + t.Fatalf("unexpected permission at %d: %#v", i, permissions[i]) + } + } + }) } } -func TestBunRolePermissionRepositoryReplaceRolePermissions(t *testing.T) { - db := setupRepoDB(t) - repo := NewBunRolePermissionRepository(db) - ctx := context.Background() +func TestBunRolePermissionsRepositoryAddRolePermission(t *testing.T) { + t.Parallel() - if err := repo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { - t.Fatalf("failed to create role: %v", err) - } - if err := repo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "posts.read"}); err != nil { - t.Fatalf("failed to create permission p1: %v", err) - } - if err := repo.CreatePermission(ctx, &types.Permission{ID: "p2", Key: "posts.write"}); err != nil { - t.Fatalf("failed to create permission p2: %v", err) + tests := []struct { + name string + setup func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, context.Context) + roleID string + permissionID string + grantedByUserID *string + wantErr error + }{ + { + name: "success", + roleID: "role-1", + permissionID: "perm-1", + setup: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + }, + }, + { + name: "duplicate grant returns conflict", + roleID: "role-1", + permissionID: "perm-1", + setup: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", nil); err != nil { + panic(err) + } + }, + wantErr: accesscontrolconstants.ErrConflict, + }, } - if err := repo.ReplaceRolePermissions(ctx, "r1", []string{"p1", "p2"}, nil); err != nil { - t.Fatalf("failed to replace role permissions: %v", err) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - permissions, err := repo.GetRolePermissions(ctx, "r1") - if err != nil { - t.Fatalf("failed to get role permissions: %v", err) - } - if len(permissions) != 2 { - t.Fatalf("expected 2 permissions, got %d", len(permissions)) + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + ctx := context.Background() + + if tc.setup != nil { + tc.setup(rolesRepo, permissionsRepo, rolePermissionsRepo, ctx) + } + + err := rolePermissionsRepo.AddRolePermission(ctx, tc.roleID, tc.permissionID, tc.grantedByUserID) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.wantErr != nil { + return + } + + permissions, err := rolePermissionsRepo.GetRolePermissions(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to fetch role permissions: %v", err) + } + if len(permissions) != 1 || permissions[0].PermissionID != tc.permissionID { + t.Fatalf("unexpected permissions: %#v", permissions) + } + if permissions[0].GrantedAt == nil { + t.Fatal("expected granted_at to be populated") + } + }) } } -func TestBunRolePermissionRepositoryReplaceUserRoles(t *testing.T) { - db := setupRepoDB(t) - repo := NewBunRolePermissionRepository(db) - ctx := context.Background() +func TestBunRolePermissionsRepositoryReplaceRolePermissions(t *testing.T) { + t.Parallel() - if err := repo.CreateRole(ctx, &types.Role{ID: "r1", Name: "role-1"}); err != nil { - t.Fatalf("failed to create role r1: %v", err) + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, context.Context) + roleID string + permissionIDs []string + wantIDs []string + }{ + { + name: "success", + roleID: "role-1", + permissionIDs: []string{"perm-2", "perm-1"}, + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-2", Key: "users.write"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-2", nil); err != nil { + panic(err) + } + }, + wantIDs: []string{"perm-1", "perm-2"}, + }, } - if err := repo.CreateRole(ctx, &types.Role{ID: "r2", Name: "role-2"}); err != nil { - t.Fatalf("failed to create role r2: %v", err) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, ctx) + } + + if err := rolePermissionsRepo.ReplaceRolePermissions(ctx, tc.roleID, tc.permissionIDs, nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + permissions, err := rolePermissionsRepo.GetRolePermissions(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to fetch role permissions: %v", err) + } + if len(permissions) != len(tc.wantIDs) { + t.Fatalf("expected %d permissions, got %d", len(tc.wantIDs), len(permissions)) + } + for i, wantID := range tc.wantIDs { + if permissions[i].PermissionID != wantID { + t.Fatalf("expected permission %s at index %d, got %#v", wantID, i, permissions[i]) + } + } + }) } +} + +func TestBunRolePermissionsRepositoryRemoveRolePermission(t *testing.T) { + t.Parallel() - if err := repo.ReplaceUserRoles(ctx, "u1", []string{"r1", "r2"}, nil); err != nil { - t.Fatalf("failed to replace user roles: %v", err) + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, context.Context) + roleID string + permissionID string + wantExists bool + }{ + { + name: "success", + roleID: "role-1", + permissionID: "perm-1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", nil); err != nil { + panic(err) + } + }, + wantExists: false, + }, } - if err := repo.AssignUserRole(ctx, "u1", "r1", nil, internaltests.PtrTime(time.Now().UTC().Add(1*time.Hour))); err == nil { - t.Fatal("expected duplicate assignment to fail due primary key") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, ctx) + } + + if err := rolePermissionsRepo.RemoveRolePermission(ctx, tc.roleID, tc.permissionID); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + permissions, err := rolePermissionsRepo.GetRolePermissions(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to fetch role permissions: %v", err) + } + if tc.wantExists { + if len(permissions) == 0 { + t.Fatal("expected permission to remain") + } + return + } + if len(permissions) != 0 { + t.Fatalf("expected no permissions after remove, got %#v", permissions) + } + }) } } diff --git a/plugins/access-control/repositories/roles_repository.go b/plugins/access-control/repositories/roles_repository.go new file mode 100644 index 00000000..718c5d81 --- /dev/null +++ b/plugins/access-control/repositories/roles_repository.go @@ -0,0 +1,88 @@ +package repositories + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/access-control/types" +) + +type BunRolesRepository struct { + db bun.IDB +} + +func NewBunRolesRepository(db bun.IDB) *BunRolesRepository { + return &BunRolesRepository{db: db} +} + +func (r *BunRolesRepository) CreateRole(ctx context.Context, role *types.Role) error { + _, err := r.db.NewInsert().Model(role).Exec(ctx) + return wrapRepositoryError("create role", err) +} + +func (r *BunRolesRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { + roles := make([]types.Role, 0) + err := r.db.NewSelect().Model(&roles).Order("created_at ASC").Scan(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get roles: %w", err) + } + return roles, nil +} + +func (r *BunRolesRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { + role := new(types.Role) + err := r.db.NewSelect().Model(role).Where("id = ?", roleID).Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get role by id: %w", err) + } + + return role, nil +} + +func (r *BunRolesRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { + query := r.db.NewUpdate(). + Model((*types.Role)(nil)). + Set("updated_at = ?", time.Now().UTC()). + Where("id = ?", roleID) + + if name != nil { + query = query.Set("name = ?", *name) + } + + if description != nil { + query = query.Set("description = ?", *description) + } + + result, err := query.Exec(ctx) + if err != nil { + return false, wrapRepositoryError("update role", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to determine updated rows: %w", err) + } + + return affected > 0, nil +} + +func (r *BunRolesRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { + result, err := r.db.NewDelete().Model((*types.Role)(nil)).Where("id = ?", roleID).Exec(ctx) + if err != nil { + return false, fmt.Errorf("failed to delete role: %w", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to determine deleted rows: %w", err) + } + + return affected > 0, nil +} diff --git a/plugins/access-control/repositories/roles_repository_test.go b/plugins/access-control/repositories/roles_repository_test.go new file mode 100644 index 00000000..13379f73 --- /dev/null +++ b/plugins/access-control/repositories/roles_repository_test.go @@ -0,0 +1,388 @@ +package repositories + +import ( + "context" + "testing" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + plugintests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestBunRolesRepositoryCreateRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + role *types.Role + seed *types.Role + wantErr error + wantID string + wantName string + wantDesc *string + }{ + { + name: "success", + role: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: false}, + wantID: "r1", + wantName: "editor", + wantDesc: new("Editor role"), + }, + { + name: "duplicate name returns conflict", + role: &types.Role{ID: "r2", Name: "editor", Description: new("Duplicate role"), IsSystem: false}, + seed: &types.Role{ID: "r1", Name: "editor", Description: new("Original role"), IsSystem: false}, + wantErr: accesscontrolconstants.ErrConflict, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + if err := repo.CreateRole(ctx, tc.seed); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + err := repo.CreateRole(ctx, tc.role) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.wantErr != nil { + return + } + + stored, err := repo.GetRoleByID(ctx, tc.wantID) + if err != nil { + t.Fatalf("failed to fetch stored role: %v", err) + } + if stored == nil { + t.Fatal("expected stored role, got nil") + } + if stored.ID != tc.wantID || stored.Name != tc.wantName { + t.Fatalf("unexpected stored role: %#v", stored) + } + if !stringPtrEqual(stored.Description, tc.wantDesc) { + t.Fatalf("expected description %#v, got %#v", tc.wantDesc, stored.Description) + } + if stored.CreatedAt.IsZero() || stored.UpdatedAt.IsZero() { + t.Fatal("expected timestamps to be populated") + } + }) + } +} + +func TestBunRolesRepositoryGetAllRoles(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedRoles []*types.Role + wantIDs []string + wantNames []string + wantDescs []*string + }{ + { + name: "empty result", + wantIDs: []string{}, + wantNames: []string{}, + wantDescs: []*string{}, + }, + { + name: "returns roles ordered by creation time", + seedRoles: []*types.Role{ + {ID: "r2", Name: "viewer", Description: new("Viewer role")}, + {ID: "r1", Name: "editor", Description: new("Editor role")}, + }, + wantIDs: []string{"r2", "r1"}, + wantNames: []string{"viewer", "editor"}, + wantDescs: []*string{new("Viewer role"), new("Editor role")}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + for _, role := range tc.seedRoles { + if err := repo.CreateRole(ctx, role); err != nil { + t.Fatalf("failed to seed role %s: %v", role.ID, err) + } + } + + roles, err := repo.GetAllRoles(ctx) + if err != nil { + t.Fatalf("failed to get roles: %v", err) + } + if roles == nil { + t.Fatal("expected roles slice, got nil") + } + if len(roles) != len(tc.wantIDs) { + t.Fatalf("expected %d roles, got %d", len(tc.wantIDs), len(roles)) + } + for i := range tc.wantIDs { + if roles[i].ID != tc.wantIDs[i] || roles[i].Name != tc.wantNames[i] { + t.Fatalf("unexpected role at %d: %#v", i, roles[i]) + } + if !stringPtrEqual(roles[i].Description, tc.wantDescs[i]) { + t.Fatalf("unexpected description at %d: %#v", i, roles[i]) + } + if roles[i].CreatedAt.IsZero() || roles[i].UpdatedAt.IsZero() { + t.Fatalf("expected timestamps to be populated at %d", i) + } + } + }) + } +} + +func TestBunRolesRepositoryGetRoleByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + roleID string + seedRole *types.Role + wantNil bool + wantName string + wantDesc *string + wantSystem bool + }{ + { + name: "not found", + roleID: "missing", + wantNil: true, + }, + { + name: "success", + roleID: "r1", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: true}, + wantName: "editor", + wantDesc: new("Editor role"), + wantSystem: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seedRole != nil { + if err := repo.CreateRole(ctx, tc.seedRole); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + role, err := repo.GetRoleByID(ctx, tc.roleID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantNil { + if role != nil { + t.Fatalf("expected nil role, got %#v", role) + } + return + } + if role == nil { + t.Fatal("expected role, got nil") + } + if role.ID != tc.roleID || role.Name != tc.wantName || role.IsSystem != tc.wantSystem { + t.Fatalf("unexpected role: %#v", role) + } + if !stringPtrEqual(role.Description, tc.wantDesc) { + t.Fatalf("expected description %#v, got %#v", tc.wantDesc, role.Description) + } + }) + } +} + +func TestBunRolesRepositoryUpdateRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedRole *types.Role + roleID string + nameValue *string + description *string + wantUpdated bool + wantName *string + wantDesc *string + }{ + { + name: "missing role", + roleID: "missing", + nameValue: new("updated"), + description: new("updated description"), + wantUpdated: false, + }, + { + name: "update name only", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: false}, + roleID: "r1", + nameValue: new("editor-updated"), + wantUpdated: true, + wantName: new("editor-updated"), + wantDesc: new("Editor role"), + }, + { + name: "update description only", + seedRole: &types.Role{ID: "r2", Name: "viewer", Description: new("Viewer role"), IsSystem: false}, + roleID: "r2", + description: new("Viewer role updated"), + wantUpdated: true, + wantName: new("viewer"), + wantDesc: new("Viewer role updated"), + }, + { + name: "update name and description", + seedRole: &types.Role{ID: "r3", Name: "author", Description: new("Author role"), IsSystem: false}, + roleID: "r3", + nameValue: new("author-updated"), + description: new("Author role updated"), + wantUpdated: true, + wantName: new("author-updated"), + wantDesc: new("Author role updated"), + }, + { + name: "update with no fields still updates timestamp", + seedRole: &types.Role{ID: "r4", Name: "reviewer", Description: new("Reviewer role"), IsSystem: false}, + roleID: "r4", + wantUpdated: true, + wantName: new("reviewer"), + wantDesc: new("Reviewer role"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seedRole != nil { + if err := repo.CreateRole(ctx, tc.seedRole); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + updated, err := repo.UpdateRole(ctx, tc.roleID, tc.nameValue, tc.description) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if updated != tc.wantUpdated { + t.Fatalf("expected updated=%v, got %v", tc.wantUpdated, updated) + } + + if !tc.wantUpdated { + return + } + + role, err := repo.GetRoleByID(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to fetch updated role: %v", err) + } + if role == nil { + t.Fatal("expected updated role, got nil") + } + if role.Name != derefOrEmpty(tc.wantName) { + t.Fatalf("expected name %q, got %q", derefOrEmpty(tc.wantName), role.Name) + } + if !stringPtrEqual(role.Description, tc.wantDesc) { + t.Fatalf("expected description %#v, got %#v", tc.wantDesc, role.Description) + } + if role.UpdatedAt.IsZero() { + t.Fatal("expected updated_at to be populated") + } + }) + } +} + +func TestBunRolesRepositoryDeleteRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedRole *types.Role + roleID string + wantDeleted bool + }{ + { + name: "missing role", + roleID: "missing", + wantDeleted: false, + }, + { + name: "success", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: false}, + roleID: "r1", + wantDeleted: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seedRole != nil { + if err := repo.CreateRole(ctx, tc.seedRole); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + deleted, err := repo.DeleteRole(ctx, tc.roleID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deleted != tc.wantDeleted { + t.Fatalf("expected deleted=%v, got %v", tc.wantDeleted, deleted) + } + + if !tc.wantDeleted { + return + } + + role, err := repo.GetRoleByID(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to verify role deletion: %v", err) + } + if role != nil { + t.Fatalf("expected deleted role to be absent, got %#v", role) + } + }) + } +} + +func stringPtrEqual(left *string, right *string) bool { + if left == nil || right == nil { + return left == right + } + return *left == *right +} + +func derefOrEmpty(value *string) string { + if value == nil { + return "" + } + return *value +} diff --git a/plugins/access-control/repositories/user_access_repository.go b/plugins/access-control/repositories/user_access_repository.go index bd1b9482..ccf0cbbb 100644 --- a/plugins/access-control/repositories/user_access_repository.go +++ b/plugins/access-control/repositories/user_access_repository.go @@ -8,7 +8,6 @@ import ( "github.com/uptrace/bun" - "github.com/Authula/authula/models" "github.com/Authula/authula/plugins/access-control/types" ) @@ -20,27 +19,28 @@ func NewBunUserAccessRepository(db bun.IDB) *BunUserAccessRepository { return &BunUserAccessRepository{db: db} } -func (r *BunUserAccessRepository) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - var rows []types.UserRoleInfo - err := r.db.NewSelect(). - TableExpr("access_control_user_roles acur"). - ColumnExpr("acur.role_id AS role_id"). - ColumnExpr("acr.name AS role_name"). - ColumnExpr("acr.description AS role_description"). - ColumnExpr("acur.assigned_by_user_id AS assigned_by_user_id"). - ColumnExpr("acur.assigned_at AS assigned_at"). - ColumnExpr("acur.expires_at AS expires_at"). - Join("JOIN access_control_roles acr ON acr.id = acur.role_id"). - Where("acur.user_id = ?", userID). - OrderExpr("acr.name ASC, acur.assigned_at DESC"). - Scan(ctx, &rows) +func (r *BunUserAccessRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { + count, err := r.db.NewSelect(). + Model((*types.UserRole)(nil)). + Where("role_id = ?", roleID). + Count(ctx) if err != nil { - return nil, fmt.Errorf("failed to get user roles: %w", err) + return 0, fmt.Errorf("failed to count role user assignments: %w", err) } - if rows == nil { - return []types.UserRoleInfo{}, nil + + return count, nil +} + +func (r *BunUserAccessRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { + count, err := r.db.NewSelect(). + Model((*types.RolePermission)(nil)). + Where("permission_id = ?", permissionID). + Count(ctx) + if err != nil { + return 0, fmt.Errorf("failed to count permission role assignments: %w", err) } - return rows, nil + + return count, nil } type userEffectivePermissionRow struct { @@ -109,72 +109,6 @@ func (r *BunUserAccessRepository) GetUserEffectivePermissions(ctx context.Contex return permissions, nil } -type userWithRoleRow struct { - UserID string `bun:"user_id"` - UserName string `bun:"user_name"` - UserEmail string `bun:"user_email"` - EmailVerified bool `bun:"email_verified"` - Image *string - Metadata []byte - CreatedAt time.Time `bun:"created_at"` - UpdatedAt time.Time `bun:"updated_at"` - RoleID *string `bun:"role_id"` - RoleName *string `bun:"role_name"` -} - -func (r *BunUserAccessRepository) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - var rows []userWithRoleRow - now := time.Now().UTC() - err := r.db.NewSelect(). - TableExpr("users u"). - ColumnExpr("u.id AS user_id"). - ColumnExpr("u.name AS user_name"). - ColumnExpr("u.email AS user_email"). - ColumnExpr("u.email_verified AS email_verified"). - ColumnExpr("u.image AS image"). - ColumnExpr("u.metadata AS metadata"). - ColumnExpr("u.created_at AS created_at"). - ColumnExpr("u.updated_at AS updated_at"). - ColumnExpr("pr.id AS role_id"). - ColumnExpr("pr.name AS role_name"). - Join("LEFT JOIN access_control_user_roles pur ON pur.user_id = u.id AND (pur.expires_at IS NULL OR pur.expires_at > ?)", now). - Join("LEFT JOIN access_control_roles pr ON pr.id = pur.role_id"). - Where("u.id = ?", userID). - OrderExpr("pr.name ASC"). - Scan(ctx, &rows) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get user with roles: %w", err) - } - if len(rows) == 0 { - return nil, nil - } - - result := &types.UserWithRoles{ - User: mapRowToUser(rows[0]), - } - - seen := make(map[string]struct{}) - for _, row := range rows { - if row.RoleID == nil || *row.RoleID == "" { - continue - } - if _, ok := seen[*row.RoleID]; ok { - continue - } - seen[*row.RoleID] = struct{}{} - roleName := "" - if row.RoleName != nil { - roleName = *row.RoleName - } - result.Roles = append(result.Roles, types.UserRoleInfo{RoleID: *row.RoleID, RoleName: roleName}) - } - - return result, nil -} - type userWithPermissionRow struct { UserID string `bun:"user_id"` UserName string `bun:"user_name"` @@ -242,38 +176,6 @@ func (r *BunUserAccessRepository) GetUserWithPermissionsByID(ctx context.Context return result, nil } -type userRow interface { - GetUserID() string - GetUserName() string - GetUserEmail() string - GetEmailVerified() bool - GetImage() *string - GetMetadata() []byte - GetCreatedAt() time.Time - GetUpdatedAt() time.Time -} - -func mapRowToUser(row userRow) models.User { - return models.User{ - ID: row.GetUserID(), - Name: row.GetUserName(), - Email: row.GetUserEmail(), - EmailVerified: row.GetEmailVerified(), - Image: row.GetImage(), - Metadata: row.GetMetadata(), - CreatedAt: row.GetCreatedAt(), - UpdatedAt: row.GetUpdatedAt(), - } -} - -func (r userWithRoleRow) GetUserID() string { return r.UserID } -func (r userWithRoleRow) GetUserName() string { return r.UserName } -func (r userWithRoleRow) GetUserEmail() string { return r.UserEmail } -func (r userWithRoleRow) GetEmailVerified() bool { return r.EmailVerified } -func (r userWithRoleRow) GetImage() *string { return r.Image } -func (r userWithRoleRow) GetMetadata() []byte { return r.Metadata } -func (r userWithRoleRow) GetCreatedAt() time.Time { return r.CreatedAt } -func (r userWithRoleRow) GetUpdatedAt() time.Time { return r.UpdatedAt } func (r userWithPermissionRow) GetUserID() string { return r.UserID } func (r userWithPermissionRow) GetUserName() string { return r.UserName } func (r userWithPermissionRow) GetUserEmail() string { return r.UserEmail } diff --git a/plugins/access-control/repositories/user_access_repository_test.go b/plugins/access-control/repositories/user_access_repository_test.go index 57af8039..b7f9dbde 100644 --- a/plugins/access-control/repositories/user_access_repository_test.go +++ b/plugins/access-control/repositories/user_access_repository_test.go @@ -3,145 +3,270 @@ package repositories import ( "context" "testing" - "testing/synctest" "time" - internaltests "github.com/Authula/authula/internal/tests" + plugintests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" ) -func TestBunUserAccessRepositoryGetUserRolesIncludesExpiredWithMetadata(t *testing.T) { - db := setupRepoDB(t) - rpRepo := NewBunRolePermissionRepository(db) - uaRepo := NewBunUserAccessRepository(db) - ctx := context.Background() +func TestBunUserAccessRepositoryCounts(t *testing.T) { + t.Parallel() - if err := rpRepo.CreateRole(ctx, &types.Role{ID: "r-active", Name: "active"}); err != nil { - t.Fatalf("failed to create active role: %v", err) - } - if err := rpRepo.CreateRole(ctx, &types.Role{ID: "r-expired", Name: "expired"}); err != nil { - t.Fatalf("failed to create expired role: %v", err) + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) + roleID string + permissionID string + wantRoleCount int + wantPermissionCount int + }{ + { + name: "empty counts", + roleID: "role-missing", + permissionID: "perm-missing", + wantRoleCount: 0, + wantPermissionCount: 0, + }, + { + name: "counts assigned records", + roleID: "role-1", + permissionID: "perm-1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", nil); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-1", nil, nil); err != nil { + panic(err) + } + }, + wantRoleCount: 1, + wantPermissionCount: 1, + }, } - assignerID := "u2" - if err := rpRepo.AssignUserRole(ctx, "u1", "r-active", &assignerID, nil); err != nil { - t.Fatalf("failed to assign active role: %v", err) - } - if err := rpRepo.AssignUserRole(ctx, "u1", "r-expired", &assignerID, internaltests.PtrTime(time.Now().UTC().Add(-1*time.Hour))); err != nil { - t.Fatalf("failed to assign expired role: %v", err) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - roles, err := uaRepo.GetUserRoles(ctx, "u1") - if err != nil { - t.Fatalf("failed to get user roles: %v", err) - } - if len(roles) != 2 { - t.Fatalf("expected 2 roles including expired, got %d", len(roles)) - } - if roles[0].RoleID != "r-active" { - t.Fatalf("expected role sorted by name, got %s", roles[0].RoleID) - } - if roles[0].AssignedByUserID == nil || *roles[0].AssignedByUserID != assignerID { - t.Fatalf("expected assigned_by_user_id=%s", assignerID) - } - if roles[0].AssignedAt == nil { - t.Fatal("expected assigned_at to be populated") - } - if roles[1].RoleID != "r-expired" { - t.Fatalf("expected expired role included, got %s", roles[1].RoleID) - } - if roles[1].ExpiresAt == nil { - t.Fatal("expected expired role to include expires_at") + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + repo := NewBunUserAccessRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) + } + + roleCount, err := repo.CountUserAssignmentsByRoleID(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to count role assignments: %v", err) + } + permissionCount, err := repo.CountRoleAssignmentsByPermissionID(ctx, tc.permissionID) + if err != nil { + t.Fatalf("failed to count permission assignments: %v", err) + } + + if roleCount != tc.wantRoleCount { + t.Fatalf("expected role count %d, got %d", tc.wantRoleCount, roleCount) + } + if permissionCount != tc.wantPermissionCount { + t.Fatalf("expected permission count %d, got %d", tc.wantPermissionCount, permissionCount) + } + }) } } -func TestBunUserAccessRepositoryGetUserRolesReturnsEmptyArrayWhenNoRoles(t *testing.T) { - db := setupRepoDB(t) - uaRepo := NewBunUserAccessRepository(db) +func TestBunUserAccessRepositoryGetUserEffectivePermissions(t *testing.T) { + t.Parallel() + + activeUntil := time.Now().UTC().Add(1 * time.Hour) + expiredAt := time.Now().UTC().Add(-1 * time.Hour) + description := new("Read users") + grantedBy := "u2" - roles, err := uaRepo.GetUserRoles(context.Background(), "missing-user") - if err != nil { - t.Fatalf("failed to get user roles: %v", err) - } - if roles == nil { - t.Fatal("expected empty roles slice, got nil") + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) + userID string + wantKeys []string + wantSources int + }{ + { + name: "empty result", + userID: "missing-user", + wantKeys: []string{}, + }, + { + name: "aggregates active permissions and ignores expired roles", + userID: "u1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "posts.read", Description: description}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "r1", "p1", &grantedBy); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "r2", "p1", &grantedBy); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, &activeUntil); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r2", nil, &expiredAt); err != nil { + panic(err) + } + }, + wantKeys: []string{"posts.read"}, + wantSources: 1, + }, } - if len(roles) != 0 { - t.Fatalf("expected 0 roles, got %d", len(roles)) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + repo := NewBunUserAccessRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) + } + + permissions, err := repo.GetUserEffectivePermissions(ctx, tc.userID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if permissions == nil { + t.Fatal("expected permissions slice, got nil") + } + if len(permissions) != len(tc.wantKeys) { + t.Fatalf("expected %d permissions, got %d", len(tc.wantKeys), len(permissions)) + } + for i, wantKey := range tc.wantKeys { + if permissions[i].PermissionKey != wantKey { + t.Fatalf("expected permission key %s at index %d, got %#v", wantKey, i, permissions[i]) + } + if len(permissions[i].Sources) != tc.wantSources { + t.Fatalf("expected %d sources, got %d", tc.wantSources, len(permissions[i].Sources)) + } + if permissions[i].PermissionDescription == nil || *permissions[i].PermissionDescription != "Read users" { + t.Fatalf("expected permission description to be populated, got %#v", permissions[i].PermissionDescription) + } + } + }) } } -func TestBunUserAccessRepositoryGetUserEffectivePermissions(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - db := setupRepoDB(t) - rpRepo := NewBunRolePermissionRepository(db) - uaRepo := NewBunUserAccessRepository(db) - ctx := context.Background() - - if err := rpRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { - t.Fatalf("failed to create role: %v", err) - } - if err := rpRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { - t.Fatalf("failed to create second role: %v", err) - } - if err := rpRepo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "posts.read", Description: internaltests.PtrString("Read posts")}); err != nil { - t.Fatalf("failed to create permission: %v", err) - } - grantedBy := "u2" - if err := rpRepo.AddRolePermission(ctx, "r1", "p1", &grantedBy); err != nil { - t.Fatalf("failed to add role permission: %v", err) - } - time.Sleep(10 * time.Millisecond) - if err := rpRepo.AddRolePermission(ctx, "r2", "p1", &grantedBy); err != nil { - t.Fatalf("failed to add second role permission: %v", err) - } - if err := rpRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { - t.Fatalf("failed to assign role: %v", err) - } - if err := rpRepo.AssignUserRole(ctx, "u1", "r2", nil, nil); err != nil { - t.Fatalf("failed to assign second role: %v", err) - } - - perms, err := uaRepo.GetUserEffectivePermissions(ctx, "u1") - if err != nil { - t.Fatalf("failed to get effective permissions: %v", err) - } - if len(perms) != 1 { - t.Fatalf("expected 1 deduplicated permission, got %d", len(perms)) - } - if perms[0].PermissionKey != "posts.read" { - t.Fatalf("expected posts.read, got %s", perms[0].PermissionKey) - } - if perms[0].PermissionDescription == nil || *perms[0].PermissionDescription != "Read posts" { - t.Fatal("expected permission description to be populated") - } - if len(perms[0].Sources) != 2 { - t.Fatalf("expected 2 permission sources, got %d", len(perms[0].Sources)) - } - if perms[0].Sources[0].RoleName != "editor" || perms[0].Sources[1].RoleName != "viewer" { - t.Fatalf("expected deterministic source ordering by role_name, got %s then %s", perms[0].Sources[0].RoleName, perms[0].Sources[1].RoleName) - } - if perms[0].Sources[0].GrantedByUserID == nil || *perms[0].Sources[0].GrantedByUserID != grantedBy { - t.Fatal("expected source granted_by_user_id to be populated") - } - if perms[0].Sources[0].GrantedAt == nil || perms[0].Sources[1].GrantedAt == nil { - t.Fatal("expected source granted_at timestamps to be populated") - } - }) -} +func TestBunUserAccessRepositoryGetUserWithPermissionsByID(t *testing.T) { + t.Parallel() -func TestBunUserAccessRepositoryGetUserEffectivePermissionsReturnsEmptyArrayWhenNoPermissions(t *testing.T) { - db := setupRepoDB(t) - uaRepo := NewBunUserAccessRepository(db) + permissionDescription := new("Read users") - perms, err := uaRepo.GetUserEffectivePermissions(context.Background(), "missing-user") - if err != nil { - t.Fatalf("failed to get effective permissions: %v", err) - } - if perms == nil { - t.Fatal("expected empty permissions slice, got nil") + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) + userID string + wantNil bool + wantPermissionKeys []string + }{ + { + name: "not found", + userID: "missing-user", + wantNil: true, + }, + { + name: "success", + userID: "u1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "posts.read", Description: permissionDescription}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "p2", Key: "posts.write"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "r1", "p1", nil); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "r2", "p2", nil); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r2", nil, nil); err != nil { + panic(err) + } + }, + wantPermissionKeys: []string{"posts.read", "posts.write"}, + }, } - if len(perms) != 0 { - t.Fatalf("expected 0 permissions, got %d", len(perms)) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + repo := NewBunUserAccessRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) + } + + userWithPermissions, err := repo.GetUserWithPermissionsByID(ctx, tc.userID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantNil { + if userWithPermissions != nil { + t.Fatalf("expected nil user, got %#v", userWithPermissions) + } + return + } + if userWithPermissions == nil { + t.Fatal("expected user, got nil") + } + if userWithPermissions.User.ID != tc.userID { + t.Fatalf("expected user ID %s, got %s", tc.userID, userWithPermissions.User.ID) + } + if len(userWithPermissions.Permissions) != len(tc.wantPermissionKeys) { + t.Fatalf("expected %d permissions, got %d", len(tc.wantPermissionKeys), len(userWithPermissions.Permissions)) + } + for i, wantKey := range tc.wantPermissionKeys { + if userWithPermissions.Permissions[i].PermissionKey != wantKey { + t.Fatalf("expected permission key %s at index %d, got %#v", wantKey, i, userWithPermissions.Permissions[i]) + } + } + }) } } diff --git a/plugins/access-control/repositories/user_roles_repository.go b/plugins/access-control/repositories/user_roles_repository.go new file mode 100644 index 00000000..3dc6af51 --- /dev/null +++ b/plugins/access-control/repositories/user_roles_repository.go @@ -0,0 +1,189 @@ +package repositories + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" +) + +type BunUserRolesRepository struct { + db bun.IDB +} + +func NewBunUserRolesRepository(db bun.IDB) *BunUserRolesRepository { + return &BunUserRolesRepository{db: db} +} + +func (r *BunUserRolesRepository) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + var rows []types.UserRoleInfo + err := r.db.NewSelect(). + TableExpr("access_control_user_roles acur"). + ColumnExpr("acur.role_id AS role_id"). + ColumnExpr("acr.name AS role_name"). + ColumnExpr("acr.description AS role_description"). + ColumnExpr("acur.assigned_by_user_id AS assigned_by_user_id"). + ColumnExpr("acur.assigned_at AS assigned_at"). + ColumnExpr("acur.expires_at AS expires_at"). + Join("JOIN access_control_roles acr ON acr.id = acur.role_id"). + Where("acur.user_id = ?", userID). + OrderExpr("acr.name ASC, acur.assigned_at DESC"). + Scan(ctx, &rows) + if err != nil { + return nil, fmt.Errorf("failed to get user roles: %w", err) + } + if rows == nil { + return []types.UserRoleInfo{}, nil + } + return rows, nil +} + +type userWithRoleRow struct { + UserID string `bun:"user_id"` + UserName string `bun:"user_name"` + UserEmail string `bun:"user_email"` + EmailVerified bool `bun:"email_verified"` + Image *string + Metadata []byte + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + RoleID *string `bun:"role_id"` + RoleName *string `bun:"role_name"` +} + +func (r *BunUserRolesRepository) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { + var rows []userWithRoleRow + now := time.Now().UTC() + err := r.db.NewSelect(). + TableExpr("users u"). + ColumnExpr("u.id AS user_id"). + ColumnExpr("u.name AS user_name"). + ColumnExpr("u.email AS user_email"). + ColumnExpr("u.email_verified AS email_verified"). + ColumnExpr("u.image AS image"). + ColumnExpr("u.metadata AS metadata"). + ColumnExpr("u.created_at AS created_at"). + ColumnExpr("u.updated_at AS updated_at"). + ColumnExpr("pr.id AS role_id"). + ColumnExpr("pr.name AS role_name"). + Join("LEFT JOIN access_control_user_roles pur ON pur.user_id = u.id AND (pur.expires_at IS NULL OR pur.expires_at > ?)", now). + Join("LEFT JOIN access_control_roles pr ON pr.id = pur.role_id"). + Where("u.id = ?", userID). + OrderExpr("pr.name ASC"). + Scan(ctx, &rows) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get user with roles: %w", err) + } + if len(rows) == 0 { + return nil, nil + } + + result := &types.UserWithRoles{User: mapRowToUser(rows[0])} + seen := make(map[string]struct{}) + for _, row := range rows { + if row.RoleID == nil || *row.RoleID == "" { + continue + } + if _, ok := seen[*row.RoleID]; ok { + continue + } + seen[*row.RoleID] = struct{}{} + roleName := "" + if row.RoleName != nil { + roleName = *row.RoleName + } + result.Roles = append(result.Roles, types.UserRoleInfo{RoleID: *row.RoleID, RoleName: roleName}) + } + + return result, nil +} + +func (r *BunUserRolesRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { + return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx.NewDelete().Model((*types.UserRole)(nil)).Where("user_id = ?", userID).Exec(ctx); err != nil { + return fmt.Errorf("failed to clear user roles: %w", err) + } + + now := time.Now().UTC() + for _, roleID := range roleIDs { + ur := &types.UserRole{ + UserID: userID, + RoleID: roleID, + AssignedByUserID: assignedByUserID, + AssignedAt: now, + } + if _, err := tx.NewInsert().Model(ur).Exec(ctx); err != nil { + return wrapRepositoryError("insert user role", err) + } + } + + return nil + }) +} + +func (r *BunUserRolesRepository) AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error { + ur := &types.UserRole{ + UserID: userID, + RoleID: roleID, + AssignedByUserID: assignedByUserID, + AssignedAt: time.Now().UTC(), + ExpiresAt: expiresAt, + } + + _, err := r.db.NewInsert().Model(ur).Exec(ctx) + return wrapRepositoryError("assign user role", err) +} + +func (r *BunUserRolesRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { + _, err := r.db.NewDelete(). + Model((*types.UserRole)(nil)). + Where("user_id = ?", userID). + Where("role_id = ?", roleID). + Exec(ctx) + if err != nil { + return fmt.Errorf("failed to remove user role: %w", err) + } + + return nil +} + +type userRow interface { + GetUserID() string + GetUserName() string + GetUserEmail() string + GetEmailVerified() bool + GetImage() *string + GetMetadata() []byte + GetCreatedAt() time.Time + GetUpdatedAt() time.Time +} + +func mapRowToUser(row userRow) models.User { + return models.User{ + ID: row.GetUserID(), + Name: row.GetUserName(), + Email: row.GetUserEmail(), + EmailVerified: row.GetEmailVerified(), + Image: row.GetImage(), + Metadata: row.GetMetadata(), + CreatedAt: row.GetCreatedAt(), + UpdatedAt: row.GetUpdatedAt(), + } +} + +func (r userWithRoleRow) GetUserID() string { return r.UserID } +func (r userWithRoleRow) GetUserName() string { return r.UserName } +func (r userWithRoleRow) GetUserEmail() string { return r.UserEmail } +func (r userWithRoleRow) GetEmailVerified() bool { return r.EmailVerified } +func (r userWithRoleRow) GetImage() *string { return r.Image } +func (r userWithRoleRow) GetMetadata() []byte { return r.Metadata } +func (r userWithRoleRow) GetCreatedAt() time.Time { return r.CreatedAt } +func (r userWithRoleRow) GetUpdatedAt() time.Time { return r.UpdatedAt } diff --git a/plugins/access-control/repositories/user_roles_repository_test.go b/plugins/access-control/repositories/user_roles_repository_test.go new file mode 100644 index 00000000..8024bc1a --- /dev/null +++ b/plugins/access-control/repositories/user_roles_repository_test.go @@ -0,0 +1,400 @@ +package repositories + +import ( + "context" + "reflect" + "testing" + "time" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + plugintests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestBunUserRolesRepositoryGetUserRoles(t *testing.T) { + t.Parallel() + + futureExpiry := time.Date(2026, 3, 30, 12, 0, 0, 0, time.UTC) + roleDescription := new("Editor role") + assignedBy := new("u2") + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + wantRoles []types.UserRoleInfo + }{ + { + name: "empty result", + userID: "missing-user", + wantRoles: []types.UserRoleInfo{}, + }, + { + name: "returns assigned roles ordered by role name", + userID: "u1", + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor", Description: roleDescription}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", assignedBy, &futureExpiry); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r2", nil, nil); err != nil { + panic(err) + } + }, + wantRoles: []types.UserRoleInfo{ + { + RoleID: "r1", + RoleName: "editor", + RoleDescription: roleDescription, + AssignedByUserID: assignedBy, + ExpiresAt: &futureExpiry, + }, + { + RoleID: "r2", + RoleName: "viewer", + RoleDescription: nil, + AssignedByUserID: nil, + ExpiresAt: nil, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + roles, err := userRolesRepo.GetUserRoles(ctx, tc.userID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if roles == nil { + t.Fatal("expected roles slice, got nil") + } + if len(roles) != len(tc.wantRoles) { + t.Fatalf("expected %d roles, got %d", len(tc.wantRoles), len(roles)) + } + for i := range tc.wantRoles { + if roles[i].RoleID != tc.wantRoles[i].RoleID || roles[i].RoleName != tc.wantRoles[i].RoleName { + t.Fatalf("unexpected role at %d: %#v", i, roles[i]) + } + if !reflect.DeepEqual(roles[i].RoleDescription, tc.wantRoles[i].RoleDescription) { + t.Fatalf("unexpected role description at %d: %#v", i, roles[i]) + } + if !reflect.DeepEqual(roles[i].AssignedByUserID, tc.wantRoles[i].AssignedByUserID) || !reflect.DeepEqual(roles[i].ExpiresAt, tc.wantRoles[i].ExpiresAt) { + t.Fatalf("unexpected assignment metadata at %d: %#v", i, roles[i]) + } + if roles[i].AssignedAt == nil { + t.Fatalf("expected assigned_at to be populated at %d", i) + } + } + }) + } +} + +func TestBunUserRolesRepositoryGetUserWithRolesByID(t *testing.T) { + t.Parallel() + + activeExpiry := time.Date(2026, 3, 30, 12, 0, 0, 0, time.UTC) + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + wantNil bool + wantRoles []types.UserRoleInfo + }{ + { + name: "not found", + userID: "missing-user", + wantNil: true, + }, + { + name: "returns user with active roles only", + userID: "u1", + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, &activeExpiry); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r2", nil, new(time.Date(2026, 3, 28, 12, 0, 0, 0, time.UTC))); err != nil { + panic(err) + } + }, + wantRoles: []types.UserRoleInfo{ + { + RoleID: "r1", + RoleName: "editor", + }, + }, + }, + { + name: "returns user with no roles", + userID: "u2", + wantRoles: []types.UserRoleInfo{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + userWithRoles, err := userRolesRepo.GetUserWithRolesByID(ctx, tc.userID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantNil { + if userWithRoles != nil { + t.Fatalf("expected nil user, got %#v", userWithRoles) + } + return + } + if userWithRoles == nil { + t.Fatal("expected user, got nil") + } + if userWithRoles.User.ID != tc.userID { + t.Fatalf("expected user ID %s, got %s", tc.userID, userWithRoles.User.ID) + } + if userWithRoles.User.Name == "" || userWithRoles.User.Email == "" { + t.Fatalf("expected user fields to be populated, got %#v", userWithRoles.User) + } + if len(userWithRoles.Roles) != len(tc.wantRoles) { + t.Fatalf("expected %d roles, got %#v", len(tc.wantRoles), userWithRoles.Roles) + } + for i := range tc.wantRoles { + if userWithRoles.Roles[i].RoleID != tc.wantRoles[i].RoleID || userWithRoles.Roles[i].RoleName != tc.wantRoles[i].RoleName { + t.Fatalf("unexpected role at %d: %#v", i, userWithRoles.Roles[i]) + } + } + }) + } +} + +func TestBunUserRolesRepositoryReplaceUserRoles(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + roleIDs []string + wantRoleIDs []string + }{ + { + name: "replaces all roles", + userID: "u1", + roleIDs: []string{"r2", "r1"}, + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { + panic(err) + } + }, + wantRoleIDs: []string{"r1", "r2"}, + }, + { + name: "empty list clears roles", + userID: "u1", + roleIDs: []string{}, + wantRoleIDs: []string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + if err := userRolesRepo.ReplaceUserRoles(ctx, tc.userID, tc.roleIDs, nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + roles, err := userRolesRepo.GetUserRoles(ctx, tc.userID) + if err != nil { + t.Fatalf("failed to fetch roles: %v", err) + } + if len(roles) != len(tc.wantRoleIDs) { + t.Fatalf("expected %d roles, got %d", len(tc.wantRoleIDs), len(roles)) + } + for i, wantRoleID := range tc.wantRoleIDs { + if roles[i].RoleID != wantRoleID { + t.Fatalf("expected role %s at index %d, got %#v", wantRoleID, i, roles[i]) + } + } + }) + } +} + +func TestBunUserRolesRepositoryAssignUserRole(t *testing.T) { + t.Parallel() + + futureExpiry := time.Date(2026, 4, 1, 12, 0, 0, 0, time.UTC) + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + roleID string + expiresAt *time.Time + wantErr error + }{ + { + name: "success", + userID: "u1", + roleID: "r1", + expiresAt: &futureExpiry, + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + }, + }, + { + name: "duplicate assignment returns conflict", + userID: "u1", + roleID: "r1", + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { + panic(err) + } + }, + wantErr: accesscontrolconstants.ErrConflict, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + err := userRolesRepo.AssignUserRole(ctx, tc.userID, tc.roleID, nil, tc.expiresAt) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.wantErr != nil { + return + } + + roles, err := userRolesRepo.GetUserRoles(ctx, tc.userID) + if err != nil { + t.Fatalf("failed to fetch assigned role: %v", err) + } + if len(roles) != 1 || roles[0].RoleID != tc.roleID || roles[0].RoleName != "editor" { + t.Fatalf("unexpected roles after assign: %#v", roles) + } + if roles[0].AssignedAt == nil { + t.Fatal("expected assigned_at to be populated") + } + if tc.expiresAt != nil && !reflect.DeepEqual(roles[0].ExpiresAt, tc.expiresAt) { + t.Fatalf("unexpected expiry: %#v", roles[0]) + } + }) + } +} + +func TestBunUserRolesRepositoryRemoveUserRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + roleID string + wantRoles []types.UserRoleInfo + }{ + { + name: "success", + userID: "u1", + roleID: "r1", + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { + panic(err) + } + }, + wantRoles: []types.UserRoleInfo{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + if err := userRolesRepo.RemoveUserRole(ctx, tc.userID, tc.roleID); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + roles, err := userRolesRepo.GetUserRoles(ctx, tc.userID) + if err != nil { + t.Fatalf("failed to fetch roles: %v", err) + } + if len(roles) != len(tc.wantRoles) { + t.Fatalf("expected %d roles after remove, got %#v", len(tc.wantRoles), roles) + } + }) + } +} diff --git a/plugins/access-control/routes.go b/plugins/access-control/routes.go index 47d7e587..db4775fe 100644 --- a/plugins/access-control/routes.go +++ b/plugins/access-control/routes.go @@ -9,14 +9,20 @@ import ( ) type routeUseCases struct { - rolePermission usecases.RolePermissionUseCase - userAccess usecases.UserRolesUseCase + roles *usecases.RolesUseCase + permissions *usecases.PermissionsUseCase + rolePermissions *usecases.RolePermissionsUseCase + userRoles *usecases.UserRolesUseCase + userAccess *usecases.UserAccessUseCase } func newRouteUseCases(api *API) routeUseCases { return routeUseCases{ - rolePermission: api.useCases.RolePermissionUseCase(), - userAccess: api.useCases.UserAccessUseCase(), + roles: api.useCases.RolesUseCase(), + permissions: api.useCases.PermissionsUseCase(), + rolePermissions: api.useCases.RolePermissionsUseCase(), + userRoles: api.useCases.UserRolesUseCase(), + userAccess: api.useCases.UserAccessUseCase(), } } @@ -25,25 +31,27 @@ func Routes(api *API) []models.Route { return []models.Route{ // Roles and permissions - {Method: http.MethodPost, Path: "/access-control/roles", Handler: handlers.NewCreateRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/roles", Handler: handlers.NewGetAllRolesHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/roles/{role_id}", Handler: handlers.NewGetRoleByIDHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPatch, Path: "/access-control/roles/{role_id}", Handler: handlers.NewUpdateRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}", Handler: handlers.NewDeleteRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPost, Path: "/access-control/permissions", Handler: handlers.NewCreatePermissionHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/permissions", Handler: handlers.NewGetAllPermissionsHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPatch, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewUpdatePermissionHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodDelete, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewDeletePermissionHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPost, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewAddRolePermissionHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewGetRolePermissionsHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPut, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewReplaceRolePermissionsHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}/permissions/{permission_id}", Handler: handlers.NewRemoveRolePermissionHandler(usecases.rolePermission).Handler()}, + {Method: http.MethodPost, Path: "/access-control/roles", Handler: handlers.NewCreateRoleHandler(usecases.roles).Handler()}, + {Method: http.MethodGet, Path: "/access-control/roles", Handler: handlers.NewGetAllRolesHandler(usecases.roles).Handler()}, + {Method: http.MethodGet, Path: "/access-control/roles/{role_id}", Handler: handlers.NewGetRoleByIDHandler(usecases.roles).Handler()}, + {Method: http.MethodPatch, Path: "/access-control/roles/{role_id}", Handler: handlers.NewUpdateRoleHandler(usecases.roles).Handler()}, + {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}", Handler: handlers.NewDeleteRoleHandler(usecases.roles).Handler()}, + {Method: http.MethodPost, Path: "/access-control/permissions", Handler: handlers.NewCreatePermissionHandler(usecases.permissions).Handler()}, + {Method: http.MethodGet, Path: "/access-control/permissions", Handler: handlers.NewGetAllPermissionsHandler(usecases.permissions).Handler()}, + {Method: http.MethodGet, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewGetPermissionByIDHandler(usecases.permissions).Handler()}, + {Method: http.MethodPatch, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewUpdatePermissionHandler(usecases.permissions).Handler()}, + {Method: http.MethodDelete, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewDeletePermissionHandler(usecases.permissions).Handler()}, + {Method: http.MethodPost, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewAddRolePermissionHandler(usecases.rolePermissions).Handler()}, + {Method: http.MethodGet, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewGetRolePermissionsHandler(usecases.rolePermissions).Handler()}, + {Method: http.MethodPut, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewReplaceRolePermissionsHandler(usecases.rolePermissions).Handler()}, + {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}/permissions/{permission_id}", Handler: handlers.NewRemoveRolePermissionHandler(usecases.rolePermissions).Handler()}, // User roles and permissions - {Method: http.MethodGet, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewGetUserRolesHandler(usecases.userAccess).Handler()}, - {Method: http.MethodPost, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewAssignUserRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPut, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewReplaceUserRolesHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodDelete, Path: "/access-control/users/{user_id}/roles/{role_id}", Handler: handlers.NewRemoveUserRoleHandler(usecases.rolePermission).Handler()}, + {Method: http.MethodGet, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewGetUserRolesHandler(usecases.userRoles).Handler()}, + {Method: http.MethodGet, Path: "/access-control/users/{user_id}/authorization-profile", Handler: handlers.NewGetUserAuthorizationProfileHandler(usecases.userAccess).Handler()}, + {Method: http.MethodPost, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewAssignUserRoleHandler(usecases.userRoles).Handler()}, + {Method: http.MethodPut, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewReplaceUserRolesHandler(usecases.userRoles).Handler()}, + {Method: http.MethodDelete, Path: "/access-control/users/{user_id}/roles/{role_id}", Handler: handlers.NewRemoveUserRoleHandler(usecases.userRoles).Handler()}, {Method: http.MethodGet, Path: "/access-control/users/{user_id}/permissions", Handler: handlers.NewGetUserEffectivePermissionsHandler(usecases.userAccess).Handler()}, } } diff --git a/plugins/access-control/services/permissions_service.go b/plugins/access-control/services/permissions_service.go new file mode 100644 index 00000000..5c08f4f4 --- /dev/null +++ b/plugins/access-control/services/permissions_service.go @@ -0,0 +1,141 @@ +package services + +import ( + "context" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/repositories" + "github.com/Authula/authula/plugins/access-control/types" +) + +type PermissionsService struct { + permissionsRepo repositories.PermissionsRepository + userAccessRepo repositories.UserAccessRepository +} + +func NewPermissionsService(permissionsRepo repositories.PermissionsRepository, userAccessRepo repositories.UserAccessRepository) *PermissionsService { + return &PermissionsService{permissionsRepo: permissionsRepo, userAccessRepo: userAccessRepo} +} + +func (s *PermissionsService) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { + if req.Key == "" { + return nil, constants.ErrBadRequest + } + + var description *string + if req.Description != nil { + description = req.Description + } + + permission := &types.Permission{ + ID: util.GenerateUUID(), + Key: req.Key, + Description: description, + IsSystem: req.IsSystem, + } + + if err := s.permissionsRepo.CreatePermission(ctx, permission); err != nil { + return nil, err + } + + return permission, nil +} + +func (s *PermissionsService) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { + return s.permissionsRepo.GetAllPermissions(ctx) +} + +func (s *PermissionsService) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + if permissionID == "" { + return nil, constants.ErrBadRequest + } + + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return nil, err + } + if permission == nil { + return nil, constants.ErrNotFound + } + + return permission, nil +} + +func (s *PermissionsService) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { + if permissionID == "" { + return nil, constants.ErrUnprocessableEntity + } + if req.Description == nil { + return nil, constants.ErrUnprocessableEntity + } + + description := *req.Description + if description == "" { + return nil, constants.ErrUnprocessableEntity + } + + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return nil, err + } + if permission == nil { + return nil, constants.ErrNotFound + } + if permission.IsSystem { + return nil, constants.ErrBadRequest + } + + updated, err := s.permissionsRepo.UpdatePermission(ctx, permissionID, &description) + if err != nil { + return nil, err + } + if !updated { + return nil, constants.ErrNotFound + } + + permission, err = s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return nil, err + } + if permission == nil { + return nil, constants.ErrNotFound + } + + return permission, nil +} + +func (s *PermissionsService) DeletePermission(ctx context.Context, permissionID string) error { + if permissionID == "" { + return constants.ErrBadRequest + } + + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return err + } + if permission == nil { + return constants.ErrNotFound + } + if permission.IsSystem { + return constants.ErrBadRequest + } + + assignmentsCount, err := s.userAccessRepo.CountRoleAssignmentsByPermissionID(ctx, permissionID) + if err != nil { + return err + } + if assignmentsCount > 0 { + return constants.ErrConflict + } + + deleted, err := s.permissionsRepo.DeletePermission(ctx, permissionID) + if err != nil { + return err + } + if !deleted { + return constants.ErrNotFound + } + + return nil +} diff --git a/plugins/access-control/services/permissions_service_test.go b/plugins/access-control/services/permissions_service_test.go new file mode 100644 index 00000000..971672cb --- /dev/null +++ b/plugins/access-control/services/permissions_service_test.go @@ -0,0 +1,403 @@ +package services + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestPermissionsServiceCreatePermission(t *testing.T) { + t.Parallel() + + description := "Read users" + + tests := []struct { + name string + req types.CreatePermissionRequest + setup func(*accesscontroltests.MockPermissionsRepository) + wantErr error + assert func(*testing.T, *types.Permission) + }{ + { + name: "blank key", + req: types.CreatePermissionRequest{Key: ""}, + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "success", + req: types.CreatePermissionRequest{ + Key: "users.read", + Description: &description, + IsSystem: true, + }, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("CreatePermission", mock.Anything, mock.MatchedBy(func(permission *types.Permission) bool { + return permission != nil && permission.ID != "" && permission.Key == "users.read" && permission.IsSystem && permission.Description != nil && *permission.Description == description + })).Return(nil).Once() + }, + assert: func(t *testing.T, permission *types.Permission) { + if permission == nil { + t.Fatal("expected permission, got nil") + } + if permission.ID == "" { + t.Fatal("expected generated ID") + } + if permission.Key != "users.read" { + t.Fatalf("expected key %q, got %q", "users.read", permission.Key) + } + if permission.Description == nil || *permission.Description != description { + t.Fatalf("expected description %q, got %#v", description, permission.Description) + } + if !permission.IsSystem { + t.Fatal("expected system permission flag to be preserved") + } + }, + }, + { + name: "repository error is returned", + req: types.CreatePermissionRequest{Key: "users.write"}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("CreatePermission", mock.Anything, mock.AnythingOfType("*types.Permission")).Return(errors.New("boom")).Once() + }, + wantErr: errors.New("boom"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setup != nil { + tc.setup(permissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, userAccessRepo) + permission, err := service.CreatePermission(context.Background(), tc.req) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + } else if err == nil || err.Error() != tc.wantErr.Error() { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.assert != nil { + tc.assert(t, permission) + } + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestPermissionsServiceGetAllPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(*accesscontroltests.MockPermissionsRepository) + want []types.Permission + wantErr error + }{ + { + name: "success", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetAllPermissions", mock.Anything).Return([]types.Permission{{ID: "perm-1", Key: "users.read"}}, nil).Once() + }, + want: []types.Permission{{ID: "perm-1", Key: "users.read"}}, + }, + { + name: "error", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetAllPermissions", mock.Anything).Return(nil, errors.New("boom")).Once() + }, + wantErr: errors.New("boom"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setup != nil { + tc.setup(permissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, userAccessRepo) + permissions, err := service.GetAllPermissions(context.Background()) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if len(permissions) != len(tc.want) || (len(permissions) == 1 && permissions[0].ID != tc.want[0].ID) { + t.Fatalf("unexpected permissions %#v", permissions) + } + } else if err == nil || err.Error() != tc.wantErr.Error() { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestPermissionsServiceGetPermissionByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id string + setup func(*accesscontroltests.MockPermissionsRepository) + wantID string + wantErr error + }{ + { + name: "blank id", + id: "", + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "not found", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "success", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + }, + wantID: "perm-1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setup != nil { + tc.setup(permissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, userAccessRepo) + permission, err := service.GetPermissionByID(context.Background(), tc.id) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if permission == nil || permission.ID != tc.wantID { + t.Fatalf("unexpected permission %#v", permission) + } + } else if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestPermissionsServiceUpdatePermission(t *testing.T) { + t.Parallel() + + updatedDescription := "Updated description" + + tests := []struct { + name string + id string + req types.UpdatePermissionRequest + setup func(*accesscontroltests.MockPermissionsRepository) + wantID string + wantErr error + }{ + { + name: "blank id", + id: "", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + wantErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "nil description", + id: "perm-1", + req: types.UpdatePermissionRequest{}, + wantErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "blank description", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: func() *string { value := ""; return &value }()}, + wantErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "not found", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "system permission", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: true}, nil).Once() + }, + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "update returns false", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + permissionsRepo.On("UpdatePermission", mock.Anything, "perm-1", &updatedDescription).Return(false, nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "success", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + permissionsRepo.On("UpdatePermission", mock.Anything, "perm-1", &updatedDescription).Return(true, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", Description: &updatedDescription}, nil).Once() + }, + wantID: "perm-1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setup != nil { + tc.setup(permissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, userAccessRepo) + permission, err := service.UpdatePermission(context.Background(), tc.id, tc.req) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if permission == nil || permission.ID != tc.wantID { + t.Fatalf("unexpected permission %#v", permission) + } + } else if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} + +func TestPermissionsServiceDeletePermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id string + setup func(*accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockUserAccessRepository) + wantErr error + }{ + { + name: "blank id", + id: "", + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "not found", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "system permission", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: true}, nil).Once() + }, + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "permission in use", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + userAccessRepo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(2, nil).Once() + }, + wantErr: accesscontrolconstants.ErrConflict, + }, + { + name: "delete returns false", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + userAccessRepo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() + permissionsRepo.On("DeletePermission", mock.Anything, "perm-1").Return(false, nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "success", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + userAccessRepo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() + permissionsRepo.On("DeletePermission", mock.Anything, "perm-1").Return(true, nil).Once() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setup != nil { + tc.setup(permissionsRepo, userAccessRepo) + } + + service := NewPermissionsService(permissionsRepo, userAccessRepo) + err := service.DeletePermission(context.Background(), tc.id) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + } else if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + permissionsRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} diff --git a/plugins/access-control/services/role_permission_service.go b/plugins/access-control/services/role_permission_service.go index a269a13d..c1359653 100644 --- a/plugins/access-control/services/role_permission_service.go +++ b/plugins/access-control/services/role_permission_service.go @@ -2,114 +2,28 @@ package services import ( "context" - "strings" - "time" - "github.com/Authula/authula/internal/util" "github.com/Authula/authula/plugins/access-control/constants" "github.com/Authula/authula/plugins/access-control/repositories" "github.com/Authula/authula/plugins/access-control/types" ) -type RolePermissionService struct { - repo repositories.RolePermissionRepository +type RolePermissionsService struct { + rolesRepo repositories.RolesRepository + permissionsRepo repositories.PermissionsRepository + rolePermissionsRepo repositories.RolePermissionsRepository } -func NewRolePermissionService(repo repositories.RolePermissionRepository) *RolePermissionService { - return &RolePermissionService{repo: repo} +func NewRolePermissionsService(rolesRepo repositories.RolesRepository, permissionsRepo repositories.PermissionsRepository, rolePermissionsRepo repositories.RolePermissionsRepository) *RolePermissionsService { + return &RolePermissionsService{rolesRepo: rolesRepo, permissionsRepo: permissionsRepo, rolePermissionsRepo: rolePermissionsRepo} } -func (s *RolePermissionService) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { - name := strings.TrimSpace(req.Name) - if name == "" { - return nil, constants.ErrBadRequest - } - - role := &types.Role{ - ID: util.GenerateUUID(), - Name: name, - Description: req.Description, - IsSystem: req.IsSystem, - } - - if err := s.repo.CreateRole(ctx, role); err != nil { - return nil, err - } - - return role, nil -} - -func (s *RolePermissionService) GetAllRoles(ctx context.Context) ([]types.Role, error) { - return s.repo.GetAllRoles(ctx) -} - -func (s *RolePermissionService) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { - roleID = strings.TrimSpace(roleID) +func (s *RolePermissionsService) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { if roleID == "" { - return nil, constants.ErrBadRequest - } - - role, err := s.repo.GetRoleByID(ctx, roleID) - if err != nil { - return nil, err - } - if role == nil { - return nil, constants.ErrNotFound - } - - permissions, err := s.repo.GetRolePermissions(ctx, roleID) - if err != nil { - return nil, err - } - - return &types.RoleDetails{Role: *role, Permissions: permissions}, nil -} - -func (s *RolePermissionService) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { - roleID = strings.TrimSpace(roleID) - if roleID == "" { - return nil, constants.ErrBadRequest - } - - if req.Name == nil && req.Description == nil { return nil, constants.ErrUnprocessableEntity } - role, err := s.repo.GetRoleByID(ctx, roleID) - if err != nil { - return nil, err - } - if role == nil { - return nil, constants.ErrNotFound - } - if role.IsSystem { - return nil, constants.ErrCannotUpdateSystemRole - } - - var name *string - if req.Name != nil { - trimmed := strings.TrimSpace(*req.Name) - if trimmed == "" { - return nil, constants.ErrBadRequest - } - name = &trimmed - } - - var description *string - if req.Description != nil { - trimmed := strings.TrimSpace(*req.Description) - description = &trimmed - } - - updated, err := s.repo.UpdateRole(ctx, roleID, name, description) - if err != nil { - return nil, err - } - if !updated { - return nil, constants.ErrNotFound - } - - role, err = s.repo.GetRoleByID(ctx, roleID) + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) if err != nil { return nil, err } @@ -117,16 +31,18 @@ func (s *RolePermissionService) UpdateRole(ctx context.Context, roleID string, r return nil, constants.ErrNotFound } - return role, nil + return s.rolePermissionsRepo.GetRolePermissions(ctx, roleID) } -func (s *RolePermissionService) DeleteRole(ctx context.Context, roleID string) error { - roleID = strings.TrimSpace(roleID) +func (s *RolePermissionsService) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { if roleID == "" { return constants.ErrBadRequest } + if permissionID == "" { + return constants.ErrBadRequest + } - role, err := s.repo.GetRoleByID(ctx, roleID) + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) if err != nil { return err } @@ -134,120 +50,10 @@ func (s *RolePermissionService) DeleteRole(ctx context.Context, roleID string) e return constants.ErrNotFound } if role.IsSystem { - return constants.ErrCannotUpdateSystemRole - } - - assignmentsCount, err := s.repo.CountUserAssignmentsByRoleID(ctx, roleID) - if err != nil { - return err - } - if assignmentsCount > 0 { - return constants.ErrConflict - } - - deleted, err := s.repo.DeleteRole(ctx, roleID) - if err != nil { - return err - } - if !deleted { - return constants.ErrNotFound - } - - return nil -} - -func (s *RolePermissionService) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { - key := strings.TrimSpace(req.Key) - if key == "" { - return nil, constants.ErrBadRequest - } - - permission := &types.Permission{ - ID: util.GenerateUUID(), - Key: key, - Description: req.Description, - IsSystem: req.IsSystem, - } - - if err := s.repo.CreatePermission(ctx, permission); err != nil { - return nil, err - } - - return permission, nil -} - -func (s *RolePermissionService) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - return s.repo.GetAllPermissions(ctx) -} - -func (s *RolePermissionService) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { - roleID = strings.TrimSpace(roleID) - if roleID == "" { - return nil, constants.ErrUnprocessableEntity - } - - role, err := s.repo.GetRoleByID(ctx, roleID) - if err != nil { - return nil, err - } - if role == nil { - return nil, constants.ErrNotFound - } - - return s.repo.GetRolePermissions(ctx, roleID) -} - -func (s *RolePermissionService) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { - permissionID = strings.TrimSpace(permissionID) - if permissionID == "" { - return nil, constants.ErrUnprocessableEntity - } - if req.Description == nil { - return nil, constants.ErrUnprocessableEntity - } - - description := strings.TrimSpace(*req.Description) - if description == "" { - return nil, constants.ErrUnprocessableEntity - } - - permission, err := s.repo.GetPermissionByID(ctx, permissionID) - if err != nil { - return nil, err - } - if permission == nil { - return nil, constants.ErrNotFound - } - if permission.IsSystem { - return nil, constants.ErrBadRequest - } - - updated, err := s.repo.UpdatePermission(ctx, permissionID, &description) - if err != nil { - return nil, err - } - if !updated { - return nil, constants.ErrNotFound - } - - permission, err = s.repo.GetPermissionByID(ctx, permissionID) - if err != nil { - return nil, err - } - if permission == nil { - return nil, constants.ErrNotFound - } - - return permission, nil -} - -func (s *RolePermissionService) DeletePermission(ctx context.Context, permissionID string) error { - permissionID = strings.TrimSpace(permissionID) - if permissionID == "" { return constants.ErrBadRequest } - permission, err := s.repo.GetPermissionByID(ctx, permissionID) + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) if err != nil { return err } @@ -258,37 +64,18 @@ func (s *RolePermissionService) DeletePermission(ctx context.Context, permission return constants.ErrBadRequest } - assignmentsCount, err := s.repo.CountRoleAssignmentsByPermissionID(ctx, permissionID) - if err != nil { - return err - } - if assignmentsCount > 0 { - return constants.ErrConflict - } - - deleted, err := s.repo.DeletePermission(ctx, permissionID) - if err != nil { - return err - } - if !deleted { - return constants.ErrNotFound - } - - return nil + return s.rolePermissionsRepo.AddRolePermission(ctx, roleID, permissionID, grantedByUserID) } -func (s *RolePermissionService) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { - roleID = strings.TrimSpace(roleID) - permissionID = strings.TrimSpace(permissionID) - +func (s *RolePermissionsService) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { if roleID == "" { - return constants.ErrBadRequest + return constants.ErrUnprocessableEntity } if permissionID == "" { - return constants.ErrBadRequest + return constants.ErrUnprocessableEntity } - role, err := s.repo.GetRoleByID(ctx, roleID) + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) if err != nil { return err } @@ -299,7 +86,7 @@ func (s *RolePermissionService) AddPermissionToRole(ctx context.Context, roleID return constants.ErrBadRequest } - permission, err := s.repo.GetPermissionByID(ctx, permissionID) + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) if err != nil { return err } @@ -310,21 +97,15 @@ func (s *RolePermissionService) AddPermissionToRole(ctx context.Context, roleID return constants.ErrBadRequest } - return s.repo.AddRolePermission(ctx, roleID, permissionID, grantedByUserID) + return s.rolePermissionsRepo.RemoveRolePermission(ctx, roleID, permissionID) } -func (s *RolePermissionService) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { - roleID = strings.TrimSpace(roleID) - permissionID = strings.TrimSpace(permissionID) - +func (s *RolePermissionsService) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { if roleID == "" { - return constants.ErrUnprocessableEntity - } - if permissionID == "" { - return constants.ErrUnprocessableEntity + return constants.ErrBadRequest } - role, err := s.repo.GetRoleByID(ctx, roleID) + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) if err != nil { return err } @@ -335,29 +116,9 @@ func (s *RolePermissionService) RemovePermissionFromRole(ctx context.Context, ro return constants.ErrBadRequest } - permission, err := s.repo.GetPermissionByID(ctx, permissionID) - if err != nil { - return err - } - if permission == nil { - return constants.ErrNotFound - } - if permission.IsSystem { - return constants.ErrBadRequest - } - - return s.repo.RemoveRolePermission(ctx, roleID, permissionID) -} - -func (s *RolePermissionService) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { - if strings.TrimSpace(roleID) == "" { - return constants.ErrBadRequest - } - normalized := make([]string, 0, len(permissionIDs)) seen := make(map[string]struct{}, len(permissionIDs)) for _, permissionID := range permissionIDs { - permissionID = strings.TrimSpace(permissionID) if permissionID == "" { continue } @@ -365,62 +126,20 @@ func (s *RolePermissionService) ReplaceRolePermissions(ctx context.Context, role continue } seen[permissionID] = struct{}{} - normalized = append(normalized, permissionID) - } - - return s.repo.ReplaceRolePermissions(ctx, roleID, normalized, grantedByUserID) -} - -func (s *RolePermissionService) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - if strings.TrimSpace(userID) == "" { - return constants.ErrBadRequest - } - normalized := make([]string, 0, len(roleIDs)) - seen := make(map[string]struct{}, len(roleIDs)) - for _, roleID := range roleIDs { - roleID = strings.TrimSpace(roleID) - if roleID == "" { - continue + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return err } - if _, ok := seen[roleID]; ok { - continue + if permission == nil { + return constants.ErrNotFound + } + if permission.IsSystem { + return constants.ErrBadRequest } - seen[roleID] = struct{}{} - normalized = append(normalized, roleID) - } - - return s.repo.ReplaceUserRoles(ctx, userID, normalized, assignedByUserID) -} - -func (s *RolePermissionService) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { - userID = strings.TrimSpace(userID) - if userID == "" { - return constants.ErrUnprocessableEntity - } - - roleID := strings.TrimSpace(req.RoleID) - if roleID == "" { - return constants.ErrUnprocessableEntity - } - - if req.ExpiresAt != nil && req.ExpiresAt.Before(time.Now().UTC()) { - return constants.ErrBadRequest - } - - return s.repo.AssignUserRole(ctx, userID, roleID, assignedByUserID, req.ExpiresAt) -} - -func (s *RolePermissionService) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { - userID = strings.TrimSpace(userID) - roleID = strings.TrimSpace(roleID) - if userID == "" { - return constants.ErrBadRequest - } - if roleID == "" { - return constants.ErrBadRequest + normalized = append(normalized, permissionID) } - return s.repo.RemoveUserRole(ctx, userID, roleID) + return s.rolePermissionsRepo.ReplaceRolePermissions(ctx, roleID, normalized, grantedByUserID) } diff --git a/plugins/access-control/services/role_permission_service_test.go b/plugins/access-control/services/role_permission_service_test.go index 7e18d213..5e568ea8 100644 --- a/plugins/access-control/services/role_permission_service_test.go +++ b/plugins/access-control/services/role_permission_service_test.go @@ -1,139 +1 @@ -package services_test - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/stretchr/testify/mock" - - "github.com/Authula/authula/plugins/access-control/constants" - servicespkg "github.com/Authula/authula/plugins/access-control/services" - testshelpers "github.com/Authula/authula/plugins/access-control/tests" - "github.com/Authula/authula/plugins/access-control/types" -) - -func newRolePermissionServiceFixture() (*servicespkg.RolePermissionService, *testshelpers.MockRolePermissionRepository) { - repo := testshelpers.NewMockRolePermissionRepository() - svc := servicespkg.NewRolePermissionService(repo) - return svc, repo -} - -func TestRolePermissionServiceCreateRoleTrimsName(t *testing.T) { - t.Parallel() - - svc, repo := newRolePermissionServiceFixture() - repo.On("CreateRole", mock.Anything, mock.AnythingOfType("*types.Role")). - Run(func(args mock.Arguments) { - role := args.Get(1).(*types.Role) - if role.Name != "admin" { - t.Fatalf("expected trimmed role name, got %q", role.Name) - } - }). - Return(nil). - Once() - - role, err := svc.CreateRole(context.Background(), types.CreateRoleRequest{Name: " admin "}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if role.Name != "admin" { - t.Fatalf("expected trimmed role name, got %q", role.Name) - } - repo.AssertExpectations(t) -} - -func TestRolePermissionServiceAssignRoleToUserRejectsPastExpiry(t *testing.T) { - t.Parallel() - - svc, _ := newRolePermissionServiceFixture() - past := time.Now().UTC().Add(-1 * time.Hour) - - err := svc.AssignRoleToUser( - context.Background(), - "user-1", - types.AssignUserRoleRequest{RoleID: "role-1", ExpiresAt: &past}, - nil, - ) - if !errors.Is(err, constants.ErrBadRequest) { - t.Fatalf("expected ErrBadRequest, got %v", err) - } -} - -func TestRolePermissionServiceDeleteRoleReturnsConflictWhenAssigned(t *testing.T) { - t.Parallel() - - svc, repo := newRolePermissionServiceFixture() - ctx := context.Background() - roleID := "role-1" - - repo.On("GetRoleByID", mock.Anything, roleID). - Return(&types.Role{ID: roleID, Name: "role-for-delete"}, nil). - Once() - repo.On("CountUserAssignmentsByRoleID", mock.Anything, roleID). - Return(1, nil). - Once() - - err := svc.DeleteRole(ctx, roleID) - if !errors.Is(err, constants.ErrConflict) { - t.Fatalf("expected ErrConflict, got %v", err) - } - repo.AssertExpectations(t) -} - -func TestRolePermissionServiceGetRolePermissionsRejectsEmptyRoleID(t *testing.T) { - t.Parallel() - - svc, _ := newRolePermissionServiceFixture() - - _, err := svc.GetRolePermissions(context.Background(), " ") - if !errors.Is(err, constants.ErrUnprocessableEntity) { - t.Fatalf("expected ErrUnprocessableEntity, got %v", err) - } -} - -func TestRolePermissionServiceGetRolePermissionsReturnsNotFoundForMissingRole(t *testing.T) { - t.Parallel() - - svc, repo := newRolePermissionServiceFixture() - repo.On("GetRoleByID", mock.Anything, "missing-role").Return((*types.Role)(nil), nil).Once() - - _, err := svc.GetRolePermissions(context.Background(), "missing-role") - if !errors.Is(err, constants.ErrNotFound) { - t.Fatalf("expected ErrNotFound, got %v", err) - } - repo.AssertExpectations(t) -} - -func TestRolePermissionServiceGetRolePermissionsReturnsAssignedPermissions(t *testing.T) { - t.Parallel() - - svc, repo := newRolePermissionServiceFixture() - ctx := context.Background() - roleID := "role-1" - permissionID := "perm-1" - permissionKey := "users.read" - - repo.On("GetRoleByID", mock.Anything, roleID). - Return(&types.Role{ID: roleID, Name: "RolePermReader"}, nil). - Once() - repo.On("GetRolePermissions", mock.Anything, roleID). - Return([]types.UserPermissionInfo{{PermissionID: permissionID, PermissionKey: permissionKey}}, nil). - Once() - - permissions, err := svc.GetRolePermissions(ctx, " "+roleID+" ") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(permissions) != 1 { - t.Fatalf("expected 1 permission, got %d", len(permissions)) - } - if permissions[0].PermissionID != permissionID { - t.Fatalf("expected permission id %q, got %q", permissionID, permissions[0].PermissionID) - } - if permissions[0].PermissionKey != permissionKey { - t.Fatalf("expected permission key %q, got %q", permissionKey, permissions[0].PermissionKey) - } - repo.AssertExpectations(t) -} +package services diff --git a/plugins/access-control/services/roles_service.go b/plugins/access-control/services/roles_service.go new file mode 100644 index 00000000..ca7cd308 --- /dev/null +++ b/plugins/access-control/services/roles_service.go @@ -0,0 +1,156 @@ +package services + +import ( + "context" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/repositories" + "github.com/Authula/authula/plugins/access-control/types" +) + +type RolesService struct { + rolesRepo repositories.RolesRepository + rolePermissionsRepo repositories.RolePermissionsRepository + userAccessRepo repositories.UserAccessRepository +} + +func NewRolesService(rolesRepo repositories.RolesRepository, rolePermissionsRepo repositories.RolePermissionsRepository, userAccessRepo repositories.UserAccessRepository) *RolesService { + return &RolesService{rolesRepo: rolesRepo, rolePermissionsRepo: rolePermissionsRepo, userAccessRepo: userAccessRepo} +} + +func (s *RolesService) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { + if req.Name == "" { + return nil, constants.ErrBadRequest + } + + var description *string + if req.Description != nil { + description = req.Description + } + + role := &types.Role{ + ID: util.GenerateUUID(), + Name: req.Name, + Description: description, + IsSystem: req.IsSystem, + } + + if err := s.rolesRepo.CreateRole(ctx, role); err != nil { + return nil, err + } + + return role, nil +} + +func (s *RolesService) GetAllRoles(ctx context.Context) ([]types.Role, error) { + return s.rolesRepo.GetAllRoles(ctx) +} + +func (s *RolesService) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { + if roleID == "" { + return nil, constants.ErrBadRequest + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return nil, err + } + if role == nil { + return nil, constants.ErrNotFound + } + + permissions, err := s.rolePermissionsRepo.GetRolePermissions(ctx, roleID) + if err != nil { + return nil, err + } + + return &types.RoleDetails{Role: *role, Permissions: permissions}, nil +} + +func (s *RolesService) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { + if roleID == "" { + return nil, constants.ErrBadRequest + } + + if req.Name == nil && req.Description == nil { + return nil, constants.ErrUnprocessableEntity + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return nil, err + } + if role == nil { + return nil, constants.ErrNotFound + } + if role.IsSystem { + return nil, constants.ErrCannotUpdateSystemRole + } + + var name *string + if req.Name != nil { + if *req.Name == "" { + return nil, constants.ErrBadRequest + } + name = req.Name + } + + var description *string + if req.Description != nil { + description = req.Description + } + + updated, err := s.rolesRepo.UpdateRole(ctx, roleID, name, description) + if err != nil { + return nil, err + } + if !updated { + return nil, constants.ErrNotFound + } + + role, err = s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return nil, err + } + if role == nil { + return nil, constants.ErrNotFound + } + + return role, nil +} + +func (s *RolesService) DeleteRole(ctx context.Context, roleID string) error { + if roleID == "" { + return constants.ErrBadRequest + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return err + } + if role == nil { + return constants.ErrNotFound + } + if role.IsSystem { + return constants.ErrCannotUpdateSystemRole + } + + assignmentsCount, err := s.userAccessRepo.CountUserAssignmentsByRoleID(ctx, roleID) + if err != nil { + return err + } + if assignmentsCount > 0 { + return constants.ErrConflict + } + + deleted, err := s.rolesRepo.DeleteRole(ctx, roleID) + if err != nil { + return err + } + if !deleted { + return constants.ErrNotFound + } + + return nil +} diff --git a/plugins/access-control/services/roles_service_test.go b/plugins/access-control/services/roles_service_test.go new file mode 100644 index 00000000..99cb2263 --- /dev/null +++ b/plugins/access-control/services/roles_service_test.go @@ -0,0 +1,84 @@ +package services + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestRolesServiceGetRoleByID(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + rolePermissionsRepo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() + + service := NewRolesService(rolesRepo, rolePermissionsRepo, userAccessRepo) + details, err := service.GetRoleByID(context.Background(), "role-1") + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if details == nil || details.Role.ID != "role-1" || len(details.Permissions) != 1 { + t.Fatalf("unexpected details %#v", details) + } + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) +} + +func TestRolesServiceDeleteRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserAccessRepository) + wantErr error + }{ + { + name: "role in use", + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + userAccessRepo.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(1, nil).Once() + }, + wantErr: accesscontrolconstants.ErrConflict, + }, + { + name: "success", + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + userAccessRepo.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(0, nil).Once() + rolesRepo.On("DeleteRole", mock.Anything, "role-1").Return(true, nil).Once() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setup != nil { + tc.setup(rolesRepo, userAccessRepo) + } + + service := NewRolesService(rolesRepo, rolePermissionsRepo, userAccessRepo) + err := service.DeleteRole(context.Background(), "role-1") + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + rolesRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) + } +} diff --git a/plugins/access-control/services/user_access_service.go b/plugins/access-control/services/user_access_service.go index 4db4d0c9..a1a5ee17 100644 --- a/plugins/access-control/services/user_access_service.go +++ b/plugins/access-control/services/user_access_service.go @@ -2,7 +2,6 @@ package services import ( "context" - "strings" "github.com/Authula/authula/plugins/access-control/constants" "github.com/Authula/authula/plugins/access-control/repositories" @@ -10,36 +9,28 @@ import ( ) type UserAccessService struct { + userRolesRepo repositories.UserRolesRepository userAccessRepo repositories.UserAccessRepository } -func NewUserAccessService(repo repositories.UserAccessRepository) *UserAccessService { - return &UserAccessService{userAccessRepo: repo} +func NewUserAccessService(userRolesRepo repositories.UserRolesRepository, userAccessRepo repositories.UserAccessRepository) *UserAccessService { + return &UserAccessService{userRolesRepo: userRolesRepo, userAccessRepo: userAccessRepo} } -func (s *UserAccessService) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - if strings.TrimSpace(userID) == "" { +func (s *UserAccessService) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { + if userID == "" { return nil, constants.ErrUnprocessableEntity } - return s.userAccessRepo.GetUserRoles(ctx, userID) -} -func (s *UserAccessService) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - if strings.TrimSpace(userID) == "" { - return nil, constants.ErrUnprocessableEntity - } - return s.userAccessRepo.GetUserWithRolesByID(ctx, userID) + return s.userAccessRepo.GetUserWithPermissionsByID(ctx, userID) } -func (s *UserAccessService) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - if strings.TrimSpace(userID) == "" { +func (s *UserAccessService) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { + if userID == "" { return nil, constants.ErrUnprocessableEntity } - return s.userAccessRepo.GetUserWithPermissionsByID(ctx, userID) -} -func (s *UserAccessService) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { - withRoles, err := s.GetUserWithRolesByID(ctx, userID) + withRoles, err := s.userRolesRepo.GetUserWithRolesByID(ctx, userID) if err != nil { return nil, err } @@ -64,13 +55,18 @@ func (s *UserAccessService) GetUserAuthorizationProfile(ctx context.Context, use } func (s *UserAccessService) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - if strings.TrimSpace(userID) == "" { + if userID == "" { return nil, constants.ErrUnprocessableEntity } + return s.userAccessRepo.GetUserEffectivePermissions(ctx, userID) } func (s *UserAccessService) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { + if userID == "" { + return false, constants.ErrUnprocessableEntity + } + permissions, err := s.GetUserEffectivePermissions(ctx, userID) if err != nil { return false, err @@ -82,7 +78,6 @@ func (s *UserAccessService) HasPermissions(ctx context.Context, userID string, r } for _, required := range requiredPermissions { - required = strings.TrimSpace(required) if required == "" { continue } diff --git a/plugins/access-control/services/user_access_service_test.go b/plugins/access-control/services/user_access_service_test.go index 00fa6dd8..a0fd6301 100644 --- a/plugins/access-control/services/user_access_service_test.go +++ b/plugins/access-control/services/user_access_service_test.go @@ -2,76 +2,80 @@ package services import ( "context" - "errors" "testing" - "github.com/Authula/authula/plugins/access-control/constants" + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" ) -type stubUserAccessRepo struct { - rolesResult []types.UserRoleInfo - rolesErr error - permissionsResult []types.UserPermissionInfo - permissionsErr error - withRolesResult *types.UserWithRoles - withRolesErr error - withPermsResult *types.UserWithPermissions - withPermsErr error -} - -func (s *stubUserAccessRepo) GetUserRoles(_ context.Context, _ string) ([]types.UserRoleInfo, error) { - return s.rolesResult, s.rolesErr -} - -func (s *stubUserAccessRepo) GetUserEffectivePermissions(_ context.Context, _ string) ([]types.UserPermissionInfo, error) { - return s.permissionsResult, s.permissionsErr -} - -func (s *stubUserAccessRepo) GetUserWithRolesByID(_ context.Context, _ string) (*types.UserWithRoles, error) { - return s.withRolesResult, s.withRolesErr -} - -func (s *stubUserAccessRepo) GetUserWithPermissionsByID(_ context.Context, _ string) (*types.UserWithPermissions, error) { - return s.withPermsResult, s.withPermsErr -} - -func TestUserAccessServiceGetUserRolesUnprocessableEntity(t *testing.T) { +func TestUserAccessServiceHasPermissions(t *testing.T) { t.Parallel() - svc := NewUserAccessService(&stubUserAccessRepo{}) - _, err := svc.GetUserRoles(context.Background(), " ") - if !errors.Is(err, constants.ErrUnprocessableEntity) { - t.Fatalf("expected ErrUnprocessableEntity, got %v", err) - } -} - -func TestUserAccessServiceHasPermissionsMatchesAnyRequiredPermission(t *testing.T) { - t.Parallel() + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() - svc := NewUserAccessService(&stubUserAccessRepo{ - permissionsResult: []types.UserPermissionInfo{{PermissionKey: "users.read"}, {PermissionKey: "users.write"}}, - }) - - ok, err := svc.HasPermissions(context.Background(), "user-1", []string{"billing.read", "users.write"}) + service := NewUserAccessService(userRolesRepo, userAccessRepo) + ok, err := service.HasPermissions(context.Background(), "user-1", []string{"users.write", "users.read"}) if err != nil { - t.Fatalf("unexpected error: %v", err) + t.Fatalf("expected nil err, got %v", err) } if !ok { - t.Fatal("expected permission check to pass when any required permission matches") + t.Fatal("expected permission check to pass") } + + userAccessRepo.AssertExpectations(t) } -func TestUserAccessServiceGetUserAuthorizationProfileNilUser(t *testing.T) { +func TestUserAccessServiceGetUserAuthorizationProfile(t *testing.T) { t.Parallel() - svc := NewUserAccessService(&stubUserAccessRepo{withRolesResult: nil}) - - profile, err := svc.GetUserAuthorizationProfile(context.Background(), "user-1") - if err != nil { - t.Fatalf("unexpected error: %v", err) + tests := []struct { + name string + userID string + setup func(*accesscontroltests.MockUserRolesRepository, *accesscontroltests.MockUserAccessRepository) + wantErr error + wantNil bool + }{ + { + name: "blank user id", + userID: "", + wantErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "composes profile", + userID: "user-1", + setup: func(userRolesRepo *accesscontroltests.MockUserRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + userRolesRepo.On("GetUserWithRolesByID", mock.Anything, "user-1").Return(&types.UserWithRoles{Roles: []types.UserRoleInfo{{RoleID: "role-1", RoleName: "admin"}}}, nil).Once() + userAccessRepo.On("GetUserWithPermissionsByID", mock.Anything, "user-1").Return(&types.UserWithPermissions{Permissions: []types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}}, nil).Once() + }, + }, } - if profile != nil { - t.Fatalf("expected nil profile, got %+v", profile) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + if tc.setup != nil { + tc.setup(userRolesRepo, userAccessRepo) + } + + service := NewUserAccessService(userRolesRepo, userAccessRepo) + profile, err := service.GetUserAuthorizationProfile(context.Background(), tc.userID) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + if tc.wantErr == nil && profile == nil { + t.Fatal("expected profile, got nil") + } + + userRolesRepo.AssertExpectations(t) + userAccessRepo.AssertExpectations(t) + }) } } diff --git a/plugins/access-control/services/user_roles_service.go b/plugins/access-control/services/user_roles_service.go new file mode 100644 index 00000000..4efbf7d7 --- /dev/null +++ b/plugins/access-control/services/user_roles_service.go @@ -0,0 +1,98 @@ +package services + +import ( + "context" + "time" + + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/repositories" + "github.com/Authula/authula/plugins/access-control/types" +) + +type UserRolesService struct { + userRolesRepo repositories.UserRolesRepository + rolesRepo repositories.RolesRepository +} + +func NewUserRolesService(userRolesRepo repositories.UserRolesRepository, rolesRepo repositories.RolesRepository) *UserRolesService { + return &UserRolesService{userRolesRepo: userRolesRepo, rolesRepo: rolesRepo} +} + +func (s *UserRolesService) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + if userID == "" { + return nil, constants.ErrUnprocessableEntity + } + + return s.userRolesRepo.GetUserRoles(ctx, userID) +} + +func (s *UserRolesService) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { + if userID == "" { + return nil, constants.ErrUnprocessableEntity + } + + return s.userRolesRepo.GetUserWithRolesByID(ctx, userID) +} + +func (s *UserRolesService) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { + if userID == "" { + return constants.ErrBadRequest + } + + normalized := make([]string, 0, len(roleIDs)) + seen := make(map[string]struct{}, len(roleIDs)) + for _, roleID := range roleIDs { + if roleID == "" { + continue + } + if _, ok := seen[roleID]; ok { + continue + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return err + } + if role == nil { + return constants.ErrNotFound + } + + seen[roleID] = struct{}{} + normalized = append(normalized, roleID) + } + + return s.userRolesRepo.ReplaceUserRoles(ctx, userID, normalized, assignedByUserID) +} + +func (s *UserRolesService) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { + if userID == "" { + return constants.ErrBadRequest + } + + roleID := req.RoleID + if roleID == "" { + return constants.ErrBadRequest + } + + if req.ExpiresAt != nil && req.ExpiresAt.Before(time.Now().UTC()) { + return constants.ErrBadRequest + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return err + } + if role == nil { + return constants.ErrNotFound + } + + return s.userRolesRepo.AssignUserRole(ctx, userID, roleID, assignedByUserID, req.ExpiresAt) +} + +func (s *UserRolesService) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { + if userID == "" || roleID == "" { + return constants.ErrBadRequest + } + + return s.userRolesRepo.RemoveUserRole(ctx, userID, roleID) +} diff --git a/plugins/access-control/services/user_roles_service_test.go b/plugins/access-control/services/user_roles_service_test.go new file mode 100644 index 00000000..b96ed2c7 --- /dev/null +++ b/plugins/access-control/services/user_roles_service_test.go @@ -0,0 +1,78 @@ +package services + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestUserRolesServiceAssignRoleToUser(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req types.AssignUserRoleRequest + setup func(*accesscontroltests.MockUserRolesRepository, *accesscontroltests.MockRolesRepository) + wantErr error + }{ + { + name: "expired assignment", + req: types.AssignUserRoleRequest{RoleID: "role-1", ExpiresAt: func() *time.Time { t := time.Now().UTC().Add(-time.Hour); return &t }()}, + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "success", + req: types.AssignUserRoleRequest{RoleID: "role-1"}, + setup: func(userRolesRepo *accesscontroltests.MockUserRolesRepository, rolesRepo *accesscontroltests.MockRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(nil).Once() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + rolesRepo := &accesscontroltests.MockRolesRepository{} + if tc.setup != nil { + tc.setup(userRolesRepo, rolesRepo) + } + + service := NewUserRolesService(userRolesRepo, rolesRepo) + err := service.AssignRoleToUser(context.Background(), "user-1", tc.req, nil) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + }) + } +} + +func TestUserRolesServiceReplaceUserRolesDedupes(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-2").Return(&types.Role{ID: "role-2", Name: "editor"}, nil).Once() + userRolesRepo.On("ReplaceUserRoles", mock.Anything, "user-1", []string{"role-1", "role-2"}, (*string)(nil)).Return(nil).Once() + + service := NewUserRolesService(userRolesRepo, rolesRepo) + err := service.ReplaceUserRoles(context.Background(), "user-1", []string{"role-1", "role-1", "role-2"}, nil) + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) +} diff --git a/plugins/access-control/tests/test_helpers.go b/plugins/access-control/tests/test_helpers.go index ce2bc664..5b7ad155 100644 --- a/plugins/access-control/tests/test_helpers.go +++ b/plugins/access-control/tests/test_helpers.go @@ -2,71 +2,32 @@ package tests import ( "context" + "database/sql" + "testing" "time" + _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/mock" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" - "github.com/Authula/authula/plugins/access-control/services" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/migrations" + "github.com/Authula/authula/models" + accesscontrolmigrations "github.com/Authula/authula/plugins/access-control/migrationset" "github.com/Authula/authula/plugins/access-control/types" - "github.com/Authula/authula/plugins/access-control/usecases" ) -type mockUserAccessRepository struct { +type MockRolesRepository struct { mock.Mock } -func (m *mockUserAccessRepository) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]types.UserRoleInfo), args.Error(1) -} - -func (m *mockUserAccessRepository) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]types.UserPermissionInfo), args.Error(1) -} - -func (m *mockUserAccessRepository) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(*types.UserWithRoles), args.Error(1) -} - -func (m *mockUserAccessRepository) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(*types.UserWithPermissions), args.Error(1) -} - -func NewUserRolesUseCaseFixture() (usecases.UserRolesUseCase, *mockUserAccessRepository) { - repo := &mockUserAccessRepository{} - service := services.NewUserAccessService(repo) - return usecases.NewUserRolesUseCase(service), repo -} - -type MockRolePermissionRepository struct { - mock.Mock -} - -func NewMockRolePermissionRepository() *MockRolePermissionRepository { - return &MockRolePermissionRepository{} -} - -func (m *MockRolePermissionRepository) CreateRole(ctx context.Context, role *types.Role) error { +func (m *MockRolesRepository) CreateRole(ctx context.Context, role *types.Role) error { args := m.Called(ctx, role) return args.Error(0) } -func (m *MockRolePermissionRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { +func (m *MockRolesRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) @@ -74,7 +35,7 @@ func (m *MockRolePermissionRepository) GetAllRoles(ctx context.Context) ([]types return args.Get(0).([]types.Role), args.Error(1) } -func (m *MockRolePermissionRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { +func (m *MockRolesRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { args := m.Called(ctx, roleID) if args.Get(0) == nil { return nil, args.Error(1) @@ -82,22 +43,21 @@ func (m *MockRolePermissionRepository) GetRoleByID(ctx context.Context, roleID s return args.Get(0).(*types.Role), args.Error(1) } -func (m *MockRolePermissionRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { +func (m *MockRolesRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { args := m.Called(ctx, roleID, name, description) return args.Bool(0), args.Error(1) } -func (m *MockRolePermissionRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { +func (m *MockRolesRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { args := m.Called(ctx, roleID) return args.Bool(0), args.Error(1) } -func (m *MockRolePermissionRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { - args := m.Called(ctx, roleID) - return args.Int(0), args.Error(1) +type MockPermissionsRepository struct { + mock.Mock } -func (m *MockRolePermissionRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { +func (m *MockPermissionsRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) @@ -105,7 +65,7 @@ func (m *MockRolePermissionRepository) GetAllPermissions(ctx context.Context) ([ return args.Get(0).([]types.Permission), args.Error(1) } -func (m *MockRolePermissionRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { +func (m *MockPermissionsRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { args := m.Called(ctx, permissionID) if args.Get(0) == nil { return nil, args.Error(1) @@ -113,27 +73,26 @@ func (m *MockRolePermissionRepository) GetPermissionByID(ctx context.Context, pe return args.Get(0).(*types.Permission), args.Error(1) } -func (m *MockRolePermissionRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { +func (m *MockPermissionsRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { args := m.Called(ctx, permission) return args.Error(0) } -func (m *MockRolePermissionRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { +func (m *MockPermissionsRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { args := m.Called(ctx, permissionID, description) return args.Bool(0), args.Error(1) } -func (m *MockRolePermissionRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { +func (m *MockPermissionsRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { args := m.Called(ctx, permissionID) return args.Bool(0), args.Error(1) } -func (m *MockRolePermissionRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { - args := m.Called(ctx, permissionID) - return args.Int(0), args.Error(1) +type MockRolePermissionsRepository struct { + mock.Mock } -func (m *MockRolePermissionRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { +func (m *MockRolePermissionsRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { args := m.Called(ctx, roleID) if args.Get(0) == nil { return nil, args.Error(1) @@ -141,38 +100,127 @@ func (m *MockRolePermissionRepository) GetRolePermissions(ctx context.Context, r return args.Get(0).([]types.UserPermissionInfo), args.Error(1) } -func (m *MockRolePermissionRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { +func (m *MockRolePermissionsRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { args := m.Called(ctx, roleID, permissionIDs, grantedByUserID) return args.Error(0) } -func (m *MockRolePermissionRepository) AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { +func (m *MockRolePermissionsRepository) AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { args := m.Called(ctx, roleID, permissionID, grantedByUserID) return args.Error(0) } -func (m *MockRolePermissionRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { +func (m *MockRolePermissionsRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { args := m.Called(ctx, roleID, permissionID) return args.Error(0) } -func (m *MockRolePermissionRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { +type MockUserRolesRepository struct { + mock.Mock +} + +func (m *MockUserRolesRepository) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.UserRoleInfo), args.Error(1) +} + +func (m *MockUserRolesRepository) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.UserWithRoles), args.Error(1) +} + +func (m *MockUserRolesRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { args := m.Called(ctx, userID, roleIDs, assignedByUserID) return args.Error(0) } -func (m *MockRolePermissionRepository) AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error { +func (m *MockUserRolesRepository) AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error { args := m.Called(ctx, userID, roleID, assignedByUserID, expiresAt) return args.Error(0) } -func (m *MockRolePermissionRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { +func (m *MockUserRolesRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { args := m.Called(ctx, userID, roleID) return args.Error(0) } -func NewRolePermissionUseCaseFixture() (usecases.RolePermissionUseCase, *MockRolePermissionRepository) { - repo := NewMockRolePermissionRepository() - service := services.NewRolePermissionService(repo) - return usecases.NewRolePermissionUseCase(service), repo +type MockUserAccessRepository struct { + mock.Mock +} + +func (m *MockUserAccessRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { + args := m.Called(ctx, roleID) + return args.Int(0), args.Error(1) +} + +func (m *MockUserAccessRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { + args := m.Called(ctx, permissionID) + return args.Int(0), args.Error(1) +} + +func (m *MockUserAccessRepository) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.UserPermissionInfo), args.Error(1) +} + +func (m *MockUserAccessRepository) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.UserWithPermissions), args.Error(1) +} + +func SetupRepoDB(t *testing.T) *bun.DB { + t.Helper() + + sqldb, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + + db := bun.NewDB(sqldb, sqlitedialect.New()) + t.Cleanup(func() { + _ = db.Close() + }) + + ctx := context.Background() + + migrator, err := migrations.NewMigrator(db, &internaltests.MockLogger{}) + if err != nil { + t.Fatalf("failed to create migrator: %v", err) + } + + coreSet, err := migrations.CoreMigrationSet("sqlite") + if err != nil { + t.Fatalf("failed to build core migration set: %v", err) + } + + accessControlSet := migrations.MigrationSet{ + PluginID: models.PluginAccessControl.String(), + DependsOn: []string{migrations.CorePluginID}, + Migrations: accesscontrolmigrations.ForProvider("sqlite"), + } + + if err := migrator.Migrate(ctx, []migrations.MigrationSet{coreSet, accessControlSet}); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + + if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name, email, email_verified, metadata) VALUES ('u1', 'User One', 'u1@example.com', 1, '{}')`); err != nil { + t.Fatalf("failed to seed user u1: %v", err) + } + if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name, email, email_verified, metadata) VALUES ('u2', 'User Two', 'u2@example.com', 1, '{}')`); err != nil { + t.Fatalf("failed to seed user u2: %v", err) + } + + return db } diff --git a/plugins/access-control/usecases/permissions_usecase.go b/plugins/access-control/usecases/permissions_usecase.go new file mode 100644 index 00000000..3d1ac21d --- /dev/null +++ b/plugins/access-control/usecases/permissions_usecase.go @@ -0,0 +1,36 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/services" + "github.com/Authula/authula/plugins/access-control/types" +) + +type PermissionsUseCase struct { + service *services.PermissionsService +} + +func NewPermissionsUseCase(service *services.PermissionsService) *PermissionsUseCase { + return &PermissionsUseCase{service: service} +} + +func (u *PermissionsUseCase) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { + return u.service.CreatePermission(ctx, req) +} + +func (u *PermissionsUseCase) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { + return u.service.GetAllPermissions(ctx) +} + +func (u *PermissionsUseCase) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + return u.service.GetPermissionByID(ctx, permissionID) +} + +func (u *PermissionsUseCase) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { + return u.service.UpdatePermission(ctx, permissionID, req) +} + +func (u *PermissionsUseCase) DeletePermission(ctx context.Context, permissionID string) error { + return u.service.DeletePermission(ctx, permissionID) +} diff --git a/plugins/access-control/usecases/role_permission_mock_test.go b/plugins/access-control/usecases/role_permission_mock_test.go deleted file mode 100644 index a99c8067..00000000 --- a/plugins/access-control/usecases/role_permission_mock_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package usecases - -import ( - "context" - "time" - - "github.com/stretchr/testify/mock" - - "github.com/Authula/authula/plugins/access-control/types" -) - -type MockRolePermissionRepository struct { - mock.Mock -} - -type MockRolePermissionService = MockRolePermissionRepository - -func (m *MockRolePermissionRepository) CreateRole(ctx context.Context, role *types.Role) error { - args := m.Called(ctx, role) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { - args := m.Called(ctx) - if roles := args.Get(0); roles != nil { - return roles.([]types.Role), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { - args := m.Called(ctx, roleID) - if role := args.Get(0); role != nil { - return role.(*types.Role), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { - args := m.Called(ctx, roleID, name, description) - return args.Bool(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { - args := m.Called(ctx, roleID) - return args.Bool(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { - args := m.Called(ctx, roleID) - return args.Int(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { - args := m.Called(ctx, roleID) - if perms := args.Get(0); perms != nil { - return perms.([]types.UserPermissionInfo), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { - args := m.Called(ctx, permission) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - args := m.Called(ctx) - if perms := args.Get(0); perms != nil { - return perms.([]types.Permission), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { - args := m.Called(ctx, permissionID) - if perm := args.Get(0); perm != nil { - return perm.(*types.Permission), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { - args := m.Called(ctx, permissionID, description) - return args.Bool(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { - args := m.Called(ctx, permissionID) - return args.Bool(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { - args := m.Called(ctx, permissionID) - return args.Int(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) AddRolePermission(ctx context.Context, roleID, permissionID string, grantedByUserID *string) error { - args := m.Called(ctx, roleID, permissionID, grantedByUserID) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) RemoveRolePermission(ctx context.Context, roleID, permissionID string) error { - args := m.Called(ctx, roleID, permissionID) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { - args := m.Called(ctx, roleID, permissionIDs, grantedByUserID) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) AssignUserRole(ctx context.Context, userID, roleID string, assignedByUserID *string, expiresAt *time.Time) error { - args := m.Called(ctx, userID, roleID, assignedByUserID, expiresAt) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) RemoveUserRole(ctx context.Context, userID, roleID string) error { - args := m.Called(ctx, userID, roleID) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - args := m.Called(ctx, userID, roleIDs, assignedByUserID) - return args.Error(0) -} diff --git a/plugins/access-control/usecases/role_permission_usecase.go b/plugins/access-control/usecases/role_permission_usecase.go index 8820fe22..d538d118 100644 --- a/plugins/access-control/usecases/role_permission_usecase.go +++ b/plugins/access-control/usecases/role_permission_usecase.go @@ -7,74 +7,26 @@ import ( "github.com/Authula/authula/plugins/access-control/types" ) -type RolePermissionUseCase struct { - service *services.RolePermissionService +type RolePermissionsUseCase struct { + service *services.RolePermissionsService } -func NewRolePermissionUseCase(service *services.RolePermissionService) RolePermissionUseCase { - return RolePermissionUseCase{service: service} +func NewRolePermissionsUseCase(service *services.RolePermissionsService) *RolePermissionsUseCase { + return &RolePermissionsUseCase{service: service} } -func (u RolePermissionUseCase) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { - return u.service.CreateRole(ctx, req) -} - -func (u RolePermissionUseCase) GetAllRoles(ctx context.Context) ([]types.Role, error) { - return u.service.GetAllRoles(ctx) -} - -func (u RolePermissionUseCase) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { - return u.service.GetRoleByID(ctx, roleID) -} - -func (u RolePermissionUseCase) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { - return u.service.UpdateRole(ctx, roleID, req) -} - -func (u RolePermissionUseCase) DeleteRole(ctx context.Context, roleID string) error { - return u.service.DeleteRole(ctx, roleID) -} - -func (u RolePermissionUseCase) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { - return u.service.CreatePermission(ctx, req) -} - -func (u RolePermissionUseCase) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - return u.service.GetAllPermissions(ctx) -} - -func (u RolePermissionUseCase) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { +func (u *RolePermissionsUseCase) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { return u.service.GetRolePermissions(ctx, roleID) } -func (u RolePermissionUseCase) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { - return u.service.UpdatePermission(ctx, permissionID, req) -} - -func (u RolePermissionUseCase) DeletePermission(ctx context.Context, permissionID string) error { - return u.service.DeletePermission(ctx, permissionID) -} - -func (u RolePermissionUseCase) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { +func (u *RolePermissionsUseCase) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { return u.service.AddPermissionToRole(ctx, roleID, permissionID, grantedByUserID) } -func (u RolePermissionUseCase) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { +func (u *RolePermissionsUseCase) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { return u.service.RemovePermissionFromRole(ctx, roleID, permissionID) } -func (u RolePermissionUseCase) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { +func (u *RolePermissionsUseCase) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { return u.service.ReplaceRolePermissions(ctx, roleID, permissionIDs, grantedByUserID) } - -func (u RolePermissionUseCase) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - return u.service.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) -} - -func (u RolePermissionUseCase) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { - return u.service.AssignRoleToUser(ctx, userID, req, assignedByUserID) -} - -func (u RolePermissionUseCase) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { - return u.service.RemoveRoleFromUser(ctx, userID, roleID) -} diff --git a/plugins/access-control/usecases/role_permission_usecase_test.go b/plugins/access-control/usecases/role_permission_usecase_test.go deleted file mode 100644 index 193c47da..00000000 --- a/plugins/access-control/usecases/role_permission_usecase_test.go +++ /dev/null @@ -1,1320 +0,0 @@ -package usecases - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/Authula/authula/plugins/access-control/services" - "github.com/Authula/authula/plugins/access-control/types" -) - -func newRolePermissionUseCase(mockRepo *MockRolePermissionService) RolePermissionUseCase { - return NewRolePermissionUseCase(services.NewRolePermissionService(mockRepo)) -} - -func TestCreateRole(t *testing.T) { - testCases := []struct { - name string - req types.CreateRoleRequest - mockErr error - expectedError string - }{ - { - name: "success with all fields", - req: types.CreateRoleRequest{ - Name: "Admin", - Description: new("Administrator role"), - IsSystem: true, - }, - expectedError: "", - }, - { - name: "success with minimal fields", - req: types.CreateRoleRequest{ - Name: "User", - }, - expectedError: "", - }, - { - name: "empty name", - req: types.CreateRoleRequest{Name: ""}, - expectedError: "bad request", - }, - { - name: "whitespace only name", - req: types.CreateRoleRequest{Name: " "}, - expectedError: "bad request", - }, - { - name: "service error", - req: types.CreateRoleRequest{Name: "Test"}, - mockErr: errors.New("database error"), - expectedError: "database error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("CreateRole", mock.Anything, mock.Anything).Return(tt.mockErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - role, err := uc.CreateRole(context.Background(), tt.req) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, role) - assert.Equal(t, tt.req.Name, role.Name) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, role) - } - }) - } -} - -func TestGetAllRoles(t *testing.T) { - testCases := []struct { - name string - mockResult []types.Role - mockErr error - expectedError string - expectedLength int - }{ - { - name: "success with roles", - mockResult: []types.Role{ - {ID: "1", Name: "Admin"}, - {ID: "2", Name: "User"}, - }, - expectedLength: 2, - }, - { - name: "empty list", - mockResult: []types.Role{}, - expectedLength: 0, - }, - { - name: "service error", - mockErr: errors.New("database error"), - expectedError: "database error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetAllRoles", mock.Anything).Return(tt.mockResult, tt.mockErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - roles, err := uc.GetAllRoles(context.Background()) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.Equal(t, tt.expectedLength, len(roles)) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestGetRoleByID(t *testing.T) { - testCases := []struct { - name string - roleID string - mockRoleResult *types.Role - mockRoleErr error - mockPermErr error - mockPermissions []types.UserPermissionInfo - expectedError string - expectedHasRole bool - }{ - { - name: "success with permissions", - roleID: "role-1", - mockRoleResult: &types.Role{ - ID: "role-1", - Name: "Admin", - }, - mockPermissions: []types.UserPermissionInfo{ - {PermissionID: "perm-1", PermissionKey: "admin:read"}, - }, - expectedHasRole: true, - }, - { - name: "empty roleID", - roleID: "", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - expectedError: "bad request", - }, - { - name: "not found", - roleID: "nonexistent", - mockRoleResult: nil, - expectedError: "not found", - }, - { - name: "service error on role fetch", - roleID: "role-1", - mockRoleErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "service error on permissions fetch", - roleID: "role-1", - mockRoleResult: &types.Role{ - ID: "role-1", - Name: "Admin", - }, - mockPermErr: errors.New("perm db error"), - expectedError: "perm db error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRoleResult, tt.mockRoleErr).Maybe() - mockService.On("GetRolePermissions", mock.Anything, mock.Anything).Return(tt.mockPermissions, tt.mockPermErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - result, err := uc.GetRoleByID(context.Background(), tt.roleID) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, len(tt.mockPermissions), len(result.Permissions)) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, result) - } - }) - } -} - -func TestUpdateRole(t *testing.T) { - testCases := []struct { - name string - roleID string - req types.UpdateRoleRequest - mockRole *types.Role - mockGetErr error - mockUpdateErr error - expectedError string - shouldCallUpdate bool - }{ - { - name: "update name only", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new("NewAdmin"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - shouldCallUpdate: true, - }, - { - name: "update description only", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Description: new("New description"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - shouldCallUpdate: true, - }, - { - name: "update both", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new("NewAdmin"), - Description: new("New description"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - shouldCallUpdate: true, - }, - { - name: "empty roleID", - roleID: "", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - req: types.UpdateRoleRequest{ - Name: new("Admin"), - }, - expectedError: "bad request", - }, - { - name: "no fields provided", - roleID: "role-1", - req: types.UpdateRoleRequest{}, - expectedError: "unprocessable entity", - }, - { - name: "empty name", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new(" "), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - expectedError: "bad request", - }, - { - name: "not found", - roleID: "role-1", - req: types.UpdateRoleRequest{Name: new("Admin")}, - mockRole: nil, - expectedError: "not found", - }, - { - name: "system role protection", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new("NewAdmin"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: true}, - expectedError: "cannot update system role", - }, - { - name: "service error on get", - roleID: "role-1", - req: types.UpdateRoleRequest{Name: new("Admin")}, - mockGetErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "service error on update", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new("NewAdmin"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - mockUpdateErr: errors.New("update failed"), - expectedError: "update failed", - shouldCallUpdate: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockGetErr).Maybe() - mockService.On("UpdateRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(true, tt.mockUpdateErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - role, err := uc.UpdateRole(context.Background(), tt.roleID, tt.req) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, role) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, role) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestDeleteRole(t *testing.T) { - testCases := []struct { - name string - roleID string - mockRole *types.Role - mockGetErr error - mockCountResult int - mockCountErr error - mockDeleteErr error - expectedError string - shouldCallDelete bool - }{ - { - name: "success", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockCountResult: 0, - shouldCallDelete: true, - }, - { - name: "empty roleID", - roleID: "", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - expectedError: "bad request", - }, - { - name: "not found", - roleID: "role-1", - mockRole: nil, - expectedError: "not found", - }, - { - name: "system role protection", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: true}, - expectedError: "cannot update system role", - }, - { - name: "service error on get", - roleID: "role-1", - mockGetErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "has user assignments", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockCountResult: 5, - expectedError: "conflict", - }, - { - name: "service error on count", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockCountErr: errors.New("count error"), - expectedError: "count error", - }, - { - name: "service error on delete", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockCountResult: 0, - mockDeleteErr: errors.New("delete failed"), - expectedError: "delete failed", - shouldCallDelete: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockGetErr).Maybe() - mockService.On("CountUserAssignmentsByRoleID", mock.Anything, mock.Anything).Return(tt.mockCountResult, tt.mockCountErr).Maybe() - mockService.On("DeleteRole", mock.Anything, mock.Anything).Return(true, tt.mockDeleteErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.DeleteRole(context.Background(), tt.roleID) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestCreatePermission(t *testing.T) { - testCases := []struct { - name string - req types.CreatePermissionRequest - mockErr error - expectedError string - }{ - { - name: "success with all fields", - req: types.CreatePermissionRequest{ - Key: "admin:write", - Description: new("Admin write access"), - IsSystem: true, - }, - expectedError: "", - }, - { - name: "success with minimal fields", - req: types.CreatePermissionRequest{ - Key: "user:read", - }, - expectedError: "", - }, - { - name: "empty key", - req: types.CreatePermissionRequest{Key: ""}, - expectedError: "bad request", - }, - { - name: "whitespace only key", - req: types.CreatePermissionRequest{Key: " "}, - expectedError: "bad request", - }, - { - name: "service error", - req: types.CreatePermissionRequest{Key: "admin:read"}, - mockErr: errors.New("database error"), - expectedError: "database error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("CreatePermission", mock.Anything, mock.Anything).Return(tt.mockErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - perm, err := uc.CreatePermission(context.Background(), tt.req) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, perm) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, perm) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestGetAllPermissions(t *testing.T) { - testCases := []struct { - name string - mockResult []types.Permission - mockErr error - expectedError string - expectedLength int - }{ - { - name: "success with permissions", - mockResult: []types.Permission{ - {ID: "1", Key: "admin:read"}, - {ID: "2", Key: "user:read"}, - }, - expectedLength: 2, - }, - { - name: "empty list", - mockResult: []types.Permission{}, - expectedLength: 0, - }, - { - name: "service error", - mockErr: errors.New("database error"), - expectedError: "database error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetAllPermissions", mock.Anything).Return(tt.mockResult, tt.mockErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - perms, err := uc.GetAllPermissions(context.Background()) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.Equal(t, tt.expectedLength, len(perms)) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestGetRolePermissions(t *testing.T) { - testCases := []struct { - name string - roleID string - mockRole *types.Role - mockRoleErr error - mockPermissions []types.UserPermissionInfo - mockPermErr error - expectedError string - expectedLength int - }{ - { - name: "success with permissions", - roleID: "role-1", - mockRole: &types.Role{ - ID: "role-1", - Name: "Admin", - }, - mockPermissions: []types.UserPermissionInfo{ - {PermissionID: "perm-1", PermissionKey: "admin:read"}, - {PermissionID: "perm-2", PermissionKey: "admin:write"}, - }, - expectedLength: 2, - }, - { - name: "empty roleID", - roleID: "", - expectedError: "unprocessable entity", - }, - { - name: "whitespace roleID", - roleID: " ", - expectedError: "unprocessable entity", - }, - { - name: "role not found", - roleID: "missing-role", - mockRole: nil, - expectedError: "not found", - }, - { - name: "role lookup error", - roleID: "role-1", - mockRoleErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "permissions lookup error", - roleID: "role-1", - mockRole: &types.Role{ - ID: "role-1", - Name: "Admin", - }, - mockPermErr: errors.New("perm db error"), - expectedError: "perm db error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockRoleErr).Maybe() - mockService.On("GetRolePermissions", mock.Anything, mock.Anything).Return(tt.mockPermissions, tt.mockPermErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - permissions, err := uc.GetRolePermissions(context.Background(), tt.roleID) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.Equal(t, tt.expectedLength, len(permissions)) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, permissions) - } - }) - } -} - -func TestUpdatePermission(t *testing.T) { - testCases := []struct { - name string - permissionID string - req types.UpdatePermissionRequest - mockPerm *types.Permission - mockGetErr error - mockUpdateErr error - expectedError string - shouldCallUpdate bool - }{ - { - name: "success", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: new("New description"), - }, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: false}, - shouldCallUpdate: true, - }, - { - name: "empty permissionID", - permissionID: "", - expectedError: "unprocessable entity", - }, - { - name: "whitespace permissionID", - permissionID: " ", - expectedError: "unprocessable entity", - }, - { - name: "nil description", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: nil, - }, - expectedError: "unprocessable entity", - }, - { - name: "empty description", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: new(" "), - }, - expectedError: "unprocessable entity", - }, - { - name: "not found", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{Description: new("New")}, - mockPerm: nil, - expectedError: "not found", - }, - { - name: "system permission protection", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: new("New"), - }, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: true}, - expectedError: "bad request", - }, - { - name: "service error on get", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{Description: new("New")}, - mockGetErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "service error on update", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: new("New"), - }, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: false}, - mockUpdateErr: errors.New("update failed"), - expectedError: "update failed", - shouldCallUpdate: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetPermissionByID", mock.Anything, mock.Anything).Return(tt.mockPerm, tt.mockGetErr).Maybe() - mockService.On("UpdatePermission", mock.Anything, mock.Anything, mock.Anything).Return(true, tt.mockUpdateErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - perm, err := uc.UpdatePermission(context.Background(), tt.permissionID, tt.req) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, perm) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, perm) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestDeletePermission(t *testing.T) { - testCases := []struct { - name string - permissionID string - mockPerm *types.Permission - mockGetErr error - mockCountResult int - mockCountErr error - mockDeleteErr error - expectedError string - shouldCallDelete bool - }{ - { - name: "success", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockCountResult: 0, - shouldCallDelete: true, - }, - { - name: "empty permissionID", - permissionID: "", - expectedError: "bad request", - }, - { - name: "whitespace permissionID", - permissionID: " ", - expectedError: "bad request", - }, - { - name: "not found", - permissionID: "perm-1", - mockPerm: nil, - expectedError: "not found", - }, - { - name: "system permission protection", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: true}, - expectedError: "bad request", - }, - { - name: "service error on get", - permissionID: "perm-1", - mockGetErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "has role assignments", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockCountResult: 3, - expectedError: "conflict", - }, - { - name: "service error on count", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockCountErr: errors.New("count error"), - expectedError: "count error", - }, - { - name: "service error on delete", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockCountResult: 0, - mockDeleteErr: errors.New("delete failed"), - expectedError: "delete failed", - shouldCallDelete: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetPermissionByID", mock.Anything, mock.Anything).Return(tt.mockPerm, tt.mockGetErr).Maybe() - mockService.On("CountRoleAssignmentsByPermissionID", mock.Anything, mock.Anything).Return(tt.mockCountResult, tt.mockCountErr).Maybe() - mockService.On("DeletePermission", mock.Anything, mock.Anything).Return(true, tt.mockDeleteErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.DeletePermission(context.Background(), tt.permissionID) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestAddPermissionToRole(t *testing.T) { - testCases := []struct { - name string - roleID string - permissionID string - mockRole *types.Role - mockRoleErr error - mockPerm *types.Permission - mockPermErr error - mockAddErr error - expectedError string - shouldCallAddRole bool - }{ - { - name: "success", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - shouldCallAddRole: true, - }, - { - name: "empty roleID", - roleID: "", - permissionID: "perm-1", - expectedError: "bad request", - }, - { - name: "empty permissionID", - roleID: "role-1", - permissionID: "", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - permissionID: "perm-1", - expectedError: "bad request", - }, - { - name: "whitespace permissionID", - roleID: "role-1", - permissionID: " ", - expectedError: "bad request", - }, - { - name: "not found", - roleID: "role-1", - permissionID: "perm-1", - mockRole: nil, - expectedError: "not found", - }, - { - name: "not found", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: nil, - expectedError: "not found", - }, - { - name: "system role protection", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: true}, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: false}, - expectedError: "bad request", - }, - { - name: "system permission protection", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: true}, - expectedError: "bad request", - }, - { - name: "service error on add", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockAddErr: errors.New("add failed"), - expectedError: "add failed", - shouldCallAddRole: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockRoleErr).Maybe() - mockService.On("GetPermissionByID", mock.Anything, mock.Anything).Return(tt.mockPerm, tt.mockPermErr).Maybe() - mockService.On("AddRolePermission", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.mockAddErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.AddPermissionToRole(context.Background(), tt.roleID, tt.permissionID, nil) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestRemovePermissionFromRole(t *testing.T) { - testCases := []struct { - name string - roleID string - permissionID string - mockRole *types.Role - mockRoleErr error - mockPerm *types.Permission - mockPermErr error - mockRemoveErr error - expectedError string - shouldCallRemoveRole bool - }{ - { - name: "success", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - shouldCallRemoveRole: true, - }, - { - name: "empty roleID", - roleID: "", - permissionID: "perm-1", - expectedError: "unprocessable entity", - }, - { - name: "empty permissionID", - roleID: "role-1", - permissionID: "", - expectedError: "unprocessable entity", - }, - { - name: "system role protection", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: true}, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: false}, - expectedError: "bad request", - }, - { - name: "system permission protection", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: true}, - expectedError: "bad request", - }, - { - name: "service error on remove", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockRemoveErr: errors.New("remove failed"), - expectedError: "remove failed", - shouldCallRemoveRole: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockRoleErr).Maybe() - mockService.On("GetPermissionByID", mock.Anything, mock.Anything).Return(tt.mockPerm, tt.mockPermErr).Maybe() - mockService.On("RemoveRolePermission", mock.Anything, mock.Anything, mock.Anything).Return(tt.mockRemoveErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.RemovePermissionFromRole(context.Background(), tt.roleID, tt.permissionID) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestReplaceRolePermissions(t *testing.T) { - testCases := []struct { - name string - roleID string - permissionIDs []string - mockReplaceErr error - expectedError string - shouldCallReplace bool - }{ - { - name: "success with permissions", - roleID: "role-1", - permissionIDs: []string{"perm-1", "perm-2"}, - shouldCallReplace: true, - }, - { - name: "success with empty list", - roleID: "role-1", - permissionIDs: []string{}, - shouldCallReplace: true, - }, - { - name: "empty roleID", - roleID: "", - permissionIDs: []string{"perm-1"}, - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - permissionIDs: []string{"perm-1"}, - expectedError: "bad request", - }, - { - name: "deduplication", - roleID: "role-1", - permissionIDs: []string{"perm-1", "perm-1", "perm-2"}, - shouldCallReplace: true, - }, - { - name: "filters empty strings", - roleID: "role-1", - permissionIDs: []string{"perm-1", "", " ", "perm-2"}, - shouldCallReplace: true, - }, - { - name: "service error", - roleID: "role-1", - permissionIDs: []string{"perm-1"}, - mockReplaceErr: errors.New("db error"), - expectedError: "db error", - shouldCallReplace: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("ReplaceRolePermissions", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.mockReplaceErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.ReplaceRolePermissions(context.Background(), tt.roleID, tt.permissionIDs, nil) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestReplaceUserRoles(t *testing.T) { - testCases := []struct { - name string - userID string - roleIDs []string - mockReplaceErr error - expectedError string - shouldCallReplace bool - }{ - { - name: "success with roles", - userID: "user-1", - roleIDs: []string{"role-1", "role-2"}, - shouldCallReplace: true, - }, - { - name: "success with empty list", - userID: "user-1", - roleIDs: []string{}, - shouldCallReplace: true, - }, - { - name: "empty userID", - userID: "", - roleIDs: []string{"role-1"}, - expectedError: "bad request", - }, - { - name: "whitespace userID", - userID: " ", - roleIDs: []string{"role-1"}, - expectedError: "bad request", - }, - { - name: "deduplication", - userID: "user-1", - roleIDs: []string{"role-1", "role-1", "role-2"}, - shouldCallReplace: true, - }, - { - name: "filters empty strings", - userID: "user-1", - roleIDs: []string{"role-1", "", " ", "role-2"}, - shouldCallReplace: true, - }, - { - name: "service error", - userID: "user-1", - roleIDs: []string{"role-1"}, - mockReplaceErr: errors.New("db error"), - expectedError: "db error", - shouldCallReplace: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("ReplaceUserRoles", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.mockReplaceErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.ReplaceUserRoles(context.Background(), tt.userID, tt.roleIDs, nil) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestAssignRoleToUser(t *testing.T) { - futureTime := time.Now().UTC().Add(24 * time.Hour) - pastTime := time.Now().UTC().Add(-24 * time.Hour) - - testCases := []struct { - name string - userID string - req types.AssignUserRoleRequest - mockAssignErr error - expectedError string - shouldCallAssign bool - }{ - { - name: "success without expiry", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "role-1", - }, - shouldCallAssign: true, - }, - { - name: "success with future expiry", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "role-1", - ExpiresAt: &futureTime, - }, - shouldCallAssign: true, - }, - { - name: "empty userID", - userID: "", - req: types.AssignUserRoleRequest{RoleID: "role-1"}, - expectedError: "unprocessable entity", - }, - { - name: "whitespace userID", - userID: " ", - req: types.AssignUserRoleRequest{RoleID: "role-1"}, - expectedError: "unprocessable entity", - }, - { - name: "empty roleID", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "", - }, - expectedError: "unprocessable entity", - }, - { - name: "past expiry date", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "role-1", - ExpiresAt: &pastTime, - }, - expectedError: "bad request", - }, - { - name: "service error on assign", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "role-1", - }, - mockAssignErr: errors.New("assign failed"), - expectedError: "assign failed", - shouldCallAssign: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("AssignUserRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.mockAssignErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.AssignRoleToUser(context.Background(), tt.userID, tt.req, nil) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestRemoveRoleFromUser(t *testing.T) { - testCases := []struct { - name string - userID string - roleID string - mockRemoveErr error - expectedError string - shouldCallRemove bool - }{ - { - name: "success", - userID: "user-1", - roleID: "role-1", - shouldCallRemove: true, - }, - { - name: "empty userID", - userID: "", - roleID: "role-1", - expectedError: "bad request", - }, - { - name: "empty roleID", - userID: "user-1", - roleID: "", - expectedError: "bad request", - }, - { - name: "whitespace userID", - userID: " ", - roleID: "role-1", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - userID: "user-1", - roleID: " ", - expectedError: "bad request", - }, - { - name: "service error", - userID: "user-1", - roleID: "role-1", - mockRemoveErr: errors.New("remove failed"), - expectedError: "remove failed", - shouldCallRemove: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("RemoveUserRole", mock.Anything, mock.Anything, mock.Anything).Return(tt.mockRemoveErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.RemoveRoleFromUser(context.Background(), tt.userID, tt.roleID) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} diff --git a/plugins/access-control/usecases/roles_usecase.go b/plugins/access-control/usecases/roles_usecase.go new file mode 100644 index 00000000..d0c9b92a --- /dev/null +++ b/plugins/access-control/usecases/roles_usecase.go @@ -0,0 +1,36 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/services" + "github.com/Authula/authula/plugins/access-control/types" +) + +type RolesUseCase struct { + service *services.RolesService +} + +func NewRolesUseCase(service *services.RolesService) *RolesUseCase { + return &RolesUseCase{service: service} +} + +func (u *RolesUseCase) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { + return u.service.CreateRole(ctx, req) +} + +func (u *RolesUseCase) GetAllRoles(ctx context.Context) ([]types.Role, error) { + return u.service.GetAllRoles(ctx) +} + +func (u *RolesUseCase) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { + return u.service.GetRoleByID(ctx, roleID) +} + +func (u *RolesUseCase) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { + return u.service.UpdateRole(ctx, roleID, req) +} + +func (u *RolesUseCase) DeleteRole(ctx context.Context, roleID string) error { + return u.service.DeleteRole(ctx, roleID) +} diff --git a/plugins/access-control/usecases/usecases.go b/plugins/access-control/usecases/usecases.go index 44fef050..e50792d7 100644 --- a/plugins/access-control/usecases/usecases.go +++ b/plugins/access-control/usecases/usecases.go @@ -7,93 +7,135 @@ import ( ) type UseCases struct { - rolePermission RolePermissionUseCase - userAccess UserRolesUseCase -} - -func NewAccessControlUseCases(rolePermission RolePermissionUseCase, userAccess UserRolesUseCase) *UseCases { + roles *RolesUseCase + permissions *PermissionsUseCase + rolePermissions *RolePermissionsUseCase + userRoles *UserRolesUseCase + userAccess *UserAccessUseCase +} + +func NewAccessControlUseCases( + roles *RolesUseCase, + permissions *PermissionsUseCase, + rolePermissions *RolePermissionsUseCase, + userRoles *UserRolesUseCase, + userAccess *UserAccessUseCase, +) *UseCases { return &UseCases{ - rolePermission: rolePermission, - userAccess: userAccess, + roles: roles, + permissions: permissions, + rolePermissions: rolePermissions, + userRoles: userRoles, + userAccess: userAccess, } } -func (u *UseCases) RolePermissionUseCase() RolePermissionUseCase { - return u.rolePermission +func (u *UseCases) RolesUseCase() *RolesUseCase { + return u.roles } -func (u *UseCases) UserAccessUseCase() UserRolesUseCase { - return u.userAccess +func (u *UseCases) PermissionsUseCase() *PermissionsUseCase { + return u.permissions } -func (u *UseCases) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { - return u.rolePermission.CreatePermission(ctx, req) +func (u *UseCases) RolePermissionsUseCase() *RolePermissionsUseCase { + return u.rolePermissions } -func (u *UseCases) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - return u.rolePermission.GetAllPermissions(ctx) +func (u *UseCases) UserRolesUseCase() *UserRolesUseCase { + return u.userRoles } -func (u *UseCases) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { - return u.rolePermission.GetRolePermissions(ctx, roleID) +func (u *UseCases) UserAccessUseCase() *UserAccessUseCase { + return u.userAccess } -func (u *UseCases) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { - return u.rolePermission.UpdatePermission(ctx, permissionID, req) -} - -func (u *UseCases) DeletePermission(ctx context.Context, permissionID string) error { - return u.rolePermission.DeletePermission(ctx, permissionID) -} +// Roles func (u *UseCases) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { - return u.rolePermission.CreateRole(ctx, req) + return u.roles.CreateRole(ctx, req) } func (u *UseCases) GetAllRoles(ctx context.Context) ([]types.Role, error) { - return u.rolePermission.GetAllRoles(ctx) + return u.roles.GetAllRoles(ctx) } func (u *UseCases) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { - return u.rolePermission.GetRoleByID(ctx, roleID) + return u.roles.GetRoleByID(ctx, roleID) } func (u *UseCases) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { - return u.rolePermission.UpdateRole(ctx, roleID, req) + return u.roles.UpdateRole(ctx, roleID, req) } func (u *UseCases) DeleteRole(ctx context.Context, roleID string) error { - return u.rolePermission.DeleteRole(ctx, roleID) + return u.roles.DeleteRole(ctx, roleID) +} + +// Permissions + +func (u *UseCases) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { + return u.permissions.CreatePermission(ctx, req) +} + +func (u *UseCases) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { + return u.permissions.GetAllPermissions(ctx) +} + +func (u *UseCases) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + return u.permissions.GetPermissionByID(ctx, permissionID) +} + +func (u *UseCases) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { + return u.permissions.UpdatePermission(ctx, permissionID, req) +} + +func (u *UseCases) DeletePermission(ctx context.Context, permissionID string) error { + return u.permissions.DeletePermission(ctx, permissionID) +} + +// Role Permissions + +func (u *UseCases) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { + return u.rolePermissions.GetRolePermissions(ctx, roleID) } func (u *UseCases) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { - return u.rolePermission.AddPermissionToRole(ctx, roleID, permissionID, grantedByUserID) + return u.rolePermissions.AddPermissionToRole(ctx, roleID, permissionID, grantedByUserID) } func (u *UseCases) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { - return u.rolePermission.RemovePermissionFromRole(ctx, roleID, permissionID) + return u.rolePermissions.RemovePermissionFromRole(ctx, roleID, permissionID) } func (u *UseCases) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { - return u.rolePermission.ReplaceRolePermissions(ctx, roleID, permissionIDs, grantedByUserID) + return u.rolePermissions.ReplaceRolePermissions(ctx, roleID, permissionIDs, grantedByUserID) +} + +// User Roles + +func (u *UseCases) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + return u.userRoles.GetUserRoles(ctx, userID) } func (u *UseCases) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - return u.rolePermission.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) + return u.userRoles.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) } func (u *UseCases) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { - return u.rolePermission.AssignRoleToUser(ctx, userID, req, assignedByUserID) + return u.userRoles.AssignRoleToUser(ctx, userID, req, assignedByUserID) } func (u *UseCases) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { - return u.rolePermission.RemoveRoleFromUser(ctx, userID, roleID) + return u.userRoles.RemoveRoleFromUser(ctx, userID, roleID) } -func (u *UseCases) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - return u.userAccess.GetUserRoles(ctx, userID) +func (u *UseCases) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { + return u.userRoles.GetUserWithRolesByID(ctx, userID) } +// User Access + func (u *UseCases) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { return u.userAccess.GetUserEffectivePermissions(ctx, userID) } @@ -102,10 +144,6 @@ func (u *UseCases) HasPermissions(ctx context.Context, userID string, requiredPe return u.userAccess.HasPermissions(ctx, userID, requiredPermissions) } -func (u *UseCases) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return u.userAccess.GetUserWithRolesByID(ctx, userID) -} - func (u *UseCases) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { return u.userAccess.GetUserWithPermissionsByID(ctx, userID) } diff --git a/plugins/access-control/usecases/user_access_usecase.go b/plugins/access-control/usecases/user_access_usecase.go index 0fc5b334..8a7c5acf 100644 --- a/plugins/access-control/usecases/user_access_usecase.go +++ b/plugins/access-control/usecases/user_access_usecase.go @@ -7,34 +7,26 @@ import ( "github.com/Authula/authula/plugins/access-control/types" ) -type UserRolesUseCase struct { +type UserAccessUseCase struct { service *services.UserAccessService } -func NewUserRolesUseCase(service *services.UserAccessService) UserRolesUseCase { - return UserRolesUseCase{service: service} +func NewUserAccessUseCase(service *services.UserAccessService) *UserAccessUseCase { + return &UserAccessUseCase{service: service} } -func (u UserRolesUseCase) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - return u.service.GetUserRoles(ctx, userID) -} - -func (u UserRolesUseCase) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return u.service.GetUserWithRolesByID(ctx, userID) -} - -func (u UserRolesUseCase) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { +func (u *UserAccessUseCase) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { return u.service.GetUserWithPermissionsByID(ctx, userID) } -func (u UserRolesUseCase) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { +func (u *UserAccessUseCase) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { return u.service.GetUserAuthorizationProfile(ctx, userID) } -func (u UserRolesUseCase) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { +func (u *UserAccessUseCase) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { return u.service.GetUserEffectivePermissions(ctx, userID) } -func (u UserRolesUseCase) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { +func (u *UserAccessUseCase) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { return u.service.HasPermissions(ctx, userID, requiredPermissions) } diff --git a/plugins/access-control/usecases/user_access_usecase_test.go b/plugins/access-control/usecases/user_access_usecase_test.go deleted file mode 100644 index ef9814cb..00000000 --- a/plugins/access-control/usecases/user_access_usecase_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package usecases - -import ( - "context" - "testing" - - "github.com/Authula/authula/plugins/access-control/services" - "github.com/Authula/authula/plugins/access-control/types" -) - -type stubUserAccessRepo struct { - permissions []types.UserPermissionInfo -} - -func (s *stubUserAccessRepo) GetUserRoles(_ context.Context, _ string) ([]types.UserRoleInfo, error) { - return []types.UserRoleInfo{{RoleID: "role-1", RoleName: "admin"}}, nil -} - -func (s *stubUserAccessRepo) GetUserEffectivePermissions(_ context.Context, _ string) ([]types.UserPermissionInfo, error) { - return s.permissions, nil -} - -func (s *stubUserAccessRepo) GetUserWithRolesByID(_ context.Context, _ string) (*types.UserWithRoles, error) { - return &types.UserWithRoles{}, nil -} - -func (s *stubUserAccessRepo) GetUserWithPermissionsByID(_ context.Context, _ string) (*types.UserWithPermissions, error) { - return &types.UserWithPermissions{}, nil -} - -func TestUserRolesUseCaseHasPermissionsPassThrough(t *testing.T) { - repo := &stubUserAccessRepo{permissions: []types.UserPermissionInfo{{PermissionKey: "users.read"}}} - uc := NewUserRolesUseCase(services.NewUserAccessService(repo)) - - ok, err := uc.HasPermissions(context.Background(), "user-1", []string{"users.read"}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !ok { - t.Fatal("expected permission check to pass") - } -} diff --git a/plugins/access-control/usecases/user_roles_usecase.go b/plugins/access-control/usecases/user_roles_usecase.go new file mode 100644 index 00000000..e4029b52 --- /dev/null +++ b/plugins/access-control/usecases/user_roles_usecase.go @@ -0,0 +1,36 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/services" + "github.com/Authula/authula/plugins/access-control/types" +) + +type UserRolesUseCase struct { + service *services.UserRolesService +} + +func NewUserRolesUseCase(service *services.UserRolesService) *UserRolesUseCase { + return &UserRolesUseCase{service: service} +} + +func (u *UserRolesUseCase) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + return u.service.GetUserRoles(ctx, userID) +} + +func (u *UserRolesUseCase) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { + return u.service.GetUserWithRolesByID(ctx, userID) +} + +func (u *UserRolesUseCase) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { + return u.service.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) +} + +func (u *UserRolesUseCase) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { + return u.service.AssignRoleToUser(ctx, userID, req, assignedByUserID) +} + +func (u *UserRolesUseCase) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { + return u.service.RemoveRoleFromUser(ctx, userID, roleID) +} From bb45a2e5b85171b300cd2e0c9761d7ade9c12d14 Mon Sep 17 00:00:00 2001 From: Tanvir Ahmed Date: Tue, 31 Mar 2026 03:40:31 +0000 Subject: [PATCH 2/6] chore: Updates made across the whole plugin including adding new functionality --- plugins/access-control/api.go | 26 +- .../handlers/permission_handlers_test.go | 48 ++-- .../access-control/handlers/role_handlers.go | 24 ++ .../handlers/role_handlers_test.go | 147 ++++++++-- .../handlers/user_access_handlers.go | 59 ---- .../handlers/user_access_handlers_test.go | 241 ---------------- .../handlers/user_permissions_handlers.go | 65 +++++ .../user_permissions_handlers_test.go | 161 +++++++++++ .../handlers/user_roles_handlers.go | 24 -- .../handlers/user_roles_handlers_test.go | 139 --------- plugins/access-control/plugin.go | 10 +- plugins/access-control/repositories/errors.go | 30 -- .../access-control/repositories/interfaces.go | 13 +- .../repositories/permissions_repository.go | 23 +- .../permissions_repository_test.go | 12 +- .../role_permission_repository.go | 25 +- .../role_permission_repository_test.go | 12 +- .../repositories/roles_repository.go | 19 +- .../repositories/roles_repository_test.go | 230 ++++++++++++++- .../repositories/user_access_repository.go | 190 ------------ .../user_access_repository_test.go | 272 ------------------ .../user_permissions_repository.go | 112 ++++++++ .../user_permissions_repository_test.go | 209 ++++++++++++++ .../repositories/user_roles_repository.go | 71 ++--- .../user_roles_repository_test.go | 110 +------ plugins/access-control/routes.go | 21 +- .../services/permissions_service.go | 28 +- .../services/permissions_service_test.go | 49 ++-- .../access-control/services/roles_service.go | 26 +- .../services/roles_service_test.go | 87 +++++- .../services/user_access_service.go | 90 ------ .../services/user_access_service_test.go | 81 ------ .../services/user_permissions_service.go | 33 +++ .../services/user_permissions_service_test.go | 139 +++++++++ .../services/user_roles_service.go | 8 - plugins/access-control/tests/test_helpers.go | 51 ++-- plugins/access-control/types/config.go | 8 + plugins/access-control/types/models.go | 142 --------- plugins/access-control/types/types.go | 140 ++++++++- .../usecases/permissions_usecase.go | 4 + .../access-control/usecases/roles_usecase.go | 4 + plugins/access-control/usecases/usecases.go | 36 +-- .../usecases/user_access_usecase.go | 32 --- .../usecases/user_permissions_usecase.go | 24 ++ .../usecases/user_roles_usecase.go | 4 - 45 files changed, 1607 insertions(+), 1672 deletions(-) delete mode 100644 plugins/access-control/handlers/user_access_handlers.go delete mode 100644 plugins/access-control/handlers/user_access_handlers_test.go create mode 100644 plugins/access-control/handlers/user_permissions_handlers.go create mode 100644 plugins/access-control/handlers/user_permissions_handlers_test.go delete mode 100644 plugins/access-control/repositories/errors.go delete mode 100644 plugins/access-control/repositories/user_access_repository.go delete mode 100644 plugins/access-control/repositories/user_access_repository_test.go create mode 100644 plugins/access-control/repositories/user_permissions_repository.go create mode 100644 plugins/access-control/repositories/user_permissions_repository_test.go delete mode 100644 plugins/access-control/services/user_access_service.go delete mode 100644 plugins/access-control/services/user_access_service_test.go create mode 100644 plugins/access-control/services/user_permissions_service.go create mode 100644 plugins/access-control/services/user_permissions_service_test.go create mode 100644 plugins/access-control/types/config.go delete mode 100644 plugins/access-control/usecases/user_access_usecase.go create mode 100644 plugins/access-control/usecases/user_permissions_usecase.go diff --git a/plugins/access-control/api.go b/plugins/access-control/api.go index b12cb211..189571eb 100644 --- a/plugins/access-control/api.go +++ b/plugins/access-control/api.go @@ -21,6 +21,10 @@ func (a *API) GetAllRoles(ctx context.Context) ([]types.Role, error) { return a.useCases.GetAllRoles(ctx) } +func (a *API) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + return a.useCases.GetRoleByName(ctx, roleName) +} + func (a *API) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { return a.useCases.GetRoleByID(ctx, roleID) } @@ -83,10 +87,6 @@ func (a *API) GetUserRoles(ctx context.Context, userID string) ([]types.UserRole return a.useCases.GetUserRoles(ctx, userID) } -func (a *API) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return a.useCases.GetUserWithRolesByID(ctx, userID) -} - func (a *API) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { return a.useCases.AssignRoleToUser(ctx, userID, req, assignedByUserID) } @@ -99,20 +99,12 @@ func (a *API) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []str return a.useCases.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) } -// User access - -func (a *API) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - return a.useCases.GetUserWithPermissionsByID(ctx, userID) -} - -func (a *API) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { - return a.useCases.GetUserAuthorizationProfile(ctx, userID) -} +// User permissions -func (a *API) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - return a.useCases.GetUserEffectivePermissions(ctx, userID) +func (a *API) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + return a.useCases.GetUserPermissions(ctx, userID) } -func (a *API) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - return a.useCases.HasPermissions(ctx, userID, requiredPermissions) +func (a *API) HasPermissions(ctx context.Context, userID string, permissionNames []string) (bool, error) { + return a.useCases.HasPermissions(ctx, userID, permissionNames) } diff --git a/plugins/access-control/handlers/permission_handlers_test.go b/plugins/access-control/handlers/permission_handlers_test.go index 3dec403f..07307c09 100644 --- a/plugins/access-control/handlers/permission_handlers_test.go +++ b/plugins/access-control/handlers/permission_handlers_test.go @@ -86,12 +86,12 @@ func TestGetAllPermissionsHandler(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} if tc.setupMock != nil { tc.setupMock(permissionsRepo) } - useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) handler := NewGetAllPermissionsHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/permissions", nil, nil) @@ -100,7 +100,6 @@ func TestGetAllPermissionsHandler(t *testing.T) { if tc.expectedStatus != http.StatusOK { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) return } @@ -112,7 +111,6 @@ func TestGetAllPermissionsHandler(t *testing.T) { assertPermissionsEqual(t, payload, tc.expectedBody.([]types.Permission)) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) }) } } @@ -188,12 +186,12 @@ func TestCreatePermissionHandler(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} if tc.setupMock != nil { tc.setupMock(permissionsRepo) } - useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) handler := NewCreatePermissionHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/permissions", tc.body, nil) @@ -202,7 +200,6 @@ func TestCreatePermissionHandler(t *testing.T) { if tc.expectedStatus != http.StatusCreated { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) return } @@ -214,7 +211,6 @@ func TestCreatePermissionHandler(t *testing.T) { assertCreatePermissionResponseEqual(t, payload, tc.expectedBody.(types.CreatePermissionResponse)) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) }) } } @@ -272,12 +268,12 @@ func TestGetPermissionByIDHandler(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} if tc.setupMock != nil { tc.setupMock(permissionsRepo) } - useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) handler := NewGetPermissionByIDHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/permissions/"+tc.permissionID, nil, nil) req.SetPathValue("permission_id", tc.permissionID) @@ -287,7 +283,6 @@ func TestGetPermissionByIDHandler(t *testing.T) { if tc.expectedStatus != http.StatusOK { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) return } @@ -299,7 +294,6 @@ func TestGetPermissionByIDHandler(t *testing.T) { assertPermissionEqual(t, payload, *tc.expectedBody.(*types.Permission)) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) }) } } @@ -378,12 +372,12 @@ func TestUpdatePermissionHandler(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} if tc.setupMock != nil { tc.setupMock(permissionsRepo) } - useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) handler := NewUpdatePermissionHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/permissions/"+tc.permissionID, tc.body, nil) req.SetPathValue("permission_id", tc.permissionID) @@ -393,7 +387,6 @@ func TestUpdatePermissionHandler(t *testing.T) { if tc.expectedStatus != http.StatusOK { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) return } @@ -405,7 +398,6 @@ func TestUpdatePermissionHandler(t *testing.T) { assertUpdatePermissionResponseEqual(t, payload, tc.expectedBody.(types.UpdatePermissionResponse)) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) }) } } @@ -416,16 +408,16 @@ func TestDeletePermissionHandler(t *testing.T) { tests := []struct { name string permissionID string - setupMock func(*accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockUserAccessRepository) + setupMock func(*accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) expectedStatus int expectedBody any }{ { name: "service error", permissionID: "perm-1", - setupMock: func(m *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + setupMock: func(m *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: false}, nil).Once() - userAccessRepo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(0, nil).Once() m.On("DeletePermission", mock.Anything, "perm-1").Return(false, errors.New("database error")).Once() }, expectedStatus: http.StatusInternalServerError, @@ -434,9 +426,9 @@ func TestDeletePermissionHandler(t *testing.T) { { name: "success", permissionID: "perm-1", - setupMock: func(m *accesscontroltests.MockPermissionsRepository, u *accesscontroltests.MockUserAccessRepository) { + setupMock: func(m *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: false}, nil).Once() - u.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(0, nil).Once() m.On("DeletePermission", mock.Anything, "perm-1").Return(true, nil).Once() }, expectedStatus: http.StatusOK, @@ -449,12 +441,12 @@ func TestDeletePermissionHandler(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} if tc.setupMock != nil { - tc.setupMock(permissionsRepo, userAccessRepo) + tc.setupMock(permissionsRepo, rolePermissionsRepo) } - useCase := newPermissionsUseCase(permissionsRepo, userAccessRepo) + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) handler := NewDeletePermissionHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/permissions/"+tc.permissionID, nil, nil) req.SetPathValue("permission_id", tc.permissionID) @@ -464,7 +456,7 @@ func TestDeletePermissionHandler(t *testing.T) { if tc.expectedStatus != http.StatusOK { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) return } @@ -476,13 +468,13 @@ func TestDeletePermissionHandler(t *testing.T) { assertDeletePermissionResponseEqual(t, payload, tc.expectedBody.(types.DeletePermissionResponse)) permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) }) } } -func newPermissionsUseCase(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) *usecases.PermissionsUseCase { - return usecases.NewPermissionsUseCase(services.NewPermissionsService(permissionsRepo, userAccessRepo)) +func newPermissionsUseCase(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) *usecases.PermissionsUseCase { + return usecases.NewPermissionsUseCase(services.NewPermissionsService(permissionsRepo, rolePermissionsRepo)) } func assertPermissionsEqual(t *testing.T, got []types.Permission, want []types.Permission) { diff --git a/plugins/access-control/handlers/role_handlers.go b/plugins/access-control/handlers/role_handlers.go index 51e45fc6..dbda4f82 100644 --- a/plugins/access-control/handlers/role_handlers.go +++ b/plugins/access-control/handlers/role_handlers.go @@ -64,6 +64,30 @@ func (h *GetAllRolesHandler) Handler() http.HandlerFunc { } } +type GetRoleByNameHandler struct { + useCase *usecases.RolesUseCase +} + +func NewGetRoleByNameHandler(useCase *usecases.RolesUseCase) *GetRoleByNameHandler { + return &GetRoleByNameHandler{useCase: useCase} +} + +func (h *GetRoleByNameHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + roleName := r.PathValue("role_name") + + role, err := h.useCase.GetRoleByName(r.Context(), roleName) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, role) + } +} + type GetRoleByIDHandler struct { useCase *usecases.RolesUseCase } diff --git a/plugins/access-control/handlers/role_handlers_test.go b/plugins/access-control/handlers/role_handlers_test.go index 2da254a2..96ffedac 100644 --- a/plugins/access-control/handlers/role_handlers_test.go +++ b/plugins/access-control/handlers/role_handlers_test.go @@ -87,12 +87,12 @@ func TestCreateRoleHandler(t *testing.T) { rolesRepo := &accesscontroltests.MockRolesRepository{} rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} if tc.setupMock != nil { tc.setupMock(rolesRepo) } - useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) handler := NewCreateRoleHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/roles", tc.body, nil) @@ -102,7 +102,7 @@ func TestCreateRoleHandler(t *testing.T) { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) return } @@ -115,7 +115,7 @@ func TestCreateRoleHandler(t *testing.T) { rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) }) } } @@ -191,12 +191,12 @@ func TestGetAllRolesHandler(t *testing.T) { rolesRepo := &accesscontroltests.MockRolesRepository{} rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} if tc.setupMock != nil { tc.setupMock(rolesRepo) } - useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) handler := NewGetAllRolesHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles", nil, nil) @@ -206,7 +206,7 @@ func TestGetAllRolesHandler(t *testing.T) { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) return } @@ -219,7 +219,7 @@ func TestGetAllRolesHandler(t *testing.T) { rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) }) } } @@ -301,12 +301,12 @@ func TestGetRoleByIDHandler(t *testing.T) { rolesRepo := &accesscontroltests.MockRolesRepository{} rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} if tc.setupMock != nil { tc.setupMock(rolesRepo, rolePermissionsRepo) } - useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) handler := NewGetRoleByIDHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles/"+tc.roleID, nil, nil) req.SetPathValue("role_id", tc.roleID) @@ -317,7 +317,7 @@ func TestGetRoleByIDHandler(t *testing.T) { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) return } @@ -330,7 +330,95 @@ func TestGetRoleByIDHandler(t *testing.T) { rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func TestGetRoleByNameHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform administrator" + + tests := []struct { + name string + roleName string + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + roleName: "missing", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByName", mock.Anything, "missing").Return((*types.Role)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + roleName: "Administrator", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByName", mock.Anything, "Administrator").Return(&types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) + handler := NewGetRoleByNameHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles/by-name/"+tc.roleName, nil, nil) + req.SetPathValue("role_name", tc.roleName) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.Role](t, reqCtx) + assertRoleEqual(t, payload, tc.expectedBody.(types.Role)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) }) } } @@ -410,12 +498,12 @@ func TestUpdateRoleHandler(t *testing.T) { rolesRepo := &accesscontroltests.MockRolesRepository{} rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} if tc.setupMock != nil { tc.setupMock(rolesRepo) } - useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) handler := NewUpdateRoleHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/roles/role-1", tc.body, nil) req.SetPathValue("role_id", "role-1") @@ -426,7 +514,7 @@ func TestUpdateRoleHandler(t *testing.T) { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) return } @@ -439,7 +527,7 @@ func TestUpdateRoleHandler(t *testing.T) { rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) }) } } @@ -450,14 +538,14 @@ func TestDeleteRoleHandler(t *testing.T) { tests := []struct { name string roleID string - setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserAccessRepository) + setupMock func(*accesscontroltests.MockRolesRepository) expectedStatus int expectedBody any }{ { name: "service error", roleID: "role-1", - setupMock: func(m *accesscontroltests.MockRolesRepository, _ *accesscontroltests.MockUserAccessRepository) { + setupMock: func(m *accesscontroltests.MockRolesRepository) { m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", IsSystem: true}, nil).Once() }, expectedStatus: http.StatusForbidden, @@ -466,9 +554,8 @@ func TestDeleteRoleHandler(t *testing.T) { { name: "success", roleID: "role-1", - setupMock: func(m *accesscontroltests.MockRolesRepository, u *accesscontroltests.MockUserAccessRepository) { + setupMock: func(m *accesscontroltests.MockRolesRepository) { m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", IsSystem: false}, nil).Once() - u.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(0, nil).Once() m.On("DeleteRole", mock.Anything, "role-1").Return(true, nil).Once() }, expectedStatus: http.StatusOK, @@ -482,12 +569,18 @@ func TestDeleteRoleHandler(t *testing.T) { rolesRepo := &accesscontroltests.MockRolesRepository{} rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} if tc.setupMock != nil { - tc.setupMock(rolesRepo, userAccessRepo) + tc.setupMock(rolesRepo) + } + if tc.roleID == "role-1" && tc.expectedStatus == http.StatusOK { + userRolesRepo.On("CountUsersByRole", mock.Anything, "role-1").Return(0, nil).Once() + } + if tc.roleID == "role-1" && tc.expectedStatus == http.StatusForbidden { + userRolesRepo.On("CountUsersByRole", mock.Anything, "role-1").Return(0, nil).Maybe() } - useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userAccessRepo) + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) handler := NewDeleteRoleHandler(useCase) req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/roles/"+tc.roleID, nil, nil) req.SetPathValue("role_id", tc.roleID) @@ -498,7 +591,7 @@ func TestDeleteRoleHandler(t *testing.T) { internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) return } @@ -511,13 +604,13 @@ func TestDeleteRoleHandler(t *testing.T) { rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) }) } } -func newRolesUseCase(rolesRepo *accesscontroltests.MockRolesRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) *usecases.RolesUseCase { - return usecases.NewRolesUseCase(services.NewRolesService(rolesRepo, rolePermissionsRepo, userAccessRepo)) +func newRolesUseCase(rolesRepo *accesscontroltests.MockRolesRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) *usecases.RolesUseCase { + return usecases.NewRolesUseCase(services.NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo)) } func assertRolesEqual(t *testing.T, got []types.Role, want []types.Role) { diff --git a/plugins/access-control/handlers/user_access_handlers.go b/plugins/access-control/handlers/user_access_handlers.go deleted file mode 100644 index 925694a0..00000000 --- a/plugins/access-control/handlers/user_access_handlers.go +++ /dev/null @@ -1,59 +0,0 @@ -package handlers - -import ( - "net/http" - - "github.com/Authula/authula/models" - "github.com/Authula/authula/plugins/access-control/types" - "github.com/Authula/authula/plugins/access-control/usecases" -) - -type GetUserEffectivePermissionsHandler struct { - useCase *usecases.UserAccessUseCase -} - -func NewGetUserEffectivePermissionsHandler(useCase *usecases.UserAccessUseCase) *GetUserEffectivePermissionsHandler { - return &GetUserEffectivePermissionsHandler{ - useCase: useCase, - } -} - -func (h *GetUserEffectivePermissionsHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - userID := r.PathValue("user_id") - - permissions, err := h.useCase.GetUserEffectivePermissions(r.Context(), userID) - if err != nil { - respondUserHandlerError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.GetUserEffectivePermissionsResponse{Permissions: permissions}) - } -} - -type GetUserAuthorizationProfileHandler struct { - useCase *usecases.UserAccessUseCase -} - -func NewGetUserAuthorizationProfileHandler(useCase *usecases.UserAccessUseCase) *GetUserAuthorizationProfileHandler { - return &GetUserAuthorizationProfileHandler{useCase: useCase} -} - -func (h *GetUserAuthorizationProfileHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - userID := r.PathValue("user_id") - - profile, err := h.useCase.GetUserAuthorizationProfile(r.Context(), userID) - if err != nil { - respondUserHandlerError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, profile) - } -} diff --git a/plugins/access-control/handlers/user_access_handlers_test.go b/plugins/access-control/handlers/user_access_handlers_test.go deleted file mode 100644 index 4ba113ef..00000000 --- a/plugins/access-control/handlers/user_access_handlers_test.go +++ /dev/null @@ -1,241 +0,0 @@ -package handlers - -import ( - "encoding/json" - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/mock" - - internaltests "github.com/Authula/authula/internal/tests" - authmodels "github.com/Authula/authula/models" - "github.com/Authula/authula/plugins/access-control/constants" - accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" - "github.com/Authula/authula/plugins/access-control/types" -) - -func TestGetUserEffectivePermissionsHandler(t *testing.T) { - t.Parallel() - - fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) - description := new(string) - *description = "read access" - - tests := []struct { - name string - userID string - setupMock func(*accesscontroltests.MockUserAccessRepository) - expectedStatus int - expectedBody any - }{ - { - name: "blank user id", - userID: "", - expectedStatus: http.StatusUnprocessableEntity, - expectedBody: map[string]string{"message": "unprocessable entity"}, - }, - { - name: "use case error", - userID: "user-1", - setupMock: func(m *accesscontroltests.MockUserAccessRepository) { - m.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return(([]types.UserPermissionInfo)(nil), constants.ErrUnauthorized).Once() - }, - expectedStatus: http.StatusUnauthorized, - expectedBody: map[string]string{"message": "unauthorized"}, - }, - { - name: "success", - userID: "user-1", - setupMock: func(m *accesscontroltests.MockUserAccessRepository) { - m.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return([]types.UserPermissionInfo{{ - PermissionID: "perm-1", - PermissionKey: "users.read", - PermissionDescription: description, - GrantedAt: &fixedTime, - }}, nil).Once() - }, - expectedStatus: http.StatusOK, - expectedBody: &types.GetUserEffectivePermissionsResponse{Permissions: []types.UserPermissionInfo{{ - PermissionID: "perm-1", - PermissionKey: "users.read", - PermissionDescription: description, - GrantedAt: &fixedTime, - }}}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} - if tc.setupMock != nil { - tc.setupMock(userAccessRepo) - } - - userRolesRepo := &accesscontroltests.MockUserRolesRepository{} - useCase := newUserAccessUseCase(userRolesRepo, userAccessRepo) - handler := NewGetUserEffectivePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/permissions", nil, nil) - req.SetPathValue("user_id", tc.userID) - - handler.Handler()(w, req) - - if tc.expectedStatus != http.StatusOK { - internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) - userAccessRepo.AssertExpectations(t) - userRolesRepo.AssertExpectations(t) - return - } - - if reqCtx.ResponseStatus != tc.expectedStatus { - t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) - } - - payload := internaltests.DecodeResponseJSON[types.GetUserEffectivePermissionsResponse](t, reqCtx) - assertGetUserEffectivePermissionsResponseEqual(t, payload, *tc.expectedBody.(*types.GetUserEffectivePermissionsResponse)) - - userAccessRepo.AssertExpectations(t) - userRolesRepo.AssertExpectations(t) - }) - } -} - -func TestGetUserAuthorizationProfileHandler(t *testing.T) { - t.Parallel() - - fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) - roleDescription := new(string) - *roleDescription = "profile role" - permissionDescription := new(string) - *permissionDescription = "profile access" - - tests := []struct { - name string - userID string - setupMock func(*accesscontroltests.MockUserRolesRepository, *accesscontroltests.MockUserAccessRepository) - expectedStatus int - expectedBody any - }{ - { - name: "blank user id", - userID: "", - expectedStatus: http.StatusUnprocessableEntity, - expectedBody: map[string]string{"message": "unprocessable entity"}, - }, - { - name: "use case error", - userID: "user-404", - setupMock: func(userRolesRepo *accesscontroltests.MockUserRolesRepository, _ *accesscontroltests.MockUserAccessRepository) { - userRolesRepo.On("GetUserWithRolesByID", mock.Anything, "user-404").Return((*types.UserWithRoles)(nil), constants.ErrNotFound).Once() - }, - expectedStatus: http.StatusNotFound, - expectedBody: map[string]string{"message": "not found"}, - }, - { - name: "success", - userID: "user-1", - setupMock: func(userRolesRepo *accesscontroltests.MockUserRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { - userRolesRepo.On("GetUserWithRolesByID", mock.Anything, "user-1").Return(&types.UserWithRoles{ - User: authmodels.User{ - ID: "user-1", - Name: "Pat", - Email: "pat@example.com", - EmailVerified: true, - Metadata: json.RawMessage("null"), - CreatedAt: fixedTime, - UpdatedAt: fixedTime, - }, - Roles: []types.UserRoleInfo{{ - RoleID: "role-1", - RoleName: "Editor", - RoleDescription: roleDescription, - AssignedByUserID: nil, - AssignedAt: &fixedTime, - ExpiresAt: nil, - }}, - }, nil).Once() - userAccessRepo.On("GetUserWithPermissionsByID", mock.Anything, "user-1").Return(&types.UserWithPermissions{ - User: authmodels.User{ - ID: "user-1", - Name: "Pat", - Email: "pat@example.com", - EmailVerified: true, - Metadata: json.RawMessage("null"), - CreatedAt: fixedTime, - UpdatedAt: fixedTime, - }, - Permissions: []types.UserPermissionInfo{{ - PermissionID: "perm-1", - PermissionKey: "users.read", - PermissionDescription: permissionDescription, - GrantedAt: &fixedTime, - }}, - }, nil).Once() - }, - expectedStatus: http.StatusOK, - expectedBody: &types.UserAuthorizationProfile{ - User: authmodels.User{ - ID: "user-1", - Name: "Pat", - Email: "pat@example.com", - EmailVerified: true, - Metadata: json.RawMessage("null"), - CreatedAt: fixedTime, - UpdatedAt: fixedTime, - }, - Roles: []types.UserRoleInfo{{ - RoleID: "role-1", - RoleName: "Editor", - RoleDescription: roleDescription, - AssignedByUserID: nil, - AssignedAt: &fixedTime, - ExpiresAt: nil, - }}, - Permissions: []types.UserPermissionInfo{{ - PermissionID: "perm-1", - PermissionKey: "users.read", - PermissionDescription: permissionDescription, - GrantedAt: &fixedTime, - }}, - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - userRolesRepo := &accesscontroltests.MockUserRolesRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} - if tc.setupMock != nil { - tc.setupMock(userRolesRepo, userAccessRepo) - } - - useCase := newUserAccessUseCase(userRolesRepo, userAccessRepo) - handler := NewGetUserAuthorizationProfileHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/authorization-profile", nil, nil) - req.SetPathValue("user_id", tc.userID) - - handler.Handler()(w, req) - - if tc.expectedStatus != http.StatusOK { - internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) - userRolesRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) - return - } - - if reqCtx.ResponseStatus != tc.expectedStatus { - t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) - } - - payload := internaltests.DecodeResponseJSON[types.UserAuthorizationProfile](t, reqCtx) - assertUserAuthorizationProfileEqual(t, payload, *tc.expectedBody.(*types.UserAuthorizationProfile)) - - userRolesRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) - }) - } -} diff --git a/plugins/access-control/handlers/user_permissions_handlers.go b/plugins/access-control/handlers/user_permissions_handlers.go new file mode 100644 index 00000000..562cbc3f --- /dev/null +++ b/plugins/access-control/handlers/user_permissions_handlers.go @@ -0,0 +1,65 @@ +package handlers + +import ( + "net/http" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +type GetUserPermissionsHandler struct { + useCase *usecases.UserPermissionsUseCase +} + +func NewGetUserPermissionsHandler(useCase *usecases.UserPermissionsUseCase) *GetUserPermissionsHandler { + return &GetUserPermissionsHandler{useCase: useCase} +} + +func (h *GetUserPermissionsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + userID := r.PathValue("user_id") + + permissions, err := h.useCase.GetUserPermissions(r.Context(), userID) + if err != nil { + respondUserHandlerError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.GetUserEffectivePermissionsResponse{Permissions: permissions}) + } +} + +type CheckUserPermissionsHandler struct { + useCase *usecases.UserPermissionsUseCase +} + +func NewCheckUserPermissionsHandler(useCase *usecases.UserPermissionsUseCase) *CheckUserPermissionsHandler { + return &CheckUserPermissionsHandler{useCase: useCase} +} + +func (h *CheckUserPermissionsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + userID := r.PathValue("user_id") + + var payload types.CheckUserPermissionsRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + allowed, err := h.useCase.HasPermissions(r.Context(), userID, payload.PermissionKeys) + if err != nil { + respondUserHandlerError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.CheckUserPermissionsResponse{HasPermissions: allowed}) + } +} diff --git a/plugins/access-control/handlers/user_permissions_handlers_test.go b/plugins/access-control/handlers/user_permissions_handlers_test.go new file mode 100644 index 00000000..424181aa --- /dev/null +++ b/plugins/access-control/handlers/user_permissions_handlers_test.go @@ -0,0 +1,161 @@ +package handlers + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +func TestGetUserPermissionsHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank user id", + userID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "success", + userID: "u1", + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("GetUserPermissions", mock.Anything, "u1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.GetUserEffectivePermissionsResponse{Permissions: []types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}}, + }, + { + name: "repo error", + userID: "u1", + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("GetUserPermissions", mock.Anything, "u1").Return(([]types.UserPermissionInfo)(nil), accesscontrolconstants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + repo := &accesscontroltests.MockUserPermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(repo) + } + + handler := NewGetUserPermissionsHandler(newUserPermissionsUseCase(repo)) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/permissions", nil, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + repo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.GetUserEffectivePermissionsResponse](t, reqCtx) + assertUserPermissionInfosEqual(t, payload.Permissions, tc.expectedBody.(types.GetUserEffectivePermissionsResponse).Permissions) + repo.AssertExpectations(t) + }) + } +} + +func TestCheckUserPermissionsHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + body []byte + setupMock func(*accesscontroltests.MockUserPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + userID: "u1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "success", + userID: "u1", + body: internaltests.MarshalToJSON(t, types.CheckUserPermissionsRequest{PermissionKeys: []string{"users.read", "users.write"}}), + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("HasPermissions", mock.Anything, "u1", []string{"users.read", "users.write"}).Return(true, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.CheckUserPermissionsResponse{HasPermissions: true}, + }, + { + name: "repo error", + userID: "u1", + body: internaltests.MarshalToJSON(t, types.CheckUserPermissionsRequest{PermissionKeys: []string{"users.read"}}), + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("HasPermissions", mock.Anything, "u1", []string{"users.read"}).Return(false, accesscontrolconstants.ErrForbidden).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "forbidden"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + repo := &accesscontroltests.MockUserPermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(repo) + } + + handler := NewCheckUserPermissionsHandler(newUserPermissionsUseCase(repo)) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/users/"+tc.userID+"/permissions/check", tc.body, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + repo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.CheckUserPermissionsResponse](t, reqCtx) + if payload != tc.expectedBody.(types.CheckUserPermissionsResponse) { + t.Fatalf("unexpected response: %#v", payload) + } + repo.AssertExpectations(t) + }) + } +} + +func newUserPermissionsUseCase(repo *accesscontroltests.MockUserPermissionsRepository) *usecases.UserPermissionsUseCase { + return usecases.NewUserPermissionsUseCase(services.NewUserPermissionsService(repo)) +} diff --git a/plugins/access-control/handlers/user_roles_handlers.go b/plugins/access-control/handlers/user_roles_handlers.go index 93514f35..a3aa1b2e 100644 --- a/plugins/access-control/handlers/user_roles_handlers.go +++ b/plugins/access-control/handlers/user_roles_handlers.go @@ -35,30 +35,6 @@ func (h *GetUserRolesHandler) Handler() http.HandlerFunc { } } -type GetUserWithRolesHandler struct { - useCase *usecases.UserRolesUseCase -} - -func NewGetUserWithRolesHandler(useCase *usecases.UserRolesUseCase) *GetUserWithRolesHandler { - return &GetUserWithRolesHandler{useCase: useCase} -} - -func (h *GetUserWithRolesHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - userID := r.PathValue("user_id") - - userWithRoles, err := h.useCase.GetUserWithRolesByID(r.Context(), userID) - if err != nil { - respondUserHandlerError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, userWithRoles) - } -} - type ReplaceUserRolesHandler struct { useCase *usecases.UserRolesUseCase } diff --git a/plugins/access-control/handlers/user_roles_handlers_test.go b/plugins/access-control/handlers/user_roles_handlers_test.go index 47fb7909..dad977d7 100644 --- a/plugins/access-control/handlers/user_roles_handlers_test.go +++ b/plugins/access-control/handlers/user_roles_handlers_test.go @@ -1,7 +1,6 @@ package handlers import ( - "encoding/json" "net/http" "testing" "time" @@ -110,119 +109,6 @@ func TestGetUserRolesHandler(t *testing.T) { } } -func TestGetUserWithRolesHandler(t *testing.T) { - t.Parallel() - - fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) - description := new(string) - *description = "platform user" - - tests := []struct { - name string - userID string - setupMock func(*accesscontroltests.MockUserRolesRepository) - expectedStatus int - expectedBody any - }{ - { - name: "blank user id", - userID: "", - expectedStatus: http.StatusUnprocessableEntity, - expectedBody: map[string]string{"message": "unprocessable entity"}, - }, - { - name: "use case error", - userID: "user-404", - setupMock: func(m *accesscontroltests.MockUserRolesRepository) { - m.On("GetUserWithRolesByID", mock.Anything, "user-404").Return((*types.UserWithRoles)(nil), constants.ErrNotFound).Once() - }, - expectedStatus: http.StatusNotFound, - expectedBody: map[string]string{"message": "not found"}, - }, - { - name: "success", - userID: "user-1", - setupMock: func(m *accesscontroltests.MockUserRolesRepository) { - m.On("GetUserWithRolesByID", mock.Anything, "user-1").Return(&types.UserWithRoles{ - User: authmodels.User{ - ID: "user-1", - Name: "Pat", - Email: "pat@example.com", - EmailVerified: true, - Metadata: json.RawMessage("null"), - CreatedAt: fixedTime, - UpdatedAt: fixedTime, - }, - Roles: []types.UserRoleInfo{{ - RoleID: "role-1", - RoleName: "Editor", - RoleDescription: description, - AssignedByUserID: nil, - AssignedAt: &fixedTime, - ExpiresAt: nil, - }}, - }, nil).Once() - }, - expectedStatus: http.StatusOK, - expectedBody: &types.UserWithRoles{ - User: authmodels.User{ - ID: "user-1", - Name: "Pat", - Email: "pat@example.com", - EmailVerified: true, - Metadata: json.RawMessage("null"), - CreatedAt: fixedTime, - UpdatedAt: fixedTime, - }, - Roles: []types.UserRoleInfo{{ - RoleID: "role-1", - RoleName: "Editor", - RoleDescription: description, - AssignedByUserID: nil, - AssignedAt: &fixedTime, - ExpiresAt: nil, - }}, - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - userRolesRepo := &accesscontroltests.MockUserRolesRepository{} - rolesRepo := &accesscontroltests.MockRolesRepository{} - if tc.setupMock != nil { - tc.setupMock(userRolesRepo) - } - - useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) - handler := NewGetUserWithRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/roles/details", nil, nil) - req.SetPathValue("user_id", tc.userID) - - handler.Handler()(w, req) - - if tc.expectedStatus != http.StatusOK { - internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) - userRolesRepo.AssertExpectations(t) - rolesRepo.AssertExpectations(t) - return - } - - if reqCtx.ResponseStatus != tc.expectedStatus { - t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) - } - - payload := internaltests.DecodeResponseJSON[types.UserWithRoles](t, reqCtx) - assertUserWithRolesEqual(t, payload, *tc.expectedBody.(*types.UserWithRoles)) - - userRolesRepo.AssertExpectations(t) - rolesRepo.AssertExpectations(t) - }) - } -} - func TestReplaceUserRolesHandler(t *testing.T) { t.Parallel() @@ -501,10 +387,6 @@ func newUserRolesUseCase(rolesRepo *accesscontroltests.MockRolesRepository, user return usecases.NewUserRolesUseCase(services.NewUserRolesService(userRolesRepo, rolesRepo)) } -func newUserAccessUseCase(userRolesRepo *accesscontroltests.MockUserRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) *usecases.UserAccessUseCase { - return usecases.NewUserAccessUseCase(services.NewUserAccessService(userRolesRepo, userAccessRepo)) -} - func assertUserRoleInfosEqual(t *testing.T, got []types.UserRoleInfo, want []types.UserRoleInfo) { t.Helper() @@ -537,27 +419,6 @@ func assertUserRoleInfoEqual(t *testing.T, got types.UserRoleInfo, want types.Us } } -func assertUserWithRolesEqual(t *testing.T, got types.UserWithRoles, want types.UserWithRoles) { - t.Helper() - - assertUserEqual(t, got.User, want.User) - assertUserRoleInfosEqual(t, got.Roles, want.Roles) -} - -func assertGetUserEffectivePermissionsResponseEqual(t *testing.T, got types.GetUserEffectivePermissionsResponse, want types.GetUserEffectivePermissionsResponse) { - t.Helper() - - assertUserPermissionInfosEqual(t, got.Permissions, want.Permissions) -} - -func assertUserAuthorizationProfileEqual(t *testing.T, got types.UserAuthorizationProfile, want types.UserAuthorizationProfile) { - t.Helper() - - assertUserEqual(t, got.User, want.User) - assertUserRoleInfosEqual(t, got.Roles, want.Roles) - assertUserPermissionInfosEqual(t, got.Permissions, want.Permissions) -} - func assertUserEqual(t *testing.T, got authmodels.User, want authmodels.User) { t.Helper() diff --git a/plugins/access-control/plugin.go b/plugins/access-control/plugin.go index e15395f6..3a94b312 100644 --- a/plugins/access-control/plugin.go +++ b/plugins/access-control/plugin.go @@ -46,20 +46,20 @@ func (p *AccessControlPlugin) Init(ctx *models.PluginContext) error { permissionsRepo := repositories.NewBunPermissionsRepository(ctx.DB) rolePermissionsRepo := repositories.NewBunRolePermissionsRepository(ctx.DB) userRolesRepo := repositories.NewBunUserRolesRepository(ctx.DB) - userAccessRepo := repositories.NewBunUserAccessRepository(ctx.DB) + userPermissionsRepo := repositories.NewBunUserPermissionsRepository(ctx.DB) - rolesService := services.NewRolesService(rolesRepo, rolePermissionsRepo, userAccessRepo) - permissionsService := services.NewPermissionsService(permissionsRepo, userAccessRepo) + rolesService := services.NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo) + permissionsService := services.NewPermissionsService(permissionsRepo, rolePermissionsRepo) rolePermissionsService := services.NewRolePermissionsService(rolesRepo, permissionsRepo, rolePermissionsRepo) userRolesService := services.NewUserRolesService(userRolesRepo, rolesRepo) - userAccessService := services.NewUserAccessService(userRolesRepo, userAccessRepo) + userPermissionsService := services.NewUserPermissionsService(userPermissionsRepo) useCases := usecases.NewAccessControlUseCases( usecases.NewRolesUseCase(rolesService), usecases.NewPermissionsUseCase(permissionsService), usecases.NewRolePermissionsUseCase(rolePermissionsService), usecases.NewUserRolesUseCase(userRolesService), - usecases.NewUserAccessUseCase(userAccessService), + usecases.NewUserPermissionsUseCase(userPermissionsService), ) p.Api = NewAPI(useCases) diff --git a/plugins/access-control/repositories/errors.go b/plugins/access-control/repositories/errors.go deleted file mode 100644 index 39ea9b18..00000000 --- a/plugins/access-control/repositories/errors.go +++ /dev/null @@ -1,30 +0,0 @@ -package repositories - -import ( - "fmt" - "strings" - - "github.com/Authula/authula/plugins/access-control/constants" -) - -func wrapRepositoryError(action string, err error) error { - if err == nil { - return nil - } - - if isUniqueConstraintError(err) { - return constants.ErrConflict - } - - return fmt.Errorf("failed to %s: %w", action, err) -} - -func isUniqueConstraintError(err error) bool { - message := strings.ToLower(err.Error()) - - return strings.Contains(message, "unique constraint") || - strings.Contains(message, "unique violation") || - strings.Contains(message, "duplicate key value") || - strings.Contains(message, "duplicate entry") || - strings.Contains(message, "error 1062") -} diff --git a/plugins/access-control/repositories/interfaces.go b/plugins/access-control/repositories/interfaces.go index 15f9a80e..34a7fc86 100644 --- a/plugins/access-control/repositories/interfaces.go +++ b/plugins/access-control/repositories/interfaces.go @@ -11,6 +11,7 @@ type RolesRepository interface { CreateRole(ctx context.Context, role *types.Role) error GetAllRoles(ctx context.Context) ([]types.Role, error) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) + GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) DeleteRole(ctx context.Context, roleID string) (bool, error) } @@ -18,6 +19,7 @@ type RolesRepository interface { type PermissionsRepository interface { GetAllPermissions(ctx context.Context) ([]types.Permission, error) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) + GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) CreatePermission(ctx context.Context, permission *types.Permission) error UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) DeletePermission(ctx context.Context, permissionID string) (bool, error) @@ -28,19 +30,18 @@ type RolePermissionsRepository interface { ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error + CountRolesByPermission(ctx context.Context, permissionID string) (int, error) } type UserRolesRepository interface { GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) - GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error RemoveUserRole(ctx context.Context, userID string, roleID string) error + CountUsersByRole(ctx context.Context, roleID string) (int, error) } -type UserAccessRepository interface { - CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) - CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) - GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) - GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) +type UserPermissionsRepository interface { + GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) + HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) } diff --git a/plugins/access-control/repositories/permissions_repository.go b/plugins/access-control/repositories/permissions_repository.go index 5680863d..4886ad4a 100644 --- a/plugins/access-control/repositories/permissions_repository.go +++ b/plugins/access-control/repositories/permissions_repository.go @@ -20,8 +20,10 @@ func NewBunPermissionsRepository(db bun.IDB) *BunPermissionsRepository { } func (r *BunPermissionsRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { - _, err := r.db.NewInsert().Model(permission).Exec(ctx) - return wrapRepositoryError("create permission", err) + if _, err := r.db.NewInsert().Model(permission).Exec(ctx); err != nil { + return err + } + return nil } func (r *BunPermissionsRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { @@ -46,6 +48,19 @@ func (r *BunPermissionsRepository) GetPermissionByID(ctx context.Context, permis return permission, nil } +func (r *BunPermissionsRepository) GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) { + permission := new(types.Permission) + err := r.db.NewSelect().Model(permission).Where("key = ?", permissionKey).Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get permission by key: %w", err) + } + + return permission, nil +} + func (r *BunPermissionsRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { query := r.db.NewUpdate(). Model((*types.Permission)(nil)). @@ -58,7 +73,7 @@ func (r *BunPermissionsRepository) UpdatePermission(ctx context.Context, permiss result, err := query.Exec(ctx) if err != nil { - return false, wrapRepositoryError("update permission", err) + return false, err } affected, err := result.RowsAffected() @@ -72,7 +87,7 @@ func (r *BunPermissionsRepository) UpdatePermission(ctx context.Context, permiss func (r *BunPermissionsRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { result, err := r.db.NewDelete().Model((*types.Permission)(nil)).Where("id = ?", permissionID).Exec(ctx) if err != nil { - return false, fmt.Errorf("failed to delete permission: %w", err) + return false, err } affected, err := result.RowsAffected() diff --git a/plugins/access-control/repositories/permissions_repository_test.go b/plugins/access-control/repositories/permissions_repository_test.go index 95ece8c3..06152c06 100644 --- a/plugins/access-control/repositories/permissions_repository_test.go +++ b/plugins/access-control/repositories/permissions_repository_test.go @@ -2,9 +2,9 @@ package repositories import ( "context" + "strings" "testing" - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" plugintests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" ) @@ -30,7 +30,6 @@ func TestBunPermissionsRepositoryCreatePermission(t *testing.T) { { name: "duplicate key returns conflict", permission: &types.Permission{ID: "p2", Key: "users.read", Description: new("Duplicate"), IsSystem: false}, - wantErr: accesscontrolconstants.ErrConflict, }, } @@ -49,6 +48,15 @@ func TestBunPermissionsRepositoryCreatePermission(t *testing.T) { } err := repo.CreatePermission(ctx, tc.permission) + if tc.name == "duplicate key returns conflict" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "UNIQUE constraint failed: access_control_permissions.key") { + t.Fatalf("expected raw unique constraint error, got %v", err) + } + return + } if err != tc.wantErr { t.Fatalf("expected err %v, got %v", tc.wantErr, err) } diff --git a/plugins/access-control/repositories/role_permission_repository.go b/plugins/access-control/repositories/role_permission_repository.go index 4a36a87d..807b5ca9 100644 --- a/plugins/access-control/repositories/role_permission_repository.go +++ b/plugins/access-control/repositories/role_permission_repository.go @@ -15,7 +15,6 @@ type BunRolePermissionRepository struct { PermissionsRepository RolePermissionsRepository UserRolesRepository - UserAccessRepository } func NewBunRolePermissionRepository(db bun.IDB) *BunRolePermissionRepository { @@ -24,7 +23,6 @@ func NewBunRolePermissionRepository(db bun.IDB) *BunRolePermissionRepository { PermissionsRepository: NewBunPermissionsRepository(db), RolePermissionsRepository: NewBunRolePermissionsRepository(db), UserRolesRepository: NewBunUserRolesRepository(db), - UserAccessRepository: NewBunUserAccessRepository(db), } } @@ -90,8 +88,9 @@ func (r *BunRolePermissionsRepository) ReplaceRolePermissions(ctx context.Contex GrantedByUserID: grantedByUserID, GrantedAt: now, } - if _, err := tx.NewInsert().Model(rp).Exec(ctx); err != nil { - return wrapRepositoryError("insert role permission", err) + _, err := tx.NewInsert().Model(rp).Exec(ctx) + if err != nil { + return err } } @@ -108,7 +107,10 @@ func (r *BunRolePermissionsRepository) AddRolePermission(ctx context.Context, ro } _, err := r.db.NewInsert().Model(rp).Exec(ctx) - return wrapRepositoryError("add role permission", err) + if err != nil { + return err + } + return nil } func (r *BunRolePermissionsRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { @@ -118,8 +120,19 @@ func (r *BunRolePermissionsRepository) RemoveRolePermission(ctx context.Context, Where("permission_id = ?", permissionID). Exec(ctx) if err != nil { - return fmt.Errorf("failed to remove role permission: %w", err) + return err } return nil } + +func (r *BunRolePermissionsRepository) CountRolesByPermission(ctx context.Context, permissionID string) (int, error) { + count, err := r.db.NewSelect(). + Model((*types.RolePermission)(nil)). + Where("permission_id = ?", permissionID). + Count(ctx) + if err != nil { + return 0, fmt.Errorf("failed to count roles by permission: %w", err) + } + return count, nil +} diff --git a/plugins/access-control/repositories/role_permission_repository_test.go b/plugins/access-control/repositories/role_permission_repository_test.go index 3e9e2ab8..9fb6ad89 100644 --- a/plugins/access-control/repositories/role_permission_repository_test.go +++ b/plugins/access-control/repositories/role_permission_repository_test.go @@ -2,9 +2,9 @@ package repositories import ( "context" + "strings" "testing" - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" plugintests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" ) @@ -125,7 +125,6 @@ func TestBunRolePermissionsRepositoryAddRolePermission(t *testing.T) { panic(err) } }, - wantErr: accesscontrolconstants.ErrConflict, }, } @@ -144,6 +143,15 @@ func TestBunRolePermissionsRepositoryAddRolePermission(t *testing.T) { } err := rolePermissionsRepo.AddRolePermission(ctx, tc.roleID, tc.permissionID, tc.grantedByUserID) + if tc.name == "duplicate grant returns conflict" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "UNIQUE constraint failed: access_control_role_permissions.role_id, access_control_role_permissions.permission_id") { + t.Fatalf("expected raw unique constraint error, got %v", err) + } + return + } if err != tc.wantErr { t.Fatalf("expected err %v, got %v", tc.wantErr, err) } diff --git a/plugins/access-control/repositories/roles_repository.go b/plugins/access-control/repositories/roles_repository.go index 718c5d81..5d876008 100644 --- a/plugins/access-control/repositories/roles_repository.go +++ b/plugins/access-control/repositories/roles_repository.go @@ -21,7 +21,7 @@ func NewBunRolesRepository(db bun.IDB) *BunRolesRepository { func (r *BunRolesRepository) CreateRole(ctx context.Context, role *types.Role) error { _, err := r.db.NewInsert().Model(role).Exec(ctx) - return wrapRepositoryError("create role", err) + return err } func (r *BunRolesRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { @@ -46,6 +46,19 @@ func (r *BunRolesRepository) GetRoleByID(ctx context.Context, roleID string) (*t return role, nil } +func (r *BunRolesRepository) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + role := new(types.Role) + err := r.db.NewSelect().Model(role).Where("name = ?", roleName).Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get role by name: %w", err) + } + + return role, nil +} + func (r *BunRolesRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { query := r.db.NewUpdate(). Model((*types.Role)(nil)). @@ -62,7 +75,7 @@ func (r *BunRolesRepository) UpdateRole(ctx context.Context, roleID string, name result, err := query.Exec(ctx) if err != nil { - return false, wrapRepositoryError("update role", err) + return false, err } affected, err := result.RowsAffected() @@ -76,7 +89,7 @@ func (r *BunRolesRepository) UpdateRole(ctx context.Context, roleID string, name func (r *BunRolesRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { result, err := r.db.NewDelete().Model((*types.Role)(nil)).Where("id = ?", roleID).Exec(ctx) if err != nil { - return false, fmt.Errorf("failed to delete role: %w", err) + return false, err } affected, err := result.RowsAffected() diff --git a/plugins/access-control/repositories/roles_repository_test.go b/plugins/access-control/repositories/roles_repository_test.go index 13379f73..9905eff4 100644 --- a/plugins/access-control/repositories/roles_repository_test.go +++ b/plugins/access-control/repositories/roles_repository_test.go @@ -2,9 +2,9 @@ package repositories import ( "context" + "strings" "testing" - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" plugintests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" ) @@ -32,7 +32,12 @@ func TestBunRolesRepositoryCreateRole(t *testing.T) { name: "duplicate name returns conflict", role: &types.Role{ID: "r2", Name: "editor", Description: new("Duplicate role"), IsSystem: false}, seed: &types.Role{ID: "r1", Name: "editor", Description: new("Original role"), IsSystem: false}, - wantErr: accesscontrolconstants.ErrConflict, + wantErr: nil, + }, + { + name: "query error returns wrapped error", + role: &types.Role{ID: "r3", Name: "reviewer", Description: new("Reviewer role"), IsSystem: false}, + wantErr: nil, }, } @@ -50,7 +55,31 @@ func TestBunRolesRepositoryCreateRole(t *testing.T) { } } + if tc.name == "query error returns wrapped error" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + err := repo.CreateRole(ctx, tc.role) + if tc.name == "query error returns wrapped error" { + if err == nil { + t.Fatal("expected error, got nil") + } + if err.Error() != "sql: database is closed" { + t.Fatalf("expected direct db error, got %v", err) + } + return + } + if tc.name == "duplicate name returns conflict" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "UNIQUE constraint failed: access_control_roles.name") { + t.Fatalf("expected raw unique constraint error, got %v", err) + } + return + } if err != tc.wantErr { t.Fatalf("expected err %v, got %v", tc.wantErr, err) } @@ -83,11 +112,12 @@ func TestBunRolesRepositoryGetAllRoles(t *testing.T) { t.Parallel() tests := []struct { - name string - seedRoles []*types.Role - wantIDs []string - wantNames []string - wantDescs []*string + name string + seedRoles []*types.Role + wantIDs []string + wantNames []string + wantDescs []*string + wantErrMsg string }{ { name: "empty result", @@ -105,6 +135,10 @@ func TestBunRolesRepositoryGetAllRoles(t *testing.T) { wantNames: []string{"viewer", "editor"}, wantDescs: []*string{new("Viewer role"), new("Editor role")}, }, + { + name: "query error", + wantErrMsg: "failed to get roles", + }, } for _, tc := range tests { @@ -121,7 +155,25 @@ func TestBunRolesRepositoryGetAllRoles(t *testing.T) { } } + if tc.wantErrMsg != "" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + roles, err := repo.GetAllRoles(ctx) + if tc.wantErrMsg != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + if roles != nil { + t.Fatalf("expected nil roles on error, got %#v", roles) + } + return + } if err != nil { t.Fatalf("failed to get roles: %v", err) } @@ -157,6 +209,7 @@ func TestBunRolesRepositoryGetRoleByID(t *testing.T) { wantName string wantDesc *string wantSystem bool + wantErrMsg string }{ { name: "not found", @@ -171,6 +224,11 @@ func TestBunRolesRepositoryGetRoleByID(t *testing.T) { wantDesc: new("Editor role"), wantSystem: true, }, + { + name: "query error", + roleID: "r1", + wantErrMsg: "failed to get role by id", + }, } for _, tc := range tests { @@ -187,7 +245,25 @@ func TestBunRolesRepositoryGetRoleByID(t *testing.T) { } } + if tc.wantErrMsg != "" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + role, err := repo.GetRoleByID(ctx, tc.roleID) + if tc.wantErrMsg != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if role != nil { + t.Fatalf("expected nil role on error, got %#v", role) + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + return + } if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -210,6 +286,97 @@ func TestBunRolesRepositoryGetRoleByID(t *testing.T) { } } +func TestBunRolesRepositoryGetRoleByName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + roleName string + seedRole *types.Role + wantNil bool + wantName string + wantDesc *string + wantSystem bool + wantErr bool + wantErrMsg string + }{ + { + name: "not found", + roleName: "missing", + wantNil: true, + }, + { + name: "success", + roleName: "editor", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: true}, + wantName: "editor", + wantDesc: new("Editor role"), + wantSystem: true, + }, + { + name: "query error", + roleName: "editor", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: false}, + wantErr: true, + wantErrMsg: "failed to get role by name", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seedRole != nil { + if err := repo.CreateRole(ctx, tc.seedRole); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + if tc.wantErr { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + + role, err := repo.GetRoleByName(ctx, tc.roleName) + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if role != nil { + t.Fatalf("expected nil role on error, got %#v", role) + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantNil { + if role != nil { + t.Fatalf("expected nil role, got %#v", role) + } + return + } + if role == nil { + t.Fatal("expected role, got nil") + } + if role.Name != tc.wantName || role.IsSystem != tc.wantSystem { + t.Fatalf("unexpected role: %#v", role) + } + if !stringPtrEqual(role.Description, tc.wantDesc) { + t.Fatalf("expected description %#v, got %#v", tc.wantDesc, role.Description) + } + }) + } +} + func TestBunRolesRepositoryUpdateRole(t *testing.T) { t.Parallel() @@ -222,6 +389,7 @@ func TestBunRolesRepositoryUpdateRole(t *testing.T) { wantUpdated bool wantName *string wantDesc *string + wantErrMsg string }{ { name: "missing role", @@ -266,6 +434,12 @@ func TestBunRolesRepositoryUpdateRole(t *testing.T) { wantName: new("reviewer"), wantDesc: new("Reviewer role"), }, + { + name: "query error", + roleID: "r5", + nameValue: new("updated"), + wantErrMsg: "sql: database is closed", + }, } for _, tc := range tests { @@ -282,7 +456,25 @@ func TestBunRolesRepositoryUpdateRole(t *testing.T) { } } + if tc.wantErrMsg != "" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + updated, err := repo.UpdateRole(ctx, tc.roleID, tc.nameValue, tc.description) + if tc.wantErrMsg != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if updated { + t.Fatal("expected updated=false on error") + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + return + } if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -322,6 +514,7 @@ func TestBunRolesRepositoryDeleteRole(t *testing.T) { seedRole *types.Role roleID string wantDeleted bool + wantErrMsg string }{ { name: "missing role", @@ -334,6 +527,11 @@ func TestBunRolesRepositoryDeleteRole(t *testing.T) { roleID: "r1", wantDeleted: true, }, + { + name: "query error", + roleID: "r5", + wantErrMsg: "sql: database is closed", + }, } for _, tc := range tests { @@ -350,7 +548,25 @@ func TestBunRolesRepositoryDeleteRole(t *testing.T) { } } + if tc.wantErrMsg != "" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + deleted, err := repo.DeleteRole(ctx, tc.roleID) + if tc.wantErrMsg != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if deleted { + t.Fatal("expected deleted=false on error") + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + return + } if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/plugins/access-control/repositories/user_access_repository.go b/plugins/access-control/repositories/user_access_repository.go deleted file mode 100644 index ccf0cbbb..00000000 --- a/plugins/access-control/repositories/user_access_repository.go +++ /dev/null @@ -1,190 +0,0 @@ -package repositories - -import ( - "context" - "database/sql" - "fmt" - "time" - - "github.com/uptrace/bun" - - "github.com/Authula/authula/plugins/access-control/types" -) - -type BunUserAccessRepository struct { - db bun.IDB -} - -func NewBunUserAccessRepository(db bun.IDB) *BunUserAccessRepository { - return &BunUserAccessRepository{db: db} -} - -func (r *BunUserAccessRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { - count, err := r.db.NewSelect(). - Model((*types.UserRole)(nil)). - Where("role_id = ?", roleID). - Count(ctx) - if err != nil { - return 0, fmt.Errorf("failed to count role user assignments: %w", err) - } - - return count, nil -} - -func (r *BunUserAccessRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { - count, err := r.db.NewSelect(). - Model((*types.RolePermission)(nil)). - Where("permission_id = ?", permissionID). - Count(ctx) - if err != nil { - return 0, fmt.Errorf("failed to count permission role assignments: %w", err) - } - - return count, nil -} - -type userEffectivePermissionRow struct { - PermissionID string `bun:"permission_id"` - PermissionKey string `bun:"permission_key"` - PermissionDescription *string `bun:"permission_description"` - SourceRoleID string `bun:"source_role_id"` - SourceRoleName string `bun:"source_role_name"` - GrantedByUserID *string `bun:"granted_by_user_id"` - GrantedAt *time.Time `bun:"granted_at"` -} - -func (r *BunUserAccessRepository) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - var rows []userEffectivePermissionRow - now := time.Now().UTC() - err := r.db.NewSelect(). - TableExpr("access_control_user_roles pur"). - ColumnExpr("pp.id AS permission_id"). - ColumnExpr("pp.key AS permission_key"). - ColumnExpr("pp.description AS permission_description"). - ColumnExpr("pr.id AS source_role_id"). - ColumnExpr("pr.name AS source_role_name"). - ColumnExpr("prp.granted_by_user_id AS granted_by_user_id"). - ColumnExpr("prp.granted_at AS granted_at"). - Join("JOIN access_control_role_permissions prp ON prp.role_id = pur.role_id"). - Join("JOIN access_control_permissions pp ON pp.id = prp.permission_id"). - Join("JOIN access_control_roles pr ON pr.id = pur.role_id"). - Where("pur.user_id = ?", userID). - Where("pur.expires_at IS NULL OR pur.expires_at > ?", now). - OrderExpr("pp.key ASC"). - OrderExpr("pr.name ASC"). - OrderExpr("CASE WHEN prp.granted_at IS NULL THEN 1 ELSE 0 END ASC"). - OrderExpr("prp.granted_at DESC"). - Scan(ctx, &rows) - if err != nil { - return nil, fmt.Errorf("failed to get user effective permissions: %w", err) - } - if rows == nil { - return []types.UserPermissionInfo{}, nil - } - - permissions := make([]types.UserPermissionInfo, 0) - permissionIndex := make(map[string]int) - - for _, row := range rows { - idx, exists := permissionIndex[row.PermissionID] - if !exists { - permissions = append(permissions, types.UserPermissionInfo{ - PermissionID: row.PermissionID, - PermissionKey: row.PermissionKey, - PermissionDescription: row.PermissionDescription, - }) - idx = len(permissions) - 1 - permissionIndex[row.PermissionID] = idx - } - - source := types.PermissionGrantSource{ - RoleID: row.SourceRoleID, - RoleName: row.SourceRoleName, - GrantedByUserID: row.GrantedByUserID, - GrantedAt: row.GrantedAt, - } - permissions[idx].Sources = append(permissions[idx].Sources, source) - } - - return permissions, nil -} - -type userWithPermissionRow struct { - UserID string `bun:"user_id"` - UserName string `bun:"user_name"` - UserEmail string `bun:"user_email"` - EmailVerified bool `bun:"email_verified"` - Image *string - Metadata []byte - CreatedAt time.Time `bun:"created_at"` - UpdatedAt time.Time `bun:"updated_at"` - PermissionID *string `bun:"permission_id"` - PermissionKey *string `bun:"permission_key"` -} - -func (r *BunUserAccessRepository) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - var rows []userWithPermissionRow - now := time.Now().UTC() - err := r.db.NewSelect(). - TableExpr("users u"). - ColumnExpr("u.id AS user_id"). - ColumnExpr("u.name AS user_name"). - ColumnExpr("u.email AS user_email"). - ColumnExpr("u.email_verified AS email_verified"). - ColumnExpr("u.image AS image"). - ColumnExpr("u.metadata AS metadata"). - ColumnExpr("u.created_at AS created_at"). - ColumnExpr("u.updated_at AS updated_at"). - ColumnExpr("ap.id AS permission_id"). - ColumnExpr("ap.key AS permission_key"). - Join("LEFT JOIN access_control_user_roles aur ON aur.user_id = u.id AND (aur.expires_at IS NULL OR aur.expires_at > ?)", now). - Join("LEFT JOIN access_control_role_permissions arp ON arp.role_id = aur.role_id"). - Join("LEFT JOIN access_control_permissions ap ON ap.id = arp.permission_id"). - Where("u.id = ?", userID). - OrderExpr("ap.key ASC"). - Scan(ctx, &rows) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get user with permissions: %w", err) - } - if len(rows) == 0 { - return nil, nil - } - - result := &types.UserWithPermissions{ - User: mapRowToUser(rows[0]), - } - - seen := make(map[string]struct{}) - for _, row := range rows { - if row.PermissionID == nil || *row.PermissionID == "" { - continue - } - if _, ok := seen[*row.PermissionID]; ok { - continue - } - seen[*row.PermissionID] = struct{}{} - permissionKey := "" - if row.PermissionKey != nil { - permissionKey = *row.PermissionKey - } - result.Permissions = append(result.Permissions, types.UserPermissionInfo{PermissionID: *row.PermissionID, PermissionKey: permissionKey}) - } - - return result, nil -} - -func (r userWithPermissionRow) GetUserID() string { return r.UserID } -func (r userWithPermissionRow) GetUserName() string { return r.UserName } -func (r userWithPermissionRow) GetUserEmail() string { return r.UserEmail } -func (r userWithPermissionRow) GetEmailVerified() bool { return r.EmailVerified } -func (r userWithPermissionRow) GetImage() *string { return r.Image } -func (r userWithPermissionRow) GetMetadata() []byte { return r.Metadata } -func (r userWithPermissionRow) GetCreatedAt() time.Time { - return r.CreatedAt -} -func (r userWithPermissionRow) GetUpdatedAt() time.Time { - return r.UpdatedAt -} diff --git a/plugins/access-control/repositories/user_access_repository_test.go b/plugins/access-control/repositories/user_access_repository_test.go deleted file mode 100644 index b7f9dbde..00000000 --- a/plugins/access-control/repositories/user_access_repository_test.go +++ /dev/null @@ -1,272 +0,0 @@ -package repositories - -import ( - "context" - "testing" - "time" - - plugintests "github.com/Authula/authula/plugins/access-control/tests" - "github.com/Authula/authula/plugins/access-control/types" -) - -func TestBunUserAccessRepositoryCounts(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) - roleID string - permissionID string - wantRoleCount int - wantPermissionCount int - }{ - { - name: "empty counts", - roleID: "role-missing", - permissionID: "perm-missing", - wantRoleCount: 0, - wantPermissionCount: 0, - }, - { - name: "counts assigned records", - roleID: "role-1", - permissionID: "perm-1", - seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { - if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { - panic(err) - } - if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { - panic(err) - } - if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", nil); err != nil { - panic(err) - } - if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-1", nil, nil); err != nil { - panic(err) - } - }, - wantRoleCount: 1, - wantPermissionCount: 1, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - db := plugintests.SetupRepoDB(t) - rolesRepo := NewBunRolesRepository(db) - permissionsRepo := NewBunPermissionsRepository(db) - rolePermissionsRepo := NewBunRolePermissionsRepository(db) - userRolesRepo := NewBunUserRolesRepository(db) - repo := NewBunUserAccessRepository(db) - ctx := context.Background() - - if tc.seed != nil { - tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) - } - - roleCount, err := repo.CountUserAssignmentsByRoleID(ctx, tc.roleID) - if err != nil { - t.Fatalf("failed to count role assignments: %v", err) - } - permissionCount, err := repo.CountRoleAssignmentsByPermissionID(ctx, tc.permissionID) - if err != nil { - t.Fatalf("failed to count permission assignments: %v", err) - } - - if roleCount != tc.wantRoleCount { - t.Fatalf("expected role count %d, got %d", tc.wantRoleCount, roleCount) - } - if permissionCount != tc.wantPermissionCount { - t.Fatalf("expected permission count %d, got %d", tc.wantPermissionCount, permissionCount) - } - }) - } -} - -func TestBunUserAccessRepositoryGetUserEffectivePermissions(t *testing.T) { - t.Parallel() - - activeUntil := time.Now().UTC().Add(1 * time.Hour) - expiredAt := time.Now().UTC().Add(-1 * time.Hour) - description := new("Read users") - grantedBy := "u2" - - tests := []struct { - name string - seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) - userID string - wantKeys []string - wantSources int - }{ - { - name: "empty result", - userID: "missing-user", - wantKeys: []string{}, - }, - { - name: "aggregates active permissions and ignores expired roles", - userID: "u1", - seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { - if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { - panic(err) - } - if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { - panic(err) - } - if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "posts.read", Description: description}); err != nil { - panic(err) - } - if err := rolePermissionsRepo.AddRolePermission(ctx, "r1", "p1", &grantedBy); err != nil { - panic(err) - } - if err := rolePermissionsRepo.AddRolePermission(ctx, "r2", "p1", &grantedBy); err != nil { - panic(err) - } - if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, &activeUntil); err != nil { - panic(err) - } - if err := userRolesRepo.AssignUserRole(ctx, "u1", "r2", nil, &expiredAt); err != nil { - panic(err) - } - }, - wantKeys: []string{"posts.read"}, - wantSources: 1, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - db := plugintests.SetupRepoDB(t) - rolesRepo := NewBunRolesRepository(db) - permissionsRepo := NewBunPermissionsRepository(db) - rolePermissionsRepo := NewBunRolePermissionsRepository(db) - userRolesRepo := NewBunUserRolesRepository(db) - repo := NewBunUserAccessRepository(db) - ctx := context.Background() - - if tc.seed != nil { - tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) - } - - permissions, err := repo.GetUserEffectivePermissions(ctx, tc.userID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if permissions == nil { - t.Fatal("expected permissions slice, got nil") - } - if len(permissions) != len(tc.wantKeys) { - t.Fatalf("expected %d permissions, got %d", len(tc.wantKeys), len(permissions)) - } - for i, wantKey := range tc.wantKeys { - if permissions[i].PermissionKey != wantKey { - t.Fatalf("expected permission key %s at index %d, got %#v", wantKey, i, permissions[i]) - } - if len(permissions[i].Sources) != tc.wantSources { - t.Fatalf("expected %d sources, got %d", tc.wantSources, len(permissions[i].Sources)) - } - if permissions[i].PermissionDescription == nil || *permissions[i].PermissionDescription != "Read users" { - t.Fatalf("expected permission description to be populated, got %#v", permissions[i].PermissionDescription) - } - } - }) - } -} - -func TestBunUserAccessRepositoryGetUserWithPermissionsByID(t *testing.T) { - t.Parallel() - - permissionDescription := new("Read users") - - tests := []struct { - name string - seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) - userID string - wantNil bool - wantPermissionKeys []string - }{ - { - name: "not found", - userID: "missing-user", - wantNil: true, - }, - { - name: "success", - userID: "u1", - seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { - if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { - panic(err) - } - if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { - panic(err) - } - if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "posts.read", Description: permissionDescription}); err != nil { - panic(err) - } - if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "p2", Key: "posts.write"}); err != nil { - panic(err) - } - if err := rolePermissionsRepo.AddRolePermission(ctx, "r1", "p1", nil); err != nil { - panic(err) - } - if err := rolePermissionsRepo.AddRolePermission(ctx, "r2", "p2", nil); err != nil { - panic(err) - } - if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { - panic(err) - } - if err := userRolesRepo.AssignUserRole(ctx, "u1", "r2", nil, nil); err != nil { - panic(err) - } - }, - wantPermissionKeys: []string{"posts.read", "posts.write"}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - db := plugintests.SetupRepoDB(t) - rolesRepo := NewBunRolesRepository(db) - permissionsRepo := NewBunPermissionsRepository(db) - rolePermissionsRepo := NewBunRolePermissionsRepository(db) - userRolesRepo := NewBunUserRolesRepository(db) - repo := NewBunUserAccessRepository(db) - ctx := context.Background() - - if tc.seed != nil { - tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) - } - - userWithPermissions, err := repo.GetUserWithPermissionsByID(ctx, tc.userID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if tc.wantNil { - if userWithPermissions != nil { - t.Fatalf("expected nil user, got %#v", userWithPermissions) - } - return - } - if userWithPermissions == nil { - t.Fatal("expected user, got nil") - } - if userWithPermissions.User.ID != tc.userID { - t.Fatalf("expected user ID %s, got %s", tc.userID, userWithPermissions.User.ID) - } - if len(userWithPermissions.Permissions) != len(tc.wantPermissionKeys) { - t.Fatalf("expected %d permissions, got %d", len(tc.wantPermissionKeys), len(userWithPermissions.Permissions)) - } - for i, wantKey := range tc.wantPermissionKeys { - if userWithPermissions.Permissions[i].PermissionKey != wantKey { - t.Fatalf("expected permission key %s at index %d, got %#v", wantKey, i, userWithPermissions.Permissions[i]) - } - } - }) - } -} diff --git a/plugins/access-control/repositories/user_permissions_repository.go b/plugins/access-control/repositories/user_permissions_repository.go new file mode 100644 index 00000000..198e3c14 --- /dev/null +++ b/plugins/access-control/repositories/user_permissions_repository.go @@ -0,0 +1,112 @@ +package repositories + +import ( + "context" + "fmt" + "time" + + "github.com/Authula/authula/plugins/access-control/types" + "github.com/uptrace/bun" +) + +type BunUserPermissionsRepository struct { + db bun.IDB +} + +func NewBunUserPermissionsRepository(db bun.IDB) *BunUserPermissionsRepository { + return &BunUserPermissionsRepository{db: db} +} + +type userPermissionGrantRow struct { + PermissionID string `bun:"permission_id"` + PermissionKey string `bun:"permission_key"` + PermissionDescription *string `bun:"permission_description"` + GrantedByUserID *string `bun:"granted_by_user_id"` + GrantedAt *time.Time `bun:"granted_at"` + RoleID string `bun:"role_id"` + RoleName string `bun:"role_name"` +} + +func (r *BunUserPermissionsRepository) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + var scanned []userPermissionGrantRow + err := r.db.NewSelect(). + TableExpr("access_control_user_roles acur"). + ColumnExpr("ap.id AS permission_id"). + ColumnExpr("ap.key AS permission_key"). + ColumnExpr("ap.description AS permission_description"). + ColumnExpr("arp.granted_by_user_id AS granted_by_user_id"). + ColumnExpr("arp.granted_at AS granted_at"). + ColumnExpr("acr.id AS role_id"). + ColumnExpr("acr.name AS role_name"). + Join("JOIN access_control_roles acr ON acr.id = acur.role_id"). + Join("JOIN access_control_role_permissions arp ON arp.role_id = acr.id"). + Join("JOIN access_control_permissions ap ON ap.id = arp.permission_id"). + Where("acur.user_id = ?", userID). + Where("(acur.expires_at IS NULL OR acur.expires_at > CURRENT_TIMESTAMP)"). + OrderExpr("ap.key ASC, acr.name ASC, arp.granted_at ASC"). + Scan(ctx, &scanned) + if err != nil { + return nil, fmt.Errorf("failed to get user permissions: %w", err) + } + + if len(scanned) == 0 { + return []types.UserPermissionInfo{}, nil + } + + permissionsByID := make(map[string]*types.UserPermissionInfo, len(scanned)) + orderedPermissionIDs := make([]string, 0, len(scanned)) + + for _, row := range scanned { + permission, exists := permissionsByID[row.PermissionID] + if !exists { + permission = &types.UserPermissionInfo{ + PermissionID: row.PermissionID, + PermissionKey: row.PermissionKey, + PermissionDescription: row.PermissionDescription, + GrantedByUserID: row.GrantedByUserID, + GrantedAt: row.GrantedAt, + Sources: []types.PermissionGrantSource{}, + } + permissionsByID[row.PermissionID] = permission + orderedPermissionIDs = append(orderedPermissionIDs, row.PermissionID) + } + + permission.Sources = append(permission.Sources, types.PermissionGrantSource{ + RoleID: row.RoleID, + RoleName: row.RoleName, + GrantedByUserID: row.GrantedByUserID, + GrantedAt: row.GrantedAt, + }) + } + + permissions := make([]types.UserPermissionInfo, 0, len(orderedPermissionIDs)) + for _, permissionID := range orderedPermissionIDs { + permissions = append(permissions, *permissionsByID[permissionID]) + } + + return permissions, nil +} + +func (r *BunUserPermissionsRepository) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + if len(permissionKeys) == 0 { + return true, nil + } + + permissions, err := r.GetUserPermissions(ctx, userID) + if err != nil { + return false, err + } + + granted := make(map[string]struct{}, len(permissions)) + for _, permission := range permissions { + granted[permission.PermissionKey] = struct{}{} + } + + for _, permissionKey := range permissionKeys { + if _, ok := granted[permissionKey]; !ok { + return false, nil + } + } + + return true, nil +} diff --git a/plugins/access-control/repositories/user_permissions_repository_test.go b/plugins/access-control/repositories/user_permissions_repository_test.go new file mode 100644 index 00000000..2fbfee27 --- /dev/null +++ b/plugins/access-control/repositories/user_permissions_repository_test.go @@ -0,0 +1,209 @@ +package repositories + +import ( + "context" + "testing" + + plugintests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestBunUserPermissionsRepositoryGetUserPermissions(t *testing.T) { + t.Parallel() + + description := new(string) + *description = "Read users" + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) + userID string + wantEmpty bool + wantKeys []string + wantSrcs []int + }{ + { + name: "empty result", + userID: "missing-user", + wantEmpty: true, + }, + { + name: "aggregates permissions across roles and ignores expired roles", + userID: "u1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-2", Name: "viewer"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read", Description: description}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-2", Key: "users.write"}); err != nil { + panic(err) + } + grantedBy := "u2" + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", &grantedBy); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-2", "perm-1", nil); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-2", "perm-2", &grantedBy); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-1", &grantedBy, nil); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-2", nil, nil); err != nil { + panic(err) + } + }, + wantKeys: []string{"users.read", "users.write"}, + wantSrcs: []int{2, 1}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + repo := NewBunUserPermissionsRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) + } + + permissions, err := repo.GetUserPermissions(ctx, tc.userID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if permissions == nil { + t.Fatal("expected permissions slice, got nil") + } + if tc.wantEmpty { + if len(permissions) != 0 { + t.Fatalf("expected no permissions, got %#v", permissions) + } + return + } + if len(permissions) != len(tc.wantKeys) { + t.Fatalf("expected %d permissions, got %d", len(tc.wantKeys), len(permissions)) + } + for i := range tc.wantKeys { + if permissions[i].PermissionKey != tc.wantKeys[i] { + t.Fatalf("unexpected permission key at %d: %#v", i, permissions[i]) + } + if permissions[i].PermissionKey == "users.read" && !sameStringPtr(permissions[i].PermissionDescription, description) { + t.Fatalf("unexpected permission description at %d: %#v", i, permissions[i]) + } + if permissions[i].GrantedAt == nil { + t.Fatalf("expected granted_at to be populated at %d", i) + } + if len(permissions[i].Sources) != tc.wantSrcs[i] { + t.Fatalf("expected %d sources at %d, got %#v", tc.wantSrcs[i], i, permissions[i]) + } + } + }) + } +} + +func TestBunUserPermissionsRepositoryHasPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) + userID string + permissionKeys []string + wantHasPerms bool + }{ + { + name: "empty permission list returns true", + userID: "u1", + permissionKeys: []string{}, + wantHasPerms: true, + }, + { + name: "missing permission returns false", + userID: "u1", + permissionKeys: []string{"users.read", "users.delete"}, + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", nil); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-1", nil, nil); err != nil { + panic(err) + } + }, + wantHasPerms: false, + }, + { + name: "success", + userID: "u1", + permissionKeys: []string{"users.read"}, + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + grantedBy := "u2" + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", &grantedBy); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-1", nil, nil); err != nil { + panic(err) + } + }, + wantHasPerms: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + repo := NewBunUserPermissionsRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) + } + + hasPermissions, err := repo.HasPermissions(ctx, tc.userID, tc.permissionKeys) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if hasPermissions != tc.wantHasPerms { + t.Fatalf("expected hasPermissions=%v, got %v", tc.wantHasPerms, hasPermissions) + } + }) + } +} + +func sameStringPtr(got, want *string) bool { + if got == nil || want == nil { + return got == nil && want == nil + } + return *got == *want +} diff --git a/plugins/access-control/repositories/user_roles_repository.go b/plugins/access-control/repositories/user_roles_repository.go index 3dc6af51..a773364d 100644 --- a/plugins/access-control/repositories/user_roles_repository.go +++ b/plugins/access-control/repositories/user_roles_repository.go @@ -2,7 +2,6 @@ package repositories import ( "context" - "database/sql" "fmt" "time" @@ -56,56 +55,6 @@ type userWithRoleRow struct { RoleName *string `bun:"role_name"` } -func (r *BunUserRolesRepository) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - var rows []userWithRoleRow - now := time.Now().UTC() - err := r.db.NewSelect(). - TableExpr("users u"). - ColumnExpr("u.id AS user_id"). - ColumnExpr("u.name AS user_name"). - ColumnExpr("u.email AS user_email"). - ColumnExpr("u.email_verified AS email_verified"). - ColumnExpr("u.image AS image"). - ColumnExpr("u.metadata AS metadata"). - ColumnExpr("u.created_at AS created_at"). - ColumnExpr("u.updated_at AS updated_at"). - ColumnExpr("pr.id AS role_id"). - ColumnExpr("pr.name AS role_name"). - Join("LEFT JOIN access_control_user_roles pur ON pur.user_id = u.id AND (pur.expires_at IS NULL OR pur.expires_at > ?)", now). - Join("LEFT JOIN access_control_roles pr ON pr.id = pur.role_id"). - Where("u.id = ?", userID). - OrderExpr("pr.name ASC"). - Scan(ctx, &rows) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get user with roles: %w", err) - } - if len(rows) == 0 { - return nil, nil - } - - result := &types.UserWithRoles{User: mapRowToUser(rows[0])} - seen := make(map[string]struct{}) - for _, row := range rows { - if row.RoleID == nil || *row.RoleID == "" { - continue - } - if _, ok := seen[*row.RoleID]; ok { - continue - } - seen[*row.RoleID] = struct{}{} - roleName := "" - if row.RoleName != nil { - roleName = *row.RoleName - } - result.Roles = append(result.Roles, types.UserRoleInfo{RoleID: *row.RoleID, RoleName: roleName}) - } - - return result, nil -} - func (r *BunUserRolesRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewDelete().Model((*types.UserRole)(nil)).Where("user_id = ?", userID).Exec(ctx); err != nil { @@ -121,7 +70,7 @@ func (r *BunUserRolesRepository) ReplaceUserRoles(ctx context.Context, userID st AssignedAt: now, } if _, err := tx.NewInsert().Model(ur).Exec(ctx); err != nil { - return wrapRepositoryError("insert user role", err) + return err } } @@ -139,7 +88,10 @@ func (r *BunUserRolesRepository) AssignUserRole(ctx context.Context, userID stri } _, err := r.db.NewInsert().Model(ur).Exec(ctx) - return wrapRepositoryError("assign user role", err) + if err != nil { + return err + } + return nil } func (r *BunUserRolesRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { @@ -149,12 +101,23 @@ func (r *BunUserRolesRepository) RemoveUserRole(ctx context.Context, userID stri Where("role_id = ?", roleID). Exec(ctx) if err != nil { - return fmt.Errorf("failed to remove user role: %w", err) + return err } return nil } +func (r *BunUserRolesRepository) CountUsersByRole(ctx context.Context, roleID string) (int, error) { + count, err := r.db.NewSelect(). + Model((*types.UserRole)(nil)). + Where("role_id = ?", roleID). + Count(ctx) + if err != nil { + return 0, fmt.Errorf("failed to count users by role: %w", err) + } + return count, nil +} + type userRow interface { GetUserID() string GetUserName() string diff --git a/plugins/access-control/repositories/user_roles_repository_test.go b/plugins/access-control/repositories/user_roles_repository_test.go index 8024bc1a..28583619 100644 --- a/plugins/access-control/repositories/user_roles_repository_test.go +++ b/plugins/access-control/repositories/user_roles_repository_test.go @@ -3,10 +3,10 @@ package repositories import ( "context" "reflect" + "strings" "testing" "time" - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" plugintests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" ) @@ -14,7 +14,8 @@ import ( func TestBunUserRolesRepositoryGetUserRoles(t *testing.T) { t.Parallel() - futureExpiry := time.Date(2026, 3, 30, 12, 0, 0, 0, time.UTC) + now := time.Now().UTC() + futureExpiry := time.Unix(now.Add(24*time.Hour).Unix(), 0).UTC() roleDescription := new("Editor role") assignedBy := new("u2") @@ -106,98 +107,6 @@ func TestBunUserRolesRepositoryGetUserRoles(t *testing.T) { } } -func TestBunUserRolesRepositoryGetUserWithRolesByID(t *testing.T) { - t.Parallel() - - activeExpiry := time.Date(2026, 3, 30, 12, 0, 0, 0, time.UTC) - - tests := []struct { - name string - seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) - userID string - wantNil bool - wantRoles []types.UserRoleInfo - }{ - { - name: "not found", - userID: "missing-user", - wantNil: true, - }, - { - name: "returns user with active roles only", - userID: "u1", - seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { - if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { - panic(err) - } - if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { - panic(err) - } - if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, &activeExpiry); err != nil { - panic(err) - } - if err := userRolesRepo.AssignUserRole(ctx, "u1", "r2", nil, new(time.Date(2026, 3, 28, 12, 0, 0, 0, time.UTC))); err != nil { - panic(err) - } - }, - wantRoles: []types.UserRoleInfo{ - { - RoleID: "r1", - RoleName: "editor", - }, - }, - }, - { - name: "returns user with no roles", - userID: "u2", - wantRoles: []types.UserRoleInfo{}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - db := plugintests.SetupRepoDB(t) - rolesRepo := NewBunRolesRepository(db) - userRolesRepo := NewBunUserRolesRepository(db) - ctx := context.Background() - - if tc.seed != nil { - tc.seed(rolesRepo, userRolesRepo, ctx) - } - - userWithRoles, err := userRolesRepo.GetUserWithRolesByID(ctx, tc.userID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if tc.wantNil { - if userWithRoles != nil { - t.Fatalf("expected nil user, got %#v", userWithRoles) - } - return - } - if userWithRoles == nil { - t.Fatal("expected user, got nil") - } - if userWithRoles.User.ID != tc.userID { - t.Fatalf("expected user ID %s, got %s", tc.userID, userWithRoles.User.ID) - } - if userWithRoles.User.Name == "" || userWithRoles.User.Email == "" { - t.Fatalf("expected user fields to be populated, got %#v", userWithRoles.User) - } - if len(userWithRoles.Roles) != len(tc.wantRoles) { - t.Fatalf("expected %d roles, got %#v", len(tc.wantRoles), userWithRoles.Roles) - } - for i := range tc.wantRoles { - if userWithRoles.Roles[i].RoleID != tc.wantRoles[i].RoleID || userWithRoles.Roles[i].RoleName != tc.wantRoles[i].RoleName { - t.Fatalf("unexpected role at %d: %#v", i, userWithRoles.Roles[i]) - } - } - }) - } -} - func TestBunUserRolesRepositoryReplaceUserRoles(t *testing.T) { t.Parallel() @@ -269,7 +178,8 @@ func TestBunUserRolesRepositoryReplaceUserRoles(t *testing.T) { func TestBunUserRolesRepositoryAssignUserRole(t *testing.T) { t.Parallel() - futureExpiry := time.Date(2026, 4, 1, 12, 0, 0, 0, time.UTC) + now := time.Now().UTC() + futureExpiry := time.Unix(now.Add(24*time.Hour).Unix(), 0).UTC() tests := []struct { name string @@ -302,7 +212,6 @@ func TestBunUserRolesRepositoryAssignUserRole(t *testing.T) { panic(err) } }, - wantErr: accesscontrolconstants.ErrConflict, }, } @@ -320,6 +229,15 @@ func TestBunUserRolesRepositoryAssignUserRole(t *testing.T) { } err := userRolesRepo.AssignUserRole(ctx, tc.userID, tc.roleID, nil, tc.expiresAt) + if tc.name == "duplicate assignment returns conflict" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "UNIQUE constraint failed: access_control_user_roles.user_id, access_control_user_roles.role_id") { + t.Fatalf("expected raw unique constraint error, got %v", err) + } + return + } if err != tc.wantErr { t.Fatalf("expected err %v, got %v", tc.wantErr, err) } diff --git a/plugins/access-control/routes.go b/plugins/access-control/routes.go index db4775fe..f1ee6d49 100644 --- a/plugins/access-control/routes.go +++ b/plugins/access-control/routes.go @@ -13,7 +13,7 @@ type routeUseCases struct { permissions *usecases.PermissionsUseCase rolePermissions *usecases.RolePermissionsUseCase userRoles *usecases.UserRolesUseCase - userAccess *usecases.UserAccessUseCase + userPermissions *usecases.UserPermissionsUseCase } func newRouteUseCases(api *API) routeUseCases { @@ -22,7 +22,7 @@ func newRouteUseCases(api *API) routeUseCases { permissions: api.useCases.PermissionsUseCase(), rolePermissions: api.useCases.RolePermissionsUseCase(), userRoles: api.useCases.UserRolesUseCase(), - userAccess: api.useCases.UserAccessUseCase(), + userPermissions: api.useCases.UserPermissionsUseCase(), } } @@ -30,28 +30,35 @@ func Routes(api *API) []models.Route { usecases := newRouteUseCases(api) return []models.Route{ - // Roles and permissions + // Roles {Method: http.MethodPost, Path: "/access-control/roles", Handler: handlers.NewCreateRoleHandler(usecases.roles).Handler()}, {Method: http.MethodGet, Path: "/access-control/roles", Handler: handlers.NewGetAllRolesHandler(usecases.roles).Handler()}, + {Method: http.MethodGet, Path: "/access-control/roles/by-name/{role_name}", Handler: handlers.NewGetRoleByNameHandler(usecases.roles).Handler()}, {Method: http.MethodGet, Path: "/access-control/roles/{role_id}", Handler: handlers.NewGetRoleByIDHandler(usecases.roles).Handler()}, {Method: http.MethodPatch, Path: "/access-control/roles/{role_id}", Handler: handlers.NewUpdateRoleHandler(usecases.roles).Handler()}, {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}", Handler: handlers.NewDeleteRoleHandler(usecases.roles).Handler()}, + + // Permissions {Method: http.MethodPost, Path: "/access-control/permissions", Handler: handlers.NewCreatePermissionHandler(usecases.permissions).Handler()}, {Method: http.MethodGet, Path: "/access-control/permissions", Handler: handlers.NewGetAllPermissionsHandler(usecases.permissions).Handler()}, {Method: http.MethodGet, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewGetPermissionByIDHandler(usecases.permissions).Handler()}, {Method: http.MethodPatch, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewUpdatePermissionHandler(usecases.permissions).Handler()}, {Method: http.MethodDelete, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewDeletePermissionHandler(usecases.permissions).Handler()}, + + // Role permissions {Method: http.MethodPost, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewAddRolePermissionHandler(usecases.rolePermissions).Handler()}, {Method: http.MethodGet, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewGetRolePermissionsHandler(usecases.rolePermissions).Handler()}, {Method: http.MethodPut, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewReplaceRolePermissionsHandler(usecases.rolePermissions).Handler()}, {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}/permissions/{permission_id}", Handler: handlers.NewRemoveRolePermissionHandler(usecases.rolePermissions).Handler()}, - // User roles and permissions + // User roles {Method: http.MethodGet, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewGetUserRolesHandler(usecases.userRoles).Handler()}, - {Method: http.MethodGet, Path: "/access-control/users/{user_id}/authorization-profile", Handler: handlers.NewGetUserAuthorizationProfileHandler(usecases.userAccess).Handler()}, - {Method: http.MethodPost, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewAssignUserRoleHandler(usecases.userRoles).Handler()}, {Method: http.MethodPut, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewReplaceUserRolesHandler(usecases.userRoles).Handler()}, + {Method: http.MethodPost, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewAssignUserRoleHandler(usecases.userRoles).Handler()}, {Method: http.MethodDelete, Path: "/access-control/users/{user_id}/roles/{role_id}", Handler: handlers.NewRemoveUserRoleHandler(usecases.userRoles).Handler()}, - {Method: http.MethodGet, Path: "/access-control/users/{user_id}/permissions", Handler: handlers.NewGetUserEffectivePermissionsHandler(usecases.userAccess).Handler()}, + + // User permissions + {Method: http.MethodGet, Path: "/access-control/users/{user_id}/permissions", Handler: handlers.NewGetUserPermissionsHandler(usecases.userPermissions).Handler()}, + {Method: http.MethodPost, Path: "/access-control/users/{user_id}/permissions/check", Handler: handlers.NewCheckUserPermissionsHandler(usecases.userPermissions).Handler()}, } } diff --git a/plugins/access-control/services/permissions_service.go b/plugins/access-control/services/permissions_service.go index 5c08f4f4..9d910ada 100644 --- a/plugins/access-control/services/permissions_service.go +++ b/plugins/access-control/services/permissions_service.go @@ -10,12 +10,12 @@ import ( ) type PermissionsService struct { - permissionsRepo repositories.PermissionsRepository - userAccessRepo repositories.UserAccessRepository + permissionsRepo repositories.PermissionsRepository + rolePermissionsRepo repositories.RolePermissionsRepository } -func NewPermissionsService(permissionsRepo repositories.PermissionsRepository, userAccessRepo repositories.UserAccessRepository) *PermissionsService { - return &PermissionsService{permissionsRepo: permissionsRepo, userAccessRepo: userAccessRepo} +func NewPermissionsService(permissionsRepo repositories.PermissionsRepository, rolePermissionsRepo repositories.RolePermissionsRepository) *PermissionsService { + return &PermissionsService{permissionsRepo: permissionsRepo, rolePermissionsRepo: rolePermissionsRepo} } func (s *PermissionsService) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { @@ -62,6 +62,22 @@ func (s *PermissionsService) GetPermissionByID(ctx context.Context, permissionID return permission, nil } +func (s *PermissionsService) GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) { + if permissionKey == "" { + return nil, constants.ErrBadRequest + } + + permission, err := s.permissionsRepo.GetPermissionByKey(ctx, permissionKey) + if err != nil { + return nil, err + } + if permission == nil { + return nil, constants.ErrNotFound + } + + return permission, nil +} + func (s *PermissionsService) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { if permissionID == "" { return nil, constants.ErrUnprocessableEntity @@ -121,11 +137,11 @@ func (s *PermissionsService) DeletePermission(ctx context.Context, permissionID return constants.ErrBadRequest } - assignmentsCount, err := s.userAccessRepo.CountRoleAssignmentsByPermissionID(ctx, permissionID) + totalCountOfRolesByPermission, err := s.rolePermissionsRepo.CountRolesByPermission(ctx, permissionID) if err != nil { return err } - if assignmentsCount > 0 { + if totalCountOfRolesByPermission > 0 { return constants.ErrConflict } diff --git a/plugins/access-control/services/permissions_service_test.go b/plugins/access-control/services/permissions_service_test.go index 971672cb..c78921e1 100644 --- a/plugins/access-control/services/permissions_service_test.go +++ b/plugins/access-control/services/permissions_service_test.go @@ -74,12 +74,13 @@ func TestPermissionsServiceCreatePermission(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + _ = rolePermissionsRepo if tc.setup != nil { tc.setup(permissionsRepo) } - service := NewPermissionsService(permissionsRepo, userAccessRepo) + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) permission, err := service.CreatePermission(context.Background(), tc.req) if tc.wantErr == nil { if err != nil { @@ -94,7 +95,6 @@ func TestPermissionsServiceCreatePermission(t *testing.T) { } permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) }) } } @@ -129,12 +129,13 @@ func TestPermissionsServiceGetAllPermissions(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + _ = rolePermissionsRepo if tc.setup != nil { tc.setup(permissionsRepo) } - service := NewPermissionsService(permissionsRepo, userAccessRepo) + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) permissions, err := service.GetAllPermissions(context.Background()) if tc.wantErr == nil { if err != nil { @@ -148,7 +149,6 @@ func TestPermissionsServiceGetAllPermissions(t *testing.T) { } permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) }) } } @@ -191,12 +191,13 @@ func TestPermissionsServiceGetPermissionByID(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + _ = rolePermissionsRepo if tc.setup != nil { tc.setup(permissionsRepo) } - service := NewPermissionsService(permissionsRepo, userAccessRepo) + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) permission, err := service.GetPermissionByID(context.Background(), tc.id) if tc.wantErr == nil { if err != nil { @@ -210,7 +211,6 @@ func TestPermissionsServiceGetPermissionByID(t *testing.T) { } permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) }) } } @@ -292,12 +292,12 @@ func TestPermissionsServiceUpdatePermission(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} if tc.setup != nil { tc.setup(permissionsRepo) } - service := NewPermissionsService(permissionsRepo, userAccessRepo) + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) permission, err := service.UpdatePermission(context.Background(), tc.id, tc.req) if tc.wantErr == nil { if err != nil { @@ -311,7 +311,6 @@ func TestPermissionsServiceUpdatePermission(t *testing.T) { } permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) }) } } @@ -322,7 +321,7 @@ func TestPermissionsServiceDeletePermission(t *testing.T) { tests := []struct { name string id string - setup func(*accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockUserAccessRepository) + setup func(*accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) wantErr error }{ { @@ -333,7 +332,7 @@ func TestPermissionsServiceDeletePermission(t *testing.T) { { name: "not found", id: "perm-1", - setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() }, wantErr: accesscontrolconstants.ErrNotFound, @@ -341,7 +340,7 @@ func TestPermissionsServiceDeletePermission(t *testing.T) { { name: "system permission", id: "perm-1", - setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: true}, nil).Once() }, wantErr: accesscontrolconstants.ErrBadRequest, @@ -349,18 +348,18 @@ func TestPermissionsServiceDeletePermission(t *testing.T) { { name: "permission in use", id: "perm-1", - setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - userAccessRepo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(2, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(2, nil).Once() }, wantErr: accesscontrolconstants.ErrConflict, }, { name: "delete returns false", id: "perm-1", - setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - userAccessRepo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(0, nil).Once() permissionsRepo.On("DeletePermission", mock.Anything, "perm-1").Return(false, nil).Once() }, wantErr: accesscontrolconstants.ErrNotFound, @@ -368,9 +367,9 @@ func TestPermissionsServiceDeletePermission(t *testing.T) { { name: "success", id: "perm-1", - setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - userAccessRepo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(0, nil).Once() permissionsRepo.On("DeletePermission", mock.Anything, "perm-1").Return(true, nil).Once() }, }, @@ -381,12 +380,12 @@ func TestPermissionsServiceDeletePermission(t *testing.T) { t.Parallel() permissionsRepo := &accesscontroltests.MockPermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} if tc.setup != nil { - tc.setup(permissionsRepo, userAccessRepo) + tc.setup(permissionsRepo, rolePermissionsRepo) } - service := NewPermissionsService(permissionsRepo, userAccessRepo) + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) err := service.DeletePermission(context.Background(), tc.id) if tc.wantErr == nil { if err != nil { @@ -397,7 +396,7 @@ func TestPermissionsServiceDeletePermission(t *testing.T) { } permissionsRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) }) } } diff --git a/plugins/access-control/services/roles_service.go b/plugins/access-control/services/roles_service.go index ca7cd308..35ac7e25 100644 --- a/plugins/access-control/services/roles_service.go +++ b/plugins/access-control/services/roles_service.go @@ -12,11 +12,11 @@ import ( type RolesService struct { rolesRepo repositories.RolesRepository rolePermissionsRepo repositories.RolePermissionsRepository - userAccessRepo repositories.UserAccessRepository + userRolesRepo repositories.UserRolesRepository } -func NewRolesService(rolesRepo repositories.RolesRepository, rolePermissionsRepo repositories.RolePermissionsRepository, userAccessRepo repositories.UserAccessRepository) *RolesService { - return &RolesService{rolesRepo: rolesRepo, rolePermissionsRepo: rolePermissionsRepo, userAccessRepo: userAccessRepo} +func NewRolesService(rolesRepo repositories.RolesRepository, rolePermissionsRepo repositories.RolePermissionsRepository, userRolesRepo repositories.UserRolesRepository) *RolesService { + return &RolesService{rolesRepo: rolesRepo, rolePermissionsRepo: rolePermissionsRepo, userRolesRepo: userRolesRepo} } func (s *RolesService) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { @@ -47,6 +47,22 @@ func (s *RolesService) GetAllRoles(ctx context.Context) ([]types.Role, error) { return s.rolesRepo.GetAllRoles(ctx) } +func (s *RolesService) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + if roleName == "" { + return nil, constants.ErrBadRequest + } + + role, err := s.rolesRepo.GetRoleByName(ctx, roleName) + if err != nil { + return nil, err + } + if role == nil { + return nil, constants.ErrNotFound + } + + return role, nil +} + func (s *RolesService) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { if roleID == "" { return nil, constants.ErrBadRequest @@ -136,11 +152,11 @@ func (s *RolesService) DeleteRole(ctx context.Context, roleID string) error { return constants.ErrCannotUpdateSystemRole } - assignmentsCount, err := s.userAccessRepo.CountUserAssignmentsByRoleID(ctx, roleID) + totalUsersByRole, err := s.userRolesRepo.CountUsersByRole(ctx, roleID) if err != nil { return err } - if assignmentsCount > 0 { + if totalUsersByRole > 0 { return constants.ErrConflict } diff --git a/plugins/access-control/services/roles_service_test.go b/plugins/access-control/services/roles_service_test.go index 99cb2263..6f588e55 100644 --- a/plugins/access-control/services/roles_service_test.go +++ b/plugins/access-control/services/roles_service_test.go @@ -16,12 +16,12 @@ func TestRolesServiceGetRoleByID(t *testing.T) { rolesRepo := &accesscontroltests.MockRolesRepository{} rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() rolePermissionsRepo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() - service := NewRolesService(rolesRepo, rolePermissionsRepo, userAccessRepo) + service := NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo) details, err := service.GetRoleByID(context.Background(), "role-1") if err != nil { t.Fatalf("expected nil err, got %v", err) @@ -32,6 +32,70 @@ func TestRolesServiceGetRoleByID(t *testing.T) { rolesRepo.AssertExpectations(t) rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) +} + +func TestRolesServiceGetRoleByName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + roleName string + setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + wantErr error + }{ + { + name: "bad request", + roleName: "", + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "not found", + roleName: "missing", + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByName", mock.Anything, "missing").Return((*types.Role)(nil), nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "success", + roleName: "admin", + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByName", mock.Anything, "admin").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setup != nil { + tc.setup(rolesRepo, userRolesRepo) + } + + service := NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo) + role, err := service.GetRoleByName(context.Background(), tc.roleName) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + if tc.wantErr != nil { + if role != nil { + t.Fatalf("expected nil role, got %#v", role) + } + } else { + if role == nil || role.ID != "role-1" || role.Name != "admin" { + t.Fatalf("unexpected role %#v", role) + } + } + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } func TestRolesServiceDeleteRole(t *testing.T) { @@ -39,22 +103,22 @@ func TestRolesServiceDeleteRole(t *testing.T) { tests := []struct { name string - setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserAccessRepository) + setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) wantErr error }{ { name: "role in use", - setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - userAccessRepo.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(1, nil).Once() + userRolesRepo.On("CountUsersByRole", mock.Anything, "role-1").Return(1, nil).Once() }, wantErr: accesscontrolconstants.ErrConflict, }, { name: "success", - setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - userAccessRepo.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(0, nil).Once() + userRolesRepo.On("CountUsersByRole", mock.Anything, "role-1").Return(0, nil).Once() rolesRepo.On("DeleteRole", mock.Anything, "role-1").Return(true, nil).Once() }, }, @@ -66,19 +130,20 @@ func TestRolesServiceDeleteRole(t *testing.T) { rolesRepo := &accesscontroltests.MockRolesRepository{} rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} if tc.setup != nil { - tc.setup(rolesRepo, userAccessRepo) + tc.setup(rolesRepo, userRolesRepo) } - service := NewRolesService(rolesRepo, rolePermissionsRepo, userAccessRepo) + service := NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo) err := service.DeleteRole(context.Background(), "role-1") if err != tc.wantErr { t.Fatalf("expected err %v, got %v", tc.wantErr, err) } rolesRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) }) } } diff --git a/plugins/access-control/services/user_access_service.go b/plugins/access-control/services/user_access_service.go deleted file mode 100644 index a1a5ee17..00000000 --- a/plugins/access-control/services/user_access_service.go +++ /dev/null @@ -1,90 +0,0 @@ -package services - -import ( - "context" - - "github.com/Authula/authula/plugins/access-control/constants" - "github.com/Authula/authula/plugins/access-control/repositories" - "github.com/Authula/authula/plugins/access-control/types" -) - -type UserAccessService struct { - userRolesRepo repositories.UserRolesRepository - userAccessRepo repositories.UserAccessRepository -} - -func NewUserAccessService(userRolesRepo repositories.UserRolesRepository, userAccessRepo repositories.UserAccessRepository) *UserAccessService { - return &UserAccessService{userRolesRepo: userRolesRepo, userAccessRepo: userAccessRepo} -} - -func (s *UserAccessService) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - if userID == "" { - return nil, constants.ErrUnprocessableEntity - } - - return s.userAccessRepo.GetUserWithPermissionsByID(ctx, userID) -} - -func (s *UserAccessService) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { - if userID == "" { - return nil, constants.ErrUnprocessableEntity - } - - withRoles, err := s.userRolesRepo.GetUserWithRolesByID(ctx, userID) - if err != nil { - return nil, err - } - if withRoles == nil { - return nil, nil - } - - withPermissions, err := s.GetUserWithPermissionsByID(ctx, userID) - if err != nil { - return nil, err - } - - profile := &types.UserAuthorizationProfile{ - User: withRoles.User, - Roles: withRoles.Roles, - } - if withPermissions != nil { - profile.Permissions = withPermissions.Permissions - } - - return profile, nil -} - -func (s *UserAccessService) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - if userID == "" { - return nil, constants.ErrUnprocessableEntity - } - - return s.userAccessRepo.GetUserEffectivePermissions(ctx, userID) -} - -func (s *UserAccessService) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - if userID == "" { - return false, constants.ErrUnprocessableEntity - } - - permissions, err := s.GetUserEffectivePermissions(ctx, userID) - if err != nil { - return false, err - } - - granted := make(map[string]struct{}, len(permissions)) - for _, permission := range permissions { - granted[permission.PermissionKey] = struct{}{} - } - - for _, required := range requiredPermissions { - if required == "" { - continue - } - if _, ok := granted[required]; ok { - return true, nil - } - } - - return false, nil -} diff --git a/plugins/access-control/services/user_access_service_test.go b/plugins/access-control/services/user_access_service_test.go deleted file mode 100644 index a0fd6301..00000000 --- a/plugins/access-control/services/user_access_service_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package services - -import ( - "context" - "testing" - - "github.com/stretchr/testify/mock" - - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" - accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" - "github.com/Authula/authula/plugins/access-control/types" -) - -func TestUserAccessServiceHasPermissions(t *testing.T) { - t.Parallel() - - userRolesRepo := &accesscontroltests.MockUserRolesRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() - - service := NewUserAccessService(userRolesRepo, userAccessRepo) - ok, err := service.HasPermissions(context.Background(), "user-1", []string{"users.write", "users.read"}) - if err != nil { - t.Fatalf("expected nil err, got %v", err) - } - if !ok { - t.Fatal("expected permission check to pass") - } - - userAccessRepo.AssertExpectations(t) -} - -func TestUserAccessServiceGetUserAuthorizationProfile(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - userID string - setup func(*accesscontroltests.MockUserRolesRepository, *accesscontroltests.MockUserAccessRepository) - wantErr error - wantNil bool - }{ - { - name: "blank user id", - userID: "", - wantErr: accesscontrolconstants.ErrUnprocessableEntity, - }, - { - name: "composes profile", - userID: "user-1", - setup: func(userRolesRepo *accesscontroltests.MockUserRolesRepository, userAccessRepo *accesscontroltests.MockUserAccessRepository) { - userRolesRepo.On("GetUserWithRolesByID", mock.Anything, "user-1").Return(&types.UserWithRoles{Roles: []types.UserRoleInfo{{RoleID: "role-1", RoleName: "admin"}}}, nil).Once() - userAccessRepo.On("GetUserWithPermissionsByID", mock.Anything, "user-1").Return(&types.UserWithPermissions{Permissions: []types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}}, nil).Once() - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - userRolesRepo := &accesscontroltests.MockUserRolesRepository{} - userAccessRepo := &accesscontroltests.MockUserAccessRepository{} - if tc.setup != nil { - tc.setup(userRolesRepo, userAccessRepo) - } - - service := NewUserAccessService(userRolesRepo, userAccessRepo) - profile, err := service.GetUserAuthorizationProfile(context.Background(), tc.userID) - if err != tc.wantErr { - t.Fatalf("expected err %v, got %v", tc.wantErr, err) - } - if tc.wantErr == nil && profile == nil { - t.Fatal("expected profile, got nil") - } - - userRolesRepo.AssertExpectations(t) - userAccessRepo.AssertExpectations(t) - }) - } -} diff --git a/plugins/access-control/services/user_permissions_service.go b/plugins/access-control/services/user_permissions_service.go new file mode 100644 index 00000000..5a92cf15 --- /dev/null +++ b/plugins/access-control/services/user_permissions_service.go @@ -0,0 +1,33 @@ +package services + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/repositories" + "github.com/Authula/authula/plugins/access-control/types" +) + +type UserPermissionsService struct { + repo repositories.UserPermissionsRepository +} + +func NewUserPermissionsService(repo repositories.UserPermissionsRepository) *UserPermissionsService { + return &UserPermissionsService{repo: repo} +} + +func (s *UserPermissionsService) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + if userID == "" { + return nil, constants.ErrUnprocessableEntity + } + + return s.repo.GetUserPermissions(ctx, userID) +} + +func (s *UserPermissionsService) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + if userID == "" { + return false, constants.ErrUnprocessableEntity + } + + return s.repo.HasPermissions(ctx, userID, permissionKeys) +} diff --git a/plugins/access-control/services/user_permissions_service_test.go b/plugins/access-control/services/user_permissions_service_test.go new file mode 100644 index 00000000..90a6f83d --- /dev/null +++ b/plugins/access-control/services/user_permissions_service_test.go @@ -0,0 +1,139 @@ +package services + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestUserPermissionsServiceGetUserPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserPermissionsRepository) + expectedCount int + expectedErr error + }{ + { + name: "blank user id", + userID: "", + expectedErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "success", + userID: "u1", + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("GetUserPermissions", mock.Anything, "u1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() + }, + expectedCount: 1, + }, + { + name: "repo error", + userID: "u1", + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("GetUserPermissions", mock.Anything, "u1").Return(([]types.UserPermissionInfo)(nil), accesscontrolconstants.ErrNotFound).Once() + }, + expectedErr: accesscontrolconstants.ErrNotFound, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + repo := &accesscontroltests.MockUserPermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(repo) + } + + service := NewUserPermissionsService(repo) + permissions, err := service.GetUserPermissions(context.Background(), tc.userID) + if tc.expectedErr != nil { + if err != tc.expectedErr { + t.Fatalf("expected err %v, got %v", tc.expectedErr, err) + } + repo.AssertExpectations(t) + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(permissions) != tc.expectedCount { + t.Fatalf("expected %d permissions, got %d", tc.expectedCount, len(permissions)) + } + repo.AssertExpectations(t) + }) + } +} + +func TestUserPermissionsServiceHasPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + permissionKeys []string + setupMock func(*accesscontroltests.MockUserPermissionsRepository) + expectedHas bool + expectedErr error + }{ + { + name: "blank user id", + userID: "", + expectedErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "success", + userID: "u1", + permissionKeys: []string{"users.read"}, + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("HasPermissions", mock.Anything, "u1", []string{"users.read"}).Return(true, nil).Once() + }, + expectedHas: true, + }, + { + name: "repo error", + userID: "u1", + permissionKeys: []string{"users.read"}, + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("HasPermissions", mock.Anything, "u1", []string{"users.read"}).Return(false, accesscontrolconstants.ErrForbidden).Once() + }, + expectedErr: accesscontrolconstants.ErrForbidden, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + repo := &accesscontroltests.MockUserPermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(repo) + } + + service := NewUserPermissionsService(repo) + hasPermissions, err := service.HasPermissions(context.Background(), tc.userID, tc.permissionKeys) + if tc.expectedErr != nil { + if err != tc.expectedErr { + t.Fatalf("expected err %v, got %v", tc.expectedErr, err) + } + repo.AssertExpectations(t) + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if hasPermissions != tc.expectedHas { + t.Fatalf("expected hasPermissions=%v, got %v", tc.expectedHas, hasPermissions) + } + repo.AssertExpectations(t) + }) + } +} diff --git a/plugins/access-control/services/user_roles_service.go b/plugins/access-control/services/user_roles_service.go index 4efbf7d7..05d8aa1d 100644 --- a/plugins/access-control/services/user_roles_service.go +++ b/plugins/access-control/services/user_roles_service.go @@ -26,14 +26,6 @@ func (s *UserRolesService) GetUserRoles(ctx context.Context, userID string) ([]t return s.userRolesRepo.GetUserRoles(ctx, userID) } -func (s *UserRolesService) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - if userID == "" { - return nil, constants.ErrUnprocessableEntity - } - - return s.userRolesRepo.GetUserWithRolesByID(ctx, userID) -} - func (s *UserRolesService) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { if userID == "" { return constants.ErrBadRequest diff --git a/plugins/access-control/tests/test_helpers.go b/plugins/access-control/tests/test_helpers.go index 5b7ad155..ddfbcc43 100644 --- a/plugins/access-control/tests/test_helpers.go +++ b/plugins/access-control/tests/test_helpers.go @@ -43,6 +43,14 @@ func (m *MockRolesRepository) GetRoleByID(ctx context.Context, roleID string) (* return args.Get(0).(*types.Role), args.Error(1) } +func (m *MockRolesRepository) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + args := m.Called(ctx, roleName) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.Role), args.Error(1) +} + func (m *MockRolesRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { args := m.Called(ctx, roleID, name, description) return args.Bool(0), args.Error(1) @@ -73,6 +81,14 @@ func (m *MockPermissionsRepository) GetPermissionByID(ctx context.Context, permi return args.Get(0).(*types.Permission), args.Error(1) } +func (m *MockPermissionsRepository) GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) { + args := m.Called(ctx, permissionKey) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.Permission), args.Error(1) +} + func (m *MockPermissionsRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { args := m.Called(ctx, permission) return args.Error(0) @@ -115,6 +131,11 @@ func (m *MockRolePermissionsRepository) RemoveRolePermission(ctx context.Context return args.Error(0) } +func (m *MockRolePermissionsRepository) CountRolesByPermission(ctx context.Context, permissionID string) (int, error) { + args := m.Called(ctx, permissionID) + return args.Int(0), args.Error(1) +} + type MockUserRolesRepository struct { mock.Mock } @@ -127,14 +148,6 @@ func (m *MockUserRolesRepository) GetUserRoles(ctx context.Context, userID strin return args.Get(0).([]types.UserRoleInfo), args.Error(1) } -func (m *MockUserRolesRepository) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(*types.UserWithRoles), args.Error(1) -} - func (m *MockUserRolesRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { args := m.Called(ctx, userID, roleIDs, assignedByUserID) return args.Error(0) @@ -150,21 +163,16 @@ func (m *MockUserRolesRepository) RemoveUserRole(ctx context.Context, userID str return args.Error(0) } -type MockUserAccessRepository struct { - mock.Mock -} - -func (m *MockUserAccessRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { +func (m *MockUserRolesRepository) CountUsersByRole(ctx context.Context, roleID string) (int, error) { args := m.Called(ctx, roleID) return args.Int(0), args.Error(1) } -func (m *MockUserAccessRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { - args := m.Called(ctx, permissionID) - return args.Int(0), args.Error(1) +type MockUserPermissionsRepository struct { + mock.Mock } -func (m *MockUserAccessRepository) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { +func (m *MockUserPermissionsRepository) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { args := m.Called(ctx, userID) if args.Get(0) == nil { return nil, args.Error(1) @@ -172,12 +180,9 @@ func (m *MockUserAccessRepository) GetUserEffectivePermissions(ctx context.Conte return args.Get(0).([]types.UserPermissionInfo), args.Error(1) } -func (m *MockUserAccessRepository) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(*types.UserWithPermissions), args.Error(1) +func (m *MockUserPermissionsRepository) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + args := m.Called(ctx, userID, permissionKeys) + return args.Bool(0), args.Error(1) } func SetupRepoDB(t *testing.T) *bun.DB { diff --git a/plugins/access-control/types/config.go b/plugins/access-control/types/config.go new file mode 100644 index 00000000..c086c87e --- /dev/null +++ b/plugins/access-control/types/config.go @@ -0,0 +1,8 @@ +package types + +type AccessControlPluginConfig struct { + Enabled bool `json:"enabled" toml:"enabled"` + // TODO: Add a field to enable auto-cleanup of expired user roles and permissions +} + +func (config *AccessControlPluginConfig) ApplyDefaults() {} diff --git a/plugins/access-control/types/models.go b/plugins/access-control/types/models.go index 25a61120..4359497f 100644 --- a/plugins/access-control/types/models.go +++ b/plugins/access-control/types/models.go @@ -4,12 +4,8 @@ import ( "time" "github.com/uptrace/bun" - - "github.com/Authula/authula/models" ) -// Models - type Role struct { bun.BaseModel `bun:"table:access_control_roles"` @@ -50,141 +46,3 @@ type UserRole struct { AssignedAt time.Time `json:"assigned_at" bun:"column:assigned_at"` ExpiresAt *time.Time `json:"expires_at" bun:"column:expires_at"` } - -// Types - -type CreateRoleRequest struct { - Name string `json:"name"` - Description *string `json:"description,omitempty"` - IsSystem bool `json:"is_system"` -} - -type CreateRoleResponse struct { - Role *Role `json:"role"` -} - -type UpdateRoleRequest struct { - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` -} - -type UpdateRoleResponse struct { - Role *Role `json:"role"` -} - -type DeleteRoleResponse struct { - Message string `json:"message"` -} - -type CreatePermissionRequest struct { - Key string `json:"key"` - Description *string `json:"description,omitempty"` - IsSystem bool `json:"is_system"` -} - -type CreatePermissionResponse struct { - Permission *Permission `json:"permission"` -} - -type UpdatePermissionRequest struct { - Description *string `json:"description,omitempty"` -} - -type UpdatePermissionResponse struct { - Permission *Permission `json:"permission"` -} - -type DeletePermissionResponse struct { - Message string `json:"message"` -} - -type AddRolePermissionRequest struct { - PermissionID string `json:"permission_id"` -} - -type AddRolePermissionResponse struct { - Message string `json:"message"` -} - -type ReplaceRolePermissionsRequest struct { - PermissionIDs []string `json:"permission_ids"` -} - -type ReplaceRolePermissionResponse struct { - Message string `json:"message"` -} - -type RemoveRolePermissionResponse struct { - Message string `json:"message"` -} - -type AssignUserRoleRequest struct { - RoleID string `json:"role_id"` - ExpiresAt *time.Time `json:"expires_at,omitempty"` -} - -type ReplaceUserRolesRequest struct { - RoleIDs []string `json:"role_ids"` -} - -type ReplaceUserRolesResponse struct { - Message string `json:"message"` -} - -type AssignUserRoleResponse struct { - Message string `json:"message"` -} - -type RemoveUserRoleResponse struct { - Message string `json:"message"` -} - -type GetUserEffectivePermissionsResponse struct { - Permissions []UserPermissionInfo `json:"permissions"` -} - -type UserRoleInfo struct { - RoleID string `json:"role_id"` - RoleName string `json:"role_name"` - RoleDescription *string `json:"role_description,omitempty"` - AssignedByUserID *string `json:"assigned_by_user_id,omitempty"` - AssignedAt *time.Time `json:"assigned_at,omitempty"` - ExpiresAt *time.Time `json:"expires_at,omitempty"` -} - -type PermissionGrantSource struct { - RoleID string `json:"role_id"` - RoleName string `json:"role_name"` - GrantedByUserID *string `json:"granted_by_user_id,omitempty"` - GrantedAt *time.Time `json:"granted_at,omitempty"` -} - -type UserPermissionInfo struct { - PermissionID string `json:"permission_id"` - PermissionKey string `json:"permission_key"` - PermissionDescription *string `json:"permission_description,omitempty"` - GrantedByUserID *string `json:"granted_by_user_id,omitempty"` - GrantedAt *time.Time `json:"granted_at,omitempty"` - Sources []PermissionGrantSource `json:"sources,omitempty"` -} - -type UserWithRoles struct { - User models.User `json:"user"` - Roles []UserRoleInfo `json:"roles"` -} - -type UserWithPermissions struct { - User models.User `json:"user"` - Permissions []UserPermissionInfo `json:"permissions"` -} - -type UserAuthorizationProfile struct { - User models.User `json:"user"` - Roles []UserRoleInfo `json:"roles"` - Permissions []UserPermissionInfo `json:"permissions"` -} - -type RoleDetails struct { - Role Role `json:"role"` - Permissions []UserPermissionInfo `json:"permissions"` -} diff --git a/plugins/access-control/types/types.go b/plugins/access-control/types/types.go index c086c87e..d39a2248 100644 --- a/plugins/access-control/types/types.go +++ b/plugins/access-control/types/types.go @@ -1,8 +1,140 @@ package types -type AccessControlPluginConfig struct { - Enabled bool `json:"enabled" toml:"enabled"` - // TODO: Add a field to enable auto-cleanup of expired user roles and permissions +import ( + "time" + + "github.com/Authula/authula/models" +) + +type CreateRoleRequest struct { + Name string `json:"name"` + Description *string `json:"description,omitempty"` + IsSystem bool `json:"is_system"` +} + +type CreateRoleResponse struct { + Role *Role `json:"role"` +} + +type UpdateRoleRequest struct { + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` +} + +type UpdateRoleResponse struct { + Role *Role `json:"role"` +} + +type DeleteRoleResponse struct { + Message string `json:"message"` +} + +type CreatePermissionRequest struct { + Key string `json:"key"` + Description *string `json:"description,omitempty"` + IsSystem bool `json:"is_system"` +} + +type CreatePermissionResponse struct { + Permission *Permission `json:"permission"` +} + +type UpdatePermissionRequest struct { + Description *string `json:"description,omitempty"` +} + +type UpdatePermissionResponse struct { + Permission *Permission `json:"permission"` +} + +type DeletePermissionResponse struct { + Message string `json:"message"` +} + +type AddRolePermissionRequest struct { + PermissionID string `json:"permission_id"` +} + +type AddRolePermissionResponse struct { + Message string `json:"message"` +} + +type ReplaceRolePermissionsRequest struct { + PermissionIDs []string `json:"permission_ids"` +} + +type ReplaceRolePermissionResponse struct { + Message string `json:"message"` } -func (config *AccessControlPluginConfig) ApplyDefaults() {} +type RemoveRolePermissionResponse struct { + Message string `json:"message"` +} + +type AssignUserRoleRequest struct { + RoleID string `json:"role_id"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` +} + +type ReplaceUserRolesRequest struct { + RoleIDs []string `json:"role_ids"` +} + +type ReplaceUserRolesResponse struct { + Message string `json:"message"` +} + +type AssignUserRoleResponse struct { + Message string `json:"message"` +} + +type RemoveUserRoleResponse struct { + Message string `json:"message"` +} + +type CheckUserPermissionsRequest struct { + PermissionKeys []string `json:"permission_keys"` +} + +type CheckUserPermissionsResponse struct { + HasPermissions bool `json:"has_permissions"` +} + +type GetUserEffectivePermissionsResponse struct { + Permissions []UserPermissionInfo `json:"permissions"` +} + +type UserRoleInfo struct { + RoleID string `json:"role_id"` + RoleName string `json:"role_name"` + RoleDescription *string `json:"role_description,omitempty"` + AssignedByUserID *string `json:"assigned_by_user_id,omitempty"` + AssignedAt *time.Time `json:"assigned_at,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` +} + +type PermissionGrantSource struct { + RoleID string `json:"role_id"` + RoleName string `json:"role_name"` + GrantedByUserID *string `json:"granted_by_user_id,omitempty"` + GrantedAt *time.Time `json:"granted_at,omitempty"` +} + +type UserPermissionInfo struct { + PermissionID string `json:"permission_id"` + PermissionKey string `json:"permission_key"` + PermissionDescription *string `json:"permission_description,omitempty"` + GrantedByUserID *string `json:"granted_by_user_id,omitempty"` + GrantedAt *time.Time `json:"granted_at,omitempty"` + Sources []PermissionGrantSource `json:"sources,omitempty"` +} + +type UserWithPermissions struct { + User models.User `json:"user"` + Permissions []UserPermissionInfo `json:"permissions"` +} + +type RoleDetails struct { + Role Role `json:"role"` + Permissions []UserPermissionInfo `json:"permissions"` +} diff --git a/plugins/access-control/usecases/permissions_usecase.go b/plugins/access-control/usecases/permissions_usecase.go index 3d1ac21d..5f2f246b 100644 --- a/plugins/access-control/usecases/permissions_usecase.go +++ b/plugins/access-control/usecases/permissions_usecase.go @@ -27,6 +27,10 @@ func (u *PermissionsUseCase) GetPermissionByID(ctx context.Context, permissionID return u.service.GetPermissionByID(ctx, permissionID) } +func (u *PermissionsUseCase) GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) { + return u.service.GetPermissionByKey(ctx, permissionKey) +} + func (u *PermissionsUseCase) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { return u.service.UpdatePermission(ctx, permissionID, req) } diff --git a/plugins/access-control/usecases/roles_usecase.go b/plugins/access-control/usecases/roles_usecase.go index d0c9b92a..05633e0f 100644 --- a/plugins/access-control/usecases/roles_usecase.go +++ b/plugins/access-control/usecases/roles_usecase.go @@ -23,6 +23,10 @@ func (u *RolesUseCase) GetAllRoles(ctx context.Context) ([]types.Role, error) { return u.service.GetAllRoles(ctx) } +func (u *RolesUseCase) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + return u.service.GetRoleByName(ctx, roleName) +} + func (u *RolesUseCase) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { return u.service.GetRoleByID(ctx, roleID) } diff --git a/plugins/access-control/usecases/usecases.go b/plugins/access-control/usecases/usecases.go index e50792d7..42ab7885 100644 --- a/plugins/access-control/usecases/usecases.go +++ b/plugins/access-control/usecases/usecases.go @@ -11,7 +11,7 @@ type UseCases struct { permissions *PermissionsUseCase rolePermissions *RolePermissionsUseCase userRoles *UserRolesUseCase - userAccess *UserAccessUseCase + userPermissions *UserPermissionsUseCase } func NewAccessControlUseCases( @@ -19,14 +19,14 @@ func NewAccessControlUseCases( permissions *PermissionsUseCase, rolePermissions *RolePermissionsUseCase, userRoles *UserRolesUseCase, - userAccess *UserAccessUseCase, + userPermissions *UserPermissionsUseCase, ) *UseCases { return &UseCases{ roles: roles, permissions: permissions, rolePermissions: rolePermissions, userRoles: userRoles, - userAccess: userAccess, + userPermissions: userPermissions, } } @@ -46,8 +46,8 @@ func (u *UseCases) UserRolesUseCase() *UserRolesUseCase { return u.userRoles } -func (u *UseCases) UserAccessUseCase() *UserAccessUseCase { - return u.userAccess +func (u *UseCases) UserPermissionsUseCase() *UserPermissionsUseCase { + return u.userPermissions } // Roles @@ -60,6 +60,10 @@ func (u *UseCases) GetAllRoles(ctx context.Context) ([]types.Role, error) { return u.roles.GetAllRoles(ctx) } +func (u *UseCases) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + return u.roles.GetRoleByName(ctx, roleName) +} + func (u *UseCases) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { return u.roles.GetRoleByID(ctx, roleID) } @@ -130,24 +134,12 @@ func (u *UseCases) RemoveRoleFromUser(ctx context.Context, userID string, roleID return u.userRoles.RemoveRoleFromUser(ctx, userID, roleID) } -func (u *UseCases) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return u.userRoles.GetUserWithRolesByID(ctx, userID) -} - -// User Access - -func (u *UseCases) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - return u.userAccess.GetUserEffectivePermissions(ctx, userID) -} - -func (u *UseCases) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - return u.userAccess.HasPermissions(ctx, userID, requiredPermissions) -} +// User Permissions -func (u *UseCases) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - return u.userAccess.GetUserWithPermissionsByID(ctx, userID) +func (u *UseCases) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + return u.userPermissions.GetUserPermissions(ctx, userID) } -func (u *UseCases) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { - return u.userAccess.GetUserAuthorizationProfile(ctx, userID) +func (u *UseCases) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + return u.userPermissions.HasPermissions(ctx, userID, permissionKeys) } diff --git a/plugins/access-control/usecases/user_access_usecase.go b/plugins/access-control/usecases/user_access_usecase.go deleted file mode 100644 index 8a7c5acf..00000000 --- a/plugins/access-control/usecases/user_access_usecase.go +++ /dev/null @@ -1,32 +0,0 @@ -package usecases - -import ( - "context" - - "github.com/Authula/authula/plugins/access-control/services" - "github.com/Authula/authula/plugins/access-control/types" -) - -type UserAccessUseCase struct { - service *services.UserAccessService -} - -func NewUserAccessUseCase(service *services.UserAccessService) *UserAccessUseCase { - return &UserAccessUseCase{service: service} -} - -func (u *UserAccessUseCase) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - return u.service.GetUserWithPermissionsByID(ctx, userID) -} - -func (u *UserAccessUseCase) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { - return u.service.GetUserAuthorizationProfile(ctx, userID) -} - -func (u *UserAccessUseCase) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - return u.service.GetUserEffectivePermissions(ctx, userID) -} - -func (u *UserAccessUseCase) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - return u.service.HasPermissions(ctx, userID, requiredPermissions) -} diff --git a/plugins/access-control/usecases/user_permissions_usecase.go b/plugins/access-control/usecases/user_permissions_usecase.go new file mode 100644 index 00000000..c35bfe5b --- /dev/null +++ b/plugins/access-control/usecases/user_permissions_usecase.go @@ -0,0 +1,24 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/services" + "github.com/Authula/authula/plugins/access-control/types" +) + +type UserPermissionsUseCase struct { + service *services.UserPermissionsService +} + +func NewUserPermissionsUseCase(service *services.UserPermissionsService) *UserPermissionsUseCase { + return &UserPermissionsUseCase{service: service} +} + +func (u *UserPermissionsUseCase) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + return u.service.GetUserPermissions(ctx, userID) +} + +func (u *UserPermissionsUseCase) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + return u.service.HasPermissions(ctx, userID, permissionKeys) +} diff --git a/plugins/access-control/usecases/user_roles_usecase.go b/plugins/access-control/usecases/user_roles_usecase.go index e4029b52..6f4c25e5 100644 --- a/plugins/access-control/usecases/user_roles_usecase.go +++ b/plugins/access-control/usecases/user_roles_usecase.go @@ -19,10 +19,6 @@ func (u *UserRolesUseCase) GetUserRoles(ctx context.Context, userID string) ([]t return u.service.GetUserRoles(ctx, userID) } -func (u *UserRolesUseCase) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return u.service.GetUserWithRolesByID(ctx, userID) -} - func (u *UserRolesUseCase) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { return u.service.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) } From f5190fac2ce09f4d6be30c291feb15ec5811094c Mon Sep 17 00:00:00 2001 From: Tanvir Ahmed Date: Tue, 31 Mar 2026 21:19:43 +0000 Subject: [PATCH 3/6] chore: Implemented assign role hook --- plugins/access-control/hooks.go | 79 +++++++++++++ plugins/access-control/hooks_test.go | 165 +++++++++++++++++++++++++++ 2 files changed, 244 insertions(+) create mode 100644 plugins/access-control/hooks_test.go diff --git a/plugins/access-control/hooks.go b/plugins/access-control/hooks.go index ce06d9ac..0d87fa1d 100644 --- a/plugins/access-control/hooks.go +++ b/plugins/access-control/hooks.go @@ -1,10 +1,12 @@ package accesscontrol import ( + "fmt" "net/http" "github.com/Authula/authula/internal/util" "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" ) type AccessControlHookID string @@ -25,6 +27,83 @@ func (p *AccessControlPlugin) Hooks() []models.Hook { Handler: p.requireAccessControl, Order: 20, }, + { + Stage: models.HookAfter, + Handler: p.assignRoleFromContextHook, + Order: 20, + }, + } +} + +func (p *AccessControlPlugin) assignRoleFromContextHook(reqCtx *models.RequestContext) error { + if p == nil || p.Api == nil || reqCtx == nil || reqCtx.Request == nil { + return nil + } + + rawValue, ok := reqCtx.Values[models.ContextAccessControlAssignRole.String()] + if !ok || rawValue == nil { + return nil + } + + assignCtx, ok := accessControlAssignRoleContext(rawValue) + if !ok || assignCtx.UserID == "" || assignCtx.RoleName == "" { + return nil + } + + ctx := reqCtx.Request.Context() + userRoles, err := p.Api.GetUserRoles(ctx, assignCtx.UserID) + if err != nil { + p.logAssignRoleHookError("failed to load user roles", assignCtx, err) + return nil + } + + for _, userRole := range userRoles { + if userRole.RoleName == assignCtx.RoleName { + return nil + } + } + + role, err := p.Api.GetRoleByName(ctx, assignCtx.RoleName) + if err != nil { + p.logAssignRoleHookError("failed to resolve role", assignCtx, err) + return nil + } + if role == nil || role.ID == "" { + p.logAssignRoleHookError("resolved role is empty", assignCtx, fmt.Errorf("role not found")) + return nil + } + + if err := p.Api.AssignRoleToUser(ctx, assignCtx.UserID, types.AssignUserRoleRequest{RoleID: role.ID}, nil); err != nil { + p.logAssignRoleHookError("failed to assign role", assignCtx, err) + } + + return nil +} + +func (p *AccessControlPlugin) logAssignRoleHookError(message string, assignCtx models.AccessControlAssignRoleContext, err error) { + if p == nil || p.logger == nil { + return + } + + p.logger.Error( + message, + "user_id", assignCtx.UserID, + "role_name", assignCtx.RoleName, + "error", err, + ) +} + +func accessControlAssignRoleContext(value any) (models.AccessControlAssignRoleContext, bool) { + switch typed := value.(type) { + case models.AccessControlAssignRoleContext: + return typed, true + case *models.AccessControlAssignRoleContext: + if typed == nil { + return models.AccessControlAssignRoleContext{}, false + } + return *typed, true + default: + return models.AccessControlAssignRoleContext{}, false } } diff --git a/plugins/access-control/hooks_test.go b/plugins/access-control/hooks_test.go new file mode 100644 index 00000000..3e02ad52 --- /dev/null +++ b/plugins/access-control/hooks_test.go @@ -0,0 +1,165 @@ +package accesscontrol + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + authmodels "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +type hookTestLogger struct { + errors []string +} + +func (l *hookTestLogger) Debug(msg string, args ...any) {} +func (l *hookTestLogger) Info(msg string, args ...any) {} +func (l *hookTestLogger) Warn(msg string, args ...any) {} +func (l *hookTestLogger) Error(msg string, args ...any) { + l.errors = append(l.errors, msg+fmt.Sprint(args...)) +} + +func newAccessControlHookTestPlugin(logger authmodels.Logger, rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) *AccessControlPlugin { + rolePermissionsService := services.NewRolePermissionsService(nil, nil, nil) + useCases := usecases.NewAccessControlUseCases( + usecases.NewRolesUseCase(services.NewRolesService(rolesRepo, nil, userRolesRepo)), + usecases.NewPermissionsUseCase(nil), + usecases.NewRolePermissionsUseCase(rolePermissionsService), + usecases.NewUserRolesUseCase(services.NewUserRolesService(userRolesRepo, rolesRepo)), + usecases.NewUserPermissionsUseCase(nil), + ) + + return &AccessControlPlugin{ + Api: NewAPI(useCases), + logger: logger, + } +} + +func TestAccessControlPluginHooksIncludesGlobalAssignRoleHook(t *testing.T) { + t.Parallel() + + hooks := (&AccessControlPlugin{}).Hooks() + if len(hooks) != 2 { + t.Fatalf("expected 2 hooks, got %d", len(hooks)) + } + + var foundGlobal bool + for _, hook := range hooks { + if hook.Stage == authmodels.HookAfter && hook.PluginID == "" && hook.Handler != nil { + foundGlobal = true + } + } + + if !foundGlobal { + t.Fatal("expected a global HookAfter assignment hook") + } +} + +func TestAccessControlPluginAssignRoleFromContextHook(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + contextValue any + setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + wantErrors int + wantErrMessage string + }{ + { + name: "missing context is a no-op", + contextValue: nil, + }, + { + name: "already assigned role is skipped", + contextValue: authmodels.AccessControlAssignRoleContext{UserID: "user-1", RoleName: "Editor"}, + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{{RoleID: "role-1", RoleName: "Editor"}}, nil).Once() + }, + }, + { + name: "assigns role when missing", + contextValue: authmodels.AccessControlAssignRoleContext{UserID: "user-1", RoleName: "Editor"}, + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{}, nil).Once() + rolesRepo.On("GetRoleByName", mock.Anything, "Editor").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(nil).Once() + }, + }, + { + name: "role lookup failure is logged and ignored", + contextValue: authmodels.AccessControlAssignRoleContext{UserID: "user-1", RoleName: "Editor"}, + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{}, nil).Once() + rolesRepo.On("GetRoleByName", mock.Anything, "Editor").Return((*types.Role)(nil), errors.New("lookup failed")).Once() + }, + wantErrors: 1, + wantErrMessage: "failed to resolve role", + }, + { + name: "assignment failure is logged and ignored", + contextValue: authmodels.AccessControlAssignRoleContext{UserID: "user-1", RoleName: "Editor"}, + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{}, nil).Once() + rolesRepo.On("GetRoleByName", mock.Anything, "Editor").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(errors.New("assign failed")).Once() + }, + wantErrors: 1, + wantErrMessage: "failed to assign role", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setup != nil { + tc.setup(rolesRepo, userRolesRepo) + } + + logger := &hookTestLogger{} + plugin := newAccessControlHookTestPlugin(logger, rolesRepo, userRolesRepo) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + reqCtx := &authmodels.RequestContext{ + Request: req, + Values: map[string]any{}, + } + if tc.contextValue != nil { + reqCtx.Values[authmodels.ContextAccessControlAssignRole.String()] = tc.contextValue + } + + err := plugin.assignRoleFromContextHook(reqCtx) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + if len(logger.errors) != tc.wantErrors { + t.Fatalf("expected %d logged errors, got %d: %#v", tc.wantErrors, len(logger.errors), logger.errors) + } + if tc.wantErrMessage != "" && (len(logger.errors) == 0 || !containsString(logger.errors[0], tc.wantErrMessage)) { + t.Fatalf("expected log message containing %q, got %#v", tc.wantErrMessage, logger.errors) + } + + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func containsString(value string, substring string) bool { + return strings.Contains(value, substring) +} From 7b2c3d3f0a21bf331fb1f650ab40a3d8ebcfd1a2 Mon Sep 17 00:00:00 2001 From: Tanvir Ahmed Date: Wed, 1 Apr 2026 02:34:53 +0000 Subject: [PATCH 4/6] chore: Updated hooks and tests --- .github/copilot-instructions.md | 22 ++++++++++--- internal/tests/mock_objects.go | 49 ++++++++++++++++++++++++++++ internal/tests/mock_services.go | 40 ----------------------- plugins/access-control/hooks.go | 18 +++------- plugins/access-control/hooks_test.go | 40 +++-------------------- 5 files changed, 76 insertions(+), 93 deletions(-) create mode 100644 internal/tests/mock_objects.go diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index cd05b0ef..10b619d2 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -1,6 +1,10 @@ # Authula Project Guidelines -### Code Style Guide +**Authula** is an open-source authentication solution that scales with you. Embed it as a library in your Go app, or run it as a standalone auth server with any tech stack. It simplifies adding robust authentication to backend services, empowering developers to build secure applications faster. + +--- + +## Code Style Guide - Always write clean code that is easy to read and maintain. - Follow consistent naming conventions for variables, functions, and structs. @@ -14,7 +18,9 @@ - Use interfaces to define behavior and promote decoupling. Never code to implementations. When writing services, make sure they implement an interface of a repository e.g. `UserService` imports `UserRepository`. This ensures that the service can be easily tested and swapped out with different implementations if needed. - For other services, define interfaces in the `interfaces.go` file within the `services` package and implement them in separate files just like the password service is an interface which has an argon2 implementation. So now it can easily be swapped out for another implementation if needed without changing the rest of the code that depends on it. -# Testing Guidelines +--- + +## Testing Guidelines - Write unit tests for as many components as possible to ensure reliability such as repositories, services and handlers as well as plugins. - Use descriptive names for test cases to clearly indicate their purpose. @@ -26,7 +32,9 @@ - Run `make build` to ensure the project builds successfully after changes. - Then run `make test` to run all tests in the project. -### Documentation Guidelines +--- + +## Documentation Guidelines - Keep documentation up to date with code changes. - Use clear and concise language in documentation. @@ -38,7 +46,9 @@ - When updating a feature, ensure that any related documentation is also updated to reflect the changes. - Create all docs in markdown format and within a top level docs/ directory. -### Security Guidelines +--- + +## Security Guidelines - Follow best practices for secure coding to prevent vulnerabilities. - Regularly review and update dependencies to address security issues. @@ -48,6 +58,8 @@ - Take into account the principle of least privilege when designing access controls. - Always take into consideration edge cases and loopholes that could be exploited by attackers and implement safeguards against them. -### Final Notes +--- + +## Agent Skills Always follow the Agent Skills located in the folder `.github/skills/` as it contains all the skills and playbooks you need to follow to make sure you are adhering to the project guidelines and best practices. diff --git a/internal/tests/mock_objects.go b/internal/tests/mock_objects.go new file mode 100644 index 00000000..f732a4a7 --- /dev/null +++ b/internal/tests/mock_objects.go @@ -0,0 +1,49 @@ +package tests + +import ( + "context" + + "github.com/stretchr/testify/mock" + + "github.com/Authula/authula/models" +) + +type MockLogger struct{} + +func (m *MockLogger) Debug(msg string, args ...any) {} +func (m *MockLogger) Info(msg string, args ...any) {} +func (m *MockLogger) Warn(msg string, args ...any) {} +func (m *MockLogger) Error(msg string, args ...any) {} +func (m *MockLogger) Panic(msg string, args ...any) {} +func (m *MockLogger) WithField(key string, value any) models.Logger { + return m +} +func (m *MockLogger) WithFields(fields map[string]any) models.Logger { + return m +} + +type MockEventBus struct { + mock.Mock +} + +func (m *MockEventBus) Publish(ctx context.Context, event models.Event) error { + args := m.Called(ctx, event) + return args.Error(0) +} + +func (m *MockEventBus) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockEventBus) Subscribe(topic string, handler models.EventHandler) (models.SubscriptionID, error) { + args := m.Called(topic, handler) + if args.Get(0) == nil { + return 0, args.Error(1) + } + return args.Get(0).(models.SubscriptionID), args.Error(1) +} + +func (m *MockEventBus) Unsubscribe(topic string, subscriptionID models.SubscriptionID) { + m.Called(topic, subscriptionID) +} diff --git a/internal/tests/mock_services.go b/internal/tests/mock_services.go index 3b0192cc..95d7badb 100644 --- a/internal/tests/mock_services.go +++ b/internal/tests/mock_services.go @@ -302,46 +302,6 @@ func (m *MockMailerService) SendEmail(ctx context.Context, to string, subject st return args.Error(0) } -type MockLogger struct{} - -func (m *MockLogger) Debug(msg string, args ...any) {} -func (m *MockLogger) Info(msg string, args ...any) {} -func (m *MockLogger) Warn(msg string, args ...any) {} -func (m *MockLogger) Error(msg string, args ...any) {} -func (m *MockLogger) Panic(msg string, args ...any) {} -func (m *MockLogger) WithField(key string, value any) models.Logger { - return m -} -func (m *MockLogger) WithFields(fields map[string]any) models.Logger { - return m -} - -type MockEventBus struct { - mock.Mock -} - -func (m *MockEventBus) Publish(ctx context.Context, event models.Event) error { - args := m.Called(ctx, event) - return args.Error(0) -} - -func (m *MockEventBus) Close() error { - args := m.Called() - return args.Error(0) -} - -func (m *MockEventBus) Subscribe(topic string, handler models.EventHandler) (models.SubscriptionID, error) { - args := m.Called(topic, handler) - if args.Get(0) == nil { - return 0, args.Error(1) - } - return args.Get(0).(models.SubscriptionID), args.Error(1) -} - -func (m *MockEventBus) Unsubscribe(topic string, subscriptionID models.SubscriptionID) { - m.Called(topic, subscriptionID) -} - type MockServiceRegistry struct { mock.Mock } diff --git a/plugins/access-control/hooks.go b/plugins/access-control/hooks.go index 0d87fa1d..a1229cfd 100644 --- a/plugins/access-control/hooks.go +++ b/plugins/access-control/hooks.go @@ -21,25 +21,21 @@ func (id AccessControlHookID) String() string { func (p *AccessControlPlugin) Hooks() []models.Hook { return []models.Hook{ + { + Stage: models.HookAfter, + Handler: p.assignRoleFromContextHook, + Order: 20, + }, { Stage: models.HookBefore, PluginID: HookIDAccessControlEnforce.String(), Handler: p.requireAccessControl, Order: 20, }, - { - Stage: models.HookAfter, - Handler: p.assignRoleFromContextHook, - Order: 20, - }, } } func (p *AccessControlPlugin) assignRoleFromContextHook(reqCtx *models.RequestContext) error { - if p == nil || p.Api == nil || reqCtx == nil || reqCtx.Request == nil { - return nil - } - rawValue, ok := reqCtx.Values[models.ContextAccessControlAssignRole.String()] if !ok || rawValue == nil { return nil @@ -81,10 +77,6 @@ func (p *AccessControlPlugin) assignRoleFromContextHook(reqCtx *models.RequestCo } func (p *AccessControlPlugin) logAssignRoleHookError(message string, assignCtx models.AccessControlAssignRoleContext, err error) { - if p == nil || p.logger == nil { - return - } - p.logger.Error( message, "user_id", assignCtx.UserID, diff --git a/plugins/access-control/hooks_test.go b/plugins/access-control/hooks_test.go index 3e02ad52..9ead6df3 100644 --- a/plugins/access-control/hooks_test.go +++ b/plugins/access-control/hooks_test.go @@ -2,15 +2,14 @@ package accesscontrol import ( "errors" - "fmt" "net/http" "net/http/httptest" - "strings" "testing" "time" "github.com/stretchr/testify/mock" + internaltests "github.com/Authula/authula/internal/tests" authmodels "github.com/Authula/authula/models" "github.com/Authula/authula/plugins/access-control/services" accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" @@ -18,17 +17,6 @@ import ( "github.com/Authula/authula/plugins/access-control/usecases" ) -type hookTestLogger struct { - errors []string -} - -func (l *hookTestLogger) Debug(msg string, args ...any) {} -func (l *hookTestLogger) Info(msg string, args ...any) {} -func (l *hookTestLogger) Warn(msg string, args ...any) {} -func (l *hookTestLogger) Error(msg string, args ...any) { - l.errors = append(l.errors, msg+fmt.Sprint(args...)) -} - func newAccessControlHookTestPlugin(logger authmodels.Logger, rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) *AccessControlPlugin { rolePermissionsService := services.NewRolePermissionsService(nil, nil, nil) useCases := usecases.NewAccessControlUseCases( @@ -69,11 +57,9 @@ func TestAccessControlPluginAssignRoleFromContextHook(t *testing.T) { t.Parallel() tests := []struct { - name string - contextValue any - setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) - wantErrors int - wantErrMessage string + name string + contextValue any + setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) }{ { name: "missing context is a no-op", @@ -103,8 +89,6 @@ func TestAccessControlPluginAssignRoleFromContextHook(t *testing.T) { userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{}, nil).Once() rolesRepo.On("GetRoleByName", mock.Anything, "Editor").Return((*types.Role)(nil), errors.New("lookup failed")).Once() }, - wantErrors: 1, - wantErrMessage: "failed to resolve role", }, { name: "assignment failure is logged and ignored", @@ -115,8 +99,6 @@ func TestAccessControlPluginAssignRoleFromContextHook(t *testing.T) { rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(errors.New("assign failed")).Once() }, - wantErrors: 1, - wantErrMessage: "failed to assign role", }, } @@ -130,8 +112,7 @@ func TestAccessControlPluginAssignRoleFromContextHook(t *testing.T) { tc.setup(rolesRepo, userRolesRepo) } - logger := &hookTestLogger{} - plugin := newAccessControlHookTestPlugin(logger, rolesRepo, userRolesRepo) + plugin := newAccessControlHookTestPlugin(&internaltests.MockLogger{}, rolesRepo, userRolesRepo) req := httptest.NewRequest(http.MethodPost, "/test", nil) reqCtx := &authmodels.RequestContext{ @@ -147,19 +128,8 @@ func TestAccessControlPluginAssignRoleFromContextHook(t *testing.T) { t.Fatalf("expected nil error, got %v", err) } - if len(logger.errors) != tc.wantErrors { - t.Fatalf("expected %d logged errors, got %d: %#v", tc.wantErrors, len(logger.errors), logger.errors) - } - if tc.wantErrMessage != "" && (len(logger.errors) == 0 || !containsString(logger.errors[0], tc.wantErrMessage)) { - t.Fatalf("expected log message containing %q, got %#v", tc.wantErrMessage, logger.errors) - } - rolesRepo.AssertExpectations(t) userRolesRepo.AssertExpectations(t) }) } } - -func containsString(value string, substring string) bool { - return strings.Contains(value, substring) -} From afbac10cb40502b1e3cef88f0fc6b161f1da8b6b Mon Sep 17 00:00:00 2001 From: Tanvir Ahmed Date: Wed, 1 Apr 2026 02:39:19 +0000 Subject: [PATCH 5/6] chore: Removed TODO in config.go --- plugins/access-control/types/config.go | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/access-control/types/config.go b/plugins/access-control/types/config.go index c086c87e..293b3473 100644 --- a/plugins/access-control/types/config.go +++ b/plugins/access-control/types/config.go @@ -2,7 +2,6 @@ package types type AccessControlPluginConfig struct { Enabled bool `json:"enabled" toml:"enabled"` - // TODO: Add a field to enable auto-cleanup of expired user roles and permissions } func (config *AccessControlPluginConfig) ApplyDefaults() {} From ee2a93ddb5984928f75d2db895c0f6e8e1c9ab74 Mon Sep 17 00:00:00 2001 From: Tanvir Ahmed Date: Wed, 1 Apr 2026 03:00:52 +0000 Subject: [PATCH 6/6] fix: Linting errors --- .../handlers/user_roles_handlers_test.go | 27 ----------- .../repositories/user_roles_repository.go | 47 ------------------- 2 files changed, 74 deletions(-) diff --git a/plugins/access-control/handlers/user_roles_handlers_test.go b/plugins/access-control/handlers/user_roles_handlers_test.go index dad977d7..f96ca2b1 100644 --- a/plugins/access-control/handlers/user_roles_handlers_test.go +++ b/plugins/access-control/handlers/user_roles_handlers_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/mock" internaltests "github.com/Authula/authula/internal/tests" - authmodels "github.com/Authula/authula/models" "github.com/Authula/authula/plugins/access-control/constants" "github.com/Authula/authula/plugins/access-control/services" accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" @@ -419,32 +418,6 @@ func assertUserRoleInfoEqual(t *testing.T, got types.UserRoleInfo, want types.Us } } -func assertUserEqual(t *testing.T, got authmodels.User, want authmodels.User) { - t.Helper() - - if got.ID != want.ID { - t.Fatalf("expected id %q, got %q", want.ID, got.ID) - } - if got.Name != want.Name { - t.Fatalf("expected name %q, got %q", want.Name, got.Name) - } - if got.Email != want.Email { - t.Fatalf("expected email %q, got %q", want.Email, got.Email) - } - if got.EmailVerified != want.EmailVerified { - t.Fatalf("expected email_verified %v, got %v", want.EmailVerified, got.EmailVerified) - } - if string(got.Metadata) != string(want.Metadata) { - t.Fatalf("expected metadata %s, got %s", string(want.Metadata), string(got.Metadata)) - } - if !timesEqual(got.CreatedAt, want.CreatedAt) { - t.Fatalf("expected created_at %v, got %v", want.CreatedAt, got.CreatedAt) - } - if !timesEqual(got.UpdatedAt, want.UpdatedAt) { - t.Fatalf("expected updated_at %v, got %v", want.UpdatedAt, got.UpdatedAt) - } -} - func assertReplaceUserRolesResponseEqual(t *testing.T, got types.ReplaceUserRolesResponse, want types.ReplaceUserRolesResponse) { t.Helper() if got.Message != want.Message { diff --git a/plugins/access-control/repositories/user_roles_repository.go b/plugins/access-control/repositories/user_roles_repository.go index a773364d..f3fb2498 100644 --- a/plugins/access-control/repositories/user_roles_repository.go +++ b/plugins/access-control/repositories/user_roles_repository.go @@ -7,7 +7,6 @@ import ( "github.com/uptrace/bun" - "github.com/Authula/authula/models" "github.com/Authula/authula/plugins/access-control/types" ) @@ -42,19 +41,6 @@ func (r *BunUserRolesRepository) GetUserRoles(ctx context.Context, userID string return rows, nil } -type userWithRoleRow struct { - UserID string `bun:"user_id"` - UserName string `bun:"user_name"` - UserEmail string `bun:"user_email"` - EmailVerified bool `bun:"email_verified"` - Image *string - Metadata []byte - CreatedAt time.Time `bun:"created_at"` - UpdatedAt time.Time `bun:"updated_at"` - RoleID *string `bun:"role_id"` - RoleName *string `bun:"role_name"` -} - func (r *BunUserRolesRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewDelete().Model((*types.UserRole)(nil)).Where("user_id = ?", userID).Exec(ctx); err != nil { @@ -117,36 +103,3 @@ func (r *BunUserRolesRepository) CountUsersByRole(ctx context.Context, roleID st } return count, nil } - -type userRow interface { - GetUserID() string - GetUserName() string - GetUserEmail() string - GetEmailVerified() bool - GetImage() *string - GetMetadata() []byte - GetCreatedAt() time.Time - GetUpdatedAt() time.Time -} - -func mapRowToUser(row userRow) models.User { - return models.User{ - ID: row.GetUserID(), - Name: row.GetUserName(), - Email: row.GetUserEmail(), - EmailVerified: row.GetEmailVerified(), - Image: row.GetImage(), - Metadata: row.GetMetadata(), - CreatedAt: row.GetCreatedAt(), - UpdatedAt: row.GetUpdatedAt(), - } -} - -func (r userWithRoleRow) GetUserID() string { return r.UserID } -func (r userWithRoleRow) GetUserName() string { return r.UserName } -func (r userWithRoleRow) GetUserEmail() string { return r.UserEmail } -func (r userWithRoleRow) GetEmailVerified() bool { return r.EmailVerified } -func (r userWithRoleRow) GetImage() *string { return r.Image } -func (r userWithRoleRow) GetMetadata() []byte { return r.Metadata } -func (r userWithRoleRow) GetCreatedAt() time.Time { return r.CreatedAt } -func (r userWithRoleRow) GetUpdatedAt() time.Time { return r.UpdatedAt }