diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index cd05b0e..10b619d 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -1,6 +1,10 @@ # Authula Project Guidelines -### Code Style Guide +**Authula** is an open-source authentication solution that scales with you. Embed it as a library in your Go app, or run it as a standalone auth server with any tech stack. It simplifies adding robust authentication to backend services, empowering developers to build secure applications faster. + +--- + +## Code Style Guide - Always write clean code that is easy to read and maintain. - Follow consistent naming conventions for variables, functions, and structs. @@ -14,7 +18,9 @@ - Use interfaces to define behavior and promote decoupling. Never code to implementations. When writing services, make sure they implement an interface of a repository e.g. `UserService` imports `UserRepository`. This ensures that the service can be easily tested and swapped out with different implementations if needed. - For other services, define interfaces in the `interfaces.go` file within the `services` package and implement them in separate files just like the password service is an interface which has an argon2 implementation. So now it can easily be swapped out for another implementation if needed without changing the rest of the code that depends on it. -# Testing Guidelines +--- + +## Testing Guidelines - Write unit tests for as many components as possible to ensure reliability such as repositories, services and handlers as well as plugins. - Use descriptive names for test cases to clearly indicate their purpose. @@ -26,7 +32,9 @@ - Run `make build` to ensure the project builds successfully after changes. - Then run `make test` to run all tests in the project. -### Documentation Guidelines +--- + +## Documentation Guidelines - Keep documentation up to date with code changes. - Use clear and concise language in documentation. @@ -38,7 +46,9 @@ - When updating a feature, ensure that any related documentation is also updated to reflect the changes. - Create all docs in markdown format and within a top level docs/ directory. -### Security Guidelines +--- + +## Security Guidelines - Follow best practices for secure coding to prevent vulnerabilities. - Regularly review and update dependencies to address security issues. @@ -48,6 +58,8 @@ - Take into account the principle of least privilege when designing access controls. - Always take into consideration edge cases and loopholes that could be exploited by attackers and implement safeguards against them. -### Final Notes +--- + +## Agent Skills Always follow the Agent Skills located in the folder `.github/skills/` as it contains all the skills and playbooks you need to follow to make sure you are adhering to the project guidelines and best practices. diff --git a/internal/tests/mock_objects.go b/internal/tests/mock_objects.go new file mode 100644 index 0000000..f732a4a --- /dev/null +++ b/internal/tests/mock_objects.go @@ -0,0 +1,49 @@ +package tests + +import ( + "context" + + "github.com/stretchr/testify/mock" + + "github.com/Authula/authula/models" +) + +type MockLogger struct{} + +func (m *MockLogger) Debug(msg string, args ...any) {} +func (m *MockLogger) Info(msg string, args ...any) {} +func (m *MockLogger) Warn(msg string, args ...any) {} +func (m *MockLogger) Error(msg string, args ...any) {} +func (m *MockLogger) Panic(msg string, args ...any) {} +func (m *MockLogger) WithField(key string, value any) models.Logger { + return m +} +func (m *MockLogger) WithFields(fields map[string]any) models.Logger { + return m +} + +type MockEventBus struct { + mock.Mock +} + +func (m *MockEventBus) Publish(ctx context.Context, event models.Event) error { + args := m.Called(ctx, event) + return args.Error(0) +} + +func (m *MockEventBus) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockEventBus) Subscribe(topic string, handler models.EventHandler) (models.SubscriptionID, error) { + args := m.Called(topic, handler) + if args.Get(0) == nil { + return 0, args.Error(1) + } + return args.Get(0).(models.SubscriptionID), args.Error(1) +} + +func (m *MockEventBus) Unsubscribe(topic string, subscriptionID models.SubscriptionID) { + m.Called(topic, subscriptionID) +} diff --git a/internal/tests/mock_services.go b/internal/tests/mock_services.go index 3b0192c..95d7bad 100644 --- a/internal/tests/mock_services.go +++ b/internal/tests/mock_services.go @@ -302,46 +302,6 @@ func (m *MockMailerService) SendEmail(ctx context.Context, to string, subject st return args.Error(0) } -type MockLogger struct{} - -func (m *MockLogger) Debug(msg string, args ...any) {} -func (m *MockLogger) Info(msg string, args ...any) {} -func (m *MockLogger) Warn(msg string, args ...any) {} -func (m *MockLogger) Error(msg string, args ...any) {} -func (m *MockLogger) Panic(msg string, args ...any) {} -func (m *MockLogger) WithField(key string, value any) models.Logger { - return m -} -func (m *MockLogger) WithFields(fields map[string]any) models.Logger { - return m -} - -type MockEventBus struct { - mock.Mock -} - -func (m *MockEventBus) Publish(ctx context.Context, event models.Event) error { - args := m.Called(ctx, event) - return args.Error(0) -} - -func (m *MockEventBus) Close() error { - args := m.Called() - return args.Error(0) -} - -func (m *MockEventBus) Subscribe(topic string, handler models.EventHandler) (models.SubscriptionID, error) { - args := m.Called(topic, handler) - if args.Get(0) == nil { - return 0, args.Error(1) - } - return args.Get(0).(models.SubscriptionID), args.Error(1) -} - -func (m *MockEventBus) Unsubscribe(topic string, subscriptionID models.SubscriptionID) { - m.Called(topic, subscriptionID) -} - type MockServiceRegistry struct { mock.Mock } diff --git a/models/plugins_context.go b/models/plugins_context.go new file mode 100644 index 0000000..09cbd49 --- /dev/null +++ b/models/plugins_context.go @@ -0,0 +1,12 @@ +package models + +const ( + ContextAccessControlAssignRole ContextKey = "access_control.assign_role" +) + +// Access Control + +type AccessControlAssignRoleContext struct { + UserID string + RoleName string +} diff --git a/plugins/access-control/api.go b/plugins/access-control/api.go index aa024c8..189571e 100644 --- a/plugins/access-control/api.go +++ b/plugins/access-control/api.go @@ -3,43 +3,28 @@ package accesscontrol import ( "context" - "github.com/Authula/authula/plugins/access-control/repositories" "github.com/Authula/authula/plugins/access-control/types" "github.com/Authula/authula/plugins/access-control/usecases" ) type API struct { - useCases *usecases.UseCases - rolePermissionRepo repositories.RolePermissionRepository - userAccessRepo repositories.UserAccessRepository + useCases *usecases.UseCases } -func NewAPI( - useCases *usecases.UseCases, - rolePermissionRepo repositories.RolePermissionRepository, - userAccessRepo repositories.UserAccessRepository, -) *API { - return &API{ - useCases: useCases, - rolePermissionRepo: rolePermissionRepo, - userAccessRepo: userAccessRepo, - } +func NewAPI(useCases *usecases.UseCases) *API { + return &API{useCases: useCases} } -func (a *API) RolePermissionRepository() repositories.RolePermissionRepository { - return a.rolePermissionRepo -} - -func (a *API) UserAccessRepository() repositories.UserAccessRepository { - return a.userAccessRepo -} - -// Roles and permissions +// Roles func (a *API) GetAllRoles(ctx context.Context) ([]types.Role, error) { return a.useCases.GetAllRoles(ctx) } +func (a *API) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + return a.useCases.GetRoleByName(ctx, roleName) +} + func (a *API) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { return a.useCases.GetRoleByID(ctx, roleID) } @@ -56,6 +41,8 @@ func (a *API) DeleteRole(ctx context.Context, roleID string) error { return a.useCases.DeleteRole(ctx, roleID) } +// Permissions + func (a *API) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { return a.useCases.CreatePermission(ctx, req) } @@ -64,6 +51,10 @@ func (a *API) GetAllPermissions(ctx context.Context) ([]types.Permission, error) return a.useCases.GetAllPermissions(ctx) } +func (a *API) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + return a.useCases.GetPermissionByID(ctx, permissionID) +} + func (a *API) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { return a.useCases.GetRolePermissions(ctx, roleID) } @@ -76,6 +67,8 @@ func (a *API) DeletePermission(ctx context.Context, permissionID string) error { return a.useCases.DeletePermission(ctx, permissionID) } +// Role permissions + func (a *API) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { return a.useCases.AddPermissionToRole(ctx, roleID, permissionID, grantedByUserID) } @@ -88,7 +81,7 @@ func (a *API) ReplaceRolePermissions(ctx context.Context, roleID string, permiss return a.useCases.ReplaceRolePermissions(ctx, roleID, permissionIDs, grantedByUserID) } -// User roles and permissions +// User roles func (a *API) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { return a.useCases.GetUserRoles(ctx, userID) @@ -106,20 +99,12 @@ func (a *API) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []str return a.useCases.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) } -func (a *API) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - return a.useCases.GetUserEffectivePermissions(ctx, userID) -} - -// User access and permissions - -func (a *API) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - return a.useCases.HasPermissions(ctx, userID, requiredPermissions) -} +// User permissions -func (a *API) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return a.useCases.GetUserWithRolesByID(ctx, userID) +func (a *API) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + return a.useCases.GetUserPermissions(ctx, userID) } -func (a *API) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - return a.useCases.GetUserWithPermissionsByID(ctx, userID) +func (a *API) HasPermissions(ctx context.Context, userID string, permissionNames []string) (bool, error) { + return a.useCases.HasPermissions(ctx, userID, permissionNames) } diff --git a/plugins/access-control/handlers/permission_handlers.go b/plugins/access-control/handlers/permission_handlers.go new file mode 100644 index 0000000..74c0784 --- /dev/null +++ b/plugins/access-control/handlers/permission_handlers.go @@ -0,0 +1,147 @@ +package handlers + +import ( + "net/http" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +type CreatePermissionHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewCreatePermissionHandler(useCase *usecases.PermissionsUseCase) *CreatePermissionHandler { + return &CreatePermissionHandler{useCase: useCase} +} + +func (h *CreatePermissionHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + var payload types.CreatePermissionRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + permission, err := h.useCase.CreatePermission(r.Context(), payload) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, &types.CreatePermissionResponse{ + Permission: permission, + }) + } +} + +type GetAllPermissionsHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewGetAllPermissionsHandler(useCase *usecases.PermissionsUseCase) *GetAllPermissionsHandler { + return &GetAllPermissionsHandler{useCase: useCase} +} + +func (h *GetAllPermissionsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + permissions, err := h.useCase.GetAllPermissions(r.Context()) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, permissions) + } +} + +type GetPermissionByIDHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewGetPermissionByIDHandler(useCase *usecases.PermissionsUseCase) *GetPermissionByIDHandler { + return &GetPermissionByIDHandler{useCase: useCase} +} + +func (h *GetPermissionByIDHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + permissionID := r.PathValue("permission_id") + + permission, err := h.useCase.GetPermissionByID(r.Context(), permissionID) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, permission) + } +} + +type UpdatePermissionHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewUpdatePermissionHandler(useCase *usecases.PermissionsUseCase) *UpdatePermissionHandler { + return &UpdatePermissionHandler{useCase: useCase} +} + +func (h *UpdatePermissionHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + permissionID := r.PathValue("permission_id") + + var payload types.UpdatePermissionRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + permission, err := h.useCase.UpdatePermission(r.Context(), permissionID, payload) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.UpdatePermissionResponse{ + Permission: permission, + }) + } +} + +type DeletePermissionHandler struct { + useCase *usecases.PermissionsUseCase +} + +func NewDeletePermissionHandler(useCase *usecases.PermissionsUseCase) *DeletePermissionHandler { + return &DeletePermissionHandler{useCase: useCase} +} + +func (h *DeletePermissionHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + permissionID := r.PathValue("permission_id") + + if err := h.useCase.DeletePermission(r.Context(), permissionID); err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.DeletePermissionResponse{ + Message: "permission deleted", + }) + } +} diff --git a/plugins/access-control/handlers/permission_handlers_test.go b/plugins/access-control/handlers/permission_handlers_test.go new file mode 100644 index 0000000..07307c0 --- /dev/null +++ b/plugins/access-control/handlers/permission_handlers_test.go @@ -0,0 +1,558 @@ +package handlers + +import ( + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +func TestGetAllPermissionsHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "read access" + + tests := []struct { + name string + setupMock func(*accesscontroltests.MockPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetAllPermissions", mock.Anything).Return(([]types.Permission)(nil), constants.ErrForbidden).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "forbidden"}, + }, + { + name: "success", + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetAllPermissions", mock.Anything).Return([]types.Permission{ + { + ID: "perm-1", + Key: "users.read", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: "perm-2", + Key: "users.write", + Description: nil, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: []types.Permission{ + { + ID: "perm-1", + Key: "users.read", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: "perm-2", + Key: "users.write", + Description: nil, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) + handler := NewGetAllPermissionsHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/permissions", nil, nil) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[[]types.Permission](t, reqCtx) + assertPermissionsEqual(t, payload, tc.expectedBody.([]types.Permission)) + + permissionsRepo.AssertExpectations(t) + }) + } +} + +func TestCreatePermissionHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "create access" + + tests := []struct { + name string + body []byte + setupMock func(*accesscontroltests.MockPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "service error", + body: internaltests.MarshalToJSON(t, types.CreatePermissionRequest{ + Key: "users.create", + Description: description, + IsSystem: false, + }), + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("CreatePermission", mock.Anything, mock.MatchedBy(func(permission *types.Permission) bool { + return permission != nil && permission.Key == "users.create" && permission.Description != nil && *permission.Description == *description && !permission.IsSystem && permission.ID != "" + })).Return(constants.ErrBadRequest).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "success", + body: internaltests.MarshalToJSON(t, types.CreatePermissionRequest{ + Key: "users.create", + Description: description, + IsSystem: false, + }), + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("CreatePermission", mock.Anything, mock.MatchedBy(func(permission *types.Permission) bool { + return permission != nil && permission.Key == "users.create" && permission.Description != nil && *permission.Description == *description && !permission.IsSystem && permission.ID != "" + })).Run(func(args mock.Arguments) { + permission := args.Get(1).(*types.Permission) + permission.ID = "perm-1" + permission.CreatedAt = fixedTime + permission.UpdatedAt = fixedTime + }).Return(nil).Once() + }, + expectedStatus: http.StatusCreated, + expectedBody: types.CreatePermissionResponse{ + Permission: &types.Permission{ + ID: "perm-1", + Key: "users.create", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) + handler := NewCreatePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/permissions", tc.body, nil) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusCreated { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.CreatePermissionResponse](t, reqCtx) + assertCreatePermissionResponseEqual(t, payload, tc.expectedBody.(types.CreatePermissionResponse)) + + permissionsRepo.AssertExpectations(t) + }) + } +} + +func TestGetPermissionByIDHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "read access" + + tests := []struct { + name string + permissionID string + setupMock func(*accesscontroltests.MockPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + permissionID: "perm-404", + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-404").Return((*types.Permission)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + permissionID: "perm-1", + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ + ID: "perm-1", + Key: "users.read", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: &types.Permission{ + ID: "perm-1", + Key: "users.read", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) + handler := NewGetPermissionByIDHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/permissions/"+tc.permissionID, nil, nil) + req.SetPathValue("permission_id", tc.permissionID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.Permission](t, reqCtx) + assertPermissionEqual(t, payload, *tc.expectedBody.(*types.Permission)) + + permissionsRepo.AssertExpectations(t) + }) + } +} + +func TestUpdatePermissionHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + updatedDescription := "updated read access" + updatedDescriptionPtr := &updatedDescription + existingDescription := new(string) + *existingDescription = "read access" + + tests := []struct { + name string + permissionID string + body []byte + setupMock func(*accesscontroltests.MockPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + permissionID: "perm-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "service error", + permissionID: "perm-1", + body: internaltests.MarshalToJSON(t, types.UpdatePermissionRequest{Description: updatedDescriptionPtr}), + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", Description: existingDescription, IsSystem: false}, nil).Once() + m.On("UpdatePermission", mock.Anything, "perm-1", mock.MatchedBy(func(description *string) bool { + return description != nil && *description == *updatedDescriptionPtr + })).Return(false, constants.ErrBadRequest).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "success", + permissionID: "perm-1", + body: internaltests.MarshalToJSON(t, types.UpdatePermissionRequest{Description: updatedDescriptionPtr}), + setupMock: func(m *accesscontroltests.MockPermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", Description: existingDescription, IsSystem: false}, nil).Once() + m.On("UpdatePermission", mock.Anything, "perm-1", mock.MatchedBy(func(description *string) bool { + return description != nil && *description == *updatedDescriptionPtr + })).Return(true, nil).Once() + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ + ID: "perm-1", + Key: "users.read", + Description: updatedDescriptionPtr, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.UpdatePermissionResponse{ + Permission: &types.Permission{ + ID: "perm-1", + Key: "users.read", + Description: updatedDescriptionPtr, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) + handler := NewUpdatePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/permissions/"+tc.permissionID, tc.body, nil) + req.SetPathValue("permission_id", tc.permissionID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.UpdatePermissionResponse](t, reqCtx) + assertUpdatePermissionResponseEqual(t, payload, tc.expectedBody.(types.UpdatePermissionResponse)) + + permissionsRepo.AssertExpectations(t) + }) + } +} + +func TestDeletePermissionHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + permissionID string + setupMock func(*accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + permissionID: "perm-1", + setupMock: func(m *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: false}, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(0, nil).Once() + m.On("DeletePermission", mock.Anything, "perm-1").Return(false, errors.New("database error")).Once() + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: map[string]string{"message": "database error"}, + }, + { + name: "success", + permissionID: "perm-1", + setupMock: func(m *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + m.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: false}, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(0, nil).Once() + m.On("DeletePermission", mock.Anything, "perm-1").Return(true, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.DeletePermissionResponse{Message: "permission deleted"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(permissionsRepo, rolePermissionsRepo) + } + + useCase := newPermissionsUseCase(permissionsRepo, rolePermissionsRepo) + handler := NewDeletePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/permissions/"+tc.permissionID, nil, nil) + req.SetPathValue("permission_id", tc.permissionID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.DeletePermissionResponse](t, reqCtx) + assertDeletePermissionResponseEqual(t, payload, tc.expectedBody.(types.DeletePermissionResponse)) + + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } +} + +func newPermissionsUseCase(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) *usecases.PermissionsUseCase { + return usecases.NewPermissionsUseCase(services.NewPermissionsService(permissionsRepo, rolePermissionsRepo)) +} + +func assertPermissionsEqual(t *testing.T, got []types.Permission, want []types.Permission) { + t.Helper() + + if len(got) != len(want) { + t.Fatalf("expected %d permissions, got %d", len(want), len(got)) + } + + for i := range want { + assertPermissionEqual(t, got[i], want[i]) + } +} + +func assertPermissionEqual(t *testing.T, got types.Permission, want types.Permission) { + t.Helper() + + if got.ID != want.ID { + t.Fatalf("expected id %q, got %q", want.ID, got.ID) + } + if got.Key != want.Key { + t.Fatalf("expected key %q, got %q", want.Key, got.Key) + } + if got.IsSystem != want.IsSystem { + t.Fatalf("expected is_system %v, got %v", want.IsSystem, got.IsSystem) + } + if !timesEqual(got.CreatedAt, want.CreatedAt) { + t.Fatalf("expected created_at %v, got %v", want.CreatedAt, got.CreatedAt) + } + if !timesEqual(got.UpdatedAt, want.UpdatedAt) { + t.Fatalf("expected updated_at %v, got %v", want.UpdatedAt, got.UpdatedAt) + } + if !stringsEqualPtr(got.Description, want.Description) { + t.Fatalf("expected description %#v, got %#v", want.Description, got.Description) + } +} + +func assertCreatePermissionResponseEqual(t *testing.T, got types.CreatePermissionResponse, want types.CreatePermissionResponse) { + t.Helper() + + if got.Permission == nil || want.Permission == nil { + if got.Permission != want.Permission { + t.Fatalf("expected permission %#v, got %#v", want.Permission, got.Permission) + } + return + } + + assertPermissionEqual(t, *got.Permission, *want.Permission) +} + +func assertUpdatePermissionResponseEqual(t *testing.T, got types.UpdatePermissionResponse, want types.UpdatePermissionResponse) { + t.Helper() + + if got.Permission == nil || want.Permission == nil { + if got.Permission != want.Permission { + t.Fatalf("expected permission %#v, got %#v", want.Permission, got.Permission) + } + return + } + + assertPermissionEqual(t, *got.Permission, *want.Permission) +} + +func assertDeletePermissionResponseEqual(t *testing.T, got types.DeletePermissionResponse, want types.DeletePermissionResponse) { + t.Helper() + + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } +} + +func timesEqual(left, right time.Time) bool { + return left.Equal(right) +} + +func stringsEqualPtr(left, right *string) bool { + if left == nil || right == nil { + return left == right + } + return *left == *right +} diff --git a/plugins/access-control/handlers/role_handlers.go b/plugins/access-control/handlers/role_handlers.go new file mode 100644 index 0000000..dbda4f8 --- /dev/null +++ b/plugins/access-control/handlers/role_handlers.go @@ -0,0 +1,171 @@ +package handlers + +import ( + "net/http" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +type CreateRoleHandler struct { + useCase *usecases.RolesUseCase +} + +func NewCreateRoleHandler(useCase *usecases.RolesUseCase) *CreateRoleHandler { + return &CreateRoleHandler{useCase: useCase} +} + +func (h *CreateRoleHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + var payload types.CreateRoleRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + role, err := h.useCase.CreateRole(r.Context(), payload) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, &types.CreateRoleResponse{ + Role: role, + }) + } +} + +type GetAllRolesHandler struct { + useCase *usecases.RolesUseCase +} + +func NewGetAllRolesHandler(useCase *usecases.RolesUseCase) *GetAllRolesHandler { + return &GetAllRolesHandler{useCase: useCase} +} + +func (h *GetAllRolesHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + roles, err := h.useCase.GetAllRoles(r.Context()) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, roles) + } +} + +type GetRoleByNameHandler struct { + useCase *usecases.RolesUseCase +} + +func NewGetRoleByNameHandler(useCase *usecases.RolesUseCase) *GetRoleByNameHandler { + return &GetRoleByNameHandler{useCase: useCase} +} + +func (h *GetRoleByNameHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + roleName := r.PathValue("role_name") + + role, err := h.useCase.GetRoleByName(r.Context(), roleName) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, role) + } +} + +type GetRoleByIDHandler struct { + useCase *usecases.RolesUseCase +} + +func NewGetRoleByIDHandler(useCase *usecases.RolesUseCase) *GetRoleByIDHandler { + return &GetRoleByIDHandler{useCase: useCase} +} + +func (h *GetRoleByIDHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + roleID := r.PathValue("role_id") + + roleDetails, err := h.useCase.GetRoleByID(r.Context(), roleID) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, roleDetails) + } +} + +type UpdateRoleHandler struct { + useCase *usecases.RolesUseCase +} + +func NewUpdateRoleHandler(useCase *usecases.RolesUseCase) *UpdateRoleHandler { + return &UpdateRoleHandler{useCase: useCase} +} + +func (h *UpdateRoleHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + roleID := r.PathValue("role_id") + + var payload types.UpdateRoleRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + role, err := h.useCase.UpdateRole(r.Context(), roleID, payload) + if err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.UpdateRoleResponse{ + Role: role, + }) + } +} + +type DeleteRoleHandler struct { + useCase *usecases.RolesUseCase +} + +func NewDeleteRoleHandler(useCase *usecases.RolesUseCase) *DeleteRoleHandler { + return &DeleteRoleHandler{useCase: useCase} +} + +func (h *DeleteRoleHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + roleID := r.PathValue("role_id") + + if err := h.useCase.DeleteRole(r.Context(), roleID); err != nil { + respondRolePermissionError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.DeleteRoleResponse{ + Message: "deleted role", + }) + } +} diff --git a/plugins/access-control/handlers/role_handlers_test.go b/plugins/access-control/handlers/role_handlers_test.go new file mode 100644 index 0000000..96ffeda --- /dev/null +++ b/plugins/access-control/handlers/role_handlers_test.go @@ -0,0 +1,718 @@ +package handlers + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +func TestCreateRoleHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform administrator" + + tests := []struct { + name string + body []byte + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "service error", + body: internaltests.MarshalToJSON(t, types.CreateRoleRequest{ + Name: "Administrator", + Description: description, + IsSystem: true, + }), + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("CreateRole", mock.Anything, mock.MatchedBy(func(role *types.Role) bool { + return role != nil && role.Name == "Administrator" && role.Description != nil && *role.Description == *description && role.IsSystem && role.ID != "" + })).Return(constants.ErrUnauthorized).Once() + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: map[string]string{"message": "unauthorized"}, + }, + { + name: "success", + body: internaltests.MarshalToJSON(t, types.CreateRoleRequest{ + Name: "Administrator", + Description: description, + IsSystem: true, + }), + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("CreateRole", mock.Anything, mock.MatchedBy(func(role *types.Role) bool { + return role != nil && role.Name == "Administrator" && role.Description != nil && *role.Description == *description && role.IsSystem && role.ID != "" + })).Run(func(args mock.Arguments) { + role := args.Get(1).(*types.Role) + role.ID = "role-1" + role.CreatedAt = fixedTime + role.UpdatedAt = fixedTime + }).Return(nil).Once() + }, + expectedStatus: http.StatusCreated, + expectedBody: types.CreateRoleResponse{ + Role: &types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) + handler := NewCreateRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/roles", tc.body, nil) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusCreated { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.CreateRoleResponse](t, reqCtx) + assertCreateRoleResponseEqual(t, payload, tc.expectedBody.(types.CreateRoleResponse)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func TestGetAllRolesHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform administrator" + + tests := []struct { + name string + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetAllRoles", mock.Anything).Return(([]types.Role)(nil), constants.ErrForbidden).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "forbidden"}, + }, + { + name: "success", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetAllRoles", mock.Anything).Return([]types.Role{ + { + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: "role-2", + Name: "Editor", + Description: nil, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: []types.Role{ + { + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: "role-2", + Name: "Editor", + Description: nil, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) + handler := NewGetAllRolesHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles", nil, nil) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[[]types.Role](t, reqCtx) + assertRolesEqual(t, payload, tc.expectedBody.([]types.Role)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func TestGetRoleByIDHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform administrator" + grantedAt := fixedTime + grantedByUserID := new(string) + *grantedByUserID = "user-1" + + tests := []struct { + name string + roleID string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + roleID: "role-404", + setupMock: func(m *accesscontroltests.MockRolesRepository, _ *accesscontroltests.MockRolePermissionsRepository) { + m.On("GetRoleByID", mock.Anything, "role-404").Return((*types.Role)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockRolesRepository, rp *accesscontroltests.MockRolePermissionsRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + rp.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{ + { + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedByUserID: grantedByUserID, + GrantedAt: &grantedAt, + }, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.RoleDetails{ + Role: types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + Permissions: []types.UserPermissionInfo{ + { + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedByUserID: grantedByUserID, + GrantedAt: &grantedAt, + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, rolePermissionsRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) + handler := NewGetRoleByIDHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles/"+tc.roleID, nil, nil) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.RoleDetails](t, reqCtx) + assertRoleDetailsEqual(t, payload, tc.expectedBody.(types.RoleDetails)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func TestGetRoleByNameHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "platform administrator" + + tests := []struct { + name string + roleName string + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + roleName: "missing", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByName", mock.Anything, "missing").Return((*types.Role)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + roleName: "Administrator", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByName", mock.Anything, "Administrator").Return(&types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: true, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) + handler := NewGetRoleByNameHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles/by-name/"+tc.roleName, nil, nil) + req.SetPathValue("role_name", tc.roleName) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.Role](t, reqCtx) + assertRoleEqual(t, payload, tc.expectedBody.(types.Role)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func TestUpdateRoleHandler(t *testing.T) { + t.Parallel() + + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "updated administrator" + updatedName := "Administrator" + + tests := []struct { + name string + body []byte + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "service error", + body: internaltests.MarshalToJSON(t, types.UpdateRoleRequest{ + Name: &updatedName, + Description: description, + }), + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", IsSystem: true}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "cannot update system role"}, + }, + { + name: "success", + body: internaltests.MarshalToJSON(t, types.UpdateRoleRequest{ + Name: &updatedName, + Description: description, + }), + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", Description: description, IsSystem: false}, nil).Once() + m.On("UpdateRole", mock.Anything, "role-1", mock.MatchedBy(func(name *string) bool { + return name != nil && *name == "Administrator" + }), mock.MatchedBy(func(desc *string) bool { + return desc != nil && *desc == *description + })).Return(true, nil).Once() + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.UpdateRoleResponse{ + Role: &types.Role{ + ID: "role-1", + Name: "Administrator", + Description: description, + IsSystem: false, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) + handler := NewUpdateRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/roles/role-1", tc.body, nil) + req.SetPathValue("role_id", "role-1") + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.UpdateRoleResponse](t, reqCtx) + assertUpdateRoleResponseEqual(t, payload, tc.expectedBody.(types.UpdateRoleResponse)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func TestDeleteRoleHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + roleID string + setupMock func(*accesscontroltests.MockRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "service error", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", IsSystem: true}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "cannot update system role"}, + }, + { + name: "success", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockRolesRepository) { + m.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator", IsSystem: false}, nil).Once() + m.On("DeleteRole", mock.Anything, "role-1").Return(true, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.DeleteRoleResponse{Message: "deleted role"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo) + } + if tc.roleID == "role-1" && tc.expectedStatus == http.StatusOK { + userRolesRepo.On("CountUsersByRole", mock.Anything, "role-1").Return(0, nil).Once() + } + if tc.roleID == "role-1" && tc.expectedStatus == http.StatusForbidden { + userRolesRepo.On("CountUsersByRole", mock.Anything, "role-1").Return(0, nil).Maybe() + } + + useCase := newRolesUseCase(rolesRepo, rolePermissionsRepo, userRolesRepo) + handler := NewDeleteRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/roles/"+tc.roleID, nil, nil) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.DeleteRoleResponse](t, reqCtx) + assertDeleteRoleResponseEqual(t, payload, tc.expectedBody.(types.DeleteRoleResponse)) + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} + +func newRolesUseCase(rolesRepo *accesscontroltests.MockRolesRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) *usecases.RolesUseCase { + return usecases.NewRolesUseCase(services.NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo)) +} + +func assertRolesEqual(t *testing.T, got []types.Role, want []types.Role) { + t.Helper() + + if len(got) != len(want) { + t.Fatalf("expected %d roles, got %d", len(want), len(got)) + } + + for i := range want { + assertRoleEqual(t, got[i], want[i]) + } +} + +func assertRoleEqual(t *testing.T, got types.Role, want types.Role) { + t.Helper() + + if got.ID != want.ID { + t.Fatalf("expected id %q, got %q", want.ID, got.ID) + } + if got.Name != want.Name { + t.Fatalf("expected name %q, got %q", want.Name, got.Name) + } + if got.IsSystem != want.IsSystem { + t.Fatalf("expected is_system %v, got %v", want.IsSystem, got.IsSystem) + } + if !timesEqual(got.CreatedAt, want.CreatedAt) { + t.Fatalf("expected created_at %v, got %v", want.CreatedAt, got.CreatedAt) + } + if !timesEqual(got.UpdatedAt, want.UpdatedAt) { + t.Fatalf("expected updated_at %v, got %v", want.UpdatedAt, got.UpdatedAt) + } + if !stringsEqualPtr(got.Description, want.Description) { + t.Fatalf("expected description %#v, got %#v", want.Description, got.Description) + } +} + +func assertRoleDetailsEqual(t *testing.T, got types.RoleDetails, want types.RoleDetails) { + t.Helper() + + assertRoleEqual(t, got.Role, want.Role) + if len(got.Permissions) != len(want.Permissions) { + t.Fatalf("expected %d permissions, got %d", len(want.Permissions), len(got.Permissions)) + } + for i := range want.Permissions { + gotPerm := got.Permissions[i] + wantPerm := want.Permissions[i] + if gotPerm.PermissionID != wantPerm.PermissionID { + t.Fatalf("expected permission id %q, got %q", wantPerm.PermissionID, gotPerm.PermissionID) + } + if gotPerm.PermissionKey != wantPerm.PermissionKey { + t.Fatalf("expected permission key %q, got %q", wantPerm.PermissionKey, gotPerm.PermissionKey) + } + if !stringsEqualPtr(gotPerm.PermissionDescription, wantPerm.PermissionDescription) { + t.Fatalf("expected permission description %#v, got %#v", wantPerm.PermissionDescription, gotPerm.PermissionDescription) + } + if !stringsEqualPtr(gotPerm.GrantedByUserID, wantPerm.GrantedByUserID) { + t.Fatalf("expected granted_by_user_id %#v, got %#v", wantPerm.GrantedByUserID, gotPerm.GrantedByUserID) + } + if !timesEqualPtr(gotPerm.GrantedAt, wantPerm.GrantedAt) { + t.Fatalf("expected granted_at %#v, got %#v", wantPerm.GrantedAt, gotPerm.GrantedAt) + } + } +} + +func assertCreateRoleResponseEqual(t *testing.T, got types.CreateRoleResponse, want types.CreateRoleResponse) { + t.Helper() + + if got.Role == nil || want.Role == nil { + if got.Role != want.Role { + t.Fatalf("expected role %#v, got %#v", want.Role, got.Role) + } + return + } + + assertRoleEqual(t, *got.Role, *want.Role) +} + +func assertUpdateRoleResponseEqual(t *testing.T, got types.UpdateRoleResponse, want types.UpdateRoleResponse) { + t.Helper() + + if got.Role == nil || want.Role == nil { + if got.Role != want.Role { + t.Fatalf("expected role %#v, got %#v", want.Role, got.Role) + } + return + } + + assertRoleEqual(t, *got.Role, *want.Role) +} + +func assertDeleteRoleResponseEqual(t *testing.T, got types.DeleteRoleResponse, want types.DeleteRoleResponse) { + t.Helper() + + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } +} + +func timesEqualPtr(left, right *time.Time) bool { + if left == nil || right == nil { + return left == right + } + return left.Equal(*right) +} diff --git a/plugins/access-control/handlers/role_permission_handlers.go b/plugins/access-control/handlers/role_permission_handlers.go index 8b8eb27..60c28db 100644 --- a/plugins/access-control/handlers/role_permission_handlers.go +++ b/plugins/access-control/handlers/role_permission_handlers.go @@ -9,171 +9,11 @@ import ( "github.com/Authula/authula/plugins/access-control/usecases" ) -type CreateRoleHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewCreateRoleHandler(useCase usecases.RolePermissionUseCase) *CreateRoleHandler { - return &CreateRoleHandler{useCase: useCase} -} - -func (h *CreateRoleHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - - var payload types.CreateRoleRequest - if err := util.ParseJSON(r, &payload); err != nil { - reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) - reqCtx.Handled = true - return - } - - role, err := h.useCase.CreateRole(r.Context(), payload) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusCreated, &types.CreateRoleResponse{ - Role: role, - }) - } -} - -type GetAllRolesHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewGetAllRolesHandler(useCase usecases.RolePermissionUseCase) *GetAllRolesHandler { - return &GetAllRolesHandler{useCase: useCase} -} - -func (h *GetAllRolesHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - - roles, err := h.useCase.GetAllRoles(r.Context()) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, roles) - } -} - -type GetRoleByIDHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewGetRoleByIDHandler(useCase usecases.RolePermissionUseCase) *GetRoleByIDHandler { - return &GetRoleByIDHandler{useCase: useCase} -} - -func (h *GetRoleByIDHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - roleID := r.PathValue("role_id") - - roleDetails, err := h.useCase.GetRoleByID(r.Context(), roleID) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, roleDetails) - } -} - -type UpdateRoleHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewUpdateRoleHandler(useCase usecases.RolePermissionUseCase) *UpdateRoleHandler { - return &UpdateRoleHandler{useCase: useCase} -} - -func (h *UpdateRoleHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - roleID := r.PathValue("role_id") - - var payload types.UpdateRoleRequest - if err := util.ParseJSON(r, &payload); err != nil { - reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) - reqCtx.Handled = true - return - } - - role, err := h.useCase.UpdateRole(r.Context(), roleID, payload) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.UpdateRoleResponse{ - Role: role, - }) - } -} - -type DeleteRoleHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewDeleteRoleHandler(useCase usecases.RolePermissionUseCase) *DeleteRoleHandler { - return &DeleteRoleHandler{useCase: useCase} -} - -func (h *DeleteRoleHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - roleID := r.PathValue("role_id") - - if err := h.useCase.DeleteRole(r.Context(), roleID); err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.DeleteRoleResponse{ - Message: "deleted role", - }) - } -} - -type GetAllPermissionsHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewGetAllPermissionsHandler(useCase usecases.RolePermissionUseCase) *GetAllPermissionsHandler { - return &GetAllPermissionsHandler{useCase: useCase} -} - -func (h *GetAllPermissionsHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - - permissions, err := h.useCase.GetAllPermissions(r.Context()) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, permissions) - } -} - type GetRolePermissionsHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.RolePermissionsUseCase } -func NewGetRolePermissionsHandler(useCase usecases.RolePermissionUseCase) *GetRolePermissionsHandler { +func NewGetRolePermissionsHandler(useCase *usecases.RolePermissionsUseCase) *GetRolePermissionsHandler { return &GetRolePermissionsHandler{useCase: useCase} } @@ -193,101 +33,11 @@ func (h *GetRolePermissionsHandler) Handler() http.HandlerFunc { } } -type CreatePermissionHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewCreatePermissionHandler(useCase usecases.RolePermissionUseCase) *CreatePermissionHandler { - return &CreatePermissionHandler{useCase: useCase} -} - -func (h *CreatePermissionHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - - var payload types.CreatePermissionRequest - if err := util.ParseJSON(r, &payload); err != nil { - reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) - reqCtx.Handled = true - return - } - - permission, err := h.useCase.CreatePermission(r.Context(), payload) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusCreated, &types.CreatePermissionResponse{ - Permission: permission, - }) - } -} - -type UpdatePermissionHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewUpdatePermissionHandler(useCase usecases.RolePermissionUseCase) *UpdatePermissionHandler { - return &UpdatePermissionHandler{useCase: useCase} -} - -func (h *UpdatePermissionHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - permissionID := r.PathValue("permission_id") - - var payload types.UpdatePermissionRequest - if err := util.ParseJSON(r, &payload); err != nil { - reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) - reqCtx.Handled = true - return - } - - permission, err := h.useCase.UpdatePermission(r.Context(), permissionID, payload) - if err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.UpdatePermissionResponse{ - Permission: permission, - }) - } -} - -type DeletePermissionHandler struct { - useCase usecases.RolePermissionUseCase -} - -func NewDeletePermissionHandler(useCase usecases.RolePermissionUseCase) *DeletePermissionHandler { - return &DeletePermissionHandler{useCase: useCase} -} - -func (h *DeletePermissionHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - permissionID := r.PathValue("permission_id") - - if err := h.useCase.DeletePermission(r.Context(), permissionID); err != nil { - respondRolePermissionError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.DeletePermissionResponse{ - Message: "permission deleted", - }) - } -} - type AddRolePermissionHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.RolePermissionsUseCase } -func NewAddRolePermissionHandler(useCase usecases.RolePermissionUseCase) *AddRolePermissionHandler { +func NewAddRolePermissionHandler(useCase *usecases.RolePermissionsUseCase) *AddRolePermissionHandler { return &AddRolePermissionHandler{useCase: useCase} } @@ -316,10 +66,10 @@ func (h *AddRolePermissionHandler) Handler() http.HandlerFunc { } type ReplaceRolePermissionsHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.RolePermissionsUseCase } -func NewReplaceRolePermissionsHandler(useCase usecases.RolePermissionUseCase) *ReplaceRolePermissionsHandler { +func NewReplaceRolePermissionsHandler(useCase *usecases.RolePermissionsUseCase) *ReplaceRolePermissionsHandler { return &ReplaceRolePermissionsHandler{useCase: useCase} } @@ -348,10 +98,10 @@ func (h *ReplaceRolePermissionsHandler) Handler() http.HandlerFunc { } type RemoveRolePermissionHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.RolePermissionsUseCase } -func NewRemoveRolePermissionHandler(useCase usecases.RolePermissionUseCase) *RemoveRolePermissionHandler { +func NewRemoveRolePermissionHandler(useCase *usecases.RolePermissionsUseCase) *RemoveRolePermissionHandler { return &RemoveRolePermissionHandler{useCase: useCase} } @@ -372,23 +122,3 @@ func (h *RemoveRolePermissionHandler) Handler() http.HandlerFunc { }) } } - -func rolePermissionActorUserID(reqCtx *models.RequestContext) *string { - if reqCtx == nil || reqCtx.UserID == nil || *reqCtx.UserID == "" { - return nil - } - return reqCtx.UserID -} - -func respondRolePermissionError(reqCtx *models.RequestContext, err error) { - if reqCtx == nil { - return - } - - reqCtx.SetJSONResponse(mapRolePermissionErrorStatus(err), map[string]any{"message": mapHttpErrorMessage(err)}) - reqCtx.Handled = true -} - -func mapRolePermissionErrorStatus(err error) int { - return mapHttpErrorStatus(err) -} diff --git a/plugins/access-control/handlers/role_permission_handlers_test.go b/plugins/access-control/handlers/role_permission_handlers_test.go index aaded8e..10d6f1d 100644 --- a/plugins/access-control/handlers/role_permission_handlers_test.go +++ b/plugins/access-control/handlers/role_permission_handlers_test.go @@ -1,660 +1,451 @@ package handlers import ( - "errors" "net/http" "testing" + "time" "github.com/stretchr/testify/mock" internaltests "github.com/Authula/authula/internal/tests" - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" - "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" ) -func TestCreateRoleHandler(t *testing.T) { +func TestGetRolePermissionsHandler(t *testing.T) { t.Parallel() - t.Run("invalid request body", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewCreateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles", []byte("{invalid"), nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("CreateRole", mock.Anything, mock.AnythingOfType("*types.Role")).Return(accesscontrolconstants.ErrConflict).Once() - handler := NewCreateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles", internaltests.MarshalToJSON(t, types.CreateRoleRequest{Name: "admin"}), nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusConflict, "conflict") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("CreateRole", mock.Anything, mock.AnythingOfType("*types.Role")).Return(nil).Once() - handler := NewCreateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles", internaltests.MarshalToJSON(t, types.CreateRoleRequest{Name: "admin"}), nil) - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusCreated { - t.Fatalf("expected status %d, got %d", http.StatusCreated, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.CreateRoleResponse](t, reqCtx) - if payload.Role == nil { - t.Fatal("expected role key, got nil") - } - repo.AssertExpectations(t) - }) + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "read access" + grantedByUserID := new(string) + *grantedByUserID = "user-1" + + tests := []struct { + name string + roleID string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank role id", + roleID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "use case error", + roleID: "role-404", + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, _ *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-404").Return((*types.Role)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + roleID: "role-1", + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + rolePermissionsRepo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{ + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedByUserID: grantedByUserID, + GrantedAt: &fixedTime, + }}, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: []types.UserPermissionInfo{{ + PermissionID: "perm-1", + PermissionKey: "users.read", + PermissionDescription: description, + GrantedByUserID: grantedByUserID, + GrantedAt: &fixedTime, + }}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, rolePermissionsRepo) + } + + useCase := newRolePermissionsUseCase(rolesRepo, permissionsRepo, rolePermissionsRepo) + handler := NewGetRolePermissionsHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/roles/"+tc.roleID+"/permissions", nil, nil) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[[]types.UserPermissionInfo](t, reqCtx) + assertUserPermissionInfosEqual(t, payload, tc.expectedBody.([]types.UserPermissionInfo)) + + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } -func TestGetAllRolesHandler(t *testing.T) { +func TestAddRolePermissionHandler(t *testing.T) { t.Parallel() - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetAllRoles", mock.Anything).Return(([]types.Role)(nil), errors.New("internal error")).Once() - handler := NewGetAllRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles", nil, nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusInternalServerError, "internal error") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetAllRoles", mock.Anything).Return([]types.Role{{ID: "role-1", Name: "admin"}}, nil).Once() - handler := NewGetAllRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles", nil, nil) - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[[]types.Role](t, reqCtx) - if payload == nil { - t.Fatalf("expected roles key, got %v", payload) - } - repo.AssertExpectations(t) - }) + grantedByUserID := "user-1" + actorUserID := &grantedByUserID + + tests := []struct { + name string + roleID string + body []byte + userID *string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + roleID: "role-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "use case error", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-1"}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("AddRolePermission", mock.Anything, "role-1", "perm-1", (*string)(nil)).Return(constants.ErrUnauthorized).Once() + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: map[string]string{"message": "unauthorized"}, + }, + { + name: "success", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-1"}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("AddRolePermission", mock.Anything, "role-1", "perm-1", (*string)(nil)).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.AddRolePermissionResponse{Message: "permission assigned to role"}, + }, + { + name: "success with actor user id", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-2"}), + userID: actorUserID, + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-2").Return(&types.Permission{ID: "perm-2", Key: "users.write"}, nil).Once() + rolePermissionsRepo.On("AddRolePermission", mock.Anything, "role-1", "perm-2", actorUserID).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.AddRolePermissionResponse{Message: "permission assigned to role"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, permissionsRepo, rolePermissionsRepo) + } + + useCase := newRolePermissionsUseCase(rolesRepo, permissionsRepo, rolePermissionsRepo) + handler := NewAddRolePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/roles/"+tc.roleID+"/permissions", tc.body, tc.userID) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.AddRolePermissionResponse](t, reqCtx) + assertAddRolePermissionResponseEqual(t, payload, tc.expectedBody.(types.AddRolePermissionResponse)) + + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } -func TestGetRoleByIDHandler(t *testing.T) { +func TestReplaceRolePermissionsHandler(t *testing.T) { t.Parallel() - t.Run("not found", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return((*types.Role)(nil), nil).Once() - handler := NewGetRoleByIDHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/role-1", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() - handler := NewGetRoleByIDHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/role-1", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[*types.RoleDetails](t, reqCtx) - if payload == nil { - t.Fatalf("expected role details, got %v", payload) - } - repo.AssertExpectations(t) - }) + actorUserID := "user-1" + actorUserIDPtr := &actorUserID + + tests := []struct { + name string + roleID string + body []byte + userID *string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + roleID: "role-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "use case error", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-1", "perm-2"}}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-2").Return(&types.Permission{ID: "perm-2", Key: "users.write"}, nil).Once() + rolePermissionsRepo.On("ReplaceRolePermissions", mock.Anything, "role-1", []string{"perm-1", "perm-2"}, (*string)(nil)).Return(constants.ErrConflict).Once() + }, + expectedStatus: http.StatusConflict, + expectedBody: map[string]string{"message": "conflict"}, + }, + { + name: "success", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-1", "perm-2"}}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-2").Return(&types.Permission{ID: "perm-2", Key: "users.write"}, nil).Once() + rolePermissionsRepo.On("ReplaceRolePermissions", mock.Anything, "role-1", []string{"perm-1", "perm-2"}, (*string)(nil)).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.ReplaceRolePermissionResponse{Message: "role permissions replaced"}, + }, + { + name: "success with actor user id", + roleID: "role-1", + body: internaltests.MarshalToJSON(t, types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-3", "perm-4"}}), + userID: actorUserIDPtr, + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-3").Return(&types.Permission{ID: "perm-3", Key: "users.create"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-4").Return(&types.Permission{ID: "perm-4", Key: "users.update"}, nil).Once() + rolePermissionsRepo.On("ReplaceRolePermissions", mock.Anything, "role-1", []string{"perm-3", "perm-4"}, actorUserIDPtr).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.ReplaceRolePermissionResponse{Message: "role permissions replaced"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, permissionsRepo, rolePermissionsRepo) + } + + useCase := newRolePermissionsUseCase(rolesRepo, permissionsRepo, rolePermissionsRepo) + handler := NewReplaceRolePermissionsHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/roles/"+tc.roleID+"/permissions", tc.body, tc.userID) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.ReplaceRolePermissionResponse](t, reqCtx) + assertReplaceRolePermissionResponseEqual(t, payload, tc.expectedBody.(types.ReplaceRolePermissionResponse)) + + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } -func TestUpdateRoleHandler(t *testing.T) { +func TestRemoveRolePermissionHandler(t *testing.T) { t.Parallel() - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewUpdateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPatch, "/access-control/roles/role-1", []byte("{invalid"), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - name := "new-admin" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return((*types.Role)(nil), nil).Once() - handler := NewUpdateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPatch, "/access-control/roles/role-1", internaltests.MarshalToJSON(t, types.UpdateRoleRequest{Name: &name}), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("cannot update system role", func(t *testing.T) { - t.Parallel() - - name := "new-admin" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin", IsSystem: true}, nil).Once() - handler := NewUpdateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPatch, "/access-control/roles/role-1", internaltests.MarshalToJSON(t, types.UpdateRoleRequest{Name: &name}), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusForbidden, "cannot update system role") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - name := "new-admin" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("UpdateRole", mock.Anything, "role-1", &name, (*string)(nil)).Return(true, nil).Once() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: name}, nil).Once() - handler := NewUpdateRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPatch, "/access-control/roles/role-1", internaltests.MarshalToJSON(t, types.UpdateRoleRequest{Name: &name}), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.UpdateRoleResponse](t, reqCtx) - if payload.Role == nil { - t.Fatalf("expected role key, got %v", payload) - } - repo.AssertExpectations(t) - }) + tests := []struct { + name string + roleID string + permissionID string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "use case error", + roleID: "role-1", + permissionID: "perm-1", + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("RemoveRolePermission", mock.Anything, "role-1", "perm-1").Return(constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + roleID: "role-1", + permissionID: "perm-1", + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Administrator"}, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("RemoveRolePermission", mock.Anything, "role-1", "perm-1").Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.RemoveRolePermissionResponse{Message: "permission removed from role"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, permissionsRepo, rolePermissionsRepo) + } + + useCase := newRolePermissionsUseCase(rolesRepo, permissionsRepo, rolePermissionsRepo) + handler := NewRemoveRolePermissionHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/roles/"+tc.roleID+"/permissions/"+tc.permissionID, nil, nil) + req.SetPathValue("role_id", tc.roleID) + req.SetPathValue("permission_id", tc.permissionID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.RemoveRolePermissionResponse](t, reqCtx) + assertRemoveRolePermissionResponseEqual(t, payload, tc.expectedBody.(types.RemoveRolePermissionResponse)) + + rolesRepo.AssertExpectations(t) + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } } -func TestDeleteRoleHandler(t *testing.T) { - t.Parallel() - - t.Run("conflict when assigned", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(1, nil).Once() - handler := NewDeleteRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/roles/role-1", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusConflict, "conflict") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("CountUserAssignmentsByRoleID", mock.Anything, "role-1").Return(0, nil).Once() - repo.On("DeleteRole", mock.Anything, "role-1").Return(true, nil).Once() - handler := NewDeleteRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/roles/role-1", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.DeleteRoleResponse](t, reqCtx) - if payload.Message != "deleted role" { - t.Fatalf("expected deleted role message, got %v", payload.Message) - } - repo.AssertExpectations(t) - }) +func newRolePermissionsUseCase(rolesRepo *accesscontroltests.MockRolesRepository, permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) *usecases.RolePermissionsUseCase { + return usecases.NewRolePermissionsUseCase(services.NewRolePermissionsService(rolesRepo, permissionsRepo, rolePermissionsRepo)) } -func TestGetAllPermissionsHandler(t *testing.T) { - t.Parallel() - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetAllPermissions", mock.Anything).Return(([]types.Permission)(nil), accesscontrolconstants.ErrForbidden).Once() - handler := NewGetAllPermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/permissions", nil, nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusForbidden, "forbidden") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() +func assertUserPermissionInfosEqual(t *testing.T, got []types.UserPermissionInfo, want []types.UserPermissionInfo) { + t.Helper() - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetAllPermissions", mock.Anything).Return([]types.Permission{{ID: "perm-1", Key: "users.read"}}, nil).Once() - handler := NewGetAllPermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/permissions", nil, nil) + if len(got) != len(want) { + t.Fatalf("expected %d permissions, got %d", len(want), len(got)) + } - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[[]types.Permission](t, reqCtx) - if payload == nil { - t.Fatalf("expected permissions key, got %v", payload) - } - repo.AssertExpectations(t) - }) + for i := range want { + assertUserPermissionInfoEqual(t, got[i], want[i]) + } } -func TestGetRolePermissionsHandler(t *testing.T) { - t.Parallel() - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return((*types.Role)(nil), nil).Once() - handler := NewGetRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/role-1/permissions", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("invalid role id", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewGetRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/%20%20%20/permissions", nil, nil) - req.SetPathValue("role_id", " ") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "unprocessable entity") - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() - handler := NewGetRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/roles/role-1/permissions", nil, nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[[]types.UserPermissionInfo](t, reqCtx) - if len(payload) != 1 || payload[0].PermissionID != "perm-1" { - t.Fatalf("expected role permissions payload, got %v", payload) - } - repo.AssertExpectations(t) - }) -} - -func TestCreatePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewCreatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/permissions", []byte("{invalid"), nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("CreatePermission", mock.Anything, mock.AnythingOfType("*types.Permission")).Return(accesscontrolconstants.ErrBadRequest).Once() - handler := NewCreatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/permissions", internaltests.MarshalToJSON(t, types.CreatePermissionRequest{Key: "users.read"}), nil) - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusBadRequest, "bad request") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("CreatePermission", mock.Anything, mock.AnythingOfType("*types.Permission")).Return(nil).Once() - handler := NewCreatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/permissions", internaltests.MarshalToJSON(t, types.CreatePermissionRequest{Key: "users.read"}), nil) - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusCreated { - t.Fatalf("expected status %d, got %d", http.StatusCreated, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.CreatePermissionResponse](t, reqCtx) - if payload.Permission == nil { - t.Fatalf("expected permission key, got %v", payload) - } - repo.AssertExpectations(t) - }) +func assertUserPermissionInfoEqual(t *testing.T, got types.UserPermissionInfo, want types.UserPermissionInfo) { + t.Helper() + + if got.PermissionID != want.PermissionID { + t.Fatalf("expected permission id %q, got %q", want.PermissionID, got.PermissionID) + } + if got.PermissionKey != want.PermissionKey { + t.Fatalf("expected permission key %q, got %q", want.PermissionKey, got.PermissionKey) + } + if !stringsEqualPtr(got.PermissionDescription, want.PermissionDescription) { + t.Fatalf("expected permission description %#v, got %#v", want.PermissionDescription, got.PermissionDescription) + } + if !stringsEqualPtr(got.GrantedByUserID, want.GrantedByUserID) { + t.Fatalf("expected granted_by_user_id %#v, got %#v", want.GrantedByUserID, got.GrantedByUserID) + } + if !timesEqualPtr(got.GrantedAt, want.GrantedAt) { + t.Fatalf("expected granted_at %#v, got %#v", want.GrantedAt, got.GrantedAt) + } } -func TestUpdatePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewUpdatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/permissions/perm-1", []byte("{invalid"), nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - desc := "updated" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() - handler := NewUpdatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/permissions/perm-1", internaltests.MarshalToJSON(t, types.UpdatePermissionRequest{Description: &desc}), nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - desc := "updated" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - repo.On("UpdatePermission", mock.Anything, "perm-1", &desc).Return(true, nil).Once() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", Description: &desc}, nil).Once() - handler := NewUpdatePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/permissions/perm-1", internaltests.MarshalToJSON(t, types.UpdatePermissionRequest{Description: &desc}), nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.UpdatePermissionResponse](t, reqCtx) - if payload.Permission == nil { - t.Fatalf("expected permission key, got %v", payload) - } - repo.AssertExpectations(t) - }) +func assertAddRolePermissionResponseEqual(t *testing.T, got types.AddRolePermissionResponse, want types.AddRolePermissionResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } } -func TestDeletePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("in use conflict", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - repo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(1, nil).Once() - handler := NewDeletePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/permissions/perm-1", nil, nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusConflict, "conflict") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "admin.read"}, nil).Once() - repo.On("CountRoleAssignmentsByPermissionID", mock.Anything, "perm-1").Return(0, nil).Once() - repo.On("DeletePermission", mock.Anything, "perm-1").Return(true, nil).Once() - handler := NewDeletePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/permissions/perm-1", nil, nil) - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[map[string]any](t, reqCtx) - if payload["message"] != "permission deleted" { - t.Fatalf("expected permission deleted message, got %v", payload["message"]) - } - repo.AssertExpectations(t) - }) +func assertReplaceRolePermissionResponseEqual(t *testing.T, got types.ReplaceRolePermissionResponse, want types.ReplaceRolePermissionResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } } -func TestAddRolePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("invalid request body", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewAddRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles/role-1/permissions", []byte("{invalid"), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return((*types.Role)(nil), nil).Once() - handler := NewAddRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles/role-1/permissions", internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-1"}), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - actorID := "actor-1" - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() - repo.On("AddRolePermission", mock.Anything, "role-1", "perm-1", &actorID).Return(nil).Once() - handler := NewAddRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/roles/role-1/permissions", internaltests.MarshalToJSON(t, types.AddRolePermissionRequest{PermissionID: "perm-1"}), nil) - req.SetPathValue("role_id", "role-1") - reqCtx.UserID = &actorID - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[map[string]any](t, reqCtx) - if payload["message"] != "permission assigned to role" { - t.Fatalf("expected permission assigned to role message, got %v", payload["message"]) - } - repo.AssertExpectations(t) - }) -} - -func TestReplaceRolePermissionsHandler(t *testing.T) { - t.Parallel() - - t.Run("invalid payload", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewReplaceRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/roles/role-1/permissions", []byte("{invalid"), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - actorID := "actor-1" - request := types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-1"}} - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("ReplaceRolePermissions", mock.Anything, "role-1", request.PermissionIDs, &actorID).Return(accesscontrolconstants.ErrForbidden).Once() - handler := NewReplaceRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/roles/role-1/permissions", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("role_id", "role-1") - reqCtx.UserID = &actorID - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusForbidden, "forbidden") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - request := types.ReplaceRolePermissionsRequest{PermissionIDs: []string{"perm-1", "perm-1", "perm-2"}} - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("ReplaceRolePermissions", mock.Anything, "role-1", []string{"perm-1", "perm-2"}, (*string)(nil)).Return(nil).Once() - handler := NewReplaceRolePermissionsHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/roles/role-1/permissions", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[map[string]any](t, reqCtx) - if payload["message"] != "role permissions replaced" { - t.Fatalf("expected role permissions replaced message, got %v", payload["message"]) - } - repo.AssertExpectations(t) - }) -} - -func TestRemoveRolePermissionHandler(t *testing.T) { - t.Parallel() - - t.Run("use case error", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() - handler := NewRemoveRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/roles/role-1/permissions/perm-1", nil, nil) - req.SetPathValue("role_id", "role-1") - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - repo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, repo := tests.NewRolePermissionUseCaseFixture() - repo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() - repo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "admin.read"}, nil).Once() - repo.On("RemoveRolePermission", mock.Anything, "role-1", "perm-1").Return(nil).Once() - handler := NewRemoveRolePermissionHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/roles/role-1/permissions/perm-1", nil, nil) - req.SetPathValue("role_id", "role-1") - req.SetPathValue("permission_id", "perm-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[map[string]any](t, reqCtx) - if payload["message"] != "permission removed from role" { - t.Fatalf("expected permission removed from role message, got %v", payload["message"]) - } - repo.AssertExpectations(t) - }) +func assertRemoveRolePermissionResponseEqual(t *testing.T, got types.RemoveRolePermissionResponse, want types.RemoveRolePermissionResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } } diff --git a/plugins/access-control/handlers/shared_helpers.go b/plugins/access-control/handlers/shared_helpers.go new file mode 100644 index 0000000..d4f9f79 --- /dev/null +++ b/plugins/access-control/handlers/shared_helpers.go @@ -0,0 +1,35 @@ +package handlers + +import "github.com/Authula/authula/models" + +func rolePermissionActorUserID(reqCtx *models.RequestContext) *string { + if reqCtx.UserID == nil || *reqCtx.UserID == "" { + return nil + } + return reqCtx.UserID +} + +func respondRolePermissionError(reqCtx *models.RequestContext, err error) { + reqCtx.SetJSONResponse(mapRolePermissionErrorStatus(err), map[string]any{"message": mapHttpErrorMessage(err)}) + reqCtx.Handled = true +} + +func mapRolePermissionErrorStatus(err error) int { + return mapHttpErrorStatus(err) +} + +func userActorUserID(reqCtx *models.RequestContext) *string { + if reqCtx.UserID == nil || *reqCtx.UserID == "" { + return nil + } + return reqCtx.UserID +} + +func respondUserHandlerError(reqCtx *models.RequestContext, err error) { + reqCtx.SetJSONResponse(mapUserHandlerErrorStatus(err), map[string]any{"message": mapHttpErrorMessage(err)}) + reqCtx.Handled = true +} + +func mapUserHandlerErrorStatus(err error) int { + return mapHttpErrorStatus(err) +} diff --git a/plugins/access-control/handlers/user_permissions_handlers.go b/plugins/access-control/handlers/user_permissions_handlers.go new file mode 100644 index 0000000..562cbc3 --- /dev/null +++ b/plugins/access-control/handlers/user_permissions_handlers.go @@ -0,0 +1,65 @@ +package handlers + +import ( + "net/http" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +type GetUserPermissionsHandler struct { + useCase *usecases.UserPermissionsUseCase +} + +func NewGetUserPermissionsHandler(useCase *usecases.UserPermissionsUseCase) *GetUserPermissionsHandler { + return &GetUserPermissionsHandler{useCase: useCase} +} + +func (h *GetUserPermissionsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + userID := r.PathValue("user_id") + + permissions, err := h.useCase.GetUserPermissions(r.Context(), userID) + if err != nil { + respondUserHandlerError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.GetUserEffectivePermissionsResponse{Permissions: permissions}) + } +} + +type CheckUserPermissionsHandler struct { + useCase *usecases.UserPermissionsUseCase +} + +func NewCheckUserPermissionsHandler(useCase *usecases.UserPermissionsUseCase) *CheckUserPermissionsHandler { + return &CheckUserPermissionsHandler{useCase: useCase} +} + +func (h *CheckUserPermissionsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + userID := r.PathValue("user_id") + + var payload types.CheckUserPermissionsRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusUnprocessableEntity, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + allowed, err := h.useCase.HasPermissions(r.Context(), userID, payload.PermissionKeys) + if err != nil { + respondUserHandlerError(reqCtx, err) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, &types.CheckUserPermissionsResponse{HasPermissions: allowed}) + } +} diff --git a/plugins/access-control/handlers/user_permissions_handlers_test.go b/plugins/access-control/handlers/user_permissions_handlers_test.go new file mode 100644 index 0000000..424181a --- /dev/null +++ b/plugins/access-control/handlers/user_permissions_handlers_test.go @@ -0,0 +1,161 @@ +package handlers + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/mock" + + internaltests "github.com/Authula/authula/internal/tests" + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" +) + +func TestGetUserPermissionsHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank user id", + userID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "success", + userID: "u1", + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("GetUserPermissions", mock.Anything, "u1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.GetUserEffectivePermissionsResponse{Permissions: []types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}}, + }, + { + name: "repo error", + userID: "u1", + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("GetUserPermissions", mock.Anything, "u1").Return(([]types.UserPermissionInfo)(nil), accesscontrolconstants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + repo := &accesscontroltests.MockUserPermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(repo) + } + + handler := NewGetUserPermissionsHandler(newUserPermissionsUseCase(repo)) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/permissions", nil, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + repo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.GetUserEffectivePermissionsResponse](t, reqCtx) + assertUserPermissionInfosEqual(t, payload.Permissions, tc.expectedBody.(types.GetUserEffectivePermissionsResponse).Permissions) + repo.AssertExpectations(t) + }) + } +} + +func TestCheckUserPermissionsHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + body []byte + setupMock func(*accesscontroltests.MockUserPermissionsRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + userID: "u1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "success", + userID: "u1", + body: internaltests.MarshalToJSON(t, types.CheckUserPermissionsRequest{PermissionKeys: []string{"users.read", "users.write"}}), + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("HasPermissions", mock.Anything, "u1", []string{"users.read", "users.write"}).Return(true, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.CheckUserPermissionsResponse{HasPermissions: true}, + }, + { + name: "repo error", + userID: "u1", + body: internaltests.MarshalToJSON(t, types.CheckUserPermissionsRequest{PermissionKeys: []string{"users.read"}}), + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("HasPermissions", mock.Anything, "u1", []string{"users.read"}).Return(false, accesscontrolconstants.ErrForbidden).Once() + }, + expectedStatus: http.StatusForbidden, + expectedBody: map[string]string{"message": "forbidden"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + repo := &accesscontroltests.MockUserPermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(repo) + } + + handler := NewCheckUserPermissionsHandler(newUserPermissionsUseCase(repo)) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/users/"+tc.userID+"/permissions/check", tc.body, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + repo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.CheckUserPermissionsResponse](t, reqCtx) + if payload != tc.expectedBody.(types.CheckUserPermissionsResponse) { + t.Fatalf("unexpected response: %#v", payload) + } + repo.AssertExpectations(t) + }) + } +} + +func newUserPermissionsUseCase(repo *accesscontroltests.MockUserPermissionsRepository) *usecases.UserPermissionsUseCase { + return usecases.NewUserPermissionsUseCase(services.NewUserPermissionsService(repo)) +} diff --git a/plugins/access-control/handlers/user_roles_handlers.go b/plugins/access-control/handlers/user_roles_handlers.go index 5b2ea59..a3aa1b2 100644 --- a/plugins/access-control/handlers/user_roles_handlers.go +++ b/plugins/access-control/handlers/user_roles_handlers.go @@ -10,10 +10,10 @@ import ( ) type GetUserRolesHandler struct { - useCase usecases.UserRolesUseCase + useCase *usecases.UserRolesUseCase } -func NewGetUserRolesHandler(useCase usecases.UserRolesUseCase) *GetUserRolesHandler { +func NewGetUserRolesHandler(useCase *usecases.UserRolesUseCase) *GetUserRolesHandler { return &GetUserRolesHandler{ useCase: useCase, } @@ -36,10 +36,10 @@ func (h *GetUserRolesHandler) Handler() http.HandlerFunc { } type ReplaceUserRolesHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.UserRolesUseCase } -func NewReplaceUserRolesHandler(useCase usecases.RolePermissionUseCase) *ReplaceUserRolesHandler { +func NewReplaceUserRolesHandler(useCase *usecases.UserRolesUseCase) *ReplaceUserRolesHandler { return &ReplaceUserRolesHandler{ useCase: useCase, } @@ -68,10 +68,10 @@ func (h *ReplaceUserRolesHandler) Handler() http.HandlerFunc { } type AssignUserRoleHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.UserRolesUseCase } -func NewAssignUserRoleHandler(useCase usecases.RolePermissionUseCase) *AssignUserRoleHandler { +func NewAssignUserRoleHandler(useCase *usecases.UserRolesUseCase) *AssignUserRoleHandler { return &AssignUserRoleHandler{useCase: useCase} } @@ -98,10 +98,10 @@ func (h *AssignUserRoleHandler) Handler() http.HandlerFunc { } type RemoveUserRoleHandler struct { - useCase usecases.RolePermissionUseCase + useCase *usecases.UserRolesUseCase } -func NewRemoveUserRoleHandler(useCase usecases.RolePermissionUseCase) *RemoveUserRoleHandler { +func NewRemoveUserRoleHandler(useCase *usecases.UserRolesUseCase) *RemoveUserRoleHandler { return &RemoveUserRoleHandler{useCase: useCase} } @@ -120,49 +120,3 @@ func (h *RemoveUserRoleHandler) Handler() http.HandlerFunc { reqCtx.SetJSONResponse(http.StatusOK, &types.RemoveUserRoleResponse{Message: "role removed"}) } } - -type GetUserEffectivePermissionsHandler struct { - useCase usecases.UserRolesUseCase -} - -func NewGetUserEffectivePermissionsHandler(useCase usecases.UserRolesUseCase) *GetUserEffectivePermissionsHandler { - return &GetUserEffectivePermissionsHandler{ - useCase: useCase, - } -} - -func (h *GetUserEffectivePermissionsHandler) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - reqCtx, _ := models.GetRequestContext(ctx) - userID := r.PathValue("user_id") - - permissions, err := h.useCase.GetUserEffectivePermissions(r.Context(), userID) - if err != nil { - respondUserHandlerError(reqCtx, err) - return - } - - reqCtx.SetJSONResponse(http.StatusOK, &types.GetUserEffectivePermissionsResponse{Permissions: permissions}) - } -} - -func userActorUserID(reqCtx *models.RequestContext) *string { - if reqCtx == nil || reqCtx.UserID == nil || *reqCtx.UserID == "" { - return nil - } - return reqCtx.UserID -} - -func respondUserHandlerError(reqCtx *models.RequestContext, err error) { - if reqCtx == nil { - return - } - - reqCtx.SetJSONResponse(mapUserHandlerErrorStatus(err), map[string]any{"message": mapHttpErrorMessage(err)}) - reqCtx.Handled = true -} - -func mapUserHandlerErrorStatus(err error) int { - return mapHttpErrorStatus(err) -} diff --git a/plugins/access-control/handlers/user_roles_handlers_test.go b/plugins/access-control/handlers/user_roles_handlers_test.go index 20d12a5..f96ca2b 100644 --- a/plugins/access-control/handlers/user_roles_handlers_test.go +++ b/plugins/access-control/handlers/user_roles_handlers_test.go @@ -8,279 +8,433 @@ import ( "github.com/stretchr/testify/mock" internaltests "github.com/Authula/authula/internal/tests" - accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" - "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" + "github.com/Authula/authula/plugins/access-control/usecases" ) func TestGetUserRolesHandler(t *testing.T) { t.Parallel() - t.Run("missing user id", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewUserRolesUseCaseFixture() - handler := NewGetUserRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users//roles", nil, nil) - req.SetPathValue("user_id", " ") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "unprocessable entity") - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, accessRepo := tests.NewUserRolesUseCaseFixture() - accessRepo.On("GetUserRoles", mock.Anything, "user-1").Return(([]types.UserRoleInfo)(nil), accesscontrolconstants.ErrNotFound).Once() - handler := NewGetUserRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users/user-1/roles", nil, nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - accessRepo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, accessRepo := tests.NewUserRolesUseCaseFixture() - expiresAt := time.Now().UTC().Add(time.Hour) - accessRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{{RoleID: "role-1", RoleName: "admin", ExpiresAt: &expiresAt}}, nil).Once() - handler := NewGetUserRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users/user-1/roles", nil, nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[[]types.UserRoleInfo](t, reqCtx) - if len(payload) != 1 { - t.Fatalf("expected 1 role, got %d", len(payload)) - } - accessRepo.AssertExpectations(t) - }) + fixedTime := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) + description := new(string) + *description = "editor role" + assignedByUserID := new(string) + *assignedByUserID = "user-2" + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank user id", + userID: "", + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "unprocessable entity"}, + }, + { + name: "use case error", + userID: "user-404", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("GetUserRoles", mock.Anything, "user-404").Return(([]types.UserRoleInfo)(nil), constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + userID: "user-1", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{{ + RoleID: "role-1", + RoleName: "Editor", + RoleDescription: description, + AssignedByUserID: assignedByUserID, + AssignedAt: &fixedTime, + ExpiresAt: &fixedTime, + }}, nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: []types.UserRoleInfo{{ + RoleID: "role-1", + RoleName: "Editor", + RoleDescription: description, + AssignedByUserID: assignedByUserID, + AssignedAt: &fixedTime, + ExpiresAt: &fixedTime, + }}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + rolesRepo := &accesscontroltests.MockRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(userRolesRepo) + } + + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewGetUserRolesHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/users/"+tc.userID+"/roles", nil, nil) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[[]types.UserRoleInfo](t, reqCtx) + assertUserRoleInfosEqual(t, payload, tc.expectedBody.([]types.UserRoleInfo)) + + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + }) + } } func TestReplaceUserRolesHandler(t *testing.T) { t.Parallel() - t.Run("invalid json", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewReplaceUserRolesHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/users/user-1/roles", []byte("{invalid"), nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - request := types.ReplaceUserRolesRequest{RoleIDs: []string{"role-1"}} - actorID := "actor-1" - roleRepo.On("ReplaceUserRoles", mock.Anything, "user-1", request.RoleIDs, &actorID).Return(accesscontrolconstants.ErrForbidden).Once() - handler := NewReplaceUserRolesHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/users/user-1/roles", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("user_id", "user-1") - reqCtx.UserID = &actorID - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusForbidden, "forbidden") - roleRepo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - request := types.ReplaceUserRolesRequest{RoleIDs: []string{"role-1", "role-2"}} - roleRepo.On("ReplaceUserRoles", mock.Anything, "user-1", request.RoleIDs, (*string)(nil)).Return(nil).Once() - handler := NewReplaceUserRolesHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/access-control/users/user-1/roles", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.ReplaceUserRolesResponse](t, reqCtx) - if payload.Message != "user roles replaced" { - t.Fatalf("expected message user roles replaced, got %v", payload.Message) - } - roleRepo.AssertExpectations(t) - }) + actorUserID := "user-1" + actorUserIDPtr := &actorUserID + + tests := []struct { + name string + userID string + body []byte + userIDPtr *string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + userID: "user-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "blank user id", + userID: "", + body: internaltests.MarshalToJSON(t, types.ReplaceUserRolesRequest{RoleIDs: []string{"role-1"}}), + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "success", + userID: "user-1", + body: internaltests.MarshalToJSON(t, types.ReplaceUserRolesRequest{RoleIDs: []string{"role-1", "role-2"}}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-2").Return(&types.Role{ID: "role-2", Name: "Viewer"}, nil).Once() + userRolesRepo.On("ReplaceUserRoles", mock.Anything, "user-1", []string{"role-1", "role-2"}, (*string)(nil)).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.ReplaceUserRolesResponse{Message: "user roles replaced"}, + }, + { + name: "success with actor user id", + userID: "user-1", + userIDPtr: actorUserIDPtr, + body: internaltests.MarshalToJSON(t, types.ReplaceUserRolesRequest{RoleIDs: []string{"role-3", "role-4"}}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-3").Return(&types.Role{ID: "role-3", Name: "Reviewer"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-4").Return(&types.Role{ID: "role-4", Name: "Commenter"}, nil).Once() + userRolesRepo.On("ReplaceUserRoles", mock.Anything, "user-1", []string{"role-3", "role-4"}, actorUserIDPtr).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.ReplaceUserRolesResponse{Message: "user roles replaced"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, userRolesRepo) + } + + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewReplaceUserRolesHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPut, "/users/"+tc.userID+"/roles", tc.body, tc.userIDPtr) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.ReplaceUserRolesResponse](t, reqCtx) + assertReplaceUserRolesResponseEqual(t, payload, tc.expectedBody.(types.ReplaceUserRolesResponse)) + + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } } func TestAssignUserRoleHandler(t *testing.T) { t.Parallel() - t.Run("invalid json", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewRolePermissionUseCaseFixture() - handler := NewAssignUserRoleHandler(useCase) - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/users/user-1/roles", []byte("{invalid"), nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "invalid request body") - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - request := types.AssignUserRoleRequest{RoleID: "role-1"} - actorID := "actor-1" - roleRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", &actorID, (*time.Time)(nil)).Return(accesscontrolconstants.ErrBadRequest).Once() - handler := NewAssignUserRoleHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/users/user-1/roles", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("user_id", "user-1") - reqCtx.UserID = &actorID - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusBadRequest, "bad request") - roleRepo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - request := types.AssignUserRoleRequest{RoleID: "role-1"} - roleRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(nil).Once() - handler := NewAssignUserRoleHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/access-control/users/user-1/roles", internaltests.MarshalToJSON(t, request), nil) - req.SetPathValue("user_id", "user-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.AssignUserRoleResponse](t, reqCtx) - if payload.Message != "role assigned" { - t.Fatalf("expected role assigned message, got %v", payload.Message) - } - roleRepo.AssertExpectations(t) - }) + futureTime := time.Date(2030, 3, 29, 13, 0, 0, 0, time.UTC) + actorUserID := "user-1" + actorUserIDPtr := &actorUserID + + tests := []struct { + name string + userID string + body []byte + userIDPtr *string + setupMock func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "invalid request body", + userID: "user-1", + body: []byte("{invalid json"), + expectedStatus: http.StatusUnprocessableEntity, + expectedBody: map[string]string{"message": "invalid request body"}, + }, + { + name: "blank user id", + userID: "", + body: internaltests.MarshalToJSON(t, types.AssignUserRoleRequest{RoleID: "role-1"}), + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "use case error", + userID: "user-1", + body: internaltests.MarshalToJSON(t, types.AssignUserRoleRequest{RoleID: "role-1"}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(constants.ErrUnauthorized).Once() + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: map[string]string{"message": "unauthorized"}, + }, + { + name: "success", + userID: "user-1", + body: internaltests.MarshalToJSON(t, types.AssignUserRoleRequest{RoleID: "role-1", ExpiresAt: &futureTime}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), &futureTime).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.AssignUserRoleResponse{Message: "role assigned"}, + }, + { + name: "success with actor user id", + userID: "user-1", + userIDPtr: actorUserIDPtr, + body: internaltests.MarshalToJSON(t, types.AssignUserRoleRequest{RoleID: "role-2"}), + setupMock: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-2").Return(&types.Role{ID: "role-2", Name: "Reviewer"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-2", actorUserIDPtr, (*time.Time)(nil)).Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.AssignUserRoleResponse{Message: "role assigned"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(rolesRepo, userRolesRepo) + } + + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewAssignUserRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/users/"+tc.userID+"/roles", tc.body, tc.userIDPtr) + req.SetPathValue("user_id", tc.userID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.AssignUserRoleResponse](t, reqCtx) + assertAssignUserRoleResponseEqual(t, payload, tc.expectedBody.(types.AssignUserRoleResponse)) + + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } } func TestRemoveUserRoleHandler(t *testing.T) { t.Parallel() - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - roleRepo.On("RemoveUserRole", mock.Anything, "user-1", "role-1").Return(accesscontrolconstants.ErrNotFound).Once() - handler := NewRemoveUserRoleHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/users/user-1/roles/role-1", nil, nil) - req.SetPathValue("user_id", "user-1") - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusNotFound, "not found") - roleRepo.AssertExpectations(t) - }) - - t.Run("success", func(t *testing.T) { - t.Parallel() - - useCase, roleRepo := tests.NewRolePermissionUseCaseFixture() - roleRepo.On("RemoveUserRole", mock.Anything, "user-1", "role-1").Return(nil).Once() - handler := NewRemoveUserRoleHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/access-control/users/user-1/roles/role-1", nil, nil) - req.SetPathValue("user_id", "user-1") - req.SetPathValue("role_id", "role-1") - - handler.Handler()(w, req) - - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.RemoveUserRoleResponse](t, reqCtx) - if payload.Message != "role removed" { - t.Fatalf("expected role removed message, got %v", payload.Message) - } - roleRepo.AssertExpectations(t) - }) + tests := []struct { + name string + userID string + roleID string + setupMock func(*accesscontroltests.MockUserRolesRepository) + expectedStatus int + expectedBody any + }{ + { + name: "blank role id", + userID: "user-1", + roleID: "", + expectedStatus: http.StatusBadRequest, + expectedBody: map[string]string{"message": "bad request"}, + }, + { + name: "use case error", + userID: "user-1", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("RemoveUserRole", mock.Anything, "user-1", "role-1").Return(constants.ErrNotFound).Once() + }, + expectedStatus: http.StatusNotFound, + expectedBody: map[string]string{"message": "not found"}, + }, + { + name: "success", + userID: "user-1", + roleID: "role-1", + setupMock: func(m *accesscontroltests.MockUserRolesRepository) { + m.On("RemoveUserRole", mock.Anything, "user-1", "role-1").Return(nil).Once() + }, + expectedStatus: http.StatusOK, + expectedBody: types.RemoveUserRoleResponse{Message: "role removed"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setupMock != nil { + tc.setupMock(userRolesRepo) + } + + rolesRepo := &accesscontroltests.MockRolesRepository{} + useCase := newUserRolesUseCase(rolesRepo, userRolesRepo) + handler := NewRemoveUserRoleHandler(useCase) + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodDelete, "/users/"+tc.userID+"/roles/"+tc.roleID, nil, nil) + req.SetPathValue("user_id", tc.userID) + req.SetPathValue("role_id", tc.roleID) + + handler.Handler()(w, req) + + if tc.expectedStatus != http.StatusOK { + internaltests.AssertErrorMessage(t, reqCtx, tc.expectedStatus, tc.expectedBody.(map[string]string)["message"]) + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + return + } + + if reqCtx.ResponseStatus != tc.expectedStatus { + t.Fatalf("expected status %d, got %d", tc.expectedStatus, reqCtx.ResponseStatus) + } + + payload := internaltests.DecodeResponseJSON[types.RemoveUserRoleResponse](t, reqCtx) + assertRemoveUserRoleResponseEqual(t, payload, tc.expectedBody.(types.RemoveUserRoleResponse)) + + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + }) + } } -func TestGetUserEffectivePermissionsHandler(t *testing.T) { - t.Parallel() - - t.Run("missing user id", func(t *testing.T) { - t.Parallel() - - useCase, _ := tests.NewUserRolesUseCaseFixture() - handler := NewGetUserEffectivePermissionsHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users//permissions", nil, nil) - req.SetPathValue("user_id", "") - - handler.Handler()(w, req) - - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnprocessableEntity, "unprocessable entity") - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - - useCase, accessRepo := tests.NewUserRolesUseCaseFixture() - accessRepo.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return(([]types.UserPermissionInfo)(nil), accesscontrolconstants.ErrUnauthorized).Once() - handler := NewGetUserEffectivePermissionsHandler(useCase) - - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users/user-1/permissions", nil, nil) - req.SetPathValue("user_id", "user-1") +func newUserRolesUseCase(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) *usecases.UserRolesUseCase { + return usecases.NewUserRolesUseCase(services.NewUserRolesService(userRolesRepo, rolesRepo)) +} - handler.Handler()(w, req) +func assertUserRoleInfosEqual(t *testing.T, got []types.UserRoleInfo, want []types.UserRoleInfo) { + t.Helper() - internaltests.AssertErrorMessage(t, reqCtx, http.StatusUnauthorized, "unauthorized") - accessRepo.AssertExpectations(t) - }) + if len(got) != len(want) { + t.Fatalf("expected %d roles, got %d", len(want), len(got)) + } - t.Run("success", func(t *testing.T) { - t.Parallel() + for i := range want { + assertUserRoleInfoEqual(t, got[i], want[i]) + } +} - useCase, accessRepo := tests.NewUserRolesUseCaseFixture() - accessRepo.On("GetUserEffectivePermissions", mock.Anything, "user-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "admin.read"}}, nil).Once() - handler := NewGetUserEffectivePermissionsHandler(useCase) +func assertUserRoleInfoEqual(t *testing.T, got types.UserRoleInfo, want types.UserRoleInfo) { + t.Helper() + + if got.RoleID != want.RoleID || got.RoleName != want.RoleName { + t.Fatalf("unexpected role info: %#v", got) + } + if !stringsEqualPtr(got.RoleDescription, want.RoleDescription) { + t.Fatalf("unexpected role description: %#v", got) + } + if !stringsEqualPtr(got.AssignedByUserID, want.AssignedByUserID) { + t.Fatalf("unexpected assigned_by_user_id: %#v", got) + } + if !timesEqualPtr(got.AssignedAt, want.AssignedAt) { + t.Fatalf("unexpected assigned_at: %#v", got) + } + if !timesEqualPtr(got.ExpiresAt, want.ExpiresAt) { + t.Fatalf("unexpected expires_at: %#v", got) + } +} - req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodGet, "/access-control/users/user-1/permissions", nil, nil) - req.SetPathValue("user_id", "user-1") +func assertReplaceUserRolesResponseEqual(t *testing.T, got types.ReplaceUserRolesResponse, want types.ReplaceUserRolesResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } +} - handler.Handler()(w, req) +func assertAssignUserRoleResponseEqual(t *testing.T, got types.AssignUserRoleResponse, want types.AssignUserRoleResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } +} - if reqCtx.ResponseStatus != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, reqCtx.ResponseStatus) - } - payload := internaltests.DecodeResponseJSON[types.GetUserEffectivePermissionsResponse](t, reqCtx) - if len(payload.Permissions) != 1 { - t.Fatalf("expected 1 permission, got %d", len(payload.Permissions)) - } - accessRepo.AssertExpectations(t) - }) +func assertRemoveUserRoleResponseEqual(t *testing.T, got types.RemoveUserRoleResponse, want types.RemoveUserRoleResponse) { + t.Helper() + if got.Message != want.Message { + t.Fatalf("expected message %q, got %q", want.Message, got.Message) + } } diff --git a/plugins/access-control/hooks.go b/plugins/access-control/hooks.go index ce06d9a..a1229cf 100644 --- a/plugins/access-control/hooks.go +++ b/plugins/access-control/hooks.go @@ -1,10 +1,12 @@ package accesscontrol import ( + "fmt" "net/http" "github.com/Authula/authula/internal/util" "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/types" ) type AccessControlHookID string @@ -19,6 +21,11 @@ func (id AccessControlHookID) String() string { func (p *AccessControlPlugin) Hooks() []models.Hook { return []models.Hook{ + { + Stage: models.HookAfter, + Handler: p.assignRoleFromContextHook, + Order: 20, + }, { Stage: models.HookBefore, PluginID: HookIDAccessControlEnforce.String(), @@ -28,6 +35,70 @@ func (p *AccessControlPlugin) Hooks() []models.Hook { } } +func (p *AccessControlPlugin) assignRoleFromContextHook(reqCtx *models.RequestContext) error { + rawValue, ok := reqCtx.Values[models.ContextAccessControlAssignRole.String()] + if !ok || rawValue == nil { + return nil + } + + assignCtx, ok := accessControlAssignRoleContext(rawValue) + if !ok || assignCtx.UserID == "" || assignCtx.RoleName == "" { + return nil + } + + ctx := reqCtx.Request.Context() + userRoles, err := p.Api.GetUserRoles(ctx, assignCtx.UserID) + if err != nil { + p.logAssignRoleHookError("failed to load user roles", assignCtx, err) + return nil + } + + for _, userRole := range userRoles { + if userRole.RoleName == assignCtx.RoleName { + return nil + } + } + + role, err := p.Api.GetRoleByName(ctx, assignCtx.RoleName) + if err != nil { + p.logAssignRoleHookError("failed to resolve role", assignCtx, err) + return nil + } + if role == nil || role.ID == "" { + p.logAssignRoleHookError("resolved role is empty", assignCtx, fmt.Errorf("role not found")) + return nil + } + + if err := p.Api.AssignRoleToUser(ctx, assignCtx.UserID, types.AssignUserRoleRequest{RoleID: role.ID}, nil); err != nil { + p.logAssignRoleHookError("failed to assign role", assignCtx, err) + } + + return nil +} + +func (p *AccessControlPlugin) logAssignRoleHookError(message string, assignCtx models.AccessControlAssignRoleContext, err error) { + p.logger.Error( + message, + "user_id", assignCtx.UserID, + "role_name", assignCtx.RoleName, + "error", err, + ) +} + +func accessControlAssignRoleContext(value any) (models.AccessControlAssignRoleContext, bool) { + switch typed := value.(type) { + case models.AccessControlAssignRoleContext: + return typed, true + case *models.AccessControlAssignRoleContext: + if typed == nil { + return models.AccessControlAssignRoleContext{}, false + } + return *typed, true + default: + return models.AccessControlAssignRoleContext{}, false + } +} + func (p *AccessControlPlugin) requireAccessControl(reqCtx *models.RequestContext) error { ctx := reqCtx.Request.Context() diff --git a/plugins/access-control/hooks_test.go b/plugins/access-control/hooks_test.go index e06c944..9ead6df 100644 --- a/plugins/access-control/hooks_test.go +++ b/plugins/access-control/hooks_test.go @@ -5,200 +5,131 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "github.com/Authula/authula/models" - "github.com/Authula/authula/plugins/access-control/tests" + internaltests "github.com/Authula/authula/internal/tests" + authmodels "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/access-control/services" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" "github.com/Authula/authula/plugins/access-control/usecases" ) -func newHookTestPluginWithUserAccessRepoMock(t *testing.T) (*AccessControlPlugin, interface { - On(string, ...any) *mock.Call - AssertExpectations(mock.TestingT) bool -}) { - t.Helper() - - userAccessUseCase, userAccessRepo := tests.NewUserRolesUseCaseFixture() - useCases := usecases.NewAccessControlUseCases(usecases.RolePermissionUseCase{}, userAccessUseCase) - plugin := &AccessControlPlugin{Api: &API{useCases: useCases}} - - return plugin, userAccessRepo +func newAccessControlHookTestPlugin(logger authmodels.Logger, rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) *AccessControlPlugin { + rolePermissionsService := services.NewRolePermissionsService(nil, nil, nil) + useCases := usecases.NewAccessControlUseCases( + usecases.NewRolesUseCase(services.NewRolesService(rolesRepo, nil, userRolesRepo)), + usecases.NewPermissionsUseCase(nil), + usecases.NewRolePermissionsUseCase(rolePermissionsService), + usecases.NewUserRolesUseCase(services.NewUserRolesService(userRolesRepo, rolesRepo)), + usecases.NewUserPermissionsUseCase(nil), + ) + + return &AccessControlPlugin{ + Api: NewAPI(useCases), + logger: logger, + } } -func TestRequireAccessControlHook(t *testing.T) { +func TestAccessControlPluginHooksIncludesGlobalAssignRoleHook(t *testing.T) { t.Parallel() - t.Run("unauthorized without user", func(t *testing.T) { - t.Parallel() - - plugin := &AccessControlPlugin{} - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{Request: req} - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.True(t, reqCtx.Handled, "expected request to be handled") - require.Equal(t, http.StatusUnauthorized, reqCtx.ResponseStatus) - }) - - t.Run("unauthorized with empty user id", func(t *testing.T) { - t.Parallel() + hooks := (&AccessControlPlugin{}).Hooks() + if len(hooks) != 2 { + t.Fatalf("expected 2 hooks, got %d", len(hooks)) + } - plugin := &AccessControlPlugin{} - userID := "" - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - UserID: &userID, + var foundGlobal bool + for _, hook := range hooks { + if hook.Stage == authmodels.HookAfter && hook.PluginID == "" && hook.Handler != nil { + foundGlobal = true } + } - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.True(t, reqCtx.Handled, "expected request to be handled") - require.Equal(t, http.StatusUnauthorized, reqCtx.ResponseStatus) - }) - - t.Run("opt in skips when no permissions metadata", func(t *testing.T) { - t.Parallel() - - plugin := &AccessControlPlugin{} - userID := "user-1" - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{}, - }, - UserID: &userID, - } + if !foundGlobal { + t.Fatal("expected a global HookAfter assignment hook") + } +} - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.False(t, reqCtx.Handled, "expected request not to be handled when no permissions metadata is set") - }) - - t.Run("internal error when HasPermissions fails", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return(nil, errors.New("database error")).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{"admin"}, - }, - }, - UserID: &userID, - } +func TestAccessControlPluginAssignRoleFromContextHook(t *testing.T) { + t.Parallel() - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.True(t, reqCtx.Handled) - require.Equal(t, http.StatusInternalServerError, reqCtx.ResponseStatus) - userAccessRepo.AssertExpectations(t) - }) - - t.Run("forbidden when user lacks permissions", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "profiles.read"}}, nil).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{"users.read"}, - }, + tests := []struct { + name string + contextValue any + setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + }{ + { + name: "missing context is a no-op", + contextValue: nil, + }, + { + name: "already assigned role is skipped", + contextValue: authmodels.AccessControlAssignRoleContext{UserID: "user-1", RoleName: "Editor"}, + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{{RoleID: "role-1", RoleName: "Editor"}}, nil).Once() }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.True(t, reqCtx.Handled, "expected request to be handled") - require.Equal(t, http.StatusForbidden, reqCtx.ResponseStatus) - userAccessRepo.AssertExpectations(t) - }) - - t.Run("allowed when user has permissions", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return([]types.UserPermissionInfo{{PermissionKey: "users.read"}}, nil).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{"users.read"}, - }, + }, + { + name: "assigns role when missing", + contextValue: authmodels.AccessControlAssignRoleContext{UserID: "user-1", RoleName: "Editor"}, + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{}, nil).Once() + rolesRepo.On("GetRoleByName", mock.Anything, "Editor").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(nil).Once() }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.False(t, reqCtx.Handled, "expected request not to be handled when user has permissions") - userAccessRepo.AssertExpectations(t) - }) - - t.Run("multiple permissions allow when user has any one required permission", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return([]types.UserPermissionInfo{{PermissionKey: "users.read"}}, nil).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{"users.read", "users.write"}, - }, + }, + { + name: "role lookup failure is logged and ignored", + contextValue: authmodels.AccessControlAssignRoleContext{UserID: "user-1", RoleName: "Editor"}, + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{}, nil).Once() + rolesRepo.On("GetRoleByName", mock.Anything, "Editor").Return((*types.Role)(nil), errors.New("lookup failed")).Once() }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.False(t, reqCtx.Handled, "expected request not to be handled when user has at least one required permission") - userAccessRepo.AssertExpectations(t) - }) - - t.Run("permissions metadata with whitespace trimming", func(t *testing.T) { - t.Parallel() - - plugin, userAccessRepo := newHookTestPluginWithUserAccessRepoMock(t) - userID := "user-1" - userAccessRepo.On("GetUserEffectivePermissions", mock.Anything, userID).Return([]types.UserPermissionInfo{{PermissionKey: "users.read"}}, nil).Once() - - req := httptest.NewRequest(http.MethodGet, "/resource", nil) - reqCtx := &models.RequestContext{ - Request: req, - Route: &models.Route{ - Metadata: map[string]any{ - "permissions": []string{" users.read ", " ", ""}, - }, + }, + { + name: "assignment failure is logged and ignored", + contextValue: authmodels.AccessControlAssignRoleContext{UserID: "user-1", RoleName: "Editor"}, + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + userRolesRepo.On("GetUserRoles", mock.Anything, "user-1").Return([]types.UserRoleInfo{}, nil).Once() + rolesRepo.On("GetRoleByName", mock.Anything, "Editor").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "Editor"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(errors.New("assign failed")).Once() }, - UserID: &userID, - } - - err := plugin.requireAccessControl(reqCtx) - require.NoError(t, err) - require.False(t, reqCtx.Handled, "expected request not to be handled when user has permissions (after trimming)") - userAccessRepo.AssertExpectations(t) - }) + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setup != nil { + tc.setup(rolesRepo, userRolesRepo) + } + + plugin := newAccessControlHookTestPlugin(&internaltests.MockLogger{}, rolesRepo, userRolesRepo) + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + reqCtx := &authmodels.RequestContext{ + Request: req, + Values: map[string]any{}, + } + if tc.contextValue != nil { + reqCtx.Values[authmodels.ContextAccessControlAssignRole.String()] = tc.contextValue + } + + err := plugin.assignRoleFromContextHook(reqCtx) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } } diff --git a/plugins/access-control/migrations.go b/plugins/access-control/migrations.go index ef6a0e9..d2faacc 100644 --- a/plugins/access-control/migrations.go +++ b/plugins/access-control/migrations.go @@ -1,242 +1,10 @@ package accesscontrol import ( - "context" - - "github.com/uptrace/bun" - "github.com/Authula/authula/migrations" + "github.com/Authula/authula/plugins/access-control/migrationset" ) func accessControlMigrationsForProvider(provider string) []migrations.Migration { - return migrations.ForProvider(provider, migrations.ProviderVariants{ - "sqlite": func() []migrations.Migration { return []migrations.Migration{accessControlSQLiteInitial()} }, - "postgres": func() []migrations.Migration { return []migrations.Migration{accessControlPostgresInitial()} }, - "mysql": func() []migrations.Migration { return []migrations.Migration{accessControlMySQLInitial()} }, - }) -} - -func accessControlSQLiteInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260309000000_access_control_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `PRAGMA foreign_keys = ON;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_roles ( - id TEXT PRIMARY KEY, - name VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system BOOLEAN NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP - );`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_permissions ( - id TEXT PRIMARY KEY, - key VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system BOOLEAN NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP - );`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( - role_id TEXT NOT NULL, - permission_id TEXT NOT NULL, - granted_by_user_id TEXT, - granted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (role_id, permission_id), - FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, - FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL - );`, - `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_role_id ON access_control_role_permissions(role_id);`, - `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_permission_id ON access_control_role_permissions(permission_id);`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_user_roles ( - user_id TEXT NOT NULL, - role_id TEXT NOT NULL, - assigned_by_user_id TEXT, - assigned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP, - PRIMARY KEY (user_id, role_id), - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL - );`, - `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_role_id ON access_control_user_roles(role_id);`, - `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_expires_at ON access_control_user_roles(expires_at);`, - // ------------------------------- - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP TABLE IF EXISTS access_control_user_roles;`, - `DROP TABLE IF EXISTS access_control_role_permissions;`, - `DROP TABLE IF EXISTS access_control_permissions;`, - `DROP TABLE IF EXISTS access_control_roles;`, - ) - }, - } -} - -func accessControlPostgresInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260309000000_access_control_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `CREATE OR REPLACE FUNCTION access_control_set_updated_at_fn() - RETURNS TRIGGER AS $$ - BEGIN - NEW.updated_at = NOW(); - RETURN NEW; - END; - - $$ LANGUAGE plpgsql;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_roles ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - name VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system BOOLEAN NOT NULL DEFAULT FALSE, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() - );`, - `DROP TRIGGER IF EXISTS update_access_control_roles_updated_at_trigger ON access_control_roles;`, - `CREATE TRIGGER update_access_control_roles_updated_at_trigger - BEFORE UPDATE ON access_control_roles - FOR EACH ROW - EXECUTE FUNCTION access_control_set_updated_at_fn();`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_permissions ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - key VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system BOOLEAN NOT NULL DEFAULT FALSE, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() - );`, - `DROP TRIGGER IF EXISTS update_access_control_permissions_updated_at_trigger ON access_control_permissions;`, - `CREATE TRIGGER update_access_control_permissions_updated_at_trigger - BEFORE UPDATE ON access_control_permissions - FOR EACH ROW - EXECUTE FUNCTION access_control_set_updated_at_fn();`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( - role_id UUID NOT NULL, - permission_id UUID NOT NULL, - granted_by_user_id UUID, - granted_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - PRIMARY KEY (role_id, permission_id), - CONSTRAINT fk_access_control_role_permissions_role FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_role_permissions_permission FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_role_permissions_granted_by FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL - );`, - `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_role_id ON access_control_role_permissions(role_id);`, - `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_permission_id ON access_control_role_permissions(permission_id);`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_user_roles ( - user_id UUID NOT NULL, - role_id UUID NOT NULL, - assigned_by_user_id UUID, - assigned_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - expires_at TIMESTAMP WITH TIME ZONE, - PRIMARY KEY (user_id, role_id), - CONSTRAINT fk_access_control_user_roles_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_user_roles_role FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_user_roles_assigned_by FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL - );`, - `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_role_id ON access_control_user_roles(role_id);`, - `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_expires_at ON access_control_user_roles(expires_at);`, - // ------------------------------- - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP TABLE IF EXISTS access_control_user_roles;`, - `DROP TABLE IF EXISTS access_control_role_permissions;`, - `DROP TRIGGER IF EXISTS update_access_control_roles_updated_at_trigger ON access_control_roles;`, - `DROP TRIGGER IF EXISTS update_access_control_permissions_updated_at_trigger ON access_control_permissions;`, - `DROP TABLE IF EXISTS access_control_permissions;`, - `DROP TABLE IF EXISTS access_control_roles;`, - `DROP FUNCTION IF EXISTS access_control_set_updated_at_fn();`, - ) - }, - } -} - -func accessControlMySQLInitial() migrations.Migration { - return migrations.Migration{ - Version: "20260309000000_access_control_initial", - Up: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `CREATE TABLE IF NOT EXISTS access_control_roles ( - id BINARY(16) NOT NULL PRIMARY KEY, - name VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system TINYINT(1) NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_permissions ( - id BINARY(16) NOT NULL PRIMARY KEY, - key VARCHAR(255) NOT NULL UNIQUE, - description TEXT, - is_system TINYINT(1) NOT NULL DEFAULT 0, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( - role_id BINARY(16) NOT NULL, - permission_id BINARY(16) NOT NULL, - granted_by_user_id BINARY(16) NULL, - granted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (role_id, permission_id), - CONSTRAINT fk_access_control_role_permissions_role_id FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_role_permissions_permission_id FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_role_permissions_granted_by_user_id FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL, - INDEX idx_access_control_role_permissions_role_id (role_id), - INDEX idx_access_control_role_permissions_permission_id (permission_id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, - // ------------------------------- - `CREATE TABLE IF NOT EXISTS access_control_user_roles ( - user_id BINARY(16) NOT NULL, - role_id BINARY(16) NOT NULL, - assigned_by_user_id BINARY(16) NULL, - assigned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP NULL, - PRIMARY KEY (user_id, role_id), - CONSTRAINT fk_access_control_user_roles_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_user_roles_role_id FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, - CONSTRAINT fk_access_control_user_roles_assigned_by_user_id FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL, - INDEX idx_access_control_user_roles_role_id (role_id), - INDEX idx_access_control_user_roles_expires_at (expires_at) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, - // ------------------------------- - ) - }, - Down: func(ctx context.Context, tx bun.Tx) error { - return migrations.ExecStatements( - ctx, - tx, - `DROP TABLE IF EXISTS access_control_user_roles;`, - `DROP TABLE IF EXISTS access_control_role_permissions;`, - `DROP TABLE IF EXISTS access_control_permissions;`, - `DROP TABLE IF EXISTS access_control_roles;`, - ) - }, - } + return migrationset.ForProvider(provider) } diff --git a/plugins/access-control/migrationset/migrations.go b/plugins/access-control/migrationset/migrations.go new file mode 100644 index 0000000..02f1efd --- /dev/null +++ b/plugins/access-control/migrationset/migrations.go @@ -0,0 +1,241 @@ +package migrationset + +import ( + "context" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/migrations" +) + +func ForProvider(provider string) []migrations.Migration { + return migrations.ForProvider(provider, migrations.ProviderVariants{ + "sqlite": func() []migrations.Migration { return []migrations.Migration{accessControlSQLiteInitial()} }, + "postgres": func() []migrations.Migration { return []migrations.Migration{accessControlPostgresInitial()} }, + "mysql": func() []migrations.Migration { return []migrations.Migration{accessControlMySQLInitial()} }, + }) +} + +func accessControlSQLiteInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_access_control_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `PRAGMA foreign_keys = ON;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_roles ( + id TEXT PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system BOOLEAN NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + );`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_permissions ( + id TEXT PRIMARY KEY, + key VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system BOOLEAN NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + );`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( + role_id TEXT NOT NULL, + permission_id TEXT NOT NULL, + granted_by_user_id TEXT, + granted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (role_id, permission_id), + FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, + FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_role_id ON access_control_role_permissions(role_id);`, + `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_permission_id ON access_control_role_permissions(permission_id);`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_user_roles ( + user_id TEXT NOT NULL, + role_id TEXT NOT NULL, + assigned_by_user_id TEXT, + assigned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP, + PRIMARY KEY (user_id, role_id), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_role_id ON access_control_user_roles(role_id);`, + `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_expires_at ON access_control_user_roles(expires_at);`, + // ------------------------------- + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS access_control_user_roles;`, + `DROP TABLE IF EXISTS access_control_role_permissions;`, + `DROP TABLE IF EXISTS access_control_permissions;`, + `DROP TABLE IF EXISTS access_control_roles;`, + ) + }, + } +} + +func accessControlPostgresInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_access_control_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE OR REPLACE FUNCTION access_control_set_updated_at_fn() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_roles ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + );`, + `DROP TRIGGER IF EXISTS update_access_control_roles_updated_at_trigger ON access_control_roles;`, + `CREATE TRIGGER update_access_control_roles_updated_at_trigger + BEFORE UPDATE ON access_control_roles + FOR EACH ROW + EXECUTE FUNCTION access_control_set_updated_at_fn();`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_permissions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + key VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + );`, + `DROP TRIGGER IF EXISTS update_access_control_permissions_updated_at_trigger ON access_control_permissions;`, + `CREATE TRIGGER update_access_control_permissions_updated_at_trigger + BEFORE UPDATE ON access_control_permissions + FOR EACH ROW + EXECUTE FUNCTION access_control_set_updated_at_fn();`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( + role_id UUID NOT NULL, + permission_id UUID NOT NULL, + granted_by_user_id UUID, + granted_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + PRIMARY KEY (role_id, permission_id), + CONSTRAINT fk_access_control_role_permissions_role FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_role_permissions_permission FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_role_permissions_granted_by FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_role_id ON access_control_role_permissions(role_id);`, + `CREATE INDEX IF NOT EXISTS idx_access_control_role_permissions_permission_id ON access_control_role_permissions(permission_id);`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_user_roles ( + user_id UUID NOT NULL, + role_id UUID NOT NULL, + assigned_by_user_id UUID, + assigned_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + expires_at TIMESTAMP WITH TIME ZONE, + PRIMARY KEY (user_id, role_id), + CONSTRAINT fk_access_control_user_roles_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_user_roles_role FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_user_roles_assigned_by FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL + );`, + `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_role_id ON access_control_user_roles(role_id);`, + `CREATE INDEX IF NOT EXISTS idx_access_control_user_roles_expires_at ON access_control_user_roles(expires_at);`, + // ------------------------------- + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS access_control_user_roles;`, + `DROP TABLE IF EXISTS access_control_role_permissions;`, + `DROP TRIGGER IF EXISTS update_access_control_roles_updated_at_trigger ON access_control_roles;`, + `DROP TRIGGER IF EXISTS update_access_control_permissions_updated_at_trigger ON access_control_permissions;`, + `DROP TABLE IF EXISTS access_control_permissions;`, + `DROP TABLE IF EXISTS access_control_roles;`, + `DROP FUNCTION IF EXISTS access_control_set_updated_at_fn();`, + ) + }, + } +} + +func accessControlMySQLInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_access_control_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE TABLE IF NOT EXISTS access_control_roles ( + id BINARY(16) NOT NULL PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system TINYINT(1) NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_permissions ( + id BINARY(16) NOT NULL PRIMARY KEY, + key VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + is_system TINYINT(1) NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_role_permissions ( + role_id BINARY(16) NOT NULL, + permission_id BINARY(16) NOT NULL, + granted_by_user_id BINARY(16) NULL, + granted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (role_id, permission_id), + CONSTRAINT fk_access_control_role_permissions_role_id FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_role_permissions_permission_id FOREIGN KEY (permission_id) REFERENCES access_control_permissions(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_role_permissions_granted_by_user_id FOREIGN KEY (granted_by_user_id) REFERENCES users(id) ON DELETE SET NULL, + INDEX idx_access_control_role_permissions_role_id (role_id), + INDEX idx_access_control_role_permissions_permission_id (permission_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + // ------------------------------- + `CREATE TABLE IF NOT EXISTS access_control_user_roles ( + user_id BINARY(16) NOT NULL, + role_id BINARY(16) NOT NULL, + assigned_by_user_id BINARY(16) NULL, + assigned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP NULL, + PRIMARY KEY (user_id, role_id), + CONSTRAINT fk_access_control_user_roles_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_user_roles_role_id FOREIGN KEY (role_id) REFERENCES access_control_roles(id) ON DELETE CASCADE, + CONSTRAINT fk_access_control_user_roles_assigned_by_user_id FOREIGN KEY (assigned_by_user_id) REFERENCES users(id) ON DELETE SET NULL, + INDEX idx_access_control_user_roles_role_id (role_id), + INDEX idx_access_control_user_roles_expires_at (expires_at) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + // ------------------------------- + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS access_control_user_roles;`, + `DROP TABLE IF EXISTS access_control_role_permissions;`, + `DROP TABLE IF EXISTS access_control_permissions;`, + `DROP TABLE IF EXISTS access_control_roles;`, + ) + }, + } +} diff --git a/plugins/access-control/plugin.go b/plugins/access-control/plugin.go index 25d4841..3a94b31 100644 --- a/plugins/access-control/plugin.go +++ b/plugins/access-control/plugin.go @@ -42,20 +42,26 @@ func (p *AccessControlPlugin) Init(ctx *models.PluginContext) error { return err } - rolePermissionRepo := repositories.NewBunRolePermissionRepository(ctx.DB) - userAccessRepo := repositories.NewBunUserAccessRepository(ctx.DB) - rolePermissionService := services.NewRolePermissionService(rolePermissionRepo) - userAccessService := services.NewUserAccessService(userAccessRepo) + rolesRepo := repositories.NewBunRolesRepository(ctx.DB) + permissionsRepo := repositories.NewBunPermissionsRepository(ctx.DB) + rolePermissionsRepo := repositories.NewBunRolePermissionsRepository(ctx.DB) + userRolesRepo := repositories.NewBunUserRolesRepository(ctx.DB) + userPermissionsRepo := repositories.NewBunUserPermissionsRepository(ctx.DB) + + rolesService := services.NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo) + permissionsService := services.NewPermissionsService(permissionsRepo, rolePermissionsRepo) + rolePermissionsService := services.NewRolePermissionsService(rolesRepo, permissionsRepo, rolePermissionsRepo) + userRolesService := services.NewUserRolesService(userRolesRepo, rolesRepo) + userPermissionsService := services.NewUserPermissionsService(userPermissionsRepo) useCases := usecases.NewAccessControlUseCases( - usecases.NewRolePermissionUseCase(rolePermissionService), - usecases.NewUserRolesUseCase(userAccessService), - ) - p.Api = NewAPI( - useCases, - rolePermissionRepo, - userAccessRepo, + usecases.NewRolesUseCase(rolesService), + usecases.NewPermissionsUseCase(permissionsService), + usecases.NewRolePermissionsUseCase(rolePermissionsService), + usecases.NewUserRolesUseCase(userRolesService), + usecases.NewUserPermissionsUseCase(userPermissionsService), ) + p.Api = NewAPI(useCases) return nil } diff --git a/plugins/access-control/repositories/interfaces.go b/plugins/access-control/repositories/interfaces.go index 7ed2387..34a7fc8 100644 --- a/plugins/access-control/repositories/interfaces.go +++ b/plugins/access-control/repositories/interfaces.go @@ -7,31 +7,41 @@ import ( "github.com/Authula/authula/plugins/access-control/types" ) -type RolePermissionRepository interface { +type RolesRepository interface { CreateRole(ctx context.Context, role *types.Role) error GetAllRoles(ctx context.Context) ([]types.Role, error) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) + GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) DeleteRole(ctx context.Context, roleID string) (bool, error) +} + +type PermissionsRepository interface { GetAllPermissions(ctx context.Context) ([]types.Permission, error) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) + GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) CreatePermission(ctx context.Context, permission *types.Permission) error UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) DeletePermission(ctx context.Context, permissionID string) (bool, error) +} + +type RolePermissionsRepository interface { GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error + CountRolesByPermission(ctx context.Context, permissionID string) (int, error) +} + +type UserRolesRepository interface { + GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error RemoveUserRole(ctx context.Context, userID string, roleID string) error - CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) - CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) + CountUsersByRole(ctx context.Context, roleID string) (int, error) } -type UserAccessRepository interface { - GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) - GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) - GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) - GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) +type UserPermissionsRepository interface { + GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) + HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) } diff --git a/plugins/access-control/repositories/permissions_repository.go b/plugins/access-control/repositories/permissions_repository.go new file mode 100644 index 0000000..4886ad4 --- /dev/null +++ b/plugins/access-control/repositories/permissions_repository.go @@ -0,0 +1,99 @@ +package repositories + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/access-control/types" +) + +type BunPermissionsRepository struct { + db bun.IDB +} + +func NewBunPermissionsRepository(db bun.IDB) *BunPermissionsRepository { + return &BunPermissionsRepository{db: db} +} + +func (r *BunPermissionsRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { + if _, err := r.db.NewInsert().Model(permission).Exec(ctx); err != nil { + return err + } + return nil +} + +func (r *BunPermissionsRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { + permissions := make([]types.Permission, 0) + err := r.db.NewSelect().Model(&permissions).Order("created_at ASC").Scan(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get permissions: %w", err) + } + return permissions, nil +} + +func (r *BunPermissionsRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + permission := new(types.Permission) + err := r.db.NewSelect().Model(permission).Where("id = ?", permissionID).Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get permission by id: %w", err) + } + + return permission, nil +} + +func (r *BunPermissionsRepository) GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) { + permission := new(types.Permission) + err := r.db.NewSelect().Model(permission).Where("key = ?", permissionKey).Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get permission by key: %w", err) + } + + return permission, nil +} + +func (r *BunPermissionsRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { + query := r.db.NewUpdate(). + Model((*types.Permission)(nil)). + Set("updated_at = ?", time.Now().UTC()). + Where("id = ?", permissionID) + + if description != nil { + query = query.Set("description = ?", *description) + } + + result, err := query.Exec(ctx) + if err != nil { + return false, err + } + + affected, err := result.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to determine updated rows: %w", err) + } + + return affected > 0, nil +} + +func (r *BunPermissionsRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { + result, err := r.db.NewDelete().Model((*types.Permission)(nil)).Where("id = ?", permissionID).Exec(ctx) + if err != nil { + return false, err + } + + affected, err := result.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to determine deleted rows: %w", err) + } + + return affected > 0, nil +} diff --git a/plugins/access-control/repositories/permissions_repository_test.go b/plugins/access-control/repositories/permissions_repository_test.go new file mode 100644 index 0000000..06152c0 --- /dev/null +++ b/plugins/access-control/repositories/permissions_repository_test.go @@ -0,0 +1,273 @@ +package repositories + +import ( + "context" + "strings" + "testing" + + plugintests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestBunPermissionsRepositoryCreatePermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + permission *types.Permission + wantErr error + wantID string + wantKey string + wantDescription *string + }{ + { + name: "success", + permission: &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}, + wantID: "p1", + wantKey: "users.read", + wantDescription: new("Read users"), + }, + { + name: "duplicate key returns conflict", + permission: &types.Permission{ID: "p2", Key: "users.read", Description: new("Duplicate"), IsSystem: false}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if tc.name == "duplicate key returns conflict" { + if err := repo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}); err != nil { + t.Fatalf("failed to seed permission: %v", err) + } + } + + err := repo.CreatePermission(ctx, tc.permission) + if tc.name == "duplicate key returns conflict" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "UNIQUE constraint failed: access_control_permissions.key") { + t.Fatalf("expected raw unique constraint error, got %v", err) + } + return + } + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.wantErr != nil { + return + } + + stored, err := repo.GetPermissionByID(ctx, tc.wantID) + if err != nil { + t.Fatalf("failed to fetch stored permission: %v", err) + } + if stored == nil { + t.Fatal("expected stored permission, got nil") + } + if stored.ID != tc.wantID || stored.Key != tc.wantKey { + t.Fatalf("unexpected stored permission: %#v", stored) + } + if tc.wantDescription != nil { + if stored.Description == nil || *stored.Description != *tc.wantDescription { + t.Fatalf("expected description %q, got %#v", *tc.wantDescription, stored.Description) + } + } + if stored.CreatedAt.IsZero() || stored.UpdatedAt.IsZero() { + t.Fatal("expected timestamps to be populated") + } + }) + } +} + +func TestBunPermissionsRepositoryGetAllPermissions(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if err := repo.CreatePermission(ctx, &types.Permission{ID: "p2", Key: "users.write"}); err != nil { + t.Fatalf("failed to seed permission p2: %v", err) + } + if err := repo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "users.read"}); err != nil { + t.Fatalf("failed to seed permission p1: %v", err) + } + + permissions, err := repo.GetAllPermissions(ctx) + if err != nil { + t.Fatalf("failed to get permissions: %v", err) + } + if len(permissions) != 2 { + t.Fatalf("expected 2 permissions, got %d", len(permissions)) + } + if permissions[0].ID != "p2" || permissions[1].ID != "p1" { + t.Fatalf("expected permissions ordered by creation time, got %#v", permissions) + } +} + +func TestBunPermissionsRepositoryGetPermissionByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + permissionID string + seedPermission *types.Permission + wantNil bool + }{ + { + name: "not found", + permissionID: "missing", + wantNil: true, + }, + { + name: "success", + permissionID: "p1", + seedPermission: &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if tc.seedPermission != nil { + if err := repo.CreatePermission(ctx, tc.seedPermission); err != nil { + t.Fatalf("failed to seed permission: %v", err) + } + } + + permission, err := repo.GetPermissionByID(ctx, tc.permissionID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantNil { + if permission != nil { + t.Fatalf("expected nil permission, got %#v", permission) + } + return + } + if permission == nil || permission.ID != tc.permissionID { + t.Fatalf("unexpected permission: %#v", permission) + } + }) + } +} + +func TestBunPermissionsRepositoryUpdatePermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedPermission *types.Permission + permissionID string + description *string + wantUpdated bool + }{ + { + name: "missing permission", + permissionID: "missing", + description: new("updated"), + wantUpdated: false, + }, + { + name: "success", + seedPermission: &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}, + permissionID: "p1", + description: new("updated"), + wantUpdated: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if tc.seedPermission != nil { + if err := repo.CreatePermission(ctx, tc.seedPermission); err != nil { + t.Fatalf("failed to seed permission: %v", err) + } + } + + updated, err := repo.UpdatePermission(ctx, tc.permissionID, tc.description) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if updated != tc.wantUpdated { + t.Fatalf("expected updated=%v, got %v", tc.wantUpdated, updated) + } + + if tc.wantUpdated { + permission, err := repo.GetPermissionByID(ctx, tc.permissionID) + if err != nil { + t.Fatalf("failed to fetch permission: %v", err) + } + if permission == nil || permission.Description == nil || *permission.Description != *tc.description { + t.Fatalf("unexpected permission after update: %#v", permission) + } + } + }) + } +} + +func TestBunPermissionsRepositoryDeletePermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedPermission *types.Permission + permissionID string + wantDeleted bool + }{ + { + name: "missing permission", + permissionID: "missing", + wantDeleted: false, + }, + { + name: "success", + seedPermission: &types.Permission{ID: "p1", Key: "users.read", Description: new("Read users"), IsSystem: false}, + permissionID: "p1", + wantDeleted: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunPermissionsRepository(db) + ctx := context.Background() + + if tc.seedPermission != nil { + if err := repo.CreatePermission(ctx, tc.seedPermission); err != nil { + t.Fatalf("failed to seed permission: %v", err) + } + } + + deleted, err := repo.DeletePermission(ctx, tc.permissionID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deleted != tc.wantDeleted { + t.Fatalf("expected deleted=%v, got %v", tc.wantDeleted, deleted) + } + }) + } +} diff --git a/plugins/access-control/repositories/repository_test_helpers_test.go b/plugins/access-control/repositories/repository_test_helpers_test.go deleted file mode 100644 index 08eb4d6..0000000 --- a/plugins/access-control/repositories/repository_test_helpers_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package repositories - -import ( - "context" - "database/sql" - "testing" - - _ "github.com/mattn/go-sqlite3" - "github.com/uptrace/bun" - "github.com/uptrace/bun/dialect/sqlitedialect" - - "github.com/Authula/authula/models" - "github.com/Authula/authula/plugins/access-control/types" -) - -func setupRepoDB(t *testing.T) *bun.DB { - t.Helper() - - sqldb, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatalf("failed to open sqlite: %v", err) - } - - db := bun.NewDB(sqldb, sqlitedialect.New()) - t.Cleanup(func() { - _ = db.Close() - }) - - ctx := context.Background() - - if _, err := db.NewCreateTable().Model((*models.User)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create users table: %v", err) - } - if _, err := db.NewCreateTable().Model((*types.Role)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create access control roles table: %v", err) - } - if _, err := db.NewCreateTable().Model((*types.Permission)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create access control permissions table: %v", err) - } - if _, err := db.NewCreateTable().Model((*types.RolePermission)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create access control role permissions table: %v", err) - } - if _, err := db.NewCreateTable().Model((*types.UserRole)(nil)).IfNotExists().Exec(ctx); err != nil { - t.Fatalf("failed to create access control user roles table: %v", err) - } - - if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name, email, email_verified, metadata) VALUES ('u1', 'User One', 'u1@example.com', 1, '{}')`); err != nil { - t.Fatalf("failed to seed user u1: %v", err) - } - if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name, email, email_verified, metadata) VALUES ('u2', 'User Two', 'u2@example.com', 1, '{}')`); err != nil { - t.Fatalf("failed to seed user u2: %v", err) - } - - return db -} diff --git a/plugins/access-control/repositories/role_permission_repository.go b/plugins/access-control/repositories/role_permission_repository.go index 7d3e61f..807b5ca 100644 --- a/plugins/access-control/repositories/role_permission_repository.go +++ b/plugins/access-control/repositories/role_permission_repository.go @@ -2,7 +2,6 @@ package repositories import ( "context" - "database/sql" "fmt" "time" @@ -12,172 +11,39 @@ import ( ) type BunRolePermissionRepository struct { - db bun.IDB + RolesRepository + PermissionsRepository + RolePermissionsRepository + UserRolesRepository } func NewBunRolePermissionRepository(db bun.IDB) *BunRolePermissionRepository { - return &BunRolePermissionRepository{db: db} -} - -func (r *BunRolePermissionRepository) CreateRole(ctx context.Context, role *types.Role) error { - _, err := r.db.NewInsert().Model(role).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to create role: %w", err) - } - return nil -} - -func (r *BunRolePermissionRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { - roles := make([]types.Role, 0) - err := r.db.NewSelect().Model(&roles).Order("created_at ASC").Scan(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get roles: %w", err) - } - return roles, nil -} - -func (r *BunRolePermissionRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { - role := new(types.Role) - err := r.db.NewSelect().Model(role).Where("id = ?", roleID).Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get role by id: %w", err) - } - - return role, nil -} - -func (r *BunRolePermissionRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { - query := r.db.NewUpdate(). - Model((*types.Role)(nil)). - Set("updated_at = ?", time.Now().UTC()). - Where("id = ?", roleID) - - if name != nil { - query = query.Set("name = ?", *name) - } - - if description != nil { - query = query.Set("description = ?", *description) - } - - result, err := query.Exec(ctx) - if err != nil { - return false, fmt.Errorf("failed to update role: %w", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to determine updated rows: %w", err) - } - - return affected > 0, nil -} - -func (r *BunRolePermissionRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { - result, err := r.db.NewDelete().Model((*types.Role)(nil)).Where("id = ?", roleID).Exec(ctx) - if err != nil { - return false, fmt.Errorf("failed to delete role: %w", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to determine deleted rows: %w", err) + return &BunRolePermissionRepository{ + RolesRepository: NewBunRolesRepository(db), + PermissionsRepository: NewBunPermissionsRepository(db), + RolePermissionsRepository: NewBunRolePermissionsRepository(db), + UserRolesRepository: NewBunUserRolesRepository(db), } - - return affected > 0, nil } -func (r *BunRolePermissionRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { - count, err := r.db.NewSelect(). - Model((*types.UserRole)(nil)). - Where("role_id = ?", roleID). - Count(ctx) - if err != nil { - return 0, fmt.Errorf("failed to count role user assignments: %w", err) - } - - return count, nil -} - -func (r *BunRolePermissionRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - permissions := make([]types.Permission, 0) - err := r.db.NewSelect().Model(&permissions).Order("created_at ASC").Scan(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get permissions: %w", err) - } - return permissions, nil -} - -func (r *BunRolePermissionRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { - permission := new(types.Permission) - err := r.db.NewSelect().Model(permission).Where("id = ?", permissionID).Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get permission by id: %w", err) - } - - return permission, nil -} - -func (r *BunRolePermissionRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { - _, err := r.db.NewInsert().Model(permission).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to create permission: %w", err) - } - return nil +type BunRolePermissionsRepository struct { + db bun.IDB } -func (r *BunRolePermissionRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { - result, err := r.db.NewUpdate(). - Model((*types.Permission)(nil)). - Set("description = ?", *description). - Set("updated_at = ?", time.Now().UTC()). - Where("id = ?", permissionID). - Exec(ctx) - if err != nil { - return false, fmt.Errorf("failed to update permission: %w", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to determine updated rows: %w", err) - } - - return affected > 0, nil +func NewBunRolePermissionsRepository(db bun.IDB) *BunRolePermissionsRepository { + return &BunRolePermissionsRepository{db: db} } -func (r *BunRolePermissionRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { - result, err := r.db.NewDelete().Model((*types.Permission)(nil)).Where("id = ?", permissionID).Exec(ctx) - if err != nil { - return false, fmt.Errorf("failed to delete permission: %w", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to determine deleted rows: %w", err) - } - - return affected > 0, nil +type rolePermissionRow struct { + PermissionID string `bun:"permission_id"` + PermissionKey string `bun:"permission_key"` + PermissionDescription *string `bun:"permission_description"` + GrantedByUserID *string `bun:"granted_by_user_id"` + GrantedAt *time.Time `bun:"granted_at"` } -func (r *BunRolePermissionRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { - count, err := r.db.NewSelect(). - Model((*types.RolePermission)(nil)). - Where("permission_id = ?", permissionID). - Count(ctx) - if err != nil { - return 0, fmt.Errorf("failed to count permission role assignments: %w", err) - } - - return count, nil -} - -func (r *BunRolePermissionRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { +func (r *BunRolePermissionsRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { + var scanned []rolePermissionRow rows := make([]types.UserPermissionInfo, 0) err := r.db.NewSelect(). @@ -190,15 +56,25 @@ func (r *BunRolePermissionRepository) GetRolePermissions(ctx context.Context, ro Join("JOIN access_control_permissions ap ON ap.id = arp.permission_id"). Where("arp.role_id = ?", roleID). OrderExpr("ap.key ASC"). - Scan(ctx, &rows) + Scan(ctx, &scanned) if err != nil { return nil, fmt.Errorf("failed to get role permissions: %w", err) } + for _, row := range scanned { + rows = append(rows, types.UserPermissionInfo{ + PermissionID: row.PermissionID, + PermissionKey: row.PermissionKey, + PermissionDescription: row.PermissionDescription, + GrantedByUserID: row.GrantedByUserID, + GrantedAt: row.GrantedAt, + }) + } + return rows, nil } -func (r *BunRolePermissionRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { +func (r *BunRolePermissionsRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewDelete().Model((*types.RolePermission)(nil)).Where("role_id = ?", roleID).Exec(ctx); err != nil { return fmt.Errorf("failed to clear role permissions: %w", err) @@ -212,8 +88,9 @@ func (r *BunRolePermissionRepository) ReplaceRolePermissions(ctx context.Context GrantedByUserID: grantedByUserID, GrantedAt: now, } - if _, err := tx.NewInsert().Model(rp).Exec(ctx); err != nil { - return fmt.Errorf("failed to insert role permission: %w", err) + _, err := tx.NewInsert().Model(rp).Exec(ctx) + if err != nil { + return err } } @@ -221,7 +98,7 @@ func (r *BunRolePermissionRepository) ReplaceRolePermissions(ctx context.Context }) } -func (r *BunRolePermissionRepository) AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { +func (r *BunRolePermissionsRepository) AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { rp := &types.RolePermission{ RoleID: roleID, PermissionID: permissionID, @@ -231,74 +108,31 @@ func (r *BunRolePermissionRepository) AddRolePermission(ctx context.Context, rol _, err := r.db.NewInsert().Model(rp).Exec(ctx) if err != nil { - return fmt.Errorf("failed to add role permission: %w", err) + return err } - return nil } -func (r *BunRolePermissionRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { +func (r *BunRolePermissionsRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { _, err := r.db.NewDelete(). Model((*types.RolePermission)(nil)). Where("role_id = ?", roleID). Where("permission_id = ?", permissionID). Exec(ctx) if err != nil { - return fmt.Errorf("failed to remove role permission: %w", err) + return err } return nil } -func (r *BunRolePermissionRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - if _, err := tx.NewDelete().Model((*types.UserRole)(nil)).Where("user_id = ?", userID).Exec(ctx); err != nil { - return fmt.Errorf("failed to clear user roles: %w", err) - } - - now := time.Now().UTC() - for _, roleID := range roleIDs { - ur := &types.UserRole{ - UserID: userID, - RoleID: roleID, - AssignedByUserID: assignedByUserID, - AssignedAt: now, - } - if _, err := tx.NewInsert().Model(ur).Exec(ctx); err != nil { - return fmt.Errorf("failed to insert user role: %w", err) - } - } - - return nil - }) -} - -func (r *BunRolePermissionRepository) AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error { - ur := &types.UserRole{ - UserID: userID, - RoleID: roleID, - AssignedByUserID: assignedByUserID, - AssignedAt: time.Now().UTC(), - ExpiresAt: expiresAt, - } - - _, err := r.db.NewInsert().Model(ur).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to assign user role: %w", err) - } - - return nil -} - -func (r *BunRolePermissionRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { - _, err := r.db.NewDelete(). - Model((*types.UserRole)(nil)). - Where("user_id = ?", userID). - Where("role_id = ?", roleID). - Exec(ctx) +func (r *BunRolePermissionsRepository) CountRolesByPermission(ctx context.Context, permissionID string) (int, error) { + count, err := r.db.NewSelect(). + Model((*types.RolePermission)(nil)). + Where("permission_id = ?", permissionID). + Count(ctx) if err != nil { - return fmt.Errorf("failed to remove user role: %w", err) + return 0, fmt.Errorf("failed to count roles by permission: %w", err) } - - return nil + return count, nil } diff --git a/plugins/access-control/repositories/role_permission_repository_test.go b/plugins/access-control/repositories/role_permission_repository_test.go index 7ad5393..9fb6ad8 100644 --- a/plugins/access-control/repositories/role_permission_repository_test.go +++ b/plugins/access-control/repositories/role_permission_repository_test.go @@ -2,89 +2,304 @@ package repositories import ( "context" + "strings" "testing" - "time" - internaltests "github.com/Authula/authula/internal/tests" + plugintests "github.com/Authula/authula/plugins/access-control/tests" "github.com/Authula/authula/plugins/access-control/types" ) -func TestBunRolePermissionRepositoryRoleCRUDAndAssignmentCount(t *testing.T) { - db := setupRepoDB(t) - repo := NewBunRolePermissionRepository(db) - ctx := context.Background() +func TestBunRolePermissionsRepositoryGetRolePermissions(t *testing.T) { + t.Parallel() - role := &types.Role{ID: "r1", Name: "admin"} - if err := repo.CreateRole(ctx, role); err != nil { - t.Fatalf("failed to create role: %v", err) + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, context.Context) + roleID string + wantIDs []string + wantKeys []string + wantNil bool + }{ + { + name: "empty result", + roleID: "role-missing", + wantNil: false, + wantIDs: []string{}, + wantKeys: []string{}, + }, + { + name: "success", + roleID: "role-1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-2", Key: "users.write", Description: new("Write users")}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read", Description: new("Read users")}); err != nil { + panic(err) + } + grantedBy := "u2" + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-2", &grantedBy); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", &grantedBy); err != nil { + panic(err) + } + }, + wantIDs: []string{"perm-1", "perm-2"}, + wantKeys: []string{"users.read", "users.write"}, + }, } - if err := repo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { - t.Fatalf("failed to assign user role: %v", err) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - count, err := repo.CountUserAssignmentsByRoleID(ctx, "r1") - if err != nil { - t.Fatalf("failed to count assignments: %v", err) - } - if count != 1 { - t.Fatalf("expected 1 assignment, got %d", count) - } + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + ctx := context.Background() - deleted, err := repo.DeleteRole(ctx, "r1") - if err != nil { - t.Fatalf("failed to delete role: %v", err) - } - if !deleted { - t.Fatal("expected role to be deleted") + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, ctx) + } + + permissions, err := rolePermissionsRepo.GetRolePermissions(ctx, tc.roleID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if permissions == nil { + t.Fatal("expected permissions slice, got nil") + } + if len(permissions) != len(tc.wantIDs) { + t.Fatalf("expected %d permissions, got %d", len(tc.wantIDs), len(permissions)) + } + for i := range tc.wantIDs { + if permissions[i].PermissionID != tc.wantIDs[i] || permissions[i].PermissionKey != tc.wantKeys[i] { + t.Fatalf("unexpected permission at %d: %#v", i, permissions[i]) + } + } + }) } } -func TestBunRolePermissionRepositoryReplaceRolePermissions(t *testing.T) { - db := setupRepoDB(t) - repo := NewBunRolePermissionRepository(db) - ctx := context.Background() +func TestBunRolePermissionsRepositoryAddRolePermission(t *testing.T) { + t.Parallel() - if err := repo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { - t.Fatalf("failed to create role: %v", err) - } - if err := repo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "posts.read"}); err != nil { - t.Fatalf("failed to create permission p1: %v", err) - } - if err := repo.CreatePermission(ctx, &types.Permission{ID: "p2", Key: "posts.write"}); err != nil { - t.Fatalf("failed to create permission p2: %v", err) + tests := []struct { + name string + setup func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, context.Context) + roleID string + permissionID string + grantedByUserID *string + wantErr error + }{ + { + name: "success", + roleID: "role-1", + permissionID: "perm-1", + setup: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + }, + }, + { + name: "duplicate grant returns conflict", + roleID: "role-1", + permissionID: "perm-1", + setup: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", nil); err != nil { + panic(err) + } + }, + }, } - if err := repo.ReplaceRolePermissions(ctx, "r1", []string{"p1", "p2"}, nil); err != nil { - t.Fatalf("failed to replace role permissions: %v", err) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - permissions, err := repo.GetRolePermissions(ctx, "r1") - if err != nil { - t.Fatalf("failed to get role permissions: %v", err) - } - if len(permissions) != 2 { - t.Fatalf("expected 2 permissions, got %d", len(permissions)) + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + ctx := context.Background() + + if tc.setup != nil { + tc.setup(rolesRepo, permissionsRepo, rolePermissionsRepo, ctx) + } + + err := rolePermissionsRepo.AddRolePermission(ctx, tc.roleID, tc.permissionID, tc.grantedByUserID) + if tc.name == "duplicate grant returns conflict" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "UNIQUE constraint failed: access_control_role_permissions.role_id, access_control_role_permissions.permission_id") { + t.Fatalf("expected raw unique constraint error, got %v", err) + } + return + } + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.wantErr != nil { + return + } + + permissions, err := rolePermissionsRepo.GetRolePermissions(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to fetch role permissions: %v", err) + } + if len(permissions) != 1 || permissions[0].PermissionID != tc.permissionID { + t.Fatalf("unexpected permissions: %#v", permissions) + } + if permissions[0].GrantedAt == nil { + t.Fatal("expected granted_at to be populated") + } + }) } } -func TestBunRolePermissionRepositoryReplaceUserRoles(t *testing.T) { - db := setupRepoDB(t) - repo := NewBunRolePermissionRepository(db) - ctx := context.Background() +func TestBunRolePermissionsRepositoryReplaceRolePermissions(t *testing.T) { + t.Parallel() - if err := repo.CreateRole(ctx, &types.Role{ID: "r1", Name: "role-1"}); err != nil { - t.Fatalf("failed to create role r1: %v", err) + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, context.Context) + roleID string + permissionIDs []string + wantIDs []string + }{ + { + name: "success", + roleID: "role-1", + permissionIDs: []string{"perm-2", "perm-1"}, + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-2", Key: "users.write"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-2", nil); err != nil { + panic(err) + } + }, + wantIDs: []string{"perm-1", "perm-2"}, + }, } - if err := repo.CreateRole(ctx, &types.Role{ID: "r2", Name: "role-2"}); err != nil { - t.Fatalf("failed to create role r2: %v", err) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, ctx) + } + + if err := rolePermissionsRepo.ReplaceRolePermissions(ctx, tc.roleID, tc.permissionIDs, nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + permissions, err := rolePermissionsRepo.GetRolePermissions(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to fetch role permissions: %v", err) + } + if len(permissions) != len(tc.wantIDs) { + t.Fatalf("expected %d permissions, got %d", len(tc.wantIDs), len(permissions)) + } + for i, wantID := range tc.wantIDs { + if permissions[i].PermissionID != wantID { + t.Fatalf("expected permission %s at index %d, got %#v", wantID, i, permissions[i]) + } + } + }) } +} + +func TestBunRolePermissionsRepositoryRemoveRolePermission(t *testing.T) { + t.Parallel() - if err := repo.ReplaceUserRoles(ctx, "u1", []string{"r1", "r2"}, nil); err != nil { - t.Fatalf("failed to replace user roles: %v", err) + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, context.Context) + roleID string + permissionID string + wantExists bool + }{ + { + name: "success", + roleID: "role-1", + permissionID: "perm-1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", nil); err != nil { + panic(err) + } + }, + wantExists: false, + }, } - if err := repo.AssignUserRole(ctx, "u1", "r1", nil, internaltests.PtrTime(time.Now().UTC().Add(1*time.Hour))); err == nil { - t.Fatal("expected duplicate assignment to fail due primary key") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, ctx) + } + + if err := rolePermissionsRepo.RemoveRolePermission(ctx, tc.roleID, tc.permissionID); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + permissions, err := rolePermissionsRepo.GetRolePermissions(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to fetch role permissions: %v", err) + } + if tc.wantExists { + if len(permissions) == 0 { + t.Fatal("expected permission to remain") + } + return + } + if len(permissions) != 0 { + t.Fatalf("expected no permissions after remove, got %#v", permissions) + } + }) } } diff --git a/plugins/access-control/repositories/roles_repository.go b/plugins/access-control/repositories/roles_repository.go new file mode 100644 index 0000000..5d87600 --- /dev/null +++ b/plugins/access-control/repositories/roles_repository.go @@ -0,0 +1,101 @@ +package repositories + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/access-control/types" +) + +type BunRolesRepository struct { + db bun.IDB +} + +func NewBunRolesRepository(db bun.IDB) *BunRolesRepository { + return &BunRolesRepository{db: db} +} + +func (r *BunRolesRepository) CreateRole(ctx context.Context, role *types.Role) error { + _, err := r.db.NewInsert().Model(role).Exec(ctx) + return err +} + +func (r *BunRolesRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { + roles := make([]types.Role, 0) + err := r.db.NewSelect().Model(&roles).Order("created_at ASC").Scan(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get roles: %w", err) + } + return roles, nil +} + +func (r *BunRolesRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { + role := new(types.Role) + err := r.db.NewSelect().Model(role).Where("id = ?", roleID).Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get role by id: %w", err) + } + + return role, nil +} + +func (r *BunRolesRepository) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + role := new(types.Role) + err := r.db.NewSelect().Model(role).Where("name = ?", roleName).Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get role by name: %w", err) + } + + return role, nil +} + +func (r *BunRolesRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { + query := r.db.NewUpdate(). + Model((*types.Role)(nil)). + Set("updated_at = ?", time.Now().UTC()). + Where("id = ?", roleID) + + if name != nil { + query = query.Set("name = ?", *name) + } + + if description != nil { + query = query.Set("description = ?", *description) + } + + result, err := query.Exec(ctx) + if err != nil { + return false, err + } + + affected, err := result.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to determine updated rows: %w", err) + } + + return affected > 0, nil +} + +func (r *BunRolesRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { + result, err := r.db.NewDelete().Model((*types.Role)(nil)).Where("id = ?", roleID).Exec(ctx) + if err != nil { + return false, err + } + + affected, err := result.RowsAffected() + if err != nil { + return false, fmt.Errorf("failed to determine deleted rows: %w", err) + } + + return affected > 0, nil +} diff --git a/plugins/access-control/repositories/roles_repository_test.go b/plugins/access-control/repositories/roles_repository_test.go new file mode 100644 index 0000000..9905eff --- /dev/null +++ b/plugins/access-control/repositories/roles_repository_test.go @@ -0,0 +1,604 @@ +package repositories + +import ( + "context" + "strings" + "testing" + + plugintests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestBunRolesRepositoryCreateRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + role *types.Role + seed *types.Role + wantErr error + wantID string + wantName string + wantDesc *string + }{ + { + name: "success", + role: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: false}, + wantID: "r1", + wantName: "editor", + wantDesc: new("Editor role"), + }, + { + name: "duplicate name returns conflict", + role: &types.Role{ID: "r2", Name: "editor", Description: new("Duplicate role"), IsSystem: false}, + seed: &types.Role{ID: "r1", Name: "editor", Description: new("Original role"), IsSystem: false}, + wantErr: nil, + }, + { + name: "query error returns wrapped error", + role: &types.Role{ID: "r3", Name: "reviewer", Description: new("Reviewer role"), IsSystem: false}, + wantErr: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + if err := repo.CreateRole(ctx, tc.seed); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + if tc.name == "query error returns wrapped error" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + + err := repo.CreateRole(ctx, tc.role) + if tc.name == "query error returns wrapped error" { + if err == nil { + t.Fatal("expected error, got nil") + } + if err.Error() != "sql: database is closed" { + t.Fatalf("expected direct db error, got %v", err) + } + return + } + if tc.name == "duplicate name returns conflict" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "UNIQUE constraint failed: access_control_roles.name") { + t.Fatalf("expected raw unique constraint error, got %v", err) + } + return + } + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.wantErr != nil { + return + } + + stored, err := repo.GetRoleByID(ctx, tc.wantID) + if err != nil { + t.Fatalf("failed to fetch stored role: %v", err) + } + if stored == nil { + t.Fatal("expected stored role, got nil") + } + if stored.ID != tc.wantID || stored.Name != tc.wantName { + t.Fatalf("unexpected stored role: %#v", stored) + } + if !stringPtrEqual(stored.Description, tc.wantDesc) { + t.Fatalf("expected description %#v, got %#v", tc.wantDesc, stored.Description) + } + if stored.CreatedAt.IsZero() || stored.UpdatedAt.IsZero() { + t.Fatal("expected timestamps to be populated") + } + }) + } +} + +func TestBunRolesRepositoryGetAllRoles(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedRoles []*types.Role + wantIDs []string + wantNames []string + wantDescs []*string + wantErrMsg string + }{ + { + name: "empty result", + wantIDs: []string{}, + wantNames: []string{}, + wantDescs: []*string{}, + }, + { + name: "returns roles ordered by creation time", + seedRoles: []*types.Role{ + {ID: "r2", Name: "viewer", Description: new("Viewer role")}, + {ID: "r1", Name: "editor", Description: new("Editor role")}, + }, + wantIDs: []string{"r2", "r1"}, + wantNames: []string{"viewer", "editor"}, + wantDescs: []*string{new("Viewer role"), new("Editor role")}, + }, + { + name: "query error", + wantErrMsg: "failed to get roles", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + for _, role := range tc.seedRoles { + if err := repo.CreateRole(ctx, role); err != nil { + t.Fatalf("failed to seed role %s: %v", role.ID, err) + } + } + + if tc.wantErrMsg != "" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + + roles, err := repo.GetAllRoles(ctx) + if tc.wantErrMsg != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + if roles != nil { + t.Fatalf("expected nil roles on error, got %#v", roles) + } + return + } + if err != nil { + t.Fatalf("failed to get roles: %v", err) + } + if roles == nil { + t.Fatal("expected roles slice, got nil") + } + if len(roles) != len(tc.wantIDs) { + t.Fatalf("expected %d roles, got %d", len(tc.wantIDs), len(roles)) + } + for i := range tc.wantIDs { + if roles[i].ID != tc.wantIDs[i] || roles[i].Name != tc.wantNames[i] { + t.Fatalf("unexpected role at %d: %#v", i, roles[i]) + } + if !stringPtrEqual(roles[i].Description, tc.wantDescs[i]) { + t.Fatalf("unexpected description at %d: %#v", i, roles[i]) + } + if roles[i].CreatedAt.IsZero() || roles[i].UpdatedAt.IsZero() { + t.Fatalf("expected timestamps to be populated at %d", i) + } + } + }) + } +} + +func TestBunRolesRepositoryGetRoleByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + roleID string + seedRole *types.Role + wantNil bool + wantName string + wantDesc *string + wantSystem bool + wantErrMsg string + }{ + { + name: "not found", + roleID: "missing", + wantNil: true, + }, + { + name: "success", + roleID: "r1", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: true}, + wantName: "editor", + wantDesc: new("Editor role"), + wantSystem: true, + }, + { + name: "query error", + roleID: "r1", + wantErrMsg: "failed to get role by id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seedRole != nil { + if err := repo.CreateRole(ctx, tc.seedRole); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + if tc.wantErrMsg != "" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + + role, err := repo.GetRoleByID(ctx, tc.roleID) + if tc.wantErrMsg != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if role != nil { + t.Fatalf("expected nil role on error, got %#v", role) + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantNil { + if role != nil { + t.Fatalf("expected nil role, got %#v", role) + } + return + } + if role == nil { + t.Fatal("expected role, got nil") + } + if role.ID != tc.roleID || role.Name != tc.wantName || role.IsSystem != tc.wantSystem { + t.Fatalf("unexpected role: %#v", role) + } + if !stringPtrEqual(role.Description, tc.wantDesc) { + t.Fatalf("expected description %#v, got %#v", tc.wantDesc, role.Description) + } + }) + } +} + +func TestBunRolesRepositoryGetRoleByName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + roleName string + seedRole *types.Role + wantNil bool + wantName string + wantDesc *string + wantSystem bool + wantErr bool + wantErrMsg string + }{ + { + name: "not found", + roleName: "missing", + wantNil: true, + }, + { + name: "success", + roleName: "editor", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: true}, + wantName: "editor", + wantDesc: new("Editor role"), + wantSystem: true, + }, + { + name: "query error", + roleName: "editor", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: false}, + wantErr: true, + wantErrMsg: "failed to get role by name", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seedRole != nil { + if err := repo.CreateRole(ctx, tc.seedRole); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + if tc.wantErr { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + + role, err := repo.GetRoleByName(ctx, tc.roleName) + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if role != nil { + t.Fatalf("expected nil role on error, got %#v", role) + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantNil { + if role != nil { + t.Fatalf("expected nil role, got %#v", role) + } + return + } + if role == nil { + t.Fatal("expected role, got nil") + } + if role.Name != tc.wantName || role.IsSystem != tc.wantSystem { + t.Fatalf("unexpected role: %#v", role) + } + if !stringPtrEqual(role.Description, tc.wantDesc) { + t.Fatalf("expected description %#v, got %#v", tc.wantDesc, role.Description) + } + }) + } +} + +func TestBunRolesRepositoryUpdateRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedRole *types.Role + roleID string + nameValue *string + description *string + wantUpdated bool + wantName *string + wantDesc *string + wantErrMsg string + }{ + { + name: "missing role", + roleID: "missing", + nameValue: new("updated"), + description: new("updated description"), + wantUpdated: false, + }, + { + name: "update name only", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: false}, + roleID: "r1", + nameValue: new("editor-updated"), + wantUpdated: true, + wantName: new("editor-updated"), + wantDesc: new("Editor role"), + }, + { + name: "update description only", + seedRole: &types.Role{ID: "r2", Name: "viewer", Description: new("Viewer role"), IsSystem: false}, + roleID: "r2", + description: new("Viewer role updated"), + wantUpdated: true, + wantName: new("viewer"), + wantDesc: new("Viewer role updated"), + }, + { + name: "update name and description", + seedRole: &types.Role{ID: "r3", Name: "author", Description: new("Author role"), IsSystem: false}, + roleID: "r3", + nameValue: new("author-updated"), + description: new("Author role updated"), + wantUpdated: true, + wantName: new("author-updated"), + wantDesc: new("Author role updated"), + }, + { + name: "update with no fields still updates timestamp", + seedRole: &types.Role{ID: "r4", Name: "reviewer", Description: new("Reviewer role"), IsSystem: false}, + roleID: "r4", + wantUpdated: true, + wantName: new("reviewer"), + wantDesc: new("Reviewer role"), + }, + { + name: "query error", + roleID: "r5", + nameValue: new("updated"), + wantErrMsg: "sql: database is closed", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seedRole != nil { + if err := repo.CreateRole(ctx, tc.seedRole); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + if tc.wantErrMsg != "" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + + updated, err := repo.UpdateRole(ctx, tc.roleID, tc.nameValue, tc.description) + if tc.wantErrMsg != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if updated { + t.Fatal("expected updated=false on error") + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if updated != tc.wantUpdated { + t.Fatalf("expected updated=%v, got %v", tc.wantUpdated, updated) + } + + if !tc.wantUpdated { + return + } + + role, err := repo.GetRoleByID(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to fetch updated role: %v", err) + } + if role == nil { + t.Fatal("expected updated role, got nil") + } + if role.Name != derefOrEmpty(tc.wantName) { + t.Fatalf("expected name %q, got %q", derefOrEmpty(tc.wantName), role.Name) + } + if !stringPtrEqual(role.Description, tc.wantDesc) { + t.Fatalf("expected description %#v, got %#v", tc.wantDesc, role.Description) + } + if role.UpdatedAt.IsZero() { + t.Fatal("expected updated_at to be populated") + } + }) + } +} + +func TestBunRolesRepositoryDeleteRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seedRole *types.Role + roleID string + wantDeleted bool + wantErrMsg string + }{ + { + name: "missing role", + roleID: "missing", + wantDeleted: false, + }, + { + name: "success", + seedRole: &types.Role{ID: "r1", Name: "editor", Description: new("Editor role"), IsSystem: false}, + roleID: "r1", + wantDeleted: true, + }, + { + name: "query error", + roleID: "r5", + wantErrMsg: "sql: database is closed", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + repo := NewBunRolesRepository(db) + ctx := context.Background() + + if tc.seedRole != nil { + if err := repo.CreateRole(ctx, tc.seedRole); err != nil { + t.Fatalf("failed to seed role: %v", err) + } + } + + if tc.wantErrMsg != "" { + if err := db.Close(); err != nil { + t.Fatalf("failed to close db: %v", err) + } + } + + deleted, err := repo.DeleteRole(ctx, tc.roleID) + if tc.wantErrMsg != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if deleted { + t.Fatal("expected deleted=false on error") + } + if !strings.Contains(err.Error(), tc.wantErrMsg) { + t.Fatalf("expected direct db error, got %v", err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deleted != tc.wantDeleted { + t.Fatalf("expected deleted=%v, got %v", tc.wantDeleted, deleted) + } + + if !tc.wantDeleted { + return + } + + role, err := repo.GetRoleByID(ctx, tc.roleID) + if err != nil { + t.Fatalf("failed to verify role deletion: %v", err) + } + if role != nil { + t.Fatalf("expected deleted role to be absent, got %#v", role) + } + }) + } +} + +func stringPtrEqual(left *string, right *string) bool { + if left == nil || right == nil { + return left == right + } + return *left == *right +} + +func derefOrEmpty(value *string) string { + if value == nil { + return "" + } + return *value +} diff --git a/plugins/access-control/repositories/user_access_repository.go b/plugins/access-control/repositories/user_access_repository.go deleted file mode 100644 index bd1b948..0000000 --- a/plugins/access-control/repositories/user_access_repository.go +++ /dev/null @@ -1,288 +0,0 @@ -package repositories - -import ( - "context" - "database/sql" - "fmt" - "time" - - "github.com/uptrace/bun" - - "github.com/Authula/authula/models" - "github.com/Authula/authula/plugins/access-control/types" -) - -type BunUserAccessRepository struct { - db bun.IDB -} - -func NewBunUserAccessRepository(db bun.IDB) *BunUserAccessRepository { - return &BunUserAccessRepository{db: db} -} - -func (r *BunUserAccessRepository) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - var rows []types.UserRoleInfo - err := r.db.NewSelect(). - TableExpr("access_control_user_roles acur"). - ColumnExpr("acur.role_id AS role_id"). - ColumnExpr("acr.name AS role_name"). - ColumnExpr("acr.description AS role_description"). - ColumnExpr("acur.assigned_by_user_id AS assigned_by_user_id"). - ColumnExpr("acur.assigned_at AS assigned_at"). - ColumnExpr("acur.expires_at AS expires_at"). - Join("JOIN access_control_roles acr ON acr.id = acur.role_id"). - Where("acur.user_id = ?", userID). - OrderExpr("acr.name ASC, acur.assigned_at DESC"). - Scan(ctx, &rows) - if err != nil { - return nil, fmt.Errorf("failed to get user roles: %w", err) - } - if rows == nil { - return []types.UserRoleInfo{}, nil - } - return rows, nil -} - -type userEffectivePermissionRow struct { - PermissionID string `bun:"permission_id"` - PermissionKey string `bun:"permission_key"` - PermissionDescription *string `bun:"permission_description"` - SourceRoleID string `bun:"source_role_id"` - SourceRoleName string `bun:"source_role_name"` - GrantedByUserID *string `bun:"granted_by_user_id"` - GrantedAt *time.Time `bun:"granted_at"` -} - -func (r *BunUserAccessRepository) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - var rows []userEffectivePermissionRow - now := time.Now().UTC() - err := r.db.NewSelect(). - TableExpr("access_control_user_roles pur"). - ColumnExpr("pp.id AS permission_id"). - ColumnExpr("pp.key AS permission_key"). - ColumnExpr("pp.description AS permission_description"). - ColumnExpr("pr.id AS source_role_id"). - ColumnExpr("pr.name AS source_role_name"). - ColumnExpr("prp.granted_by_user_id AS granted_by_user_id"). - ColumnExpr("prp.granted_at AS granted_at"). - Join("JOIN access_control_role_permissions prp ON prp.role_id = pur.role_id"). - Join("JOIN access_control_permissions pp ON pp.id = prp.permission_id"). - Join("JOIN access_control_roles pr ON pr.id = pur.role_id"). - Where("pur.user_id = ?", userID). - Where("pur.expires_at IS NULL OR pur.expires_at > ?", now). - OrderExpr("pp.key ASC"). - OrderExpr("pr.name ASC"). - OrderExpr("CASE WHEN prp.granted_at IS NULL THEN 1 ELSE 0 END ASC"). - OrderExpr("prp.granted_at DESC"). - Scan(ctx, &rows) - if err != nil { - return nil, fmt.Errorf("failed to get user effective permissions: %w", err) - } - if rows == nil { - return []types.UserPermissionInfo{}, nil - } - - permissions := make([]types.UserPermissionInfo, 0) - permissionIndex := make(map[string]int) - - for _, row := range rows { - idx, exists := permissionIndex[row.PermissionID] - if !exists { - permissions = append(permissions, types.UserPermissionInfo{ - PermissionID: row.PermissionID, - PermissionKey: row.PermissionKey, - PermissionDescription: row.PermissionDescription, - }) - idx = len(permissions) - 1 - permissionIndex[row.PermissionID] = idx - } - - source := types.PermissionGrantSource{ - RoleID: row.SourceRoleID, - RoleName: row.SourceRoleName, - GrantedByUserID: row.GrantedByUserID, - GrantedAt: row.GrantedAt, - } - permissions[idx].Sources = append(permissions[idx].Sources, source) - } - - return permissions, nil -} - -type userWithRoleRow struct { - UserID string `bun:"user_id"` - UserName string `bun:"user_name"` - UserEmail string `bun:"user_email"` - EmailVerified bool `bun:"email_verified"` - Image *string - Metadata []byte - CreatedAt time.Time `bun:"created_at"` - UpdatedAt time.Time `bun:"updated_at"` - RoleID *string `bun:"role_id"` - RoleName *string `bun:"role_name"` -} - -func (r *BunUserAccessRepository) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - var rows []userWithRoleRow - now := time.Now().UTC() - err := r.db.NewSelect(). - TableExpr("users u"). - ColumnExpr("u.id AS user_id"). - ColumnExpr("u.name AS user_name"). - ColumnExpr("u.email AS user_email"). - ColumnExpr("u.email_verified AS email_verified"). - ColumnExpr("u.image AS image"). - ColumnExpr("u.metadata AS metadata"). - ColumnExpr("u.created_at AS created_at"). - ColumnExpr("u.updated_at AS updated_at"). - ColumnExpr("pr.id AS role_id"). - ColumnExpr("pr.name AS role_name"). - Join("LEFT JOIN access_control_user_roles pur ON pur.user_id = u.id AND (pur.expires_at IS NULL OR pur.expires_at > ?)", now). - Join("LEFT JOIN access_control_roles pr ON pr.id = pur.role_id"). - Where("u.id = ?", userID). - OrderExpr("pr.name ASC"). - Scan(ctx, &rows) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get user with roles: %w", err) - } - if len(rows) == 0 { - return nil, nil - } - - result := &types.UserWithRoles{ - User: mapRowToUser(rows[0]), - } - - seen := make(map[string]struct{}) - for _, row := range rows { - if row.RoleID == nil || *row.RoleID == "" { - continue - } - if _, ok := seen[*row.RoleID]; ok { - continue - } - seen[*row.RoleID] = struct{}{} - roleName := "" - if row.RoleName != nil { - roleName = *row.RoleName - } - result.Roles = append(result.Roles, types.UserRoleInfo{RoleID: *row.RoleID, RoleName: roleName}) - } - - return result, nil -} - -type userWithPermissionRow struct { - UserID string `bun:"user_id"` - UserName string `bun:"user_name"` - UserEmail string `bun:"user_email"` - EmailVerified bool `bun:"email_verified"` - Image *string - Metadata []byte - CreatedAt time.Time `bun:"created_at"` - UpdatedAt time.Time `bun:"updated_at"` - PermissionID *string `bun:"permission_id"` - PermissionKey *string `bun:"permission_key"` -} - -func (r *BunUserAccessRepository) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - var rows []userWithPermissionRow - now := time.Now().UTC() - err := r.db.NewSelect(). - TableExpr("users u"). - ColumnExpr("u.id AS user_id"). - ColumnExpr("u.name AS user_name"). - ColumnExpr("u.email AS user_email"). - ColumnExpr("u.email_verified AS email_verified"). - ColumnExpr("u.image AS image"). - ColumnExpr("u.metadata AS metadata"). - ColumnExpr("u.created_at AS created_at"). - ColumnExpr("u.updated_at AS updated_at"). - ColumnExpr("ap.id AS permission_id"). - ColumnExpr("ap.key AS permission_key"). - Join("LEFT JOIN access_control_user_roles aur ON aur.user_id = u.id AND (aur.expires_at IS NULL OR aur.expires_at > ?)", now). - Join("LEFT JOIN access_control_role_permissions arp ON arp.role_id = aur.role_id"). - Join("LEFT JOIN access_control_permissions ap ON ap.id = arp.permission_id"). - Where("u.id = ?", userID). - OrderExpr("ap.key ASC"). - Scan(ctx, &rows) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("failed to get user with permissions: %w", err) - } - if len(rows) == 0 { - return nil, nil - } - - result := &types.UserWithPermissions{ - User: mapRowToUser(rows[0]), - } - - seen := make(map[string]struct{}) - for _, row := range rows { - if row.PermissionID == nil || *row.PermissionID == "" { - continue - } - if _, ok := seen[*row.PermissionID]; ok { - continue - } - seen[*row.PermissionID] = struct{}{} - permissionKey := "" - if row.PermissionKey != nil { - permissionKey = *row.PermissionKey - } - result.Permissions = append(result.Permissions, types.UserPermissionInfo{PermissionID: *row.PermissionID, PermissionKey: permissionKey}) - } - - return result, nil -} - -type userRow interface { - GetUserID() string - GetUserName() string - GetUserEmail() string - GetEmailVerified() bool - GetImage() *string - GetMetadata() []byte - GetCreatedAt() time.Time - GetUpdatedAt() time.Time -} - -func mapRowToUser(row userRow) models.User { - return models.User{ - ID: row.GetUserID(), - Name: row.GetUserName(), - Email: row.GetUserEmail(), - EmailVerified: row.GetEmailVerified(), - Image: row.GetImage(), - Metadata: row.GetMetadata(), - CreatedAt: row.GetCreatedAt(), - UpdatedAt: row.GetUpdatedAt(), - } -} - -func (r userWithRoleRow) GetUserID() string { return r.UserID } -func (r userWithRoleRow) GetUserName() string { return r.UserName } -func (r userWithRoleRow) GetUserEmail() string { return r.UserEmail } -func (r userWithRoleRow) GetEmailVerified() bool { return r.EmailVerified } -func (r userWithRoleRow) GetImage() *string { return r.Image } -func (r userWithRoleRow) GetMetadata() []byte { return r.Metadata } -func (r userWithRoleRow) GetCreatedAt() time.Time { return r.CreatedAt } -func (r userWithRoleRow) GetUpdatedAt() time.Time { return r.UpdatedAt } -func (r userWithPermissionRow) GetUserID() string { return r.UserID } -func (r userWithPermissionRow) GetUserName() string { return r.UserName } -func (r userWithPermissionRow) GetUserEmail() string { return r.UserEmail } -func (r userWithPermissionRow) GetEmailVerified() bool { return r.EmailVerified } -func (r userWithPermissionRow) GetImage() *string { return r.Image } -func (r userWithPermissionRow) GetMetadata() []byte { return r.Metadata } -func (r userWithPermissionRow) GetCreatedAt() time.Time { - return r.CreatedAt -} -func (r userWithPermissionRow) GetUpdatedAt() time.Time { - return r.UpdatedAt -} diff --git a/plugins/access-control/repositories/user_access_repository_test.go b/plugins/access-control/repositories/user_access_repository_test.go deleted file mode 100644 index 57af803..0000000 --- a/plugins/access-control/repositories/user_access_repository_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package repositories - -import ( - "context" - "testing" - "testing/synctest" - "time" - - internaltests "github.com/Authula/authula/internal/tests" - "github.com/Authula/authula/plugins/access-control/types" -) - -func TestBunUserAccessRepositoryGetUserRolesIncludesExpiredWithMetadata(t *testing.T) { - db := setupRepoDB(t) - rpRepo := NewBunRolePermissionRepository(db) - uaRepo := NewBunUserAccessRepository(db) - ctx := context.Background() - - if err := rpRepo.CreateRole(ctx, &types.Role{ID: "r-active", Name: "active"}); err != nil { - t.Fatalf("failed to create active role: %v", err) - } - if err := rpRepo.CreateRole(ctx, &types.Role{ID: "r-expired", Name: "expired"}); err != nil { - t.Fatalf("failed to create expired role: %v", err) - } - - assignerID := "u2" - if err := rpRepo.AssignUserRole(ctx, "u1", "r-active", &assignerID, nil); err != nil { - t.Fatalf("failed to assign active role: %v", err) - } - if err := rpRepo.AssignUserRole(ctx, "u1", "r-expired", &assignerID, internaltests.PtrTime(time.Now().UTC().Add(-1*time.Hour))); err != nil { - t.Fatalf("failed to assign expired role: %v", err) - } - - roles, err := uaRepo.GetUserRoles(ctx, "u1") - if err != nil { - t.Fatalf("failed to get user roles: %v", err) - } - if len(roles) != 2 { - t.Fatalf("expected 2 roles including expired, got %d", len(roles)) - } - if roles[0].RoleID != "r-active" { - t.Fatalf("expected role sorted by name, got %s", roles[0].RoleID) - } - if roles[0].AssignedByUserID == nil || *roles[0].AssignedByUserID != assignerID { - t.Fatalf("expected assigned_by_user_id=%s", assignerID) - } - if roles[0].AssignedAt == nil { - t.Fatal("expected assigned_at to be populated") - } - if roles[1].RoleID != "r-expired" { - t.Fatalf("expected expired role included, got %s", roles[1].RoleID) - } - if roles[1].ExpiresAt == nil { - t.Fatal("expected expired role to include expires_at") - } -} - -func TestBunUserAccessRepositoryGetUserRolesReturnsEmptyArrayWhenNoRoles(t *testing.T) { - db := setupRepoDB(t) - uaRepo := NewBunUserAccessRepository(db) - - roles, err := uaRepo.GetUserRoles(context.Background(), "missing-user") - if err != nil { - t.Fatalf("failed to get user roles: %v", err) - } - if roles == nil { - t.Fatal("expected empty roles slice, got nil") - } - if len(roles) != 0 { - t.Fatalf("expected 0 roles, got %d", len(roles)) - } -} - -func TestBunUserAccessRepositoryGetUserEffectivePermissions(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - db := setupRepoDB(t) - rpRepo := NewBunRolePermissionRepository(db) - uaRepo := NewBunUserAccessRepository(db) - ctx := context.Background() - - if err := rpRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { - t.Fatalf("failed to create role: %v", err) - } - if err := rpRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { - t.Fatalf("failed to create second role: %v", err) - } - if err := rpRepo.CreatePermission(ctx, &types.Permission{ID: "p1", Key: "posts.read", Description: internaltests.PtrString("Read posts")}); err != nil { - t.Fatalf("failed to create permission: %v", err) - } - grantedBy := "u2" - if err := rpRepo.AddRolePermission(ctx, "r1", "p1", &grantedBy); err != nil { - t.Fatalf("failed to add role permission: %v", err) - } - time.Sleep(10 * time.Millisecond) - if err := rpRepo.AddRolePermission(ctx, "r2", "p1", &grantedBy); err != nil { - t.Fatalf("failed to add second role permission: %v", err) - } - if err := rpRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { - t.Fatalf("failed to assign role: %v", err) - } - if err := rpRepo.AssignUserRole(ctx, "u1", "r2", nil, nil); err != nil { - t.Fatalf("failed to assign second role: %v", err) - } - - perms, err := uaRepo.GetUserEffectivePermissions(ctx, "u1") - if err != nil { - t.Fatalf("failed to get effective permissions: %v", err) - } - if len(perms) != 1 { - t.Fatalf("expected 1 deduplicated permission, got %d", len(perms)) - } - if perms[0].PermissionKey != "posts.read" { - t.Fatalf("expected posts.read, got %s", perms[0].PermissionKey) - } - if perms[0].PermissionDescription == nil || *perms[0].PermissionDescription != "Read posts" { - t.Fatal("expected permission description to be populated") - } - if len(perms[0].Sources) != 2 { - t.Fatalf("expected 2 permission sources, got %d", len(perms[0].Sources)) - } - if perms[0].Sources[0].RoleName != "editor" || perms[0].Sources[1].RoleName != "viewer" { - t.Fatalf("expected deterministic source ordering by role_name, got %s then %s", perms[0].Sources[0].RoleName, perms[0].Sources[1].RoleName) - } - if perms[0].Sources[0].GrantedByUserID == nil || *perms[0].Sources[0].GrantedByUserID != grantedBy { - t.Fatal("expected source granted_by_user_id to be populated") - } - if perms[0].Sources[0].GrantedAt == nil || perms[0].Sources[1].GrantedAt == nil { - t.Fatal("expected source granted_at timestamps to be populated") - } - }) -} - -func TestBunUserAccessRepositoryGetUserEffectivePermissionsReturnsEmptyArrayWhenNoPermissions(t *testing.T) { - db := setupRepoDB(t) - uaRepo := NewBunUserAccessRepository(db) - - perms, err := uaRepo.GetUserEffectivePermissions(context.Background(), "missing-user") - if err != nil { - t.Fatalf("failed to get effective permissions: %v", err) - } - if perms == nil { - t.Fatal("expected empty permissions slice, got nil") - } - if len(perms) != 0 { - t.Fatalf("expected 0 permissions, got %d", len(perms)) - } -} diff --git a/plugins/access-control/repositories/user_permissions_repository.go b/plugins/access-control/repositories/user_permissions_repository.go new file mode 100644 index 0000000..198e3c1 --- /dev/null +++ b/plugins/access-control/repositories/user_permissions_repository.go @@ -0,0 +1,112 @@ +package repositories + +import ( + "context" + "fmt" + "time" + + "github.com/Authula/authula/plugins/access-control/types" + "github.com/uptrace/bun" +) + +type BunUserPermissionsRepository struct { + db bun.IDB +} + +func NewBunUserPermissionsRepository(db bun.IDB) *BunUserPermissionsRepository { + return &BunUserPermissionsRepository{db: db} +} + +type userPermissionGrantRow struct { + PermissionID string `bun:"permission_id"` + PermissionKey string `bun:"permission_key"` + PermissionDescription *string `bun:"permission_description"` + GrantedByUserID *string `bun:"granted_by_user_id"` + GrantedAt *time.Time `bun:"granted_at"` + RoleID string `bun:"role_id"` + RoleName string `bun:"role_name"` +} + +func (r *BunUserPermissionsRepository) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + var scanned []userPermissionGrantRow + err := r.db.NewSelect(). + TableExpr("access_control_user_roles acur"). + ColumnExpr("ap.id AS permission_id"). + ColumnExpr("ap.key AS permission_key"). + ColumnExpr("ap.description AS permission_description"). + ColumnExpr("arp.granted_by_user_id AS granted_by_user_id"). + ColumnExpr("arp.granted_at AS granted_at"). + ColumnExpr("acr.id AS role_id"). + ColumnExpr("acr.name AS role_name"). + Join("JOIN access_control_roles acr ON acr.id = acur.role_id"). + Join("JOIN access_control_role_permissions arp ON arp.role_id = acr.id"). + Join("JOIN access_control_permissions ap ON ap.id = arp.permission_id"). + Where("acur.user_id = ?", userID). + Where("(acur.expires_at IS NULL OR acur.expires_at > CURRENT_TIMESTAMP)"). + OrderExpr("ap.key ASC, acr.name ASC, arp.granted_at ASC"). + Scan(ctx, &scanned) + if err != nil { + return nil, fmt.Errorf("failed to get user permissions: %w", err) + } + + if len(scanned) == 0 { + return []types.UserPermissionInfo{}, nil + } + + permissionsByID := make(map[string]*types.UserPermissionInfo, len(scanned)) + orderedPermissionIDs := make([]string, 0, len(scanned)) + + for _, row := range scanned { + permission, exists := permissionsByID[row.PermissionID] + if !exists { + permission = &types.UserPermissionInfo{ + PermissionID: row.PermissionID, + PermissionKey: row.PermissionKey, + PermissionDescription: row.PermissionDescription, + GrantedByUserID: row.GrantedByUserID, + GrantedAt: row.GrantedAt, + Sources: []types.PermissionGrantSource{}, + } + permissionsByID[row.PermissionID] = permission + orderedPermissionIDs = append(orderedPermissionIDs, row.PermissionID) + } + + permission.Sources = append(permission.Sources, types.PermissionGrantSource{ + RoleID: row.RoleID, + RoleName: row.RoleName, + GrantedByUserID: row.GrantedByUserID, + GrantedAt: row.GrantedAt, + }) + } + + permissions := make([]types.UserPermissionInfo, 0, len(orderedPermissionIDs)) + for _, permissionID := range orderedPermissionIDs { + permissions = append(permissions, *permissionsByID[permissionID]) + } + + return permissions, nil +} + +func (r *BunUserPermissionsRepository) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + if len(permissionKeys) == 0 { + return true, nil + } + + permissions, err := r.GetUserPermissions(ctx, userID) + if err != nil { + return false, err + } + + granted := make(map[string]struct{}, len(permissions)) + for _, permission := range permissions { + granted[permission.PermissionKey] = struct{}{} + } + + for _, permissionKey := range permissionKeys { + if _, ok := granted[permissionKey]; !ok { + return false, nil + } + } + + return true, nil +} diff --git a/plugins/access-control/repositories/user_permissions_repository_test.go b/plugins/access-control/repositories/user_permissions_repository_test.go new file mode 100644 index 0000000..2fbfee2 --- /dev/null +++ b/plugins/access-control/repositories/user_permissions_repository_test.go @@ -0,0 +1,209 @@ +package repositories + +import ( + "context" + "testing" + + plugintests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestBunUserPermissionsRepositoryGetUserPermissions(t *testing.T) { + t.Parallel() + + description := new(string) + *description = "Read users" + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) + userID string + wantEmpty bool + wantKeys []string + wantSrcs []int + }{ + { + name: "empty result", + userID: "missing-user", + wantEmpty: true, + }, + { + name: "aggregates permissions across roles and ignores expired roles", + userID: "u1", + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-2", Name: "viewer"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read", Description: description}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-2", Key: "users.write"}); err != nil { + panic(err) + } + grantedBy := "u2" + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", &grantedBy); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-2", "perm-1", nil); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-2", "perm-2", &grantedBy); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-1", &grantedBy, nil); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-2", nil, nil); err != nil { + panic(err) + } + }, + wantKeys: []string{"users.read", "users.write"}, + wantSrcs: []int{2, 1}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + repo := NewBunUserPermissionsRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) + } + + permissions, err := repo.GetUserPermissions(ctx, tc.userID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if permissions == nil { + t.Fatal("expected permissions slice, got nil") + } + if tc.wantEmpty { + if len(permissions) != 0 { + t.Fatalf("expected no permissions, got %#v", permissions) + } + return + } + if len(permissions) != len(tc.wantKeys) { + t.Fatalf("expected %d permissions, got %d", len(tc.wantKeys), len(permissions)) + } + for i := range tc.wantKeys { + if permissions[i].PermissionKey != tc.wantKeys[i] { + t.Fatalf("unexpected permission key at %d: %#v", i, permissions[i]) + } + if permissions[i].PermissionKey == "users.read" && !sameStringPtr(permissions[i].PermissionDescription, description) { + t.Fatalf("unexpected permission description at %d: %#v", i, permissions[i]) + } + if permissions[i].GrantedAt == nil { + t.Fatalf("expected granted_at to be populated at %d", i) + } + if len(permissions[i].Sources) != tc.wantSrcs[i] { + t.Fatalf("expected %d sources at %d, got %#v", tc.wantSrcs[i], i, permissions[i]) + } + } + }) + } +} + +func TestBunUserPermissionsRepositoryHasPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunPermissionsRepository, *BunRolePermissionsRepository, *BunUserRolesRepository, context.Context) + userID string + permissionKeys []string + wantHasPerms bool + }{ + { + name: "empty permission list returns true", + userID: "u1", + permissionKeys: []string{}, + wantHasPerms: true, + }, + { + name: "missing permission returns false", + userID: "u1", + permissionKeys: []string{"users.read", "users.delete"}, + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", nil); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-1", nil, nil); err != nil { + panic(err) + } + }, + wantHasPerms: false, + }, + { + name: "success", + userID: "u1", + permissionKeys: []string{"users.read"}, + seed: func(rolesRepo *BunRolesRepository, permissionsRepo *BunPermissionsRepository, rolePermissionsRepo *BunRolePermissionsRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "role-1", Name: "editor"}); err != nil { + panic(err) + } + if err := permissionsRepo.CreatePermission(ctx, &types.Permission{ID: "perm-1", Key: "users.read"}); err != nil { + panic(err) + } + grantedBy := "u2" + if err := rolePermissionsRepo.AddRolePermission(ctx, "role-1", "perm-1", &grantedBy); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "role-1", nil, nil); err != nil { + panic(err) + } + }, + wantHasPerms: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + permissionsRepo := NewBunPermissionsRepository(db) + rolePermissionsRepo := NewBunRolePermissionsRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + repo := NewBunUserPermissionsRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, permissionsRepo, rolePermissionsRepo, userRolesRepo, ctx) + } + + hasPermissions, err := repo.HasPermissions(ctx, tc.userID, tc.permissionKeys) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if hasPermissions != tc.wantHasPerms { + t.Fatalf("expected hasPermissions=%v, got %v", tc.wantHasPerms, hasPermissions) + } + }) + } +} + +func sameStringPtr(got, want *string) bool { + if got == nil || want == nil { + return got == nil && want == nil + } + return *got == *want +} diff --git a/plugins/access-control/repositories/user_roles_repository.go b/plugins/access-control/repositories/user_roles_repository.go new file mode 100644 index 0000000..f3fb249 --- /dev/null +++ b/plugins/access-control/repositories/user_roles_repository.go @@ -0,0 +1,105 @@ +package repositories + +import ( + "context" + "fmt" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/access-control/types" +) + +type BunUserRolesRepository struct { + db bun.IDB +} + +func NewBunUserRolesRepository(db bun.IDB) *BunUserRolesRepository { + return &BunUserRolesRepository{db: db} +} + +func (r *BunUserRolesRepository) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + var rows []types.UserRoleInfo + err := r.db.NewSelect(). + TableExpr("access_control_user_roles acur"). + ColumnExpr("acur.role_id AS role_id"). + ColumnExpr("acr.name AS role_name"). + ColumnExpr("acr.description AS role_description"). + ColumnExpr("acur.assigned_by_user_id AS assigned_by_user_id"). + ColumnExpr("acur.assigned_at AS assigned_at"). + ColumnExpr("acur.expires_at AS expires_at"). + Join("JOIN access_control_roles acr ON acr.id = acur.role_id"). + Where("acur.user_id = ?", userID). + OrderExpr("acr.name ASC, acur.assigned_at DESC"). + Scan(ctx, &rows) + if err != nil { + return nil, fmt.Errorf("failed to get user roles: %w", err) + } + if rows == nil { + return []types.UserRoleInfo{}, nil + } + return rows, nil +} + +func (r *BunUserRolesRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { + return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx.NewDelete().Model((*types.UserRole)(nil)).Where("user_id = ?", userID).Exec(ctx); err != nil { + return fmt.Errorf("failed to clear user roles: %w", err) + } + + now := time.Now().UTC() + for _, roleID := range roleIDs { + ur := &types.UserRole{ + UserID: userID, + RoleID: roleID, + AssignedByUserID: assignedByUserID, + AssignedAt: now, + } + if _, err := tx.NewInsert().Model(ur).Exec(ctx); err != nil { + return err + } + } + + return nil + }) +} + +func (r *BunUserRolesRepository) AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error { + ur := &types.UserRole{ + UserID: userID, + RoleID: roleID, + AssignedByUserID: assignedByUserID, + AssignedAt: time.Now().UTC(), + ExpiresAt: expiresAt, + } + + _, err := r.db.NewInsert().Model(ur).Exec(ctx) + if err != nil { + return err + } + return nil +} + +func (r *BunUserRolesRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { + _, err := r.db.NewDelete(). + Model((*types.UserRole)(nil)). + Where("user_id = ?", userID). + Where("role_id = ?", roleID). + Exec(ctx) + if err != nil { + return err + } + + return nil +} + +func (r *BunUserRolesRepository) CountUsersByRole(ctx context.Context, roleID string) (int, error) { + count, err := r.db.NewSelect(). + Model((*types.UserRole)(nil)). + Where("role_id = ?", roleID). + Count(ctx) + if err != nil { + return 0, fmt.Errorf("failed to count users by role: %w", err) + } + return count, nil +} diff --git a/plugins/access-control/repositories/user_roles_repository_test.go b/plugins/access-control/repositories/user_roles_repository_test.go new file mode 100644 index 0000000..2858361 --- /dev/null +++ b/plugins/access-control/repositories/user_roles_repository_test.go @@ -0,0 +1,318 @@ +package repositories + +import ( + "context" + "reflect" + "strings" + "testing" + "time" + + plugintests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestBunUserRolesRepositoryGetUserRoles(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + futureExpiry := time.Unix(now.Add(24*time.Hour).Unix(), 0).UTC() + roleDescription := new("Editor role") + assignedBy := new("u2") + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + wantRoles []types.UserRoleInfo + }{ + { + name: "empty result", + userID: "missing-user", + wantRoles: []types.UserRoleInfo{}, + }, + { + name: "returns assigned roles ordered by role name", + userID: "u1", + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor", Description: roleDescription}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", assignedBy, &futureExpiry); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r2", nil, nil); err != nil { + panic(err) + } + }, + wantRoles: []types.UserRoleInfo{ + { + RoleID: "r1", + RoleName: "editor", + RoleDescription: roleDescription, + AssignedByUserID: assignedBy, + ExpiresAt: &futureExpiry, + }, + { + RoleID: "r2", + RoleName: "viewer", + RoleDescription: nil, + AssignedByUserID: nil, + ExpiresAt: nil, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + roles, err := userRolesRepo.GetUserRoles(ctx, tc.userID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if roles == nil { + t.Fatal("expected roles slice, got nil") + } + if len(roles) != len(tc.wantRoles) { + t.Fatalf("expected %d roles, got %d", len(tc.wantRoles), len(roles)) + } + for i := range tc.wantRoles { + if roles[i].RoleID != tc.wantRoles[i].RoleID || roles[i].RoleName != tc.wantRoles[i].RoleName { + t.Fatalf("unexpected role at %d: %#v", i, roles[i]) + } + if !reflect.DeepEqual(roles[i].RoleDescription, tc.wantRoles[i].RoleDescription) { + t.Fatalf("unexpected role description at %d: %#v", i, roles[i]) + } + if !reflect.DeepEqual(roles[i].AssignedByUserID, tc.wantRoles[i].AssignedByUserID) || !reflect.DeepEqual(roles[i].ExpiresAt, tc.wantRoles[i].ExpiresAt) { + t.Fatalf("unexpected assignment metadata at %d: %#v", i, roles[i]) + } + if roles[i].AssignedAt == nil { + t.Fatalf("expected assigned_at to be populated at %d", i) + } + } + }) + } +} + +func TestBunUserRolesRepositoryReplaceUserRoles(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + roleIDs []string + wantRoleIDs []string + }{ + { + name: "replaces all roles", + userID: "u1", + roleIDs: []string{"r2", "r1"}, + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r2", Name: "viewer"}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { + panic(err) + } + }, + wantRoleIDs: []string{"r1", "r2"}, + }, + { + name: "empty list clears roles", + userID: "u1", + roleIDs: []string{}, + wantRoleIDs: []string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + if err := userRolesRepo.ReplaceUserRoles(ctx, tc.userID, tc.roleIDs, nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + roles, err := userRolesRepo.GetUserRoles(ctx, tc.userID) + if err != nil { + t.Fatalf("failed to fetch roles: %v", err) + } + if len(roles) != len(tc.wantRoleIDs) { + t.Fatalf("expected %d roles, got %d", len(tc.wantRoleIDs), len(roles)) + } + for i, wantRoleID := range tc.wantRoleIDs { + if roles[i].RoleID != wantRoleID { + t.Fatalf("expected role %s at index %d, got %#v", wantRoleID, i, roles[i]) + } + } + }) + } +} + +func TestBunUserRolesRepositoryAssignUserRole(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + futureExpiry := time.Unix(now.Add(24*time.Hour).Unix(), 0).UTC() + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + roleID string + expiresAt *time.Time + wantErr error + }{ + { + name: "success", + userID: "u1", + roleID: "r1", + expiresAt: &futureExpiry, + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + }, + }, + { + name: "duplicate assignment returns conflict", + userID: "u1", + roleID: "r1", + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { + panic(err) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + err := userRolesRepo.AssignUserRole(ctx, tc.userID, tc.roleID, nil, tc.expiresAt) + if tc.name == "duplicate assignment returns conflict" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "UNIQUE constraint failed: access_control_user_roles.user_id, access_control_user_roles.role_id") { + t.Fatalf("expected raw unique constraint error, got %v", err) + } + return + } + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.wantErr != nil { + return + } + + roles, err := userRolesRepo.GetUserRoles(ctx, tc.userID) + if err != nil { + t.Fatalf("failed to fetch assigned role: %v", err) + } + if len(roles) != 1 || roles[0].RoleID != tc.roleID || roles[0].RoleName != "editor" { + t.Fatalf("unexpected roles after assign: %#v", roles) + } + if roles[0].AssignedAt == nil { + t.Fatal("expected assigned_at to be populated") + } + if tc.expiresAt != nil && !reflect.DeepEqual(roles[0].ExpiresAt, tc.expiresAt) { + t.Fatalf("unexpected expiry: %#v", roles[0]) + } + }) + } +} + +func TestBunUserRolesRepositoryRemoveUserRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + seed func(*BunRolesRepository, *BunUserRolesRepository, context.Context) + userID string + roleID string + wantRoles []types.UserRoleInfo + }{ + { + name: "success", + userID: "u1", + roleID: "r1", + seed: func(rolesRepo *BunRolesRepository, userRolesRepo *BunUserRolesRepository, ctx context.Context) { + if err := rolesRepo.CreateRole(ctx, &types.Role{ID: "r1", Name: "editor"}); err != nil { + panic(err) + } + if err := userRolesRepo.AssignUserRole(ctx, "u1", "r1", nil, nil); err != nil { + panic(err) + } + }, + wantRoles: []types.UserRoleInfo{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db := plugintests.SetupRepoDB(t) + rolesRepo := NewBunRolesRepository(db) + userRolesRepo := NewBunUserRolesRepository(db) + ctx := context.Background() + + if tc.seed != nil { + tc.seed(rolesRepo, userRolesRepo, ctx) + } + + if err := userRolesRepo.RemoveUserRole(ctx, tc.userID, tc.roleID); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + roles, err := userRolesRepo.GetUserRoles(ctx, tc.userID) + if err != nil { + t.Fatalf("failed to fetch roles: %v", err) + } + if len(roles) != len(tc.wantRoles) { + t.Fatalf("expected %d roles after remove, got %#v", len(tc.wantRoles), roles) + } + }) + } +} diff --git a/plugins/access-control/routes.go b/plugins/access-control/routes.go index 47d7e58..f1ee6d4 100644 --- a/plugins/access-control/routes.go +++ b/plugins/access-control/routes.go @@ -9,14 +9,20 @@ import ( ) type routeUseCases struct { - rolePermission usecases.RolePermissionUseCase - userAccess usecases.UserRolesUseCase + roles *usecases.RolesUseCase + permissions *usecases.PermissionsUseCase + rolePermissions *usecases.RolePermissionsUseCase + userRoles *usecases.UserRolesUseCase + userPermissions *usecases.UserPermissionsUseCase } func newRouteUseCases(api *API) routeUseCases { return routeUseCases{ - rolePermission: api.useCases.RolePermissionUseCase(), - userAccess: api.useCases.UserAccessUseCase(), + roles: api.useCases.RolesUseCase(), + permissions: api.useCases.PermissionsUseCase(), + rolePermissions: api.useCases.RolePermissionsUseCase(), + userRoles: api.useCases.UserRolesUseCase(), + userPermissions: api.useCases.UserPermissionsUseCase(), } } @@ -24,26 +30,35 @@ func Routes(api *API) []models.Route { usecases := newRouteUseCases(api) return []models.Route{ - // Roles and permissions - {Method: http.MethodPost, Path: "/access-control/roles", Handler: handlers.NewCreateRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/roles", Handler: handlers.NewGetAllRolesHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/roles/{role_id}", Handler: handlers.NewGetRoleByIDHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPatch, Path: "/access-control/roles/{role_id}", Handler: handlers.NewUpdateRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}", Handler: handlers.NewDeleteRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPost, Path: "/access-control/permissions", Handler: handlers.NewCreatePermissionHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/permissions", Handler: handlers.NewGetAllPermissionsHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPatch, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewUpdatePermissionHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodDelete, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewDeletePermissionHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPost, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewAddRolePermissionHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewGetRolePermissionsHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPut, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewReplaceRolePermissionsHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}/permissions/{permission_id}", Handler: handlers.NewRemoveRolePermissionHandler(usecases.rolePermission).Handler()}, - - // User roles and permissions - {Method: http.MethodGet, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewGetUserRolesHandler(usecases.userAccess).Handler()}, - {Method: http.MethodPost, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewAssignUserRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodPut, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewReplaceUserRolesHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodDelete, Path: "/access-control/users/{user_id}/roles/{role_id}", Handler: handlers.NewRemoveUserRoleHandler(usecases.rolePermission).Handler()}, - {Method: http.MethodGet, Path: "/access-control/users/{user_id}/permissions", Handler: handlers.NewGetUserEffectivePermissionsHandler(usecases.userAccess).Handler()}, + // Roles + {Method: http.MethodPost, Path: "/access-control/roles", Handler: handlers.NewCreateRoleHandler(usecases.roles).Handler()}, + {Method: http.MethodGet, Path: "/access-control/roles", Handler: handlers.NewGetAllRolesHandler(usecases.roles).Handler()}, + {Method: http.MethodGet, Path: "/access-control/roles/by-name/{role_name}", Handler: handlers.NewGetRoleByNameHandler(usecases.roles).Handler()}, + {Method: http.MethodGet, Path: "/access-control/roles/{role_id}", Handler: handlers.NewGetRoleByIDHandler(usecases.roles).Handler()}, + {Method: http.MethodPatch, Path: "/access-control/roles/{role_id}", Handler: handlers.NewUpdateRoleHandler(usecases.roles).Handler()}, + {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}", Handler: handlers.NewDeleteRoleHandler(usecases.roles).Handler()}, + + // Permissions + {Method: http.MethodPost, Path: "/access-control/permissions", Handler: handlers.NewCreatePermissionHandler(usecases.permissions).Handler()}, + {Method: http.MethodGet, Path: "/access-control/permissions", Handler: handlers.NewGetAllPermissionsHandler(usecases.permissions).Handler()}, + {Method: http.MethodGet, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewGetPermissionByIDHandler(usecases.permissions).Handler()}, + {Method: http.MethodPatch, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewUpdatePermissionHandler(usecases.permissions).Handler()}, + {Method: http.MethodDelete, Path: "/access-control/permissions/{permission_id}", Handler: handlers.NewDeletePermissionHandler(usecases.permissions).Handler()}, + + // Role permissions + {Method: http.MethodPost, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewAddRolePermissionHandler(usecases.rolePermissions).Handler()}, + {Method: http.MethodGet, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewGetRolePermissionsHandler(usecases.rolePermissions).Handler()}, + {Method: http.MethodPut, Path: "/access-control/roles/{role_id}/permissions", Handler: handlers.NewReplaceRolePermissionsHandler(usecases.rolePermissions).Handler()}, + {Method: http.MethodDelete, Path: "/access-control/roles/{role_id}/permissions/{permission_id}", Handler: handlers.NewRemoveRolePermissionHandler(usecases.rolePermissions).Handler()}, + + // User roles + {Method: http.MethodGet, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewGetUserRolesHandler(usecases.userRoles).Handler()}, + {Method: http.MethodPut, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewReplaceUserRolesHandler(usecases.userRoles).Handler()}, + {Method: http.MethodPost, Path: "/access-control/users/{user_id}/roles", Handler: handlers.NewAssignUserRoleHandler(usecases.userRoles).Handler()}, + {Method: http.MethodDelete, Path: "/access-control/users/{user_id}/roles/{role_id}", Handler: handlers.NewRemoveUserRoleHandler(usecases.userRoles).Handler()}, + + // User permissions + {Method: http.MethodGet, Path: "/access-control/users/{user_id}/permissions", Handler: handlers.NewGetUserPermissionsHandler(usecases.userPermissions).Handler()}, + {Method: http.MethodPost, Path: "/access-control/users/{user_id}/permissions/check", Handler: handlers.NewCheckUserPermissionsHandler(usecases.userPermissions).Handler()}, } } diff --git a/plugins/access-control/services/permissions_service.go b/plugins/access-control/services/permissions_service.go new file mode 100644 index 0000000..9d910ad --- /dev/null +++ b/plugins/access-control/services/permissions_service.go @@ -0,0 +1,157 @@ +package services + +import ( + "context" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/repositories" + "github.com/Authula/authula/plugins/access-control/types" +) + +type PermissionsService struct { + permissionsRepo repositories.PermissionsRepository + rolePermissionsRepo repositories.RolePermissionsRepository +} + +func NewPermissionsService(permissionsRepo repositories.PermissionsRepository, rolePermissionsRepo repositories.RolePermissionsRepository) *PermissionsService { + return &PermissionsService{permissionsRepo: permissionsRepo, rolePermissionsRepo: rolePermissionsRepo} +} + +func (s *PermissionsService) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { + if req.Key == "" { + return nil, constants.ErrBadRequest + } + + var description *string + if req.Description != nil { + description = req.Description + } + + permission := &types.Permission{ + ID: util.GenerateUUID(), + Key: req.Key, + Description: description, + IsSystem: req.IsSystem, + } + + if err := s.permissionsRepo.CreatePermission(ctx, permission); err != nil { + return nil, err + } + + return permission, nil +} + +func (s *PermissionsService) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { + return s.permissionsRepo.GetAllPermissions(ctx) +} + +func (s *PermissionsService) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + if permissionID == "" { + return nil, constants.ErrBadRequest + } + + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return nil, err + } + if permission == nil { + return nil, constants.ErrNotFound + } + + return permission, nil +} + +func (s *PermissionsService) GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) { + if permissionKey == "" { + return nil, constants.ErrBadRequest + } + + permission, err := s.permissionsRepo.GetPermissionByKey(ctx, permissionKey) + if err != nil { + return nil, err + } + if permission == nil { + return nil, constants.ErrNotFound + } + + return permission, nil +} + +func (s *PermissionsService) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { + if permissionID == "" { + return nil, constants.ErrUnprocessableEntity + } + if req.Description == nil { + return nil, constants.ErrUnprocessableEntity + } + + description := *req.Description + if description == "" { + return nil, constants.ErrUnprocessableEntity + } + + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return nil, err + } + if permission == nil { + return nil, constants.ErrNotFound + } + if permission.IsSystem { + return nil, constants.ErrBadRequest + } + + updated, err := s.permissionsRepo.UpdatePermission(ctx, permissionID, &description) + if err != nil { + return nil, err + } + if !updated { + return nil, constants.ErrNotFound + } + + permission, err = s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return nil, err + } + if permission == nil { + return nil, constants.ErrNotFound + } + + return permission, nil +} + +func (s *PermissionsService) DeletePermission(ctx context.Context, permissionID string) error { + if permissionID == "" { + return constants.ErrBadRequest + } + + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return err + } + if permission == nil { + return constants.ErrNotFound + } + if permission.IsSystem { + return constants.ErrBadRequest + } + + totalCountOfRolesByPermission, err := s.rolePermissionsRepo.CountRolesByPermission(ctx, permissionID) + if err != nil { + return err + } + if totalCountOfRolesByPermission > 0 { + return constants.ErrConflict + } + + deleted, err := s.permissionsRepo.DeletePermission(ctx, permissionID) + if err != nil { + return err + } + if !deleted { + return constants.ErrNotFound + } + + return nil +} diff --git a/plugins/access-control/services/permissions_service_test.go b/plugins/access-control/services/permissions_service_test.go new file mode 100644 index 0000000..c78921e --- /dev/null +++ b/plugins/access-control/services/permissions_service_test.go @@ -0,0 +1,402 @@ +package services + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestPermissionsServiceCreatePermission(t *testing.T) { + t.Parallel() + + description := "Read users" + + tests := []struct { + name string + req types.CreatePermissionRequest + setup func(*accesscontroltests.MockPermissionsRepository) + wantErr error + assert func(*testing.T, *types.Permission) + }{ + { + name: "blank key", + req: types.CreatePermissionRequest{Key: ""}, + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "success", + req: types.CreatePermissionRequest{ + Key: "users.read", + Description: &description, + IsSystem: true, + }, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("CreatePermission", mock.Anything, mock.MatchedBy(func(permission *types.Permission) bool { + return permission != nil && permission.ID != "" && permission.Key == "users.read" && permission.IsSystem && permission.Description != nil && *permission.Description == description + })).Return(nil).Once() + }, + assert: func(t *testing.T, permission *types.Permission) { + if permission == nil { + t.Fatal("expected permission, got nil") + } + if permission.ID == "" { + t.Fatal("expected generated ID") + } + if permission.Key != "users.read" { + t.Fatalf("expected key %q, got %q", "users.read", permission.Key) + } + if permission.Description == nil || *permission.Description != description { + t.Fatalf("expected description %q, got %#v", description, permission.Description) + } + if !permission.IsSystem { + t.Fatal("expected system permission flag to be preserved") + } + }, + }, + { + name: "repository error is returned", + req: types.CreatePermissionRequest{Key: "users.write"}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("CreatePermission", mock.Anything, mock.AnythingOfType("*types.Permission")).Return(errors.New("boom")).Once() + }, + wantErr: errors.New("boom"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + _ = rolePermissionsRepo + if tc.setup != nil { + tc.setup(permissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) + permission, err := service.CreatePermission(context.Background(), tc.req) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + } else if err == nil || err.Error() != tc.wantErr.Error() { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + if tc.assert != nil { + tc.assert(t, permission) + } + + permissionsRepo.AssertExpectations(t) + }) + } +} + +func TestPermissionsServiceGetAllPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(*accesscontroltests.MockPermissionsRepository) + want []types.Permission + wantErr error + }{ + { + name: "success", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetAllPermissions", mock.Anything).Return([]types.Permission{{ID: "perm-1", Key: "users.read"}}, nil).Once() + }, + want: []types.Permission{{ID: "perm-1", Key: "users.read"}}, + }, + { + name: "error", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetAllPermissions", mock.Anything).Return(nil, errors.New("boom")).Once() + }, + wantErr: errors.New("boom"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + _ = rolePermissionsRepo + if tc.setup != nil { + tc.setup(permissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) + permissions, err := service.GetAllPermissions(context.Background()) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if len(permissions) != len(tc.want) || (len(permissions) == 1 && permissions[0].ID != tc.want[0].ID) { + t.Fatalf("unexpected permissions %#v", permissions) + } + } else if err == nil || err.Error() != tc.wantErr.Error() { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + permissionsRepo.AssertExpectations(t) + }) + } +} + +func TestPermissionsServiceGetPermissionByID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id string + setup func(*accesscontroltests.MockPermissionsRepository) + wantID string + wantErr error + }{ + { + name: "blank id", + id: "", + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "not found", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "success", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + }, + wantID: "perm-1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + _ = rolePermissionsRepo + if tc.setup != nil { + tc.setup(permissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) + permission, err := service.GetPermissionByID(context.Background(), tc.id) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if permission == nil || permission.ID != tc.wantID { + t.Fatalf("unexpected permission %#v", permission) + } + } else if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + permissionsRepo.AssertExpectations(t) + }) + } +} + +func TestPermissionsServiceUpdatePermission(t *testing.T) { + t.Parallel() + + updatedDescription := "Updated description" + + tests := []struct { + name string + id string + req types.UpdatePermissionRequest + setup func(*accesscontroltests.MockPermissionsRepository) + wantID string + wantErr error + }{ + { + name: "blank id", + id: "", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + wantErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "nil description", + id: "perm-1", + req: types.UpdatePermissionRequest{}, + wantErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "blank description", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: func() *string { value := ""; return &value }()}, + wantErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "not found", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "system permission", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: true}, nil).Once() + }, + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "update returns false", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + permissionsRepo.On("UpdatePermission", mock.Anything, "perm-1", &updatedDescription).Return(false, nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "success", + id: "perm-1", + req: types.UpdatePermissionRequest{Description: &updatedDescription}, + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + permissionsRepo.On("UpdatePermission", mock.Anything, "perm-1", &updatedDescription).Return(true, nil).Once() + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", Description: &updatedDescription}, nil).Once() + }, + wantID: "perm-1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setup != nil { + tc.setup(permissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) + permission, err := service.UpdatePermission(context.Background(), tc.id, tc.req) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if permission == nil || permission.ID != tc.wantID { + t.Fatalf("unexpected permission %#v", permission) + } + } else if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + permissionsRepo.AssertExpectations(t) + }) + } +} + +func TestPermissionsServiceDeletePermission(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id string + setup func(*accesscontroltests.MockPermissionsRepository, *accesscontroltests.MockRolePermissionsRepository) + wantErr error + }{ + { + name: "blank id", + id: "", + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "not found", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return((*types.Permission)(nil), nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "system permission", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read", IsSystem: true}, nil).Once() + }, + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "permission in use", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(2, nil).Once() + }, + wantErr: accesscontrolconstants.ErrConflict, + }, + { + name: "delete returns false", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(0, nil).Once() + permissionsRepo.On("DeletePermission", mock.Anything, "perm-1").Return(false, nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "success", + id: "perm-1", + setup: func(permissionsRepo *accesscontroltests.MockPermissionsRepository, rolePermissionsRepo *accesscontroltests.MockRolePermissionsRepository) { + permissionsRepo.On("GetPermissionByID", mock.Anything, "perm-1").Return(&types.Permission{ID: "perm-1", Key: "users.read"}, nil).Once() + rolePermissionsRepo.On("CountRolesByPermission", mock.Anything, "perm-1").Return(0, nil).Once() + permissionsRepo.On("DeletePermission", mock.Anything, "perm-1").Return(true, nil).Once() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + permissionsRepo := &accesscontroltests.MockPermissionsRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + if tc.setup != nil { + tc.setup(permissionsRepo, rolePermissionsRepo) + } + + service := NewPermissionsService(permissionsRepo, rolePermissionsRepo) + err := service.DeletePermission(context.Background(), tc.id) + if tc.wantErr == nil { + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + } else if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + permissionsRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } +} diff --git a/plugins/access-control/services/role_permission_service.go b/plugins/access-control/services/role_permission_service.go index a269a13..c135965 100644 --- a/plugins/access-control/services/role_permission_service.go +++ b/plugins/access-control/services/role_permission_service.go @@ -2,114 +2,28 @@ package services import ( "context" - "strings" - "time" - "github.com/Authula/authula/internal/util" "github.com/Authula/authula/plugins/access-control/constants" "github.com/Authula/authula/plugins/access-control/repositories" "github.com/Authula/authula/plugins/access-control/types" ) -type RolePermissionService struct { - repo repositories.RolePermissionRepository +type RolePermissionsService struct { + rolesRepo repositories.RolesRepository + permissionsRepo repositories.PermissionsRepository + rolePermissionsRepo repositories.RolePermissionsRepository } -func NewRolePermissionService(repo repositories.RolePermissionRepository) *RolePermissionService { - return &RolePermissionService{repo: repo} +func NewRolePermissionsService(rolesRepo repositories.RolesRepository, permissionsRepo repositories.PermissionsRepository, rolePermissionsRepo repositories.RolePermissionsRepository) *RolePermissionsService { + return &RolePermissionsService{rolesRepo: rolesRepo, permissionsRepo: permissionsRepo, rolePermissionsRepo: rolePermissionsRepo} } -func (s *RolePermissionService) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { - name := strings.TrimSpace(req.Name) - if name == "" { - return nil, constants.ErrBadRequest - } - - role := &types.Role{ - ID: util.GenerateUUID(), - Name: name, - Description: req.Description, - IsSystem: req.IsSystem, - } - - if err := s.repo.CreateRole(ctx, role); err != nil { - return nil, err - } - - return role, nil -} - -func (s *RolePermissionService) GetAllRoles(ctx context.Context) ([]types.Role, error) { - return s.repo.GetAllRoles(ctx) -} - -func (s *RolePermissionService) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { - roleID = strings.TrimSpace(roleID) +func (s *RolePermissionsService) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { if roleID == "" { - return nil, constants.ErrBadRequest - } - - role, err := s.repo.GetRoleByID(ctx, roleID) - if err != nil { - return nil, err - } - if role == nil { - return nil, constants.ErrNotFound - } - - permissions, err := s.repo.GetRolePermissions(ctx, roleID) - if err != nil { - return nil, err - } - - return &types.RoleDetails{Role: *role, Permissions: permissions}, nil -} - -func (s *RolePermissionService) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { - roleID = strings.TrimSpace(roleID) - if roleID == "" { - return nil, constants.ErrBadRequest - } - - if req.Name == nil && req.Description == nil { return nil, constants.ErrUnprocessableEntity } - role, err := s.repo.GetRoleByID(ctx, roleID) - if err != nil { - return nil, err - } - if role == nil { - return nil, constants.ErrNotFound - } - if role.IsSystem { - return nil, constants.ErrCannotUpdateSystemRole - } - - var name *string - if req.Name != nil { - trimmed := strings.TrimSpace(*req.Name) - if trimmed == "" { - return nil, constants.ErrBadRequest - } - name = &trimmed - } - - var description *string - if req.Description != nil { - trimmed := strings.TrimSpace(*req.Description) - description = &trimmed - } - - updated, err := s.repo.UpdateRole(ctx, roleID, name, description) - if err != nil { - return nil, err - } - if !updated { - return nil, constants.ErrNotFound - } - - role, err = s.repo.GetRoleByID(ctx, roleID) + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) if err != nil { return nil, err } @@ -117,16 +31,18 @@ func (s *RolePermissionService) UpdateRole(ctx context.Context, roleID string, r return nil, constants.ErrNotFound } - return role, nil + return s.rolePermissionsRepo.GetRolePermissions(ctx, roleID) } -func (s *RolePermissionService) DeleteRole(ctx context.Context, roleID string) error { - roleID = strings.TrimSpace(roleID) +func (s *RolePermissionsService) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { if roleID == "" { return constants.ErrBadRequest } + if permissionID == "" { + return constants.ErrBadRequest + } - role, err := s.repo.GetRoleByID(ctx, roleID) + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) if err != nil { return err } @@ -134,120 +50,10 @@ func (s *RolePermissionService) DeleteRole(ctx context.Context, roleID string) e return constants.ErrNotFound } if role.IsSystem { - return constants.ErrCannotUpdateSystemRole - } - - assignmentsCount, err := s.repo.CountUserAssignmentsByRoleID(ctx, roleID) - if err != nil { - return err - } - if assignmentsCount > 0 { - return constants.ErrConflict - } - - deleted, err := s.repo.DeleteRole(ctx, roleID) - if err != nil { - return err - } - if !deleted { - return constants.ErrNotFound - } - - return nil -} - -func (s *RolePermissionService) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { - key := strings.TrimSpace(req.Key) - if key == "" { - return nil, constants.ErrBadRequest - } - - permission := &types.Permission{ - ID: util.GenerateUUID(), - Key: key, - Description: req.Description, - IsSystem: req.IsSystem, - } - - if err := s.repo.CreatePermission(ctx, permission); err != nil { - return nil, err - } - - return permission, nil -} - -func (s *RolePermissionService) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - return s.repo.GetAllPermissions(ctx) -} - -func (s *RolePermissionService) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { - roleID = strings.TrimSpace(roleID) - if roleID == "" { - return nil, constants.ErrUnprocessableEntity - } - - role, err := s.repo.GetRoleByID(ctx, roleID) - if err != nil { - return nil, err - } - if role == nil { - return nil, constants.ErrNotFound - } - - return s.repo.GetRolePermissions(ctx, roleID) -} - -func (s *RolePermissionService) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { - permissionID = strings.TrimSpace(permissionID) - if permissionID == "" { - return nil, constants.ErrUnprocessableEntity - } - if req.Description == nil { - return nil, constants.ErrUnprocessableEntity - } - - description := strings.TrimSpace(*req.Description) - if description == "" { - return nil, constants.ErrUnprocessableEntity - } - - permission, err := s.repo.GetPermissionByID(ctx, permissionID) - if err != nil { - return nil, err - } - if permission == nil { - return nil, constants.ErrNotFound - } - if permission.IsSystem { - return nil, constants.ErrBadRequest - } - - updated, err := s.repo.UpdatePermission(ctx, permissionID, &description) - if err != nil { - return nil, err - } - if !updated { - return nil, constants.ErrNotFound - } - - permission, err = s.repo.GetPermissionByID(ctx, permissionID) - if err != nil { - return nil, err - } - if permission == nil { - return nil, constants.ErrNotFound - } - - return permission, nil -} - -func (s *RolePermissionService) DeletePermission(ctx context.Context, permissionID string) error { - permissionID = strings.TrimSpace(permissionID) - if permissionID == "" { return constants.ErrBadRequest } - permission, err := s.repo.GetPermissionByID(ctx, permissionID) + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) if err != nil { return err } @@ -258,37 +64,18 @@ func (s *RolePermissionService) DeletePermission(ctx context.Context, permission return constants.ErrBadRequest } - assignmentsCount, err := s.repo.CountRoleAssignmentsByPermissionID(ctx, permissionID) - if err != nil { - return err - } - if assignmentsCount > 0 { - return constants.ErrConflict - } - - deleted, err := s.repo.DeletePermission(ctx, permissionID) - if err != nil { - return err - } - if !deleted { - return constants.ErrNotFound - } - - return nil + return s.rolePermissionsRepo.AddRolePermission(ctx, roleID, permissionID, grantedByUserID) } -func (s *RolePermissionService) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { - roleID = strings.TrimSpace(roleID) - permissionID = strings.TrimSpace(permissionID) - +func (s *RolePermissionsService) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { if roleID == "" { - return constants.ErrBadRequest + return constants.ErrUnprocessableEntity } if permissionID == "" { - return constants.ErrBadRequest + return constants.ErrUnprocessableEntity } - role, err := s.repo.GetRoleByID(ctx, roleID) + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) if err != nil { return err } @@ -299,7 +86,7 @@ func (s *RolePermissionService) AddPermissionToRole(ctx context.Context, roleID return constants.ErrBadRequest } - permission, err := s.repo.GetPermissionByID(ctx, permissionID) + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) if err != nil { return err } @@ -310,21 +97,15 @@ func (s *RolePermissionService) AddPermissionToRole(ctx context.Context, roleID return constants.ErrBadRequest } - return s.repo.AddRolePermission(ctx, roleID, permissionID, grantedByUserID) + return s.rolePermissionsRepo.RemoveRolePermission(ctx, roleID, permissionID) } -func (s *RolePermissionService) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { - roleID = strings.TrimSpace(roleID) - permissionID = strings.TrimSpace(permissionID) - +func (s *RolePermissionsService) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { if roleID == "" { - return constants.ErrUnprocessableEntity - } - if permissionID == "" { - return constants.ErrUnprocessableEntity + return constants.ErrBadRequest } - role, err := s.repo.GetRoleByID(ctx, roleID) + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) if err != nil { return err } @@ -335,29 +116,9 @@ func (s *RolePermissionService) RemovePermissionFromRole(ctx context.Context, ro return constants.ErrBadRequest } - permission, err := s.repo.GetPermissionByID(ctx, permissionID) - if err != nil { - return err - } - if permission == nil { - return constants.ErrNotFound - } - if permission.IsSystem { - return constants.ErrBadRequest - } - - return s.repo.RemoveRolePermission(ctx, roleID, permissionID) -} - -func (s *RolePermissionService) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { - if strings.TrimSpace(roleID) == "" { - return constants.ErrBadRequest - } - normalized := make([]string, 0, len(permissionIDs)) seen := make(map[string]struct{}, len(permissionIDs)) for _, permissionID := range permissionIDs { - permissionID = strings.TrimSpace(permissionID) if permissionID == "" { continue } @@ -365,62 +126,20 @@ func (s *RolePermissionService) ReplaceRolePermissions(ctx context.Context, role continue } seen[permissionID] = struct{}{} - normalized = append(normalized, permissionID) - } - - return s.repo.ReplaceRolePermissions(ctx, roleID, normalized, grantedByUserID) -} - -func (s *RolePermissionService) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - if strings.TrimSpace(userID) == "" { - return constants.ErrBadRequest - } - normalized := make([]string, 0, len(roleIDs)) - seen := make(map[string]struct{}, len(roleIDs)) - for _, roleID := range roleIDs { - roleID = strings.TrimSpace(roleID) - if roleID == "" { - continue + permission, err := s.permissionsRepo.GetPermissionByID(ctx, permissionID) + if err != nil { + return err } - if _, ok := seen[roleID]; ok { - continue + if permission == nil { + return constants.ErrNotFound + } + if permission.IsSystem { + return constants.ErrBadRequest } - seen[roleID] = struct{}{} - normalized = append(normalized, roleID) - } - - return s.repo.ReplaceUserRoles(ctx, userID, normalized, assignedByUserID) -} - -func (s *RolePermissionService) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { - userID = strings.TrimSpace(userID) - if userID == "" { - return constants.ErrUnprocessableEntity - } - - roleID := strings.TrimSpace(req.RoleID) - if roleID == "" { - return constants.ErrUnprocessableEntity - } - - if req.ExpiresAt != nil && req.ExpiresAt.Before(time.Now().UTC()) { - return constants.ErrBadRequest - } - - return s.repo.AssignUserRole(ctx, userID, roleID, assignedByUserID, req.ExpiresAt) -} - -func (s *RolePermissionService) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { - userID = strings.TrimSpace(userID) - roleID = strings.TrimSpace(roleID) - if userID == "" { - return constants.ErrBadRequest - } - if roleID == "" { - return constants.ErrBadRequest + normalized = append(normalized, permissionID) } - return s.repo.RemoveUserRole(ctx, userID, roleID) + return s.rolePermissionsRepo.ReplaceRolePermissions(ctx, roleID, normalized, grantedByUserID) } diff --git a/plugins/access-control/services/role_permission_service_test.go b/plugins/access-control/services/role_permission_service_test.go index 7e18d21..5e568ea 100644 --- a/plugins/access-control/services/role_permission_service_test.go +++ b/plugins/access-control/services/role_permission_service_test.go @@ -1,139 +1 @@ -package services_test - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/stretchr/testify/mock" - - "github.com/Authula/authula/plugins/access-control/constants" - servicespkg "github.com/Authula/authula/plugins/access-control/services" - testshelpers "github.com/Authula/authula/plugins/access-control/tests" - "github.com/Authula/authula/plugins/access-control/types" -) - -func newRolePermissionServiceFixture() (*servicespkg.RolePermissionService, *testshelpers.MockRolePermissionRepository) { - repo := testshelpers.NewMockRolePermissionRepository() - svc := servicespkg.NewRolePermissionService(repo) - return svc, repo -} - -func TestRolePermissionServiceCreateRoleTrimsName(t *testing.T) { - t.Parallel() - - svc, repo := newRolePermissionServiceFixture() - repo.On("CreateRole", mock.Anything, mock.AnythingOfType("*types.Role")). - Run(func(args mock.Arguments) { - role := args.Get(1).(*types.Role) - if role.Name != "admin" { - t.Fatalf("expected trimmed role name, got %q", role.Name) - } - }). - Return(nil). - Once() - - role, err := svc.CreateRole(context.Background(), types.CreateRoleRequest{Name: " admin "}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if role.Name != "admin" { - t.Fatalf("expected trimmed role name, got %q", role.Name) - } - repo.AssertExpectations(t) -} - -func TestRolePermissionServiceAssignRoleToUserRejectsPastExpiry(t *testing.T) { - t.Parallel() - - svc, _ := newRolePermissionServiceFixture() - past := time.Now().UTC().Add(-1 * time.Hour) - - err := svc.AssignRoleToUser( - context.Background(), - "user-1", - types.AssignUserRoleRequest{RoleID: "role-1", ExpiresAt: &past}, - nil, - ) - if !errors.Is(err, constants.ErrBadRequest) { - t.Fatalf("expected ErrBadRequest, got %v", err) - } -} - -func TestRolePermissionServiceDeleteRoleReturnsConflictWhenAssigned(t *testing.T) { - t.Parallel() - - svc, repo := newRolePermissionServiceFixture() - ctx := context.Background() - roleID := "role-1" - - repo.On("GetRoleByID", mock.Anything, roleID). - Return(&types.Role{ID: roleID, Name: "role-for-delete"}, nil). - Once() - repo.On("CountUserAssignmentsByRoleID", mock.Anything, roleID). - Return(1, nil). - Once() - - err := svc.DeleteRole(ctx, roleID) - if !errors.Is(err, constants.ErrConflict) { - t.Fatalf("expected ErrConflict, got %v", err) - } - repo.AssertExpectations(t) -} - -func TestRolePermissionServiceGetRolePermissionsRejectsEmptyRoleID(t *testing.T) { - t.Parallel() - - svc, _ := newRolePermissionServiceFixture() - - _, err := svc.GetRolePermissions(context.Background(), " ") - if !errors.Is(err, constants.ErrUnprocessableEntity) { - t.Fatalf("expected ErrUnprocessableEntity, got %v", err) - } -} - -func TestRolePermissionServiceGetRolePermissionsReturnsNotFoundForMissingRole(t *testing.T) { - t.Parallel() - - svc, repo := newRolePermissionServiceFixture() - repo.On("GetRoleByID", mock.Anything, "missing-role").Return((*types.Role)(nil), nil).Once() - - _, err := svc.GetRolePermissions(context.Background(), "missing-role") - if !errors.Is(err, constants.ErrNotFound) { - t.Fatalf("expected ErrNotFound, got %v", err) - } - repo.AssertExpectations(t) -} - -func TestRolePermissionServiceGetRolePermissionsReturnsAssignedPermissions(t *testing.T) { - t.Parallel() - - svc, repo := newRolePermissionServiceFixture() - ctx := context.Background() - roleID := "role-1" - permissionID := "perm-1" - permissionKey := "users.read" - - repo.On("GetRoleByID", mock.Anything, roleID). - Return(&types.Role{ID: roleID, Name: "RolePermReader"}, nil). - Once() - repo.On("GetRolePermissions", mock.Anything, roleID). - Return([]types.UserPermissionInfo{{PermissionID: permissionID, PermissionKey: permissionKey}}, nil). - Once() - - permissions, err := svc.GetRolePermissions(ctx, " "+roleID+" ") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(permissions) != 1 { - t.Fatalf("expected 1 permission, got %d", len(permissions)) - } - if permissions[0].PermissionID != permissionID { - t.Fatalf("expected permission id %q, got %q", permissionID, permissions[0].PermissionID) - } - if permissions[0].PermissionKey != permissionKey { - t.Fatalf("expected permission key %q, got %q", permissionKey, permissions[0].PermissionKey) - } - repo.AssertExpectations(t) -} +package services diff --git a/plugins/access-control/services/roles_service.go b/plugins/access-control/services/roles_service.go new file mode 100644 index 0000000..35ac7e2 --- /dev/null +++ b/plugins/access-control/services/roles_service.go @@ -0,0 +1,172 @@ +package services + +import ( + "context" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/repositories" + "github.com/Authula/authula/plugins/access-control/types" +) + +type RolesService struct { + rolesRepo repositories.RolesRepository + rolePermissionsRepo repositories.RolePermissionsRepository + userRolesRepo repositories.UserRolesRepository +} + +func NewRolesService(rolesRepo repositories.RolesRepository, rolePermissionsRepo repositories.RolePermissionsRepository, userRolesRepo repositories.UserRolesRepository) *RolesService { + return &RolesService{rolesRepo: rolesRepo, rolePermissionsRepo: rolePermissionsRepo, userRolesRepo: userRolesRepo} +} + +func (s *RolesService) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { + if req.Name == "" { + return nil, constants.ErrBadRequest + } + + var description *string + if req.Description != nil { + description = req.Description + } + + role := &types.Role{ + ID: util.GenerateUUID(), + Name: req.Name, + Description: description, + IsSystem: req.IsSystem, + } + + if err := s.rolesRepo.CreateRole(ctx, role); err != nil { + return nil, err + } + + return role, nil +} + +func (s *RolesService) GetAllRoles(ctx context.Context) ([]types.Role, error) { + return s.rolesRepo.GetAllRoles(ctx) +} + +func (s *RolesService) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + if roleName == "" { + return nil, constants.ErrBadRequest + } + + role, err := s.rolesRepo.GetRoleByName(ctx, roleName) + if err != nil { + return nil, err + } + if role == nil { + return nil, constants.ErrNotFound + } + + return role, nil +} + +func (s *RolesService) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { + if roleID == "" { + return nil, constants.ErrBadRequest + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return nil, err + } + if role == nil { + return nil, constants.ErrNotFound + } + + permissions, err := s.rolePermissionsRepo.GetRolePermissions(ctx, roleID) + if err != nil { + return nil, err + } + + return &types.RoleDetails{Role: *role, Permissions: permissions}, nil +} + +func (s *RolesService) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { + if roleID == "" { + return nil, constants.ErrBadRequest + } + + if req.Name == nil && req.Description == nil { + return nil, constants.ErrUnprocessableEntity + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return nil, err + } + if role == nil { + return nil, constants.ErrNotFound + } + if role.IsSystem { + return nil, constants.ErrCannotUpdateSystemRole + } + + var name *string + if req.Name != nil { + if *req.Name == "" { + return nil, constants.ErrBadRequest + } + name = req.Name + } + + var description *string + if req.Description != nil { + description = req.Description + } + + updated, err := s.rolesRepo.UpdateRole(ctx, roleID, name, description) + if err != nil { + return nil, err + } + if !updated { + return nil, constants.ErrNotFound + } + + role, err = s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return nil, err + } + if role == nil { + return nil, constants.ErrNotFound + } + + return role, nil +} + +func (s *RolesService) DeleteRole(ctx context.Context, roleID string) error { + if roleID == "" { + return constants.ErrBadRequest + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return err + } + if role == nil { + return constants.ErrNotFound + } + if role.IsSystem { + return constants.ErrCannotUpdateSystemRole + } + + totalUsersByRole, err := s.userRolesRepo.CountUsersByRole(ctx, roleID) + if err != nil { + return err + } + if totalUsersByRole > 0 { + return constants.ErrConflict + } + + deleted, err := s.rolesRepo.DeleteRole(ctx, roleID) + if err != nil { + return err + } + if !deleted { + return constants.ErrNotFound + } + + return nil +} diff --git a/plugins/access-control/services/roles_service_test.go b/plugins/access-control/services/roles_service_test.go new file mode 100644 index 0000000..6f588e5 --- /dev/null +++ b/plugins/access-control/services/roles_service_test.go @@ -0,0 +1,149 @@ +package services + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestRolesServiceGetRoleByID(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + rolePermissionsRepo.On("GetRolePermissions", mock.Anything, "role-1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() + + service := NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo) + details, err := service.GetRoleByID(context.Background(), "role-1") + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if details == nil || details.Role.ID != "role-1" || len(details.Permissions) != 1 { + t.Fatalf("unexpected details %#v", details) + } + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) +} + +func TestRolesServiceGetRoleByName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + roleName string + setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + wantErr error + }{ + { + name: "bad request", + roleName: "", + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "not found", + roleName: "missing", + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByName", mock.Anything, "missing").Return((*types.Role)(nil), nil).Once() + }, + wantErr: accesscontrolconstants.ErrNotFound, + }, + { + name: "success", + roleName: "admin", + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByName", mock.Anything, "admin").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setup != nil { + tc.setup(rolesRepo, userRolesRepo) + } + + service := NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo) + role, err := service.GetRoleByName(context.Background(), tc.roleName) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + if tc.wantErr != nil { + if role != nil { + t.Fatalf("expected nil role, got %#v", role) + } + } else { + if role == nil || role.ID != "role-1" || role.Name != "admin" { + t.Fatalf("unexpected role %#v", role) + } + } + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + }) + } +} + +func TestRolesServiceDeleteRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(*accesscontroltests.MockRolesRepository, *accesscontroltests.MockUserRolesRepository) + wantErr error + }{ + { + name: "role in use", + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + userRolesRepo.On("CountUsersByRole", mock.Anything, "role-1").Return(1, nil).Once() + }, + wantErr: accesscontrolconstants.ErrConflict, + }, + { + name: "success", + setup: func(rolesRepo *accesscontroltests.MockRolesRepository, userRolesRepo *accesscontroltests.MockUserRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + userRolesRepo.On("CountUsersByRole", mock.Anything, "role-1").Return(0, nil).Once() + rolesRepo.On("DeleteRole", mock.Anything, "role-1").Return(true, nil).Once() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolePermissionsRepo := &accesscontroltests.MockRolePermissionsRepository{} + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + if tc.setup != nil { + tc.setup(rolesRepo, userRolesRepo) + } + + service := NewRolesService(rolesRepo, rolePermissionsRepo, userRolesRepo) + err := service.DeleteRole(context.Background(), "role-1") + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + rolesRepo.AssertExpectations(t) + rolePermissionsRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) + }) + } +} diff --git a/plugins/access-control/services/user_access_service.go b/plugins/access-control/services/user_access_service.go deleted file mode 100644 index 4db4d0c..0000000 --- a/plugins/access-control/services/user_access_service.go +++ /dev/null @@ -1,95 +0,0 @@ -package services - -import ( - "context" - "strings" - - "github.com/Authula/authula/plugins/access-control/constants" - "github.com/Authula/authula/plugins/access-control/repositories" - "github.com/Authula/authula/plugins/access-control/types" -) - -type UserAccessService struct { - userAccessRepo repositories.UserAccessRepository -} - -func NewUserAccessService(repo repositories.UserAccessRepository) *UserAccessService { - return &UserAccessService{userAccessRepo: repo} -} - -func (s *UserAccessService) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - if strings.TrimSpace(userID) == "" { - return nil, constants.ErrUnprocessableEntity - } - return s.userAccessRepo.GetUserRoles(ctx, userID) -} - -func (s *UserAccessService) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - if strings.TrimSpace(userID) == "" { - return nil, constants.ErrUnprocessableEntity - } - return s.userAccessRepo.GetUserWithRolesByID(ctx, userID) -} - -func (s *UserAccessService) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - if strings.TrimSpace(userID) == "" { - return nil, constants.ErrUnprocessableEntity - } - return s.userAccessRepo.GetUserWithPermissionsByID(ctx, userID) -} - -func (s *UserAccessService) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { - withRoles, err := s.GetUserWithRolesByID(ctx, userID) - if err != nil { - return nil, err - } - if withRoles == nil { - return nil, nil - } - - withPermissions, err := s.GetUserWithPermissionsByID(ctx, userID) - if err != nil { - return nil, err - } - - profile := &types.UserAuthorizationProfile{ - User: withRoles.User, - Roles: withRoles.Roles, - } - if withPermissions != nil { - profile.Permissions = withPermissions.Permissions - } - - return profile, nil -} - -func (s *UserAccessService) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - if strings.TrimSpace(userID) == "" { - return nil, constants.ErrUnprocessableEntity - } - return s.userAccessRepo.GetUserEffectivePermissions(ctx, userID) -} - -func (s *UserAccessService) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - permissions, err := s.GetUserEffectivePermissions(ctx, userID) - if err != nil { - return false, err - } - - granted := make(map[string]struct{}, len(permissions)) - for _, permission := range permissions { - granted[permission.PermissionKey] = struct{}{} - } - - for _, required := range requiredPermissions { - required = strings.TrimSpace(required) - if required == "" { - continue - } - if _, ok := granted[required]; ok { - return true, nil - } - } - - return false, nil -} diff --git a/plugins/access-control/services/user_access_service_test.go b/plugins/access-control/services/user_access_service_test.go deleted file mode 100644 index 00fa6dd..0000000 --- a/plugins/access-control/services/user_access_service_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package services - -import ( - "context" - "errors" - "testing" - - "github.com/Authula/authula/plugins/access-control/constants" - "github.com/Authula/authula/plugins/access-control/types" -) - -type stubUserAccessRepo struct { - rolesResult []types.UserRoleInfo - rolesErr error - permissionsResult []types.UserPermissionInfo - permissionsErr error - withRolesResult *types.UserWithRoles - withRolesErr error - withPermsResult *types.UserWithPermissions - withPermsErr error -} - -func (s *stubUserAccessRepo) GetUserRoles(_ context.Context, _ string) ([]types.UserRoleInfo, error) { - return s.rolesResult, s.rolesErr -} - -func (s *stubUserAccessRepo) GetUserEffectivePermissions(_ context.Context, _ string) ([]types.UserPermissionInfo, error) { - return s.permissionsResult, s.permissionsErr -} - -func (s *stubUserAccessRepo) GetUserWithRolesByID(_ context.Context, _ string) (*types.UserWithRoles, error) { - return s.withRolesResult, s.withRolesErr -} - -func (s *stubUserAccessRepo) GetUserWithPermissionsByID(_ context.Context, _ string) (*types.UserWithPermissions, error) { - return s.withPermsResult, s.withPermsErr -} - -func TestUserAccessServiceGetUserRolesUnprocessableEntity(t *testing.T) { - t.Parallel() - - svc := NewUserAccessService(&stubUserAccessRepo{}) - _, err := svc.GetUserRoles(context.Background(), " ") - if !errors.Is(err, constants.ErrUnprocessableEntity) { - t.Fatalf("expected ErrUnprocessableEntity, got %v", err) - } -} - -func TestUserAccessServiceHasPermissionsMatchesAnyRequiredPermission(t *testing.T) { - t.Parallel() - - svc := NewUserAccessService(&stubUserAccessRepo{ - permissionsResult: []types.UserPermissionInfo{{PermissionKey: "users.read"}, {PermissionKey: "users.write"}}, - }) - - ok, err := svc.HasPermissions(context.Background(), "user-1", []string{"billing.read", "users.write"}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !ok { - t.Fatal("expected permission check to pass when any required permission matches") - } -} - -func TestUserAccessServiceGetUserAuthorizationProfileNilUser(t *testing.T) { - t.Parallel() - - svc := NewUserAccessService(&stubUserAccessRepo{withRolesResult: nil}) - - profile, err := svc.GetUserAuthorizationProfile(context.Background(), "user-1") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if profile != nil { - t.Fatalf("expected nil profile, got %+v", profile) - } -} diff --git a/plugins/access-control/services/user_permissions_service.go b/plugins/access-control/services/user_permissions_service.go new file mode 100644 index 0000000..5a92cf1 --- /dev/null +++ b/plugins/access-control/services/user_permissions_service.go @@ -0,0 +1,33 @@ +package services + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/repositories" + "github.com/Authula/authula/plugins/access-control/types" +) + +type UserPermissionsService struct { + repo repositories.UserPermissionsRepository +} + +func NewUserPermissionsService(repo repositories.UserPermissionsRepository) *UserPermissionsService { + return &UserPermissionsService{repo: repo} +} + +func (s *UserPermissionsService) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + if userID == "" { + return nil, constants.ErrUnprocessableEntity + } + + return s.repo.GetUserPermissions(ctx, userID) +} + +func (s *UserPermissionsService) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + if userID == "" { + return false, constants.ErrUnprocessableEntity + } + + return s.repo.HasPermissions(ctx, userID, permissionKeys) +} diff --git a/plugins/access-control/services/user_permissions_service_test.go b/plugins/access-control/services/user_permissions_service_test.go new file mode 100644 index 0000000..90a6f83 --- /dev/null +++ b/plugins/access-control/services/user_permissions_service_test.go @@ -0,0 +1,139 @@ +package services + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestUserPermissionsServiceGetUserPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + setupMock func(*accesscontroltests.MockUserPermissionsRepository) + expectedCount int + expectedErr error + }{ + { + name: "blank user id", + userID: "", + expectedErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "success", + userID: "u1", + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("GetUserPermissions", mock.Anything, "u1").Return([]types.UserPermissionInfo{{PermissionID: "perm-1", PermissionKey: "users.read"}}, nil).Once() + }, + expectedCount: 1, + }, + { + name: "repo error", + userID: "u1", + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("GetUserPermissions", mock.Anything, "u1").Return(([]types.UserPermissionInfo)(nil), accesscontrolconstants.ErrNotFound).Once() + }, + expectedErr: accesscontrolconstants.ErrNotFound, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + repo := &accesscontroltests.MockUserPermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(repo) + } + + service := NewUserPermissionsService(repo) + permissions, err := service.GetUserPermissions(context.Background(), tc.userID) + if tc.expectedErr != nil { + if err != tc.expectedErr { + t.Fatalf("expected err %v, got %v", tc.expectedErr, err) + } + repo.AssertExpectations(t) + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(permissions) != tc.expectedCount { + t.Fatalf("expected %d permissions, got %d", tc.expectedCount, len(permissions)) + } + repo.AssertExpectations(t) + }) + } +} + +func TestUserPermissionsServiceHasPermissions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + permissionKeys []string + setupMock func(*accesscontroltests.MockUserPermissionsRepository) + expectedHas bool + expectedErr error + }{ + { + name: "blank user id", + userID: "", + expectedErr: accesscontrolconstants.ErrUnprocessableEntity, + }, + { + name: "success", + userID: "u1", + permissionKeys: []string{"users.read"}, + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("HasPermissions", mock.Anything, "u1", []string{"users.read"}).Return(true, nil).Once() + }, + expectedHas: true, + }, + { + name: "repo error", + userID: "u1", + permissionKeys: []string{"users.read"}, + setupMock: func(m *accesscontroltests.MockUserPermissionsRepository) { + m.On("HasPermissions", mock.Anything, "u1", []string{"users.read"}).Return(false, accesscontrolconstants.ErrForbidden).Once() + }, + expectedErr: accesscontrolconstants.ErrForbidden, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + repo := &accesscontroltests.MockUserPermissionsRepository{} + if tc.setupMock != nil { + tc.setupMock(repo) + } + + service := NewUserPermissionsService(repo) + hasPermissions, err := service.HasPermissions(context.Background(), tc.userID, tc.permissionKeys) + if tc.expectedErr != nil { + if err != tc.expectedErr { + t.Fatalf("expected err %v, got %v", tc.expectedErr, err) + } + repo.AssertExpectations(t) + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if hasPermissions != tc.expectedHas { + t.Fatalf("expected hasPermissions=%v, got %v", tc.expectedHas, hasPermissions) + } + repo.AssertExpectations(t) + }) + } +} diff --git a/plugins/access-control/services/user_roles_service.go b/plugins/access-control/services/user_roles_service.go new file mode 100644 index 0000000..05d8aa1 --- /dev/null +++ b/plugins/access-control/services/user_roles_service.go @@ -0,0 +1,90 @@ +package services + +import ( + "context" + "time" + + "github.com/Authula/authula/plugins/access-control/constants" + "github.com/Authula/authula/plugins/access-control/repositories" + "github.com/Authula/authula/plugins/access-control/types" +) + +type UserRolesService struct { + userRolesRepo repositories.UserRolesRepository + rolesRepo repositories.RolesRepository +} + +func NewUserRolesService(userRolesRepo repositories.UserRolesRepository, rolesRepo repositories.RolesRepository) *UserRolesService { + return &UserRolesService{userRolesRepo: userRolesRepo, rolesRepo: rolesRepo} +} + +func (s *UserRolesService) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + if userID == "" { + return nil, constants.ErrUnprocessableEntity + } + + return s.userRolesRepo.GetUserRoles(ctx, userID) +} + +func (s *UserRolesService) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { + if userID == "" { + return constants.ErrBadRequest + } + + normalized := make([]string, 0, len(roleIDs)) + seen := make(map[string]struct{}, len(roleIDs)) + for _, roleID := range roleIDs { + if roleID == "" { + continue + } + if _, ok := seen[roleID]; ok { + continue + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return err + } + if role == nil { + return constants.ErrNotFound + } + + seen[roleID] = struct{}{} + normalized = append(normalized, roleID) + } + + return s.userRolesRepo.ReplaceUserRoles(ctx, userID, normalized, assignedByUserID) +} + +func (s *UserRolesService) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { + if userID == "" { + return constants.ErrBadRequest + } + + roleID := req.RoleID + if roleID == "" { + return constants.ErrBadRequest + } + + if req.ExpiresAt != nil && req.ExpiresAt.Before(time.Now().UTC()) { + return constants.ErrBadRequest + } + + role, err := s.rolesRepo.GetRoleByID(ctx, roleID) + if err != nil { + return err + } + if role == nil { + return constants.ErrNotFound + } + + return s.userRolesRepo.AssignUserRole(ctx, userID, roleID, assignedByUserID, req.ExpiresAt) +} + +func (s *UserRolesService) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { + if userID == "" || roleID == "" { + return constants.ErrBadRequest + } + + return s.userRolesRepo.RemoveUserRole(ctx, userID, roleID) +} diff --git a/plugins/access-control/services/user_roles_service_test.go b/plugins/access-control/services/user_roles_service_test.go new file mode 100644 index 0000000..b96ed2c --- /dev/null +++ b/plugins/access-control/services/user_roles_service_test.go @@ -0,0 +1,78 @@ +package services + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + accesscontrolconstants "github.com/Authula/authula/plugins/access-control/constants" + accesscontroltests "github.com/Authula/authula/plugins/access-control/tests" + "github.com/Authula/authula/plugins/access-control/types" +) + +func TestUserRolesServiceAssignRoleToUser(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req types.AssignUserRoleRequest + setup func(*accesscontroltests.MockUserRolesRepository, *accesscontroltests.MockRolesRepository) + wantErr error + }{ + { + name: "expired assignment", + req: types.AssignUserRoleRequest{RoleID: "role-1", ExpiresAt: func() *time.Time { t := time.Now().UTC().Add(-time.Hour); return &t }()}, + wantErr: accesscontrolconstants.ErrBadRequest, + }, + { + name: "success", + req: types.AssignUserRoleRequest{RoleID: "role-1"}, + setup: func(userRolesRepo *accesscontroltests.MockUserRolesRepository, rolesRepo *accesscontroltests.MockRolesRepository) { + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + userRolesRepo.On("AssignUserRole", mock.Anything, "user-1", "role-1", (*string)(nil), (*time.Time)(nil)).Return(nil).Once() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + rolesRepo := &accesscontroltests.MockRolesRepository{} + if tc.setup != nil { + tc.setup(userRolesRepo, rolesRepo) + } + + service := NewUserRolesService(userRolesRepo, rolesRepo) + err := service.AssignRoleToUser(context.Background(), "user-1", tc.req, nil) + if err != tc.wantErr { + t.Fatalf("expected err %v, got %v", tc.wantErr, err) + } + + userRolesRepo.AssertExpectations(t) + rolesRepo.AssertExpectations(t) + }) + } +} + +func TestUserRolesServiceReplaceUserRolesDedupes(t *testing.T) { + t.Parallel() + + userRolesRepo := &accesscontroltests.MockUserRolesRepository{} + rolesRepo := &accesscontroltests.MockRolesRepository{} + rolesRepo.On("GetRoleByID", mock.Anything, "role-1").Return(&types.Role{ID: "role-1", Name: "admin"}, nil).Once() + rolesRepo.On("GetRoleByID", mock.Anything, "role-2").Return(&types.Role{ID: "role-2", Name: "editor"}, nil).Once() + userRolesRepo.On("ReplaceUserRoles", mock.Anything, "user-1", []string{"role-1", "role-2"}, (*string)(nil)).Return(nil).Once() + + service := NewUserRolesService(userRolesRepo, rolesRepo) + err := service.ReplaceUserRoles(context.Background(), "user-1", []string{"role-1", "role-1", "role-2"}, nil) + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + + rolesRepo.AssertExpectations(t) + userRolesRepo.AssertExpectations(t) +} diff --git a/plugins/access-control/tests/test_helpers.go b/plugins/access-control/tests/test_helpers.go index ce2bc66..ddfbcc4 100644 --- a/plugins/access-control/tests/test_helpers.go +++ b/plugins/access-control/tests/test_helpers.go @@ -2,71 +2,32 @@ package tests import ( "context" + "database/sql" + "testing" "time" + _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/mock" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" - "github.com/Authula/authula/plugins/access-control/services" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/migrations" + "github.com/Authula/authula/models" + accesscontrolmigrations "github.com/Authula/authula/plugins/access-control/migrationset" "github.com/Authula/authula/plugins/access-control/types" - "github.com/Authula/authula/plugins/access-control/usecases" ) -type mockUserAccessRepository struct { +type MockRolesRepository struct { mock.Mock } -func (m *mockUserAccessRepository) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]types.UserRoleInfo), args.Error(1) -} - -func (m *mockUserAccessRepository) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]types.UserPermissionInfo), args.Error(1) -} - -func (m *mockUserAccessRepository) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(*types.UserWithRoles), args.Error(1) -} - -func (m *mockUserAccessRepository) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - args := m.Called(ctx, userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(*types.UserWithPermissions), args.Error(1) -} - -func NewUserRolesUseCaseFixture() (usecases.UserRolesUseCase, *mockUserAccessRepository) { - repo := &mockUserAccessRepository{} - service := services.NewUserAccessService(repo) - return usecases.NewUserRolesUseCase(service), repo -} - -type MockRolePermissionRepository struct { - mock.Mock -} - -func NewMockRolePermissionRepository() *MockRolePermissionRepository { - return &MockRolePermissionRepository{} -} - -func (m *MockRolePermissionRepository) CreateRole(ctx context.Context, role *types.Role) error { +func (m *MockRolesRepository) CreateRole(ctx context.Context, role *types.Role) error { args := m.Called(ctx, role) return args.Error(0) } -func (m *MockRolePermissionRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { +func (m *MockRolesRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) @@ -74,7 +35,7 @@ func (m *MockRolePermissionRepository) GetAllRoles(ctx context.Context) ([]types return args.Get(0).([]types.Role), args.Error(1) } -func (m *MockRolePermissionRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { +func (m *MockRolesRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { args := m.Called(ctx, roleID) if args.Get(0) == nil { return nil, args.Error(1) @@ -82,22 +43,29 @@ func (m *MockRolePermissionRepository) GetRoleByID(ctx context.Context, roleID s return args.Get(0).(*types.Role), args.Error(1) } -func (m *MockRolePermissionRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { +func (m *MockRolesRepository) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + args := m.Called(ctx, roleName) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.Role), args.Error(1) +} + +func (m *MockRolesRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { args := m.Called(ctx, roleID, name, description) return args.Bool(0), args.Error(1) } -func (m *MockRolePermissionRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { +func (m *MockRolesRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { args := m.Called(ctx, roleID) return args.Bool(0), args.Error(1) } -func (m *MockRolePermissionRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { - args := m.Called(ctx, roleID) - return args.Int(0), args.Error(1) +type MockPermissionsRepository struct { + mock.Mock } -func (m *MockRolePermissionRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { +func (m *MockPermissionsRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) @@ -105,7 +73,7 @@ func (m *MockRolePermissionRepository) GetAllPermissions(ctx context.Context) ([ return args.Get(0).([]types.Permission), args.Error(1) } -func (m *MockRolePermissionRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { +func (m *MockPermissionsRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { args := m.Called(ctx, permissionID) if args.Get(0) == nil { return nil, args.Error(1) @@ -113,27 +81,34 @@ func (m *MockRolePermissionRepository) GetPermissionByID(ctx context.Context, pe return args.Get(0).(*types.Permission), args.Error(1) } -func (m *MockRolePermissionRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { +func (m *MockPermissionsRepository) GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) { + args := m.Called(ctx, permissionKey) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.Permission), args.Error(1) +} + +func (m *MockPermissionsRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { args := m.Called(ctx, permission) return args.Error(0) } -func (m *MockRolePermissionRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { +func (m *MockPermissionsRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { args := m.Called(ctx, permissionID, description) return args.Bool(0), args.Error(1) } -func (m *MockRolePermissionRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { +func (m *MockPermissionsRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { args := m.Called(ctx, permissionID) return args.Bool(0), args.Error(1) } -func (m *MockRolePermissionRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { - args := m.Called(ctx, permissionID) - return args.Int(0), args.Error(1) +type MockRolePermissionsRepository struct { + mock.Mock } -func (m *MockRolePermissionRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { +func (m *MockRolePermissionsRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { args := m.Called(ctx, roleID) if args.Get(0) == nil { return nil, args.Error(1) @@ -141,38 +116,116 @@ func (m *MockRolePermissionRepository) GetRolePermissions(ctx context.Context, r return args.Get(0).([]types.UserPermissionInfo), args.Error(1) } -func (m *MockRolePermissionRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { +func (m *MockRolePermissionsRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { args := m.Called(ctx, roleID, permissionIDs, grantedByUserID) return args.Error(0) } -func (m *MockRolePermissionRepository) AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { +func (m *MockRolePermissionsRepository) AddRolePermission(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { args := m.Called(ctx, roleID, permissionID, grantedByUserID) return args.Error(0) } -func (m *MockRolePermissionRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { +func (m *MockRolePermissionsRepository) RemoveRolePermission(ctx context.Context, roleID string, permissionID string) error { args := m.Called(ctx, roleID, permissionID) return args.Error(0) } -func (m *MockRolePermissionRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { +func (m *MockRolePermissionsRepository) CountRolesByPermission(ctx context.Context, permissionID string) (int, error) { + args := m.Called(ctx, permissionID) + return args.Int(0), args.Error(1) +} + +type MockUserRolesRepository struct { + mock.Mock +} + +func (m *MockUserRolesRepository) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.UserRoleInfo), args.Error(1) +} + +func (m *MockUserRolesRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { args := m.Called(ctx, userID, roleIDs, assignedByUserID) return args.Error(0) } -func (m *MockRolePermissionRepository) AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error { +func (m *MockUserRolesRepository) AssignUserRole(ctx context.Context, userID string, roleID string, assignedByUserID *string, expiresAt *time.Time) error { args := m.Called(ctx, userID, roleID, assignedByUserID, expiresAt) return args.Error(0) } -func (m *MockRolePermissionRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { +func (m *MockUserRolesRepository) RemoveUserRole(ctx context.Context, userID string, roleID string) error { args := m.Called(ctx, userID, roleID) return args.Error(0) } -func NewRolePermissionUseCaseFixture() (usecases.RolePermissionUseCase, *MockRolePermissionRepository) { - repo := NewMockRolePermissionRepository() - service := services.NewRolePermissionService(repo) - return usecases.NewRolePermissionUseCase(service), repo +func (m *MockUserRolesRepository) CountUsersByRole(ctx context.Context, roleID string) (int, error) { + args := m.Called(ctx, roleID) + return args.Int(0), args.Error(1) +} + +type MockUserPermissionsRepository struct { + mock.Mock +} + +func (m *MockUserPermissionsRepository) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.UserPermissionInfo), args.Error(1) +} + +func (m *MockUserPermissionsRepository) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + args := m.Called(ctx, userID, permissionKeys) + return args.Bool(0), args.Error(1) +} + +func SetupRepoDB(t *testing.T) *bun.DB { + t.Helper() + + sqldb, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("failed to open sqlite: %v", err) + } + + db := bun.NewDB(sqldb, sqlitedialect.New()) + t.Cleanup(func() { + _ = db.Close() + }) + + ctx := context.Background() + + migrator, err := migrations.NewMigrator(db, &internaltests.MockLogger{}) + if err != nil { + t.Fatalf("failed to create migrator: %v", err) + } + + coreSet, err := migrations.CoreMigrationSet("sqlite") + if err != nil { + t.Fatalf("failed to build core migration set: %v", err) + } + + accessControlSet := migrations.MigrationSet{ + PluginID: models.PluginAccessControl.String(), + DependsOn: []string{migrations.CorePluginID}, + Migrations: accesscontrolmigrations.ForProvider("sqlite"), + } + + if err := migrator.Migrate(ctx, []migrations.MigrationSet{coreSet, accessControlSet}); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + + if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name, email, email_verified, metadata) VALUES ('u1', 'User One', 'u1@example.com', 1, '{}')`); err != nil { + t.Fatalf("failed to seed user u1: %v", err) + } + if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name, email, email_verified, metadata) VALUES ('u2', 'User Two', 'u2@example.com', 1, '{}')`); err != nil { + t.Fatalf("failed to seed user u2: %v", err) + } + + return db } diff --git a/plugins/access-control/types/config.go b/plugins/access-control/types/config.go new file mode 100644 index 0000000..293b347 --- /dev/null +++ b/plugins/access-control/types/config.go @@ -0,0 +1,7 @@ +package types + +type AccessControlPluginConfig struct { + Enabled bool `json:"enabled" toml:"enabled"` +} + +func (config *AccessControlPluginConfig) ApplyDefaults() {} diff --git a/plugins/access-control/types/models.go b/plugins/access-control/types/models.go index 25a6112..4359497 100644 --- a/plugins/access-control/types/models.go +++ b/plugins/access-control/types/models.go @@ -4,12 +4,8 @@ import ( "time" "github.com/uptrace/bun" - - "github.com/Authula/authula/models" ) -// Models - type Role struct { bun.BaseModel `bun:"table:access_control_roles"` @@ -50,141 +46,3 @@ type UserRole struct { AssignedAt time.Time `json:"assigned_at" bun:"column:assigned_at"` ExpiresAt *time.Time `json:"expires_at" bun:"column:expires_at"` } - -// Types - -type CreateRoleRequest struct { - Name string `json:"name"` - Description *string `json:"description,omitempty"` - IsSystem bool `json:"is_system"` -} - -type CreateRoleResponse struct { - Role *Role `json:"role"` -} - -type UpdateRoleRequest struct { - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` -} - -type UpdateRoleResponse struct { - Role *Role `json:"role"` -} - -type DeleteRoleResponse struct { - Message string `json:"message"` -} - -type CreatePermissionRequest struct { - Key string `json:"key"` - Description *string `json:"description,omitempty"` - IsSystem bool `json:"is_system"` -} - -type CreatePermissionResponse struct { - Permission *Permission `json:"permission"` -} - -type UpdatePermissionRequest struct { - Description *string `json:"description,omitempty"` -} - -type UpdatePermissionResponse struct { - Permission *Permission `json:"permission"` -} - -type DeletePermissionResponse struct { - Message string `json:"message"` -} - -type AddRolePermissionRequest struct { - PermissionID string `json:"permission_id"` -} - -type AddRolePermissionResponse struct { - Message string `json:"message"` -} - -type ReplaceRolePermissionsRequest struct { - PermissionIDs []string `json:"permission_ids"` -} - -type ReplaceRolePermissionResponse struct { - Message string `json:"message"` -} - -type RemoveRolePermissionResponse struct { - Message string `json:"message"` -} - -type AssignUserRoleRequest struct { - RoleID string `json:"role_id"` - ExpiresAt *time.Time `json:"expires_at,omitempty"` -} - -type ReplaceUserRolesRequest struct { - RoleIDs []string `json:"role_ids"` -} - -type ReplaceUserRolesResponse struct { - Message string `json:"message"` -} - -type AssignUserRoleResponse struct { - Message string `json:"message"` -} - -type RemoveUserRoleResponse struct { - Message string `json:"message"` -} - -type GetUserEffectivePermissionsResponse struct { - Permissions []UserPermissionInfo `json:"permissions"` -} - -type UserRoleInfo struct { - RoleID string `json:"role_id"` - RoleName string `json:"role_name"` - RoleDescription *string `json:"role_description,omitempty"` - AssignedByUserID *string `json:"assigned_by_user_id,omitempty"` - AssignedAt *time.Time `json:"assigned_at,omitempty"` - ExpiresAt *time.Time `json:"expires_at,omitempty"` -} - -type PermissionGrantSource struct { - RoleID string `json:"role_id"` - RoleName string `json:"role_name"` - GrantedByUserID *string `json:"granted_by_user_id,omitempty"` - GrantedAt *time.Time `json:"granted_at,omitempty"` -} - -type UserPermissionInfo struct { - PermissionID string `json:"permission_id"` - PermissionKey string `json:"permission_key"` - PermissionDescription *string `json:"permission_description,omitempty"` - GrantedByUserID *string `json:"granted_by_user_id,omitempty"` - GrantedAt *time.Time `json:"granted_at,omitempty"` - Sources []PermissionGrantSource `json:"sources,omitempty"` -} - -type UserWithRoles struct { - User models.User `json:"user"` - Roles []UserRoleInfo `json:"roles"` -} - -type UserWithPermissions struct { - User models.User `json:"user"` - Permissions []UserPermissionInfo `json:"permissions"` -} - -type UserAuthorizationProfile struct { - User models.User `json:"user"` - Roles []UserRoleInfo `json:"roles"` - Permissions []UserPermissionInfo `json:"permissions"` -} - -type RoleDetails struct { - Role Role `json:"role"` - Permissions []UserPermissionInfo `json:"permissions"` -} diff --git a/plugins/access-control/types/types.go b/plugins/access-control/types/types.go index c086c87..d39a224 100644 --- a/plugins/access-control/types/types.go +++ b/plugins/access-control/types/types.go @@ -1,8 +1,140 @@ package types -type AccessControlPluginConfig struct { - Enabled bool `json:"enabled" toml:"enabled"` - // TODO: Add a field to enable auto-cleanup of expired user roles and permissions +import ( + "time" + + "github.com/Authula/authula/models" +) + +type CreateRoleRequest struct { + Name string `json:"name"` + Description *string `json:"description,omitempty"` + IsSystem bool `json:"is_system"` +} + +type CreateRoleResponse struct { + Role *Role `json:"role"` +} + +type UpdateRoleRequest struct { + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` +} + +type UpdateRoleResponse struct { + Role *Role `json:"role"` +} + +type DeleteRoleResponse struct { + Message string `json:"message"` +} + +type CreatePermissionRequest struct { + Key string `json:"key"` + Description *string `json:"description,omitempty"` + IsSystem bool `json:"is_system"` +} + +type CreatePermissionResponse struct { + Permission *Permission `json:"permission"` +} + +type UpdatePermissionRequest struct { + Description *string `json:"description,omitempty"` +} + +type UpdatePermissionResponse struct { + Permission *Permission `json:"permission"` +} + +type DeletePermissionResponse struct { + Message string `json:"message"` +} + +type AddRolePermissionRequest struct { + PermissionID string `json:"permission_id"` +} + +type AddRolePermissionResponse struct { + Message string `json:"message"` +} + +type ReplaceRolePermissionsRequest struct { + PermissionIDs []string `json:"permission_ids"` +} + +type ReplaceRolePermissionResponse struct { + Message string `json:"message"` } -func (config *AccessControlPluginConfig) ApplyDefaults() {} +type RemoveRolePermissionResponse struct { + Message string `json:"message"` +} + +type AssignUserRoleRequest struct { + RoleID string `json:"role_id"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` +} + +type ReplaceUserRolesRequest struct { + RoleIDs []string `json:"role_ids"` +} + +type ReplaceUserRolesResponse struct { + Message string `json:"message"` +} + +type AssignUserRoleResponse struct { + Message string `json:"message"` +} + +type RemoveUserRoleResponse struct { + Message string `json:"message"` +} + +type CheckUserPermissionsRequest struct { + PermissionKeys []string `json:"permission_keys"` +} + +type CheckUserPermissionsResponse struct { + HasPermissions bool `json:"has_permissions"` +} + +type GetUserEffectivePermissionsResponse struct { + Permissions []UserPermissionInfo `json:"permissions"` +} + +type UserRoleInfo struct { + RoleID string `json:"role_id"` + RoleName string `json:"role_name"` + RoleDescription *string `json:"role_description,omitempty"` + AssignedByUserID *string `json:"assigned_by_user_id,omitempty"` + AssignedAt *time.Time `json:"assigned_at,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` +} + +type PermissionGrantSource struct { + RoleID string `json:"role_id"` + RoleName string `json:"role_name"` + GrantedByUserID *string `json:"granted_by_user_id,omitempty"` + GrantedAt *time.Time `json:"granted_at,omitempty"` +} + +type UserPermissionInfo struct { + PermissionID string `json:"permission_id"` + PermissionKey string `json:"permission_key"` + PermissionDescription *string `json:"permission_description,omitempty"` + GrantedByUserID *string `json:"granted_by_user_id,omitempty"` + GrantedAt *time.Time `json:"granted_at,omitempty"` + Sources []PermissionGrantSource `json:"sources,omitempty"` +} + +type UserWithPermissions struct { + User models.User `json:"user"` + Permissions []UserPermissionInfo `json:"permissions"` +} + +type RoleDetails struct { + Role Role `json:"role"` + Permissions []UserPermissionInfo `json:"permissions"` +} diff --git a/plugins/access-control/usecases/permissions_usecase.go b/plugins/access-control/usecases/permissions_usecase.go new file mode 100644 index 0000000..5f2f246 --- /dev/null +++ b/plugins/access-control/usecases/permissions_usecase.go @@ -0,0 +1,40 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/services" + "github.com/Authula/authula/plugins/access-control/types" +) + +type PermissionsUseCase struct { + service *services.PermissionsService +} + +func NewPermissionsUseCase(service *services.PermissionsService) *PermissionsUseCase { + return &PermissionsUseCase{service: service} +} + +func (u *PermissionsUseCase) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { + return u.service.CreatePermission(ctx, req) +} + +func (u *PermissionsUseCase) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { + return u.service.GetAllPermissions(ctx) +} + +func (u *PermissionsUseCase) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + return u.service.GetPermissionByID(ctx, permissionID) +} + +func (u *PermissionsUseCase) GetPermissionByKey(ctx context.Context, permissionKey string) (*types.Permission, error) { + return u.service.GetPermissionByKey(ctx, permissionKey) +} + +func (u *PermissionsUseCase) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { + return u.service.UpdatePermission(ctx, permissionID, req) +} + +func (u *PermissionsUseCase) DeletePermission(ctx context.Context, permissionID string) error { + return u.service.DeletePermission(ctx, permissionID) +} diff --git a/plugins/access-control/usecases/role_permission_mock_test.go b/plugins/access-control/usecases/role_permission_mock_test.go deleted file mode 100644 index a99c806..0000000 --- a/plugins/access-control/usecases/role_permission_mock_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package usecases - -import ( - "context" - "time" - - "github.com/stretchr/testify/mock" - - "github.com/Authula/authula/plugins/access-control/types" -) - -type MockRolePermissionRepository struct { - mock.Mock -} - -type MockRolePermissionService = MockRolePermissionRepository - -func (m *MockRolePermissionRepository) CreateRole(ctx context.Context, role *types.Role) error { - args := m.Called(ctx, role) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) GetAllRoles(ctx context.Context) ([]types.Role, error) { - args := m.Called(ctx) - if roles := args.Get(0); roles != nil { - return roles.([]types.Role), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) GetRoleByID(ctx context.Context, roleID string) (*types.Role, error) { - args := m.Called(ctx, roleID) - if role := args.Get(0); role != nil { - return role.(*types.Role), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) UpdateRole(ctx context.Context, roleID string, name *string, description *string) (bool, error) { - args := m.Called(ctx, roleID, name, description) - return args.Bool(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) DeleteRole(ctx context.Context, roleID string) (bool, error) { - args := m.Called(ctx, roleID) - return args.Bool(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) CountUserAssignmentsByRoleID(ctx context.Context, roleID string) (int, error) { - args := m.Called(ctx, roleID) - return args.Int(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { - args := m.Called(ctx, roleID) - if perms := args.Get(0); perms != nil { - return perms.([]types.UserPermissionInfo), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) CreatePermission(ctx context.Context, permission *types.Permission) error { - args := m.Called(ctx, permission) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - args := m.Called(ctx) - if perms := args.Get(0); perms != nil { - return perms.([]types.Permission), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { - args := m.Called(ctx, permissionID) - if perm := args.Get(0); perm != nil { - return perm.(*types.Permission), args.Error(1) - } - return nil, args.Error(1) -} - -func (m *MockRolePermissionRepository) UpdatePermission(ctx context.Context, permissionID string, description *string) (bool, error) { - args := m.Called(ctx, permissionID, description) - return args.Bool(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) DeletePermission(ctx context.Context, permissionID string) (bool, error) { - args := m.Called(ctx, permissionID) - return args.Bool(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) CountRoleAssignmentsByPermissionID(ctx context.Context, permissionID string) (int, error) { - args := m.Called(ctx, permissionID) - return args.Int(0), args.Error(1) -} - -func (m *MockRolePermissionRepository) AddRolePermission(ctx context.Context, roleID, permissionID string, grantedByUserID *string) error { - args := m.Called(ctx, roleID, permissionID, grantedByUserID) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) RemoveRolePermission(ctx context.Context, roleID, permissionID string) error { - args := m.Called(ctx, roleID, permissionID) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { - args := m.Called(ctx, roleID, permissionIDs, grantedByUserID) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) AssignUserRole(ctx context.Context, userID, roleID string, assignedByUserID *string, expiresAt *time.Time) error { - args := m.Called(ctx, userID, roleID, assignedByUserID, expiresAt) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) RemoveUserRole(ctx context.Context, userID, roleID string) error { - args := m.Called(ctx, userID, roleID) - return args.Error(0) -} - -func (m *MockRolePermissionRepository) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - args := m.Called(ctx, userID, roleIDs, assignedByUserID) - return args.Error(0) -} diff --git a/plugins/access-control/usecases/role_permission_usecase.go b/plugins/access-control/usecases/role_permission_usecase.go index 8820fe2..d538d11 100644 --- a/plugins/access-control/usecases/role_permission_usecase.go +++ b/plugins/access-control/usecases/role_permission_usecase.go @@ -7,74 +7,26 @@ import ( "github.com/Authula/authula/plugins/access-control/types" ) -type RolePermissionUseCase struct { - service *services.RolePermissionService +type RolePermissionsUseCase struct { + service *services.RolePermissionsService } -func NewRolePermissionUseCase(service *services.RolePermissionService) RolePermissionUseCase { - return RolePermissionUseCase{service: service} +func NewRolePermissionsUseCase(service *services.RolePermissionsService) *RolePermissionsUseCase { + return &RolePermissionsUseCase{service: service} } -func (u RolePermissionUseCase) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { - return u.service.CreateRole(ctx, req) -} - -func (u RolePermissionUseCase) GetAllRoles(ctx context.Context) ([]types.Role, error) { - return u.service.GetAllRoles(ctx) -} - -func (u RolePermissionUseCase) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { - return u.service.GetRoleByID(ctx, roleID) -} - -func (u RolePermissionUseCase) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { - return u.service.UpdateRole(ctx, roleID, req) -} - -func (u RolePermissionUseCase) DeleteRole(ctx context.Context, roleID string) error { - return u.service.DeleteRole(ctx, roleID) -} - -func (u RolePermissionUseCase) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { - return u.service.CreatePermission(ctx, req) -} - -func (u RolePermissionUseCase) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - return u.service.GetAllPermissions(ctx) -} - -func (u RolePermissionUseCase) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { +func (u *RolePermissionsUseCase) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { return u.service.GetRolePermissions(ctx, roleID) } -func (u RolePermissionUseCase) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { - return u.service.UpdatePermission(ctx, permissionID, req) -} - -func (u RolePermissionUseCase) DeletePermission(ctx context.Context, permissionID string) error { - return u.service.DeletePermission(ctx, permissionID) -} - -func (u RolePermissionUseCase) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { +func (u *RolePermissionsUseCase) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { return u.service.AddPermissionToRole(ctx, roleID, permissionID, grantedByUserID) } -func (u RolePermissionUseCase) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { +func (u *RolePermissionsUseCase) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { return u.service.RemovePermissionFromRole(ctx, roleID, permissionID) } -func (u RolePermissionUseCase) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { +func (u *RolePermissionsUseCase) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { return u.service.ReplaceRolePermissions(ctx, roleID, permissionIDs, grantedByUserID) } - -func (u RolePermissionUseCase) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - return u.service.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) -} - -func (u RolePermissionUseCase) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { - return u.service.AssignRoleToUser(ctx, userID, req, assignedByUserID) -} - -func (u RolePermissionUseCase) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { - return u.service.RemoveRoleFromUser(ctx, userID, roleID) -} diff --git a/plugins/access-control/usecases/role_permission_usecase_test.go b/plugins/access-control/usecases/role_permission_usecase_test.go deleted file mode 100644 index 193c47d..0000000 --- a/plugins/access-control/usecases/role_permission_usecase_test.go +++ /dev/null @@ -1,1320 +0,0 @@ -package usecases - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/Authula/authula/plugins/access-control/services" - "github.com/Authula/authula/plugins/access-control/types" -) - -func newRolePermissionUseCase(mockRepo *MockRolePermissionService) RolePermissionUseCase { - return NewRolePermissionUseCase(services.NewRolePermissionService(mockRepo)) -} - -func TestCreateRole(t *testing.T) { - testCases := []struct { - name string - req types.CreateRoleRequest - mockErr error - expectedError string - }{ - { - name: "success with all fields", - req: types.CreateRoleRequest{ - Name: "Admin", - Description: new("Administrator role"), - IsSystem: true, - }, - expectedError: "", - }, - { - name: "success with minimal fields", - req: types.CreateRoleRequest{ - Name: "User", - }, - expectedError: "", - }, - { - name: "empty name", - req: types.CreateRoleRequest{Name: ""}, - expectedError: "bad request", - }, - { - name: "whitespace only name", - req: types.CreateRoleRequest{Name: " "}, - expectedError: "bad request", - }, - { - name: "service error", - req: types.CreateRoleRequest{Name: "Test"}, - mockErr: errors.New("database error"), - expectedError: "database error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("CreateRole", mock.Anything, mock.Anything).Return(tt.mockErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - role, err := uc.CreateRole(context.Background(), tt.req) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, role) - assert.Equal(t, tt.req.Name, role.Name) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, role) - } - }) - } -} - -func TestGetAllRoles(t *testing.T) { - testCases := []struct { - name string - mockResult []types.Role - mockErr error - expectedError string - expectedLength int - }{ - { - name: "success with roles", - mockResult: []types.Role{ - {ID: "1", Name: "Admin"}, - {ID: "2", Name: "User"}, - }, - expectedLength: 2, - }, - { - name: "empty list", - mockResult: []types.Role{}, - expectedLength: 0, - }, - { - name: "service error", - mockErr: errors.New("database error"), - expectedError: "database error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetAllRoles", mock.Anything).Return(tt.mockResult, tt.mockErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - roles, err := uc.GetAllRoles(context.Background()) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.Equal(t, tt.expectedLength, len(roles)) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestGetRoleByID(t *testing.T) { - testCases := []struct { - name string - roleID string - mockRoleResult *types.Role - mockRoleErr error - mockPermErr error - mockPermissions []types.UserPermissionInfo - expectedError string - expectedHasRole bool - }{ - { - name: "success with permissions", - roleID: "role-1", - mockRoleResult: &types.Role{ - ID: "role-1", - Name: "Admin", - }, - mockPermissions: []types.UserPermissionInfo{ - {PermissionID: "perm-1", PermissionKey: "admin:read"}, - }, - expectedHasRole: true, - }, - { - name: "empty roleID", - roleID: "", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - expectedError: "bad request", - }, - { - name: "not found", - roleID: "nonexistent", - mockRoleResult: nil, - expectedError: "not found", - }, - { - name: "service error on role fetch", - roleID: "role-1", - mockRoleErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "service error on permissions fetch", - roleID: "role-1", - mockRoleResult: &types.Role{ - ID: "role-1", - Name: "Admin", - }, - mockPermErr: errors.New("perm db error"), - expectedError: "perm db error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRoleResult, tt.mockRoleErr).Maybe() - mockService.On("GetRolePermissions", mock.Anything, mock.Anything).Return(tt.mockPermissions, tt.mockPermErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - result, err := uc.GetRoleByID(context.Background(), tt.roleID) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, len(tt.mockPermissions), len(result.Permissions)) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, result) - } - }) - } -} - -func TestUpdateRole(t *testing.T) { - testCases := []struct { - name string - roleID string - req types.UpdateRoleRequest - mockRole *types.Role - mockGetErr error - mockUpdateErr error - expectedError string - shouldCallUpdate bool - }{ - { - name: "update name only", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new("NewAdmin"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - shouldCallUpdate: true, - }, - { - name: "update description only", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Description: new("New description"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - shouldCallUpdate: true, - }, - { - name: "update both", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new("NewAdmin"), - Description: new("New description"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - shouldCallUpdate: true, - }, - { - name: "empty roleID", - roleID: "", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - req: types.UpdateRoleRequest{ - Name: new("Admin"), - }, - expectedError: "bad request", - }, - { - name: "no fields provided", - roleID: "role-1", - req: types.UpdateRoleRequest{}, - expectedError: "unprocessable entity", - }, - { - name: "empty name", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new(" "), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - expectedError: "bad request", - }, - { - name: "not found", - roleID: "role-1", - req: types.UpdateRoleRequest{Name: new("Admin")}, - mockRole: nil, - expectedError: "not found", - }, - { - name: "system role protection", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new("NewAdmin"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: true}, - expectedError: "cannot update system role", - }, - { - name: "service error on get", - roleID: "role-1", - req: types.UpdateRoleRequest{Name: new("Admin")}, - mockGetErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "service error on update", - roleID: "role-1", - req: types.UpdateRoleRequest{ - Name: new("NewAdmin"), - }, - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: false}, - mockUpdateErr: errors.New("update failed"), - expectedError: "update failed", - shouldCallUpdate: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockGetErr).Maybe() - mockService.On("UpdateRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(true, tt.mockUpdateErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - role, err := uc.UpdateRole(context.Background(), tt.roleID, tt.req) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, role) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, role) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestDeleteRole(t *testing.T) { - testCases := []struct { - name string - roleID string - mockRole *types.Role - mockGetErr error - mockCountResult int - mockCountErr error - mockDeleteErr error - expectedError string - shouldCallDelete bool - }{ - { - name: "success", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockCountResult: 0, - shouldCallDelete: true, - }, - { - name: "empty roleID", - roleID: "", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - expectedError: "bad request", - }, - { - name: "not found", - roleID: "role-1", - mockRole: nil, - expectedError: "not found", - }, - { - name: "system role protection", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: true}, - expectedError: "cannot update system role", - }, - { - name: "service error on get", - roleID: "role-1", - mockGetErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "has user assignments", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockCountResult: 5, - expectedError: "conflict", - }, - { - name: "service error on count", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockCountErr: errors.New("count error"), - expectedError: "count error", - }, - { - name: "service error on delete", - roleID: "role-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockCountResult: 0, - mockDeleteErr: errors.New("delete failed"), - expectedError: "delete failed", - shouldCallDelete: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockGetErr).Maybe() - mockService.On("CountUserAssignmentsByRoleID", mock.Anything, mock.Anything).Return(tt.mockCountResult, tt.mockCountErr).Maybe() - mockService.On("DeleteRole", mock.Anything, mock.Anything).Return(true, tt.mockDeleteErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.DeleteRole(context.Background(), tt.roleID) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestCreatePermission(t *testing.T) { - testCases := []struct { - name string - req types.CreatePermissionRequest - mockErr error - expectedError string - }{ - { - name: "success with all fields", - req: types.CreatePermissionRequest{ - Key: "admin:write", - Description: new("Admin write access"), - IsSystem: true, - }, - expectedError: "", - }, - { - name: "success with minimal fields", - req: types.CreatePermissionRequest{ - Key: "user:read", - }, - expectedError: "", - }, - { - name: "empty key", - req: types.CreatePermissionRequest{Key: ""}, - expectedError: "bad request", - }, - { - name: "whitespace only key", - req: types.CreatePermissionRequest{Key: " "}, - expectedError: "bad request", - }, - { - name: "service error", - req: types.CreatePermissionRequest{Key: "admin:read"}, - mockErr: errors.New("database error"), - expectedError: "database error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("CreatePermission", mock.Anything, mock.Anything).Return(tt.mockErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - perm, err := uc.CreatePermission(context.Background(), tt.req) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, perm) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, perm) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestGetAllPermissions(t *testing.T) { - testCases := []struct { - name string - mockResult []types.Permission - mockErr error - expectedError string - expectedLength int - }{ - { - name: "success with permissions", - mockResult: []types.Permission{ - {ID: "1", Key: "admin:read"}, - {ID: "2", Key: "user:read"}, - }, - expectedLength: 2, - }, - { - name: "empty list", - mockResult: []types.Permission{}, - expectedLength: 0, - }, - { - name: "service error", - mockErr: errors.New("database error"), - expectedError: "database error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetAllPermissions", mock.Anything).Return(tt.mockResult, tt.mockErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - perms, err := uc.GetAllPermissions(context.Background()) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.Equal(t, tt.expectedLength, len(perms)) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestGetRolePermissions(t *testing.T) { - testCases := []struct { - name string - roleID string - mockRole *types.Role - mockRoleErr error - mockPermissions []types.UserPermissionInfo - mockPermErr error - expectedError string - expectedLength int - }{ - { - name: "success with permissions", - roleID: "role-1", - mockRole: &types.Role{ - ID: "role-1", - Name: "Admin", - }, - mockPermissions: []types.UserPermissionInfo{ - {PermissionID: "perm-1", PermissionKey: "admin:read"}, - {PermissionID: "perm-2", PermissionKey: "admin:write"}, - }, - expectedLength: 2, - }, - { - name: "empty roleID", - roleID: "", - expectedError: "unprocessable entity", - }, - { - name: "whitespace roleID", - roleID: " ", - expectedError: "unprocessable entity", - }, - { - name: "role not found", - roleID: "missing-role", - mockRole: nil, - expectedError: "not found", - }, - { - name: "role lookup error", - roleID: "role-1", - mockRoleErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "permissions lookup error", - roleID: "role-1", - mockRole: &types.Role{ - ID: "role-1", - Name: "Admin", - }, - mockPermErr: errors.New("perm db error"), - expectedError: "perm db error", - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockRoleErr).Maybe() - mockService.On("GetRolePermissions", mock.Anything, mock.Anything).Return(tt.mockPermissions, tt.mockPermErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - permissions, err := uc.GetRolePermissions(context.Background(), tt.roleID) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.Equal(t, tt.expectedLength, len(permissions)) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, permissions) - } - }) - } -} - -func TestUpdatePermission(t *testing.T) { - testCases := []struct { - name string - permissionID string - req types.UpdatePermissionRequest - mockPerm *types.Permission - mockGetErr error - mockUpdateErr error - expectedError string - shouldCallUpdate bool - }{ - { - name: "success", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: new("New description"), - }, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: false}, - shouldCallUpdate: true, - }, - { - name: "empty permissionID", - permissionID: "", - expectedError: "unprocessable entity", - }, - { - name: "whitespace permissionID", - permissionID: " ", - expectedError: "unprocessable entity", - }, - { - name: "nil description", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: nil, - }, - expectedError: "unprocessable entity", - }, - { - name: "empty description", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: new(" "), - }, - expectedError: "unprocessable entity", - }, - { - name: "not found", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{Description: new("New")}, - mockPerm: nil, - expectedError: "not found", - }, - { - name: "system permission protection", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: new("New"), - }, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: true}, - expectedError: "bad request", - }, - { - name: "service error on get", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{Description: new("New")}, - mockGetErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "service error on update", - permissionID: "perm-1", - req: types.UpdatePermissionRequest{ - Description: new("New"), - }, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: false}, - mockUpdateErr: errors.New("update failed"), - expectedError: "update failed", - shouldCallUpdate: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetPermissionByID", mock.Anything, mock.Anything).Return(tt.mockPerm, tt.mockGetErr).Maybe() - mockService.On("UpdatePermission", mock.Anything, mock.Anything, mock.Anything).Return(true, tt.mockUpdateErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - perm, err := uc.UpdatePermission(context.Background(), tt.permissionID, tt.req) - - if tt.expectedError == "" { - assert.NoError(t, err) - assert.NotNil(t, perm) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - assert.Nil(t, perm) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestDeletePermission(t *testing.T) { - testCases := []struct { - name string - permissionID string - mockPerm *types.Permission - mockGetErr error - mockCountResult int - mockCountErr error - mockDeleteErr error - expectedError string - shouldCallDelete bool - }{ - { - name: "success", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockCountResult: 0, - shouldCallDelete: true, - }, - { - name: "empty permissionID", - permissionID: "", - expectedError: "bad request", - }, - { - name: "whitespace permissionID", - permissionID: " ", - expectedError: "bad request", - }, - { - name: "not found", - permissionID: "perm-1", - mockPerm: nil, - expectedError: "not found", - }, - { - name: "system permission protection", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: true}, - expectedError: "bad request", - }, - { - name: "service error on get", - permissionID: "perm-1", - mockGetErr: errors.New("db error"), - expectedError: "db error", - }, - { - name: "has role assignments", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockCountResult: 3, - expectedError: "conflict", - }, - { - name: "service error on count", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockCountErr: errors.New("count error"), - expectedError: "count error", - }, - { - name: "service error on delete", - permissionID: "perm-1", - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockCountResult: 0, - mockDeleteErr: errors.New("delete failed"), - expectedError: "delete failed", - shouldCallDelete: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetPermissionByID", mock.Anything, mock.Anything).Return(tt.mockPerm, tt.mockGetErr).Maybe() - mockService.On("CountRoleAssignmentsByPermissionID", mock.Anything, mock.Anything).Return(tt.mockCountResult, tt.mockCountErr).Maybe() - mockService.On("DeletePermission", mock.Anything, mock.Anything).Return(true, tt.mockDeleteErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.DeletePermission(context.Background(), tt.permissionID) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestAddPermissionToRole(t *testing.T) { - testCases := []struct { - name string - roleID string - permissionID string - mockRole *types.Role - mockRoleErr error - mockPerm *types.Permission - mockPermErr error - mockAddErr error - expectedError string - shouldCallAddRole bool - }{ - { - name: "success", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - shouldCallAddRole: true, - }, - { - name: "empty roleID", - roleID: "", - permissionID: "perm-1", - expectedError: "bad request", - }, - { - name: "empty permissionID", - roleID: "role-1", - permissionID: "", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - permissionID: "perm-1", - expectedError: "bad request", - }, - { - name: "whitespace permissionID", - roleID: "role-1", - permissionID: " ", - expectedError: "bad request", - }, - { - name: "not found", - roleID: "role-1", - permissionID: "perm-1", - mockRole: nil, - expectedError: "not found", - }, - { - name: "not found", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: nil, - expectedError: "not found", - }, - { - name: "system role protection", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: true}, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: false}, - expectedError: "bad request", - }, - { - name: "system permission protection", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: true}, - expectedError: "bad request", - }, - { - name: "service error on add", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockAddErr: errors.New("add failed"), - expectedError: "add failed", - shouldCallAddRole: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockRoleErr).Maybe() - mockService.On("GetPermissionByID", mock.Anything, mock.Anything).Return(tt.mockPerm, tt.mockPermErr).Maybe() - mockService.On("AddRolePermission", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.mockAddErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.AddPermissionToRole(context.Background(), tt.roleID, tt.permissionID, nil) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestRemovePermissionFromRole(t *testing.T) { - testCases := []struct { - name string - roleID string - permissionID string - mockRole *types.Role - mockRoleErr error - mockPerm *types.Permission - mockPermErr error - mockRemoveErr error - expectedError string - shouldCallRemoveRole bool - }{ - { - name: "success", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - shouldCallRemoveRole: true, - }, - { - name: "empty roleID", - roleID: "", - permissionID: "perm-1", - expectedError: "unprocessable entity", - }, - { - name: "empty permissionID", - roleID: "role-1", - permissionID: "", - expectedError: "unprocessable entity", - }, - { - name: "system role protection", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "Admin", IsSystem: true}, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: false}, - expectedError: "bad request", - }, - { - name: "system permission protection", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "admin:read", IsSystem: true}, - expectedError: "bad request", - }, - { - name: "service error on remove", - roleID: "role-1", - permissionID: "perm-1", - mockRole: &types.Role{ID: "role-1", Name: "User", IsSystem: false}, - mockPerm: &types.Permission{ID: "perm-1", Key: "user:read", IsSystem: false}, - mockRemoveErr: errors.New("remove failed"), - expectedError: "remove failed", - shouldCallRemoveRole: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("GetRoleByID", mock.Anything, mock.Anything).Return(tt.mockRole, tt.mockRoleErr).Maybe() - mockService.On("GetPermissionByID", mock.Anything, mock.Anything).Return(tt.mockPerm, tt.mockPermErr).Maybe() - mockService.On("RemoveRolePermission", mock.Anything, mock.Anything, mock.Anything).Return(tt.mockRemoveErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.RemovePermissionFromRole(context.Background(), tt.roleID, tt.permissionID) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestReplaceRolePermissions(t *testing.T) { - testCases := []struct { - name string - roleID string - permissionIDs []string - mockReplaceErr error - expectedError string - shouldCallReplace bool - }{ - { - name: "success with permissions", - roleID: "role-1", - permissionIDs: []string{"perm-1", "perm-2"}, - shouldCallReplace: true, - }, - { - name: "success with empty list", - roleID: "role-1", - permissionIDs: []string{}, - shouldCallReplace: true, - }, - { - name: "empty roleID", - roleID: "", - permissionIDs: []string{"perm-1"}, - expectedError: "bad request", - }, - { - name: "whitespace roleID", - roleID: " ", - permissionIDs: []string{"perm-1"}, - expectedError: "bad request", - }, - { - name: "deduplication", - roleID: "role-1", - permissionIDs: []string{"perm-1", "perm-1", "perm-2"}, - shouldCallReplace: true, - }, - { - name: "filters empty strings", - roleID: "role-1", - permissionIDs: []string{"perm-1", "", " ", "perm-2"}, - shouldCallReplace: true, - }, - { - name: "service error", - roleID: "role-1", - permissionIDs: []string{"perm-1"}, - mockReplaceErr: errors.New("db error"), - expectedError: "db error", - shouldCallReplace: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("ReplaceRolePermissions", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.mockReplaceErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.ReplaceRolePermissions(context.Background(), tt.roleID, tt.permissionIDs, nil) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestReplaceUserRoles(t *testing.T) { - testCases := []struct { - name string - userID string - roleIDs []string - mockReplaceErr error - expectedError string - shouldCallReplace bool - }{ - { - name: "success with roles", - userID: "user-1", - roleIDs: []string{"role-1", "role-2"}, - shouldCallReplace: true, - }, - { - name: "success with empty list", - userID: "user-1", - roleIDs: []string{}, - shouldCallReplace: true, - }, - { - name: "empty userID", - userID: "", - roleIDs: []string{"role-1"}, - expectedError: "bad request", - }, - { - name: "whitespace userID", - userID: " ", - roleIDs: []string{"role-1"}, - expectedError: "bad request", - }, - { - name: "deduplication", - userID: "user-1", - roleIDs: []string{"role-1", "role-1", "role-2"}, - shouldCallReplace: true, - }, - { - name: "filters empty strings", - userID: "user-1", - roleIDs: []string{"role-1", "", " ", "role-2"}, - shouldCallReplace: true, - }, - { - name: "service error", - userID: "user-1", - roleIDs: []string{"role-1"}, - mockReplaceErr: errors.New("db error"), - expectedError: "db error", - shouldCallReplace: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("ReplaceUserRoles", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.mockReplaceErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.ReplaceUserRoles(context.Background(), tt.userID, tt.roleIDs, nil) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestAssignRoleToUser(t *testing.T) { - futureTime := time.Now().UTC().Add(24 * time.Hour) - pastTime := time.Now().UTC().Add(-24 * time.Hour) - - testCases := []struct { - name string - userID string - req types.AssignUserRoleRequest - mockAssignErr error - expectedError string - shouldCallAssign bool - }{ - { - name: "success without expiry", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "role-1", - }, - shouldCallAssign: true, - }, - { - name: "success with future expiry", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "role-1", - ExpiresAt: &futureTime, - }, - shouldCallAssign: true, - }, - { - name: "empty userID", - userID: "", - req: types.AssignUserRoleRequest{RoleID: "role-1"}, - expectedError: "unprocessable entity", - }, - { - name: "whitespace userID", - userID: " ", - req: types.AssignUserRoleRequest{RoleID: "role-1"}, - expectedError: "unprocessable entity", - }, - { - name: "empty roleID", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "", - }, - expectedError: "unprocessable entity", - }, - { - name: "past expiry date", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "role-1", - ExpiresAt: &pastTime, - }, - expectedError: "bad request", - }, - { - name: "service error on assign", - userID: "user-1", - req: types.AssignUserRoleRequest{ - RoleID: "role-1", - }, - mockAssignErr: errors.New("assign failed"), - expectedError: "assign failed", - shouldCallAssign: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("AssignUserRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.mockAssignErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.AssignRoleToUser(context.Background(), tt.userID, tt.req, nil) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} - -func TestRemoveRoleFromUser(t *testing.T) { - testCases := []struct { - name string - userID string - roleID string - mockRemoveErr error - expectedError string - shouldCallRemove bool - }{ - { - name: "success", - userID: "user-1", - roleID: "role-1", - shouldCallRemove: true, - }, - { - name: "empty userID", - userID: "", - roleID: "role-1", - expectedError: "bad request", - }, - { - name: "empty roleID", - userID: "user-1", - roleID: "", - expectedError: "bad request", - }, - { - name: "whitespace userID", - userID: " ", - roleID: "role-1", - expectedError: "bad request", - }, - { - name: "whitespace roleID", - userID: "user-1", - roleID: " ", - expectedError: "bad request", - }, - { - name: "service error", - userID: "user-1", - roleID: "role-1", - mockRemoveErr: errors.New("remove failed"), - expectedError: "remove failed", - shouldCallRemove: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - mockService := &MockRolePermissionService{} - mockService.On("RemoveUserRole", mock.Anything, mock.Anything, mock.Anything).Return(tt.mockRemoveErr).Maybe() - - uc := newRolePermissionUseCase(mockService) - err := uc.RemoveRoleFromUser(context.Background(), tt.userID, tt.roleID) - - if tt.expectedError == "" { - assert.NoError(t, err) - } else { - assert.Error(t, err) - assert.Equal(t, tt.expectedError, err.Error()) - } - - mockService.AssertExpectations(t) - }) - } -} diff --git a/plugins/access-control/usecases/roles_usecase.go b/plugins/access-control/usecases/roles_usecase.go new file mode 100644 index 0000000..05633e0 --- /dev/null +++ b/plugins/access-control/usecases/roles_usecase.go @@ -0,0 +1,40 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/services" + "github.com/Authula/authula/plugins/access-control/types" +) + +type RolesUseCase struct { + service *services.RolesService +} + +func NewRolesUseCase(service *services.RolesService) *RolesUseCase { + return &RolesUseCase{service: service} +} + +func (u *RolesUseCase) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { + return u.service.CreateRole(ctx, req) +} + +func (u *RolesUseCase) GetAllRoles(ctx context.Context) ([]types.Role, error) { + return u.service.GetAllRoles(ctx) +} + +func (u *RolesUseCase) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + return u.service.GetRoleByName(ctx, roleName) +} + +func (u *RolesUseCase) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { + return u.service.GetRoleByID(ctx, roleID) +} + +func (u *RolesUseCase) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { + return u.service.UpdateRole(ctx, roleID, req) +} + +func (u *RolesUseCase) DeleteRole(ctx context.Context, roleID string) error { + return u.service.DeleteRole(ctx, roleID) +} diff --git a/plugins/access-control/usecases/usecases.go b/plugins/access-control/usecases/usecases.go index 44fef05..42ab788 100644 --- a/plugins/access-control/usecases/usecases.go +++ b/plugins/access-control/usecases/usecases.go @@ -7,109 +7,139 @@ import ( ) type UseCases struct { - rolePermission RolePermissionUseCase - userAccess UserRolesUseCase -} - -func NewAccessControlUseCases(rolePermission RolePermissionUseCase, userAccess UserRolesUseCase) *UseCases { + roles *RolesUseCase + permissions *PermissionsUseCase + rolePermissions *RolePermissionsUseCase + userRoles *UserRolesUseCase + userPermissions *UserPermissionsUseCase +} + +func NewAccessControlUseCases( + roles *RolesUseCase, + permissions *PermissionsUseCase, + rolePermissions *RolePermissionsUseCase, + userRoles *UserRolesUseCase, + userPermissions *UserPermissionsUseCase, +) *UseCases { return &UseCases{ - rolePermission: rolePermission, - userAccess: userAccess, + roles: roles, + permissions: permissions, + rolePermissions: rolePermissions, + userRoles: userRoles, + userPermissions: userPermissions, } } -func (u *UseCases) RolePermissionUseCase() RolePermissionUseCase { - return u.rolePermission +func (u *UseCases) RolesUseCase() *RolesUseCase { + return u.roles } -func (u *UseCases) UserAccessUseCase() UserRolesUseCase { - return u.userAccess -} - -func (u *UseCases) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { - return u.rolePermission.CreatePermission(ctx, req) +func (u *UseCases) PermissionsUseCase() *PermissionsUseCase { + return u.permissions } -func (u *UseCases) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { - return u.rolePermission.GetAllPermissions(ctx) +func (u *UseCases) RolePermissionsUseCase() *RolePermissionsUseCase { + return u.rolePermissions } -func (u *UseCases) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { - return u.rolePermission.GetRolePermissions(ctx, roleID) +func (u *UseCases) UserRolesUseCase() *UserRolesUseCase { + return u.userRoles } -func (u *UseCases) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { - return u.rolePermission.UpdatePermission(ctx, permissionID, req) +func (u *UseCases) UserPermissionsUseCase() *UserPermissionsUseCase { + return u.userPermissions } -func (u *UseCases) DeletePermission(ctx context.Context, permissionID string) error { - return u.rolePermission.DeletePermission(ctx, permissionID) -} +// Roles func (u *UseCases) CreateRole(ctx context.Context, req types.CreateRoleRequest) (*types.Role, error) { - return u.rolePermission.CreateRole(ctx, req) + return u.roles.CreateRole(ctx, req) } func (u *UseCases) GetAllRoles(ctx context.Context) ([]types.Role, error) { - return u.rolePermission.GetAllRoles(ctx) + return u.roles.GetAllRoles(ctx) +} + +func (u *UseCases) GetRoleByName(ctx context.Context, roleName string) (*types.Role, error) { + return u.roles.GetRoleByName(ctx, roleName) } func (u *UseCases) GetRoleByID(ctx context.Context, roleID string) (*types.RoleDetails, error) { - return u.rolePermission.GetRoleByID(ctx, roleID) + return u.roles.GetRoleByID(ctx, roleID) } func (u *UseCases) UpdateRole(ctx context.Context, roleID string, req types.UpdateRoleRequest) (*types.Role, error) { - return u.rolePermission.UpdateRole(ctx, roleID, req) + return u.roles.UpdateRole(ctx, roleID, req) } func (u *UseCases) DeleteRole(ctx context.Context, roleID string) error { - return u.rolePermission.DeleteRole(ctx, roleID) + return u.roles.DeleteRole(ctx, roleID) } -func (u *UseCases) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { - return u.rolePermission.AddPermissionToRole(ctx, roleID, permissionID, grantedByUserID) +// Permissions + +func (u *UseCases) CreatePermission(ctx context.Context, req types.CreatePermissionRequest) (*types.Permission, error) { + return u.permissions.CreatePermission(ctx, req) } -func (u *UseCases) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { - return u.rolePermission.RemovePermissionFromRole(ctx, roleID, permissionID) +func (u *UseCases) GetAllPermissions(ctx context.Context) ([]types.Permission, error) { + return u.permissions.GetAllPermissions(ctx) } -func (u *UseCases) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { - return u.rolePermission.ReplaceRolePermissions(ctx, roleID, permissionIDs, grantedByUserID) +func (u *UseCases) GetPermissionByID(ctx context.Context, permissionID string) (*types.Permission, error) { + return u.permissions.GetPermissionByID(ctx, permissionID) } -func (u *UseCases) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { - return u.rolePermission.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) +func (u *UseCases) UpdatePermission(ctx context.Context, permissionID string, req types.UpdatePermissionRequest) (*types.Permission, error) { + return u.permissions.UpdatePermission(ctx, permissionID, req) } -func (u *UseCases) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { - return u.rolePermission.AssignRoleToUser(ctx, userID, req, assignedByUserID) +func (u *UseCases) DeletePermission(ctx context.Context, permissionID string) error { + return u.permissions.DeletePermission(ctx, permissionID) } -func (u *UseCases) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { - return u.rolePermission.RemoveRoleFromUser(ctx, userID, roleID) +// Role Permissions + +func (u *UseCases) GetRolePermissions(ctx context.Context, roleID string) ([]types.UserPermissionInfo, error) { + return u.rolePermissions.GetRolePermissions(ctx, roleID) } +func (u *UseCases) AddPermissionToRole(ctx context.Context, roleID string, permissionID string, grantedByUserID *string) error { + return u.rolePermissions.AddPermissionToRole(ctx, roleID, permissionID, grantedByUserID) +} + +func (u *UseCases) RemovePermissionFromRole(ctx context.Context, roleID string, permissionID string) error { + return u.rolePermissions.RemovePermissionFromRole(ctx, roleID, permissionID) +} + +func (u *UseCases) ReplaceRolePermissions(ctx context.Context, roleID string, permissionIDs []string, grantedByUserID *string) error { + return u.rolePermissions.ReplaceRolePermissions(ctx, roleID, permissionIDs, grantedByUserID) +} + +// User Roles + func (u *UseCases) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - return u.userAccess.GetUserRoles(ctx, userID) + return u.userRoles.GetUserRoles(ctx, userID) } -func (u *UseCases) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - return u.userAccess.GetUserEffectivePermissions(ctx, userID) +func (u *UseCases) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { + return u.userRoles.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) } -func (u *UseCases) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - return u.userAccess.HasPermissions(ctx, userID, requiredPermissions) +func (u *UseCases) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { + return u.userRoles.AssignRoleToUser(ctx, userID, req, assignedByUserID) } -func (u *UseCases) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return u.userAccess.GetUserWithRolesByID(ctx, userID) +func (u *UseCases) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { + return u.userRoles.RemoveRoleFromUser(ctx, userID, roleID) } -func (u *UseCases) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - return u.userAccess.GetUserWithPermissionsByID(ctx, userID) +// User Permissions + +func (u *UseCases) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + return u.userPermissions.GetUserPermissions(ctx, userID) } -func (u *UseCases) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { - return u.userAccess.GetUserAuthorizationProfile(ctx, userID) +func (u *UseCases) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + return u.userPermissions.HasPermissions(ctx, userID, permissionKeys) } diff --git a/plugins/access-control/usecases/user_access_usecase.go b/plugins/access-control/usecases/user_access_usecase.go deleted file mode 100644 index 0fc5b33..0000000 --- a/plugins/access-control/usecases/user_access_usecase.go +++ /dev/null @@ -1,40 +0,0 @@ -package usecases - -import ( - "context" - - "github.com/Authula/authula/plugins/access-control/services" - "github.com/Authula/authula/plugins/access-control/types" -) - -type UserRolesUseCase struct { - service *services.UserAccessService -} - -func NewUserRolesUseCase(service *services.UserAccessService) UserRolesUseCase { - return UserRolesUseCase{service: service} -} - -func (u UserRolesUseCase) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { - return u.service.GetUserRoles(ctx, userID) -} - -func (u UserRolesUseCase) GetUserWithRolesByID(ctx context.Context, userID string) (*types.UserWithRoles, error) { - return u.service.GetUserWithRolesByID(ctx, userID) -} - -func (u UserRolesUseCase) GetUserWithPermissionsByID(ctx context.Context, userID string) (*types.UserWithPermissions, error) { - return u.service.GetUserWithPermissionsByID(ctx, userID) -} - -func (u UserRolesUseCase) GetUserAuthorizationProfile(ctx context.Context, userID string) (*types.UserAuthorizationProfile, error) { - return u.service.GetUserAuthorizationProfile(ctx, userID) -} - -func (u UserRolesUseCase) GetUserEffectivePermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { - return u.service.GetUserEffectivePermissions(ctx, userID) -} - -func (u UserRolesUseCase) HasPermissions(ctx context.Context, userID string, requiredPermissions []string) (bool, error) { - return u.service.HasPermissions(ctx, userID, requiredPermissions) -} diff --git a/plugins/access-control/usecases/user_access_usecase_test.go b/plugins/access-control/usecases/user_access_usecase_test.go deleted file mode 100644 index ef9814c..0000000 --- a/plugins/access-control/usecases/user_access_usecase_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package usecases - -import ( - "context" - "testing" - - "github.com/Authula/authula/plugins/access-control/services" - "github.com/Authula/authula/plugins/access-control/types" -) - -type stubUserAccessRepo struct { - permissions []types.UserPermissionInfo -} - -func (s *stubUserAccessRepo) GetUserRoles(_ context.Context, _ string) ([]types.UserRoleInfo, error) { - return []types.UserRoleInfo{{RoleID: "role-1", RoleName: "admin"}}, nil -} - -func (s *stubUserAccessRepo) GetUserEffectivePermissions(_ context.Context, _ string) ([]types.UserPermissionInfo, error) { - return s.permissions, nil -} - -func (s *stubUserAccessRepo) GetUserWithRolesByID(_ context.Context, _ string) (*types.UserWithRoles, error) { - return &types.UserWithRoles{}, nil -} - -func (s *stubUserAccessRepo) GetUserWithPermissionsByID(_ context.Context, _ string) (*types.UserWithPermissions, error) { - return &types.UserWithPermissions{}, nil -} - -func TestUserRolesUseCaseHasPermissionsPassThrough(t *testing.T) { - repo := &stubUserAccessRepo{permissions: []types.UserPermissionInfo{{PermissionKey: "users.read"}}} - uc := NewUserRolesUseCase(services.NewUserAccessService(repo)) - - ok, err := uc.HasPermissions(context.Background(), "user-1", []string{"users.read"}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !ok { - t.Fatal("expected permission check to pass") - } -} diff --git a/plugins/access-control/usecases/user_permissions_usecase.go b/plugins/access-control/usecases/user_permissions_usecase.go new file mode 100644 index 0000000..c35bfe5 --- /dev/null +++ b/plugins/access-control/usecases/user_permissions_usecase.go @@ -0,0 +1,24 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/services" + "github.com/Authula/authula/plugins/access-control/types" +) + +type UserPermissionsUseCase struct { + service *services.UserPermissionsService +} + +func NewUserPermissionsUseCase(service *services.UserPermissionsService) *UserPermissionsUseCase { + return &UserPermissionsUseCase{service: service} +} + +func (u *UserPermissionsUseCase) GetUserPermissions(ctx context.Context, userID string) ([]types.UserPermissionInfo, error) { + return u.service.GetUserPermissions(ctx, userID) +} + +func (u *UserPermissionsUseCase) HasPermissions(ctx context.Context, userID string, permissionKeys []string) (bool, error) { + return u.service.HasPermissions(ctx, userID, permissionKeys) +} diff --git a/plugins/access-control/usecases/user_roles_usecase.go b/plugins/access-control/usecases/user_roles_usecase.go new file mode 100644 index 0000000..6f4c25e --- /dev/null +++ b/plugins/access-control/usecases/user_roles_usecase.go @@ -0,0 +1,32 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/access-control/services" + "github.com/Authula/authula/plugins/access-control/types" +) + +type UserRolesUseCase struct { + service *services.UserRolesService +} + +func NewUserRolesUseCase(service *services.UserRolesService) *UserRolesUseCase { + return &UserRolesUseCase{service: service} +} + +func (u *UserRolesUseCase) GetUserRoles(ctx context.Context, userID string) ([]types.UserRoleInfo, error) { + return u.service.GetUserRoles(ctx, userID) +} + +func (u *UserRolesUseCase) ReplaceUserRoles(ctx context.Context, userID string, roleIDs []string, assignedByUserID *string) error { + return u.service.ReplaceUserRoles(ctx, userID, roleIDs, assignedByUserID) +} + +func (u *UserRolesUseCase) AssignRoleToUser(ctx context.Context, userID string, req types.AssignUserRoleRequest, assignedByUserID *string) error { + return u.service.AssignRoleToUser(ctx, userID, req, assignedByUserID) +} + +func (u *UserRolesUseCase) RemoveRoleFromUser(ctx context.Context, userID string, roleID string) error { + return u.service.RemoveRoleFromUser(ctx, userID, roleID) +}