From ad2b43a4259821426982e3b592b99810c393c4cf Mon Sep 17 00:00:00 2001 From: Tanvir Ahmed Date: Tue, 24 Mar 2026 15:31:03 +0000 Subject: [PATCH 1/2] feat: Organizations Plugin chore: Wrote tests for organization handlers and more chore: Implemented more tests for the plugin and improved naming of structs and more chore: Many more updates across the plugin chore: Updated tests chore: Updated --- .agents/skills/plugin-architecture/SKILL.md | 7 +- .../examples/todo_plugin.go | 65 -- config.example.toml | 44 - internal/bootstrap/plugin_factory.go | 18 + internal/errors/errors.go | 35 + models/plugin.go | 1 + plugins/admin/services/state_service_test.go | 2 +- .../handlers/verify_handler_test.go | 4 +- plugins/organizations/api.go | 115 ++ plugins/organizations/constants/constants.go | 5 + plugins/organizations/events/events.go | 15 + .../handlers/organization_handlers.go | 157 +++ .../handlers/organization_handlers_test.go | 458 ++++++++ .../organization_invitation_handlers.go | 191 +++ .../organization_invitation_handlers_test.go | 1019 +++++++++++++++++ .../handlers/organization_member_handlers.go | 166 +++ .../organization_member_handlers_test.go | 550 +++++++++ .../handlers/organization_team_handlers.go | 136 +++ .../organization_team_handlers_test.go | 597 ++++++++++ .../organization_team_member_handlers.go | 134 +++ .../organization_team_member_handlers_test.go | 648 +++++++++++ plugins/organizations/hooks.go | 203 ++++ plugins/organizations/hooks_test.go | 149 +++ plugins/organizations/migrations.go | 330 ++++++ plugins/organizations/plugin.go | 115 ++ .../bun_organization_invitation_repository.go | 98 ++ ...organization_invitation_repository_test.go | 307 +++++ .../bun_organization_member_repository.go | 91 ++ ...bun_organization_member_repository_test.go | 315 +++++ .../bun_organization_repository.go | 90 ++ .../bun_organization_repository_test.go | 257 +++++ ...bun_organization_team_member_repository.go | 75 ++ ...rganization_team_member_repository_test.go | 1 + .../bun_organization_team_repository.go | 92 ++ .../bun_organization_team_repository_test.go | 438 +++++++ .../organizations/repositories/interfaces.go | 58 + plugins/organizations/routes.go | 167 +++ .../organization_invitation_service.go | 572 +++++++++ .../organization_invitation_service_test.go | 703 ++++++++++++ .../services/organization_member_service.go | 209 ++++ .../organization_member_service_test.go | 800 +++++++++++++ .../services/organization_service.go | 222 ++++ .../services/organization_service_test.go | 374 ++++++ .../services/organization_team_service.go | 361 ++++++ .../organization_team_service_test.go | 548 +++++++++ plugins/organizations/tests/test_helpers.go | 503 ++++++++ plugins/organizations/types/models.go | 93 ++ plugins/organizations/types/requests.go | 52 + plugins/organizations/types/types.go | 66 ++ .../organization_invitation_usecase.go | 40 + .../usecases/organization_member_usecase.go | 36 + .../usecases/organization_team_usecase.go | 48 + .../usecases/organization_usecase.go | 36 + 53 files changed, 11701 insertions(+), 115 deletions(-) delete mode 100644 .agents/skills/plugin-architecture/examples/todo_plugin.go create mode 100644 internal/errors/errors.go create mode 100644 plugins/organizations/api.go create mode 100644 plugins/organizations/constants/constants.go create mode 100644 plugins/organizations/events/events.go create mode 100644 plugins/organizations/handlers/organization_handlers.go create mode 100644 plugins/organizations/handlers/organization_handlers_test.go create mode 100644 plugins/organizations/handlers/organization_invitation_handlers.go create mode 100644 plugins/organizations/handlers/organization_invitation_handlers_test.go create mode 100644 plugins/organizations/handlers/organization_member_handlers.go create mode 100644 plugins/organizations/handlers/organization_member_handlers_test.go create mode 100644 plugins/organizations/handlers/organization_team_handlers.go create mode 100644 plugins/organizations/handlers/organization_team_handlers_test.go create mode 100644 plugins/organizations/handlers/organization_team_member_handlers.go create mode 100644 plugins/organizations/handlers/organization_team_member_handlers_test.go create mode 100644 plugins/organizations/hooks.go create mode 100644 plugins/organizations/hooks_test.go create mode 100644 plugins/organizations/migrations.go create mode 100644 plugins/organizations/plugin.go create mode 100644 plugins/organizations/repositories/bun_organization_invitation_repository.go create mode 100644 plugins/organizations/repositories/bun_organization_invitation_repository_test.go create mode 100644 plugins/organizations/repositories/bun_organization_member_repository.go create mode 100644 plugins/organizations/repositories/bun_organization_member_repository_test.go create mode 100644 plugins/organizations/repositories/bun_organization_repository.go create mode 100644 plugins/organizations/repositories/bun_organization_repository_test.go create mode 100644 plugins/organizations/repositories/bun_organization_team_member_repository.go create mode 100644 plugins/organizations/repositories/bun_organization_team_member_repository_test.go create mode 100644 plugins/organizations/repositories/bun_organization_team_repository.go create mode 100644 plugins/organizations/repositories/bun_organization_team_repository_test.go create mode 100644 plugins/organizations/repositories/interfaces.go create mode 100644 plugins/organizations/routes.go create mode 100644 plugins/organizations/services/organization_invitation_service.go create mode 100644 plugins/organizations/services/organization_invitation_service_test.go create mode 100644 plugins/organizations/services/organization_member_service.go create mode 100644 plugins/organizations/services/organization_member_service_test.go create mode 100644 plugins/organizations/services/organization_service.go create mode 100644 plugins/organizations/services/organization_service_test.go create mode 100644 plugins/organizations/services/organization_team_service.go create mode 100644 plugins/organizations/services/organization_team_service_test.go create mode 100644 plugins/organizations/tests/test_helpers.go create mode 100644 plugins/organizations/types/models.go create mode 100644 plugins/organizations/types/requests.go create mode 100644 plugins/organizations/types/types.go create mode 100644 plugins/organizations/usecases/organization_invitation_usecase.go create mode 100644 plugins/organizations/usecases/organization_member_usecase.go create mode 100644 plugins/organizations/usecases/organization_team_usecase.go create mode 100644 plugins/organizations/usecases/organization_usecase.go diff --git a/.agents/skills/plugin-architecture/SKILL.md b/.agents/skills/plugin-architecture/SKILL.md index fb441f66..6397366f 100644 --- a/.agents/skills/plugin-architecture/SKILL.md +++ b/.agents/skills/plugin-architecture/SKILL.md @@ -26,6 +26,7 @@ description: Build pluggable authentication features using the plugin system wit ## Pattern Plugin lifecycle: + - Define metadata (ID, name, version) - Implement Init to retrieve services, create repositories/services - Implement optional Routes, Migrations, Middleware @@ -34,8 +35,9 @@ Plugin lifecycle: ## Example -See [examples/todo_plugin.go](examples/todo_plugin.go) for: -- TodosPlugin metadata and Init pattern +See [plugins/email-password/plugin.go](../../../plugins/email-password/plugin.go) for: + +- EmailPasswordPlugin metadata and Init pattern - Service retrieval and registration - Routes definition and handler wiring - Lifecycle management (Close) @@ -52,7 +54,6 @@ See [examples/todo_plugin.go](examples/todo_plugin.go) for: ## References -- [models/plugin.go](../../../models/plugin.go) - Plugin interface definitions - [plugins/email-password/plugin.go](../../../plugins/email-password/plugin.go) - Email-password plugin - [plugins/jwt/plugin.go](../../../plugins/jwt/plugin.go) - JWT plugin - [internal/bootstrap/plugin_factory.go](../../../internal/bootstrap/plugin_factory.go) - Plugin factory diff --git a/.agents/skills/plugin-architecture/examples/todo_plugin.go b/.agents/skills/plugin-architecture/examples/todo_plugin.go deleted file mode 100644 index c998aa77..00000000 --- a/.agents/skills/plugin-architecture/examples/todo_plugin.go +++ /dev/null @@ -1,65 +0,0 @@ -package examples - -import ( - "net/http" - - "github.com/uptightgo/bun" -) - -// Type aliases for plugin interfaces (simplified) -type PluginContext struct { - Database bun.IDB -} - -type Route struct { - Method string - Path string - Handler func(w http.ResponseWriter, r *http.Request) error -} - -// TodosPlugin -type TodosPlugin struct { - todoService TodoService - createTodoUseCase *CreateTodoUseCase - markCompleteUseCase *MarkTodoCompleteUseCase -} - -func (p *TodosPlugin) Init(ctx *PluginContext) error { - // Step 1: Create repository with database - todoRepo := NewBunTodoRepository(ctx.Database) - - // Step 2: Create service - p.todoService = NewTodoService(todoRepo) - - // Step 3: Create use cases - p.createTodoUseCase = NewCreateTodoUseCase(p.todoService) - p.markCompleteUseCase = NewMarkTodoCompleteUseCase(p.todoService) - - return nil -} - -func (p *TodosPlugin) Routes() []Route { - createHandler := NewCreateTodoHandler(p.createTodoUseCase) - completeHandler := NewMarkTodoCompleteHandler(p.markCompleteUseCase) - - return []Route{ - { - Method: http.MethodPost, - Path: "/todos", - Handler: func(w http.ResponseWriter, r *http.Request) error { - return createHandler.Handle(w, r) - }, - }, - { - Method: http.MethodPut, - Path: "/todos/{id}/complete", - Handler: func(w http.ResponseWriter, r *http.Request) error { - return completeHandler.Handle(w, r) - }, - }, - } -} - -func (p *TodosPlugin) Close() error { - return nil -} diff --git a/config.example.toml b/config.example.toml index 288530a8..40a8f165 100644 --- a/config.example.toml +++ b/config.example.toml @@ -151,47 +151,3 @@ buffer_size = 100 # In standalone mode, all plugin-to-route associations are defined here via [[route_mappings]] tables. # This enables full plugin routing control without code changes. # Plugin IDs follow the format "{plugin_name}.{operation}" (e.g., "session.auth", "csrf.protect") - -# Example routes: -# [[route_mappings]] -# path = "/me" -# method = "GET" -# plugins = ["session.auth"] (SSR) or ["bearer.auth"] (SPA/mobile) - -# [[route_mappings]] -# path = "/sign-in" -# method = "POST" -# plugins = ["session.auth.optional"] - -# [[route_mappings]] -# path = "/sign-up" -# method = "POST" - -# [[route_mappings]] -# path = "/change-password" -# method = "POST" -# plugins = ["session.auth", "csrf.protect"] - -# [[route_mappings]] -# path = "/sign-out" -# method = "POST" -# plugins = ["session.auth", "csrf.protect"] - -# Access control (opt-in per route) -# [[route_mappings]] -# path = "/admin/users" -# method = "GET" -# plugins = ["session.auth", "access_control.enforce"] -# permissions = ["users.read"] - -# If using TOTP plugin, keep /totp/verify and /totp/verify-backup-code accessible -# to the pending-token flow (do not require an existing session cookie). -# [[route_mappings]] -# path = "/totp/verify" -# method = "POST" -# plugins = ["session.auth.optional"] - -# [[route_mappings]] -# path = "/totp/verify-backup-code" -# method = "POST" -# plugins = ["session.auth.optional"] diff --git a/internal/bootstrap/plugin_factory.go b/internal/bootstrap/plugin_factory.go index 6d4b39da..d1b2cd28 100644 --- a/internal/bootstrap/plugin_factory.go +++ b/internal/bootstrap/plugin_factory.go @@ -23,6 +23,8 @@ import ( magiclinkplugintypes "github.com/Authula/authula/plugins/magic-link/types" oauth2plugin "github.com/Authula/authula/plugins/oauth2" oauth2plugintypes "github.com/Authula/authula/plugins/oauth2/types" + organizationsplugin "github.com/Authula/authula/plugins/organizations" + organizationsplugintypes "github.com/Authula/authula/plugins/organizations/types" ratelimitplugin "github.com/Authula/authula/plugins/rate-limit" secondarystorageplugin "github.com/Authula/authula/plugins/secondary-storage" sessionplugin "github.com/Authula/authula/plugins/session" @@ -248,6 +250,22 @@ var pluginFactories = []PluginFactory{ return accesscontrolplugin.New(typedConfig.(accesscontrolplugintypes.AccessControlPluginConfig)) }, }, + { + ID: models.PluginOrganizations.String(), + RequiredByDefault: false, + ConfigParser: func(rawConfig any) (any, error) { + config := organizationsplugintypes.OrganizationsPluginConfig{} + if rawConfig != nil { + if err := util.ParsePluginConfig(rawConfig, &config); err != nil { + return nil, fmt.Errorf("failed to parse organizations plugin config: %w", err) + } + } + return config, nil + }, + Constructor: func(typedConfig any) models.Plugin { + return organizationsplugin.New(typedConfig.(organizationsplugintypes.OrganizationsPluginConfig)) + }, + }, { ID: models.PluginTOTP.String(), RequiredByDefault: false, diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 00000000..0b19cf39 --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,35 @@ +package errors + +import ( + "errors" + "net/http" + + "github.com/Authula/authula/models" +) + +var ( + ErrBadRequest = errors.New("bad request") + ErrUnauthorized = errors.New("unauthorized") + ErrForbidden = errors.New("forbidden") + ErrNotFound = errors.New("not found") + ErrConflict = errors.New("conflict") + ErrUnprocessableEntity = errors.New("unprocessable entity") +) + +func HandleError(err error, reqCtx *models.RequestContext) { + status := http.StatusBadRequest + switch err { + case ErrUnauthorized: + status = http.StatusUnauthorized + case ErrForbidden: + status = http.StatusForbidden + case ErrNotFound: + status = http.StatusNotFound + case ErrConflict: + status = http.StatusConflict + case ErrUnprocessableEntity: + status = http.StatusUnprocessableEntity + } + reqCtx.SetJSONResponse(status, map[string]any{"message": err.Error()}) + reqCtx.Handled = true +} diff --git a/models/plugin.go b/models/plugin.go index a1b2128e..a35be3e9 100644 --- a/models/plugin.go +++ b/models/plugin.go @@ -26,6 +26,7 @@ const ( PluginRateLimit PluginID = "ratelimit" PluginMagicLink PluginID = "magic_link" PluginTOTP PluginID = "totp" + PluginOrganizations PluginID = "organizations" ) func (id PluginID) String() string { diff --git a/plugins/admin/services/state_service_test.go b/plugins/admin/services/state_service_test.go index 5eeb8242..4ff4a61b 100644 --- a/plugins/admin/services/state_service_test.go +++ b/plugins/admin/services/state_service_test.go @@ -97,7 +97,7 @@ func TestStateService_UpsertUserState(t *testing.T) { if tc.userExists { // always prepare an Upsert expectation when user exists if tc.hasRepoErr { - usr.On("Upsert", mock.Anything, mock.Anything).Return(errors.New("boom")).Once() + usr.On("Upsert", mock.Anything, mock.Anything).Return(errors.New("some error")).Once() } else if tc.expectCall != nil { usr.On("Upsert", mock.Anything, mock.MatchedBy(tc.expectCall)).Return(nil).Once() } else { diff --git a/plugins/magic-link/handlers/verify_handler_test.go b/plugins/magic-link/handlers/verify_handler_test.go index 11db34fe..b5e7e7ec 100644 --- a/plugins/magic-link/handlers/verify_handler_test.go +++ b/plugins/magic-link/handlers/verify_handler_test.go @@ -119,14 +119,14 @@ func TestVerifyHandler_UntrustedCallbackURL(t *testing.T) { func TestVerifyHandler_UseCaseError(t *testing.T) { useCase := &mockVerifyUseCase{} - useCase.On("Verify", mock.Anything, "abc", mock.Anything, mock.Anything).Return("", errors.New("boom")).Once() + useCase.On("Verify", mock.Anything, "abc", mock.Anything, mock.Anything).Return("", errors.New("some error")).Once() handler := &VerifyHandler{UseCase: useCase} req, reqCtx, w := newCallbackRequest(t, "/magic-link/verify?token=abc") handler.Handler()(w, req) - assertErrorResponse(t, reqCtx, http.StatusBadRequest, "boom") + assertErrorResponse(t, reqCtx, http.StatusBadRequest, "some error") useCase.AssertExpectations(t) } diff --git a/plugins/organizations/api.go b/plugins/organizations/api.go new file mode 100644 index 00000000..daed0e48 --- /dev/null +++ b/plugins/organizations/api.go @@ -0,0 +1,115 @@ +package organizations + +import ( + "context" + + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type API struct { + organizationUseCase *usecases.OrganizationUseCase + invitationUseCase *usecases.OrganizationInvitationUseCase + memberUseCase *usecases.OrganizationMemberUseCase + teamUseCase *usecases.OrganizationTeamUseCase +} + +func BuildAPI(organizationUseCase *usecases.OrganizationUseCase, invitationUseCase *usecases.OrganizationInvitationUseCase, memberUseCase *usecases.OrganizationMemberUseCase, teamUseCase *usecases.OrganizationTeamUseCase) *API { + return &API{organizationUseCase: organizationUseCase, invitationUseCase: invitationUseCase, memberUseCase: memberUseCase, teamUseCase: teamUseCase} +} + +func (a *API) CreateOrganization(ctx context.Context, actorUserID string, request types.CreateOrganizationRequest) (*types.Organization, error) { + return a.organizationUseCase.CreateOrganization(ctx, actorUserID, request) +} + +func (a *API) GetAllOrganizationsByUserID(ctx context.Context, actorUserID string) ([]types.Organization, error) { + return a.organizationUseCase.GetAllOrganizationsByUserID(ctx, actorUserID) +} + +func (a *API) GetOrganization(ctx context.Context, actorUserID string, organizationID string) (*types.Organization, error) { + return a.organizationUseCase.GetOrganization(ctx, actorUserID, organizationID) +} + +func (a *API) UpdateOrganization(ctx context.Context, actorUserID string, organizationID string, request types.UpdateOrganizationRequest) (*types.Organization, error) { + return a.organizationUseCase.UpdateOrganization(ctx, actorUserID, organizationID, request) +} + +func (a *API) DeleteOrganization(ctx context.Context, actorUserID string, organizationID string) error { + return a.organizationUseCase.DeleteOrganization(ctx, actorUserID, organizationID) +} + +func (a *API) CreateInvitation(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationInvitationRequest) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.CreateInvitation(ctx, actorUserID, organizationID, request) +} + +func (a *API) GetInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.GetInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (a *API) GetAllInvitations(ctx context.Context, actorUserID string, organizationID string) ([]types.OrganizationInvitation, error) { + return a.invitationUseCase.GetAllInvitations(ctx, actorUserID, organizationID) +} + +func (a *API) RevokeInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.RevokeInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (a *API) AcceptInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.AcceptInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (a *API) RejectInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.RejectInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (a *API) AddMember(ctx context.Context, actorUserID string, organizationID string, request types.AddOrganizationMemberRequest) (*types.OrganizationMember, error) { + return a.memberUseCase.AddMember(ctx, actorUserID, organizationID, request) +} + +func (a *API) GetAllMembers(ctx context.Context, actorUserID string, organizationID string, page int, limit int) ([]types.OrganizationMember, error) { + return a.memberUseCase.GetAllMembers(ctx, actorUserID, organizationID, page, limit) +} + +func (a *API) GetMember(ctx context.Context, actorUserID string, organizationID string, memberID string) (*types.OrganizationMember, error) { + return a.memberUseCase.GetMember(ctx, actorUserID, organizationID, memberID) +} + +func (a *API) UpdateMember(ctx context.Context, actorUserID string, organizationID string, memberID string, request types.UpdateOrganizationMemberRequest) (*types.OrganizationMember, error) { + return a.memberUseCase.UpdateMember(ctx, actorUserID, organizationID, memberID, request) +} + +func (a *API) RemoveMember(ctx context.Context, actorUserID string, organizationID string, memberID string) error { + return a.memberUseCase.RemoveMember(ctx, actorUserID, organizationID, memberID) +} + +func (a *API) CreateTeam(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationTeamRequest) (*types.OrganizationTeam, error) { + return a.teamUseCase.CreateTeam(ctx, actorUserID, organizationID, request) +} + +func (a *API) GetAllTeams(ctx context.Context, actorUserID string, organizationID string) ([]types.OrganizationTeam, error) { + return a.teamUseCase.GetAllTeams(ctx, actorUserID, organizationID) +} + +func (a *API) UpdateTeam(ctx context.Context, actorUserID string, organizationID string, teamID string, request types.UpdateOrganizationTeamRequest) (*types.OrganizationTeam, error) { + return a.teamUseCase.UpdateTeam(ctx, actorUserID, organizationID, teamID, request) +} + +func (a *API) DeleteTeam(ctx context.Context, actorUserID string, organizationID string, teamID string) error { + return a.teamUseCase.DeleteTeam(ctx, actorUserID, organizationID, teamID) +} + +func (a *API) AddTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, request types.AddOrganizationTeamMemberRequest) (*types.OrganizationTeamMember, error) { + return a.teamUseCase.AddTeamMember(ctx, actorUserID, organizationID, teamID, request) +} + +func (a *API) GetTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, memberID string) (*types.OrganizationTeamMember, error) { + return a.teamUseCase.GetTeamMember(ctx, actorUserID, organizationID, teamID, memberID) +} + +func (a *API) GetAllTeamMembers(ctx context.Context, actorUserID string, organizationID string, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) { + return a.teamUseCase.GetAllTeamMembers(ctx, actorUserID, organizationID, teamID, page, limit) +} + +func (a *API) RemoveTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, memberID string) error { + return a.teamUseCase.RemoveTeamMember(ctx, actorUserID, organizationID, teamID, memberID) +} diff --git a/plugins/organizations/constants/constants.go b/plugins/organizations/constants/constants.go new file mode 100644 index 00000000..1530dde7 --- /dev/null +++ b/plugins/organizations/constants/constants.go @@ -0,0 +1,5 @@ +package constants + +const ( + EventOrganizationsInvitationCreated = "organizations.invitation.created" +) diff --git a/plugins/organizations/events/events.go b/plugins/organizations/events/events.go new file mode 100644 index 00000000..1a2d562e --- /dev/null +++ b/plugins/organizations/events/events.go @@ -0,0 +1,15 @@ +package events + +import "time" + +type OrganizationInvitationCreatedEvent struct { + ID string `json:"id"` + InvitationID string `json:"invitation_id"` + OrganizationID string `json:"organization_id"` + OrganizationName string `json:"organization_name"` + InviteeEmail string `json:"invitee_email"` + InviterID string `json:"inviter_id"` + Role string `json:"role"` + ExpiresAt time.Time `json:"expires_at"` + RedirectURL string `json:"redirect_url,omitempty"` +} diff --git a/plugins/organizations/handlers/organization_handlers.go b/plugins/organizations/handlers/organization_handlers.go new file mode 100644 index 00000000..333cb1ae --- /dev/null +++ b/plugins/organizations/handlers/organization_handlers.go @@ -0,0 +1,157 @@ +package handlers + +import ( + "net/http" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type CreateOrganizationHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *CreateOrganizationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + var payload types.CreateOrganizationRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + organization, err := h.UseCase.CreateOrganization(ctx, userID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, organization) + } +} + +type GetAllOrganizationsByUserIDHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *GetAllOrganizationsByUserIDHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizations, err := h.UseCase.GetAllOrganizationsByUserID(ctx, userID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, organizations) + } +} + +type GetOrganizationHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *GetOrganizationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + organization, err := h.UseCase.GetOrganization(ctx, userID, organizationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, organization) + } +} + +type UpdateOrganizationHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *UpdateOrganizationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + var payload types.UpdateOrganizationRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + organization, err := h.UseCase.UpdateOrganization(ctx, userID, organizationID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, organization) + } +} + +type DeleteOrganizationHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *DeleteOrganizationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + if err := h.UseCase.DeleteOrganization(ctx, userID, organizationID); err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusNoContent, nil) + } +} diff --git a/plugins/organizations/handlers/organization_handlers_test.go b/plugins/organizations/handlers/organization_handlers_test.go new file mode 100644 index 00000000..3782096c --- /dev/null +++ b/plugins/organizations/handlers/organization_handlers_test.go @@ -0,0 +1,458 @@ +package handlers + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type organizationHandlerFixture struct { + repo *orgtests.MockOrganizationRepository +} + +type organizationHandlerCase struct { + name string + userID *string + body []byte + organizationID string + prepare func(*organizationHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(t *testing.T, reqCtx *models.RequestContext) +} + +func newOrganizationHandlerFixture() *organizationHandlerFixture { + return &organizationHandlerFixture{repo: &orgtests.MockOrganizationRepository{}} +} + +func (f *organizationHandlerFixture) useCase() *usecases.OrganizationUseCase { + service := orgservices.NewOrganizationService(f.repo, nil) + return usecases.NewOrganizationUseCase(service) +} + +func (f *organizationHandlerFixture) newRequest(t *testing.T, method, path string, body []byte, userID *string, organizationID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + return req, w, reqCtx +} + +func TestCreateOrganizationHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationRequest{Name: "Acme Inc"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + body: []byte("{invalid"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "unprocessable_entity", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationRequest{Name: " "}), + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: "unprocessable entity", + }, + { + name: "repo_error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationRequest{ + Name: "Acme Inc", + }), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("Create", mock.Anything, mock.MatchedBy(func(org *orgtypes.Organization) bool { + return org != nil && org.OwnerID == "user-1" && org.Name == "Acme Inc" && org.Slug == "acme-inc" + })).Return((*orgtypes.Organization)(nil), errors.New("create failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "create failed", + }, + { + name: "success", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationRequest{ + Name: "Acme Inc", + }), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("Create", mock.Anything, mock.MatchedBy(func(org *orgtypes.Organization) bool { + return org != nil && org.OwnerID == "user-1" && org.Name == "Acme Inc" && org.Slug == "acme-inc" + })).Return(&orgtypes.Organization{ + ID: "org-1", + OwnerID: "user-1", + Name: "Acme Inc", + Slug: "acme-inc", + }, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + org := internaltests.DecodeResponseJSON[orgtypes.Organization](t, reqCtx) + assert.Equal(t, "org-1", org.ID) + assert.Equal(t, "user-1", org.OwnerID) + assert.Equal(t, "Acme Inc", org.Name) + assert.Equal(t, "acme-inc", org.Slug) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &CreateOrganizationHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodPost, "/organizations", tt.body, tt.userID, "") + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} + +func TestGetAllOrganizationsByUserIDHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "repo_error", + userID: new("user-1"), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetAllByOwnerID", mock.Anything, "user-1").Return(([]orgtypes.Organization)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetAllByOwnerID", mock.Anything, "user-1").Return([]orgtypes.Organization{{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + organizations := internaltests.DecodeResponseJSON[[]orgtypes.Organization](t, reqCtx) + require.Len(t, organizations, 1) + assert.Equal(t, "org-1", organizations[0].ID) + assert.Equal(t, "user-1", organizations[0].OwnerID) + assert.Equal(t, "Acme Inc", organizations[0].Name) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &GetAllOrganizationsByUserIDHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodGet, "/organizations", nil, tt.userID, "") + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} + +func TestGetOrganizationHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "not_found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + org := internaltests.DecodeResponseJSON[orgtypes.Organization](t, reqCtx) + assert.Equal(t, "org-1", org.ID) + assert.Equal(t, "user-1", org.OwnerID) + assert.Equal(t, "Acme Inc", org.Name) + assert.Equal(t, "acme-inc", org.Slug) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &GetOrganizationHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodGet, "/organizations/test-id", nil, tt.userID, tt.organizationID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} + +func TestUpdateOrganizationHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{Name: "Acme Platform"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + organizationID: "org-1", + body: []byte("{invalid"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "unprocessable_entity", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{Name: " "}), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: "unprocessable entity", + }, + { + name: "not_found", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{Name: "Acme Platform"}), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{Name: "Acme Platform"}), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{ + Name: "Acme Platform", + Logo: new("http://some/url/logo.svg"), + Metadata: json.RawMessage(`{"tier":"pro"}`), + }), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + f.repo.On("Update", mock.Anything, mock.MatchedBy(func(org *orgtypes.Organization) bool { + return org != nil && org.ID == "org-1" && org.OwnerID == "user-1" && org.Name == "Acme Platform" && org.Slug == "acme-inc" + })).Return(&orgtypes.Organization{ + ID: "org-1", + OwnerID: "user-1", + Name: "Acme Platform", + Slug: "acme-inc", + Logo: new("http://some/url/logo.svg"), + Metadata: json.RawMessage(`{"tier":"pro"}`), + }, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + org := internaltests.DecodeResponseJSON[orgtypes.Organization](t, reqCtx) + assert.Equal(t, "org-1", org.ID) + assert.Equal(t, "user-1", org.OwnerID) + assert.Equal(t, "Acme Platform", org.Name) + assert.Equal(t, "acme-inc", org.Slug) + require.NotNil(t, org.Logo) + assert.Equal(t, "http://some/url/logo.svg", *org.Logo) + assert.JSONEq(t, `{"tier":"pro"}`, string(org.Metadata)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &UpdateOrganizationHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodPatch, "/organizations/test-id", tt.body, tt.userID, tt.organizationID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} + +func TestDeleteOrganizationHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "not_found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + f.repo.On("Delete", mock.Anything, "org-1").Return(nil).Once() + }, + expectedStatus: http.StatusNoContent, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + assert.Equal(t, "null", string(reqCtx.ResponseBody)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &DeleteOrganizationHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodDelete, "/organizations/test-id", nil, tt.userID, tt.organizationID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} diff --git a/plugins/organizations/handlers/organization_invitation_handlers.go b/plugins/organizations/handlers/organization_invitation_handlers.go new file mode 100644 index 00000000..b7c385ec --- /dev/null +++ b/plugins/organizations/handlers/organization_invitation_handlers.go @@ -0,0 +1,191 @@ +package handlers + +import ( + "net/http" + "strings" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type CreateOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *CreateOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + + var payload types.CreateOrganizationInvitationRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + invitation, err := h.UseCase.CreateInvitation(ctx, userID, organizationID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, invitation) + } +} + +type GetAllOrganizationInvitationsHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *GetAllOrganizationInvitationsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitations, err := h.UseCase.GetAllInvitations(ctx, userID, organizationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitations) + } +} + +type GetOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *GetOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitationID := r.PathValue("invitation_id") + invitation, err := h.UseCase.GetInvitation(ctx, userID, organizationID, invitationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitation) + } +} + +type RevokeOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *RevokeOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitationID := r.PathValue("invitation_id") + invitation, err := h.UseCase.RevokeInvitation(ctx, userID, organizationID, invitationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitation) + } +} + +type AcceptOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *AcceptOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitationID := r.PathValue("invitation_id") + invitation, err := h.UseCase.AcceptInvitation(ctx, userID, organizationID, invitationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + if redirectURL := strings.TrimSpace(r.URL.Query().Get("redirect_url")); redirectURL != "" { + reqCtx.RedirectURL = redirectURL + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitation) + } +} + +type RejectOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *RejectOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitationID := r.PathValue("invitation_id") + invitation, err := h.UseCase.RejectInvitation(ctx, userID, organizationID, invitationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitation) + } +} diff --git a/plugins/organizations/handlers/organization_invitation_handlers_test.go b/plugins/organizations/handlers/organization_invitation_handlers_test.go new file mode 100644 index 00000000..1516bf06 --- /dev/null +++ b/plugins/organizations/handlers/organization_invitation_handlers_test.go @@ -0,0 +1,1019 @@ +package handlers + +import ( + "context" + "database/sql" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/uptrace/bun" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" + rootservices "github.com/Authula/authula/services" +) + +type organizationInvitationTxRunner interface { + RunInTx(ctx context.Context, opts *sql.TxOptions, fn func(context.Context, bun.Tx) error) error +} + +func newOrganizationInvitationServiceForHandlerTest( + txRunner organizationInvitationTxRunner, + pluginConfig *orgtypes.OrganizationsPluginConfig, + userService rootservices.UserService, + orgRepo *orgtests.MockOrganizationRepository, + invRepo *orgtests.MockOrganizationInvitationRepository, + memberRepo *orgtests.MockOrganizationMemberRepository, + invHooks orgservices.OrganizationInvitationHookExecutor, + memberHooks orgservices.OrganizationMemberHookExecutor, +) *orgservices.OrganizationInvitationService { + return orgservices.NewOrganizationInvitationService( + txRunner, + &models.Config{BaseURL: "https://example.com", BasePath: "/auth"}, + pluginConfig, + &internaltests.MockLogger{}, + userService, + orgRepo, + invRepo, + memberRepo, + nil, + nil, + invHooks, + memberHooks, + ) +} + +type organizationInvitationHandlerFixture struct { + pluginConfig *orgtypes.OrganizationsPluginConfig + orgRepo *orgtests.MockOrganizationRepository + invRepo *orgtests.MockOrganizationInvitationRepository + memberRepo *orgtests.MockOrganizationMemberRepository + userSvc *internaltests.MockUserService + txRunner *orgtests.MockOrganizationInvitationTxRunner +} + +type organizationInvitationHandlerCase struct { + name string + userID *string + body []byte + organizationID string + invitationID string + prepare func(*organizationInvitationHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(*testing.T, *models.RequestContext) +} + +func newOrganizationInvitationHandlerFixture() *organizationInvitationHandlerFixture { + return &organizationInvitationHandlerFixture{ + pluginConfig: &orgtypes.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + }, + orgRepo: &orgtests.MockOrganizationRepository{}, + invRepo: &orgtests.MockOrganizationInvitationRepository{}, + memberRepo: &orgtests.MockOrganizationMemberRepository{}, + userSvc: &internaltests.MockUserService{}, + txRunner: &orgtests.MockOrganizationInvitationTxRunner{}, + } +} + +func (f *organizationInvitationHandlerFixture) useCase() *usecases.OrganizationInvitationUseCase { + service := newOrganizationInvitationServiceForHandlerTest(f.txRunner, f.pluginConfig, f.userSvc, f.orgRepo, f.invRepo, f.memberRepo, nil, nil) + return usecases.NewOrganizationInvitationUseCase(service) +} + +func (f *organizationInvitationHandlerFixture) newRequest(t *testing.T, method, path string, body []byte, userID *string, organizationID, invitationID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + t.Helper() + + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + if invitationID != "" { + req.SetPathValue("invitation_id", invitationID) + } + return req, w, reqCtx +} + +func runOrganizationInvitationHandlerCases(t *testing.T, method, path string, buildHandler func(*organizationInvitationHandlerFixture) http.HandlerFunc, cases []organizationInvitationHandlerCase) { + t.Helper() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationInvitationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := buildHandler(fixture) + req, w, reqCtx := fixture.newRequest(t, method, path, tt.body, tt.userID, tt.organizationID, tt.invitationID) + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + + if fixture.userSvc != nil { + fixture.userSvc.AssertExpectations(t) + } + if fixture.orgRepo != nil { + fixture.orgRepo.AssertExpectations(t) + } + if fixture.invRepo != nil { + fixture.invRepo.AssertExpectations(t) + } + if fixture.memberRepo != nil { + fixture.memberRepo.AssertExpectations(t) + } + + }) + } +} + +func TestCreateOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodPost, "/organizations/org-1/invitations", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&CreateOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + organizationID: "org-1", + body: []byte("{"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "empty_email", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: " ", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "invalid_email", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "not-an-email", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "empty_role", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: " "}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "organization_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "existing_invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "pending_conflict", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "USER@EXAMPLE.COM", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "expired_existing_invitation_then_create_success", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{ + Email: "user@example.com", + Role: "member", + }), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.pluginConfig.InvitationExpiresIn = 36 * time.Hour + expectedExpiresAt := time.Now().UTC().Add(fixture.pluginConfig.InvitationExpiresIn) + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(&orgtypes.OrganizationInvitation{ID: "inv-old", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(-time.Hour)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-old" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-old", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Now().UTC().Add(-time.Hour)}, nil).Once() + fixture.invRepo.On("Create", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.OrganizationID == "org-1" && invitation.InviterID == "user-1" && invitation.Email == "user@example.com" && invitation.Role == "member" && invitation.Status == orgtypes.OrganizationInvitationStatusPending && invitation.ExpiresAt.After(expectedExpiresAt.Add(-2*time.Second)) && invitation.ExpiresAt.Before(expectedExpiresAt.Add(2*time.Second)) + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-2", OrganizationID: "org-1", InviterID: "user-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: expectedExpiresAt}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-2", invitation.ID) + require.Equal(t, "org-1", invitation.OrganizationID) + require.Equal(t, "user-1", invitation.InviterID) + require.Equal(t, "user@example.com", invitation.Email) + require.Equal(t, "member", invitation.Role) + require.Equal(t, orgtypes.OrganizationInvitationStatusPending, invitation.Status) + require.WithinDuration(t, time.Now().UTC().Add(36*time.Hour), invitation.ExpiresAt, 2*time.Second) + }, + }, + { + name: "invalid_expiry_config", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{ + Email: "user@example.com", + Role: "member", + }), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.pluginConfig.InvitationExpiresIn = 0 + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "success_applies_configured_expiry", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{ + Email: "user@example.com", + Role: "member", + }), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.pluginConfig.InvitationExpiresIn = 90 * time.Minute + expectedExpiresAt := time.Now().UTC().Add(fixture.pluginConfig.InvitationExpiresIn) + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + fixture.invRepo.On("Create", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.OrganizationID == "org-1" && invitation.InviterID == "user-1" && invitation.Email == "user@example.com" && invitation.Role == "member" && invitation.Status == orgtypes.OrganizationInvitationStatusPending && invitation.ExpiresAt.After(expectedExpiresAt.Add(-2*time.Second)) && invitation.ExpiresAt.Before(expectedExpiresAt.Add(2*time.Second)) + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-2", OrganizationID: "org-1", InviterID: "user-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: expectedExpiresAt}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-2", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusPending, invitation.Status) + require.Equal(t, "org-1", invitation.OrganizationID) + require.Equal(t, "user-1", invitation.InviterID) + require.Equal(t, "user@example.com", invitation.Email) + require.WithinDuration(t, time.Now().UTC().Add(90*time.Minute), invitation.ExpiresAt, 2*time.Second) + }, + }, + }) +} + +func TestGetAllOrganizationInvitationsHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodGet, "/organizations/org-1/invitations", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&GetAllOrganizationInvitationsHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "organization_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "repo_error", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return(([]orgtypes.OrganizationInvitation)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "expired_pending_invitation_is_updated", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return([]orgtypes.OrganizationInvitation{{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitations := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationInvitation](t, reqCtx) + require.Len(t, invitations, 1) + require.Equal(t, "inv-1", invitations[0].ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusExpired, invitations[0].Status) + }, + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return([]orgtypes.OrganizationInvitation{{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Date(2026, time.January, 2, 3, 4, 5, 0, time.UTC)}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitations := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationInvitation](t, reqCtx) + require.Len(t, invitations, 1) + require.Equal(t, "inv-1", invitations[0].ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusAccepted, invitations[0].Status) + }, + }, + }) +} + +func TestGetOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodGet, "/organizations/org-1/invitations/inv-1", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&GetOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + invitationID: "inv-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "organization_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "invitation_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_from_other_organization", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-2", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "expired_pending_invitation_is_updated", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusExpired, invitation.Status) + }, + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Date(2026, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusAccepted, invitation.Status) + }, + }, + }) +} + +func TestRevokeOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodDelete, "/organizations/org-1/invitations/inv-1", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&RevokeOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + invitationID: "inv-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "organization_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "invitation_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_from_other_organization", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-2", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "already_handled_conflict", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "expired_pending_invitation_conflict", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusRevoked + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusRevoked, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusRevoked, invitation.Status) + }, + }, + }) +} + +func TestAcceptOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodPost, "/organizations/org-1/invitations/inv-1/accept", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&AcceptOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + invitationID: "inv-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "user_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return((*models.User)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "user_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return((*models.User)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "user_missing_email", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: " "}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_from_other_organization", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-2", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "expired_pending_invitation_conflict", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "email_mismatch_forbidden", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "other@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "member_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-1").Return((*orgtypes.OrganizationMember)(nil), errors.New("member lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "member lookup failed", + }, + { + name: "success_creates_member", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + fixture.memberRepo.On("Create", mock.Anything, mock.MatchedBy(func(member *orgtypes.OrganizationMember) bool { + return member != nil && member.OrganizationID == "org-1" && member.UserID == "user-1" && member.Role == "member" + })).Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusAccepted + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusAccepted, invitation.Status) + }, + }, + { + name: "success_with_existing_member", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusAccepted + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusAccepted, invitation.Status) + }, + }, + }) +} + +func TestAcceptOrganizationInvitationHandler_RedirectURL(t *testing.T) { + t.Parallel() + + fixture := newOrganizationInvitationHandlerFixture() + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusAccepted + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + + handler := (&AcceptOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + userID := "user-1" + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/organizations/org-1/invitations/inv-1/accept?redirect_url=https%3A%2F%2Fapp.example.com%2Fwelcome", nil, &userID) + req.SetPathValue("organization_id", "org-1") + req.SetPathValue("invitation_id", "inv-1") + handler.ServeHTTP(w, req) + + require.Equal(t, "https://app.example.com/welcome", reqCtx.RedirectURL) + require.Equal(t, 0, reqCtx.ResponseStatus) + require.False(t, reqCtx.ResponseReady) + fixture.userSvc.AssertExpectations(t) + fixture.invRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) +} + +func TestRejectOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodPost, "/organizations/org-1/invitations/inv-1/reject", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&RejectOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + invitationID: "inv-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "user_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return((*models.User)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "user_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return((*models.User)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "user_missing_email", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: ""}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_from_other_organization", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-2", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "already_handled_conflict", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusRejected, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "email_mismatch_forbidden", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "other@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "update_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusRejected + })).Return((*orgtypes.OrganizationInvitation)(nil), errors.New("update failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "update failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusRejected + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusRejected, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusRejected, invitation.Status) + }, + }, + }) +} diff --git a/plugins/organizations/handlers/organization_member_handlers.go b/plugins/organizations/handlers/organization_member_handlers.go new file mode 100644 index 00000000..de5fc7d1 --- /dev/null +++ b/plugins/organizations/handlers/organization_member_handlers.go @@ -0,0 +1,166 @@ +package handlers + +import ( + "net/http" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type AddOrganizationMemberHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *AddOrganizationMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + + var payload types.AddOrganizationMemberRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + member, err := h.UseCase.AddMember(ctx, userID, organizationID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, member) + } +} + +type GetAllOrganizationMembersHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *GetAllOrganizationMembersHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + page := util.GetQueryInt(r, "page", 1) + limit := util.GetQueryInt(r, "limit", 10) + members, err := h.UseCase.GetAllMembers(ctx, userID, organizationID, page, limit) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, members) + } +} + +type GetOrganizationMemberHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *GetOrganizationMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + memberID := r.PathValue("member_id") + member, err := h.UseCase.GetMember(ctx, userID, organizationID, memberID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, member) + } +} + +type UpdateOrganizationMemberHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *UpdateOrganizationMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + memberID := r.PathValue("member_id") + + var payload types.UpdateOrganizationMemberRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + member, err := h.UseCase.UpdateMember(ctx, userID, organizationID, memberID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, member) + } +} + +type DeleteOrganizationMemberHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *DeleteOrganizationMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + memberID := r.PathValue("member_id") + if err := h.UseCase.RemoveMember(ctx, userID, organizationID, memberID); err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusNoContent, nil) + } +} diff --git a/plugins/organizations/handlers/organization_member_handlers_test.go b/plugins/organizations/handlers/organization_member_handlers_test.go new file mode 100644 index 00000000..8559a6d6 --- /dev/null +++ b/plugins/organizations/handlers/organization_member_handlers_test.go @@ -0,0 +1,550 @@ +package handlers + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type organizationMemberHandlerFixture struct { + userSvc *internaltests.MockUserService + orgRepo *orgtests.MockOrganizationRepository + orgMemberRepo *orgtests.MockOrganizationMemberRepository +} + +func newOrganizationMemberHandlerFixture() *organizationMemberHandlerFixture { + return &organizationMemberHandlerFixture{ + userSvc: &internaltests.MockUserService{}, + orgRepo: &orgtests.MockOrganizationRepository{}, + orgMemberRepo: &orgtests.MockOrganizationMemberRepository{}, + } +} + +func (f *organizationMemberHandlerFixture) useCase() *usecases.OrganizationMemberUseCase { + service := orgservices.NewOrganizationMemberService(f.userSvc, f.orgRepo, f.orgMemberRepo, nil) + return usecases.NewOrganizationMemberUseCase(service) +} + +func (f *organizationMemberHandlerFixture) newRequest(t *testing.T, method, path string, body []byte, userID *string, organizationID, memberID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + if memberID != "" { + req.SetPathValue("member_id", memberID) + } + return req, w, reqCtx +} + +type organizationMemberHandlerCase struct { + name string + userID *string + body []byte + organizationID string + memberID string + prepare func(*organizationMemberHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(*testing.T, *models.RequestContext) +} + +func runOrganizationMemberHandlerCases(t *testing.T, method, path string, buildHandler func(*organizationMemberHandlerFixture) http.HandlerFunc, cases []organizationMemberHandlerCase) { + t.Helper() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationMemberHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := buildHandler(fixture) + req, w, reqCtx := fixture.newRequest(t, method, path, tt.body, tt.userID, tt.organizationID, tt.memberID) + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.userSvc.AssertExpectations(t) + fixture.orgRepo.AssertExpectations(t) + fixture.orgMemberRepo.AssertExpectations(t) + }) + } +} + +func TestAddOrganizationMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodPost, "/organizations/org-1/members", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&AddOrganizationMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid request body", + userID: new("user-1"), + body: []byte("{"), + organizationID: "org-1", + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "bad request when user id is empty", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: "unprocessable entity", + }, + { + name: "bad request when role is empty", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: ""}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: "unprocessable entity", + }, + { + name: "organization not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-3", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "user not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.userSvc.On("GetByID", mock.Anything, "user-2").Return((*models.User)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "member lookup error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user@example.com"}, nil).Once() + fixture.orgMemberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return((*orgtypes.OrganizationMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "conflict", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user@example.com"}, nil).Once() + fixture.orgMemberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: "conflict", + }, + { + name: "success", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user@example.com"}, nil).Once() + fixture.orgMemberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + fixture.orgMemberRepo.On("Create", mock.Anything, mock.MatchedBy(func(member *orgtypes.OrganizationMember) bool { + return member != nil && member.OrganizationID == "org-1" && member.UserID == "user-2" && member.Role == "member" + })).Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + member := internaltests.DecodeResponseJSON[orgtypes.OrganizationMember](t, reqCtx) + require.Equal(t, "mem-1", member.ID) + require.Equal(t, "org-1", member.OrganizationID) + require.Equal(t, "user-2", member.UserID) + require.Equal(t, "member", member.Role) + }, + }, + }) +} + +func TestGetAllOrganizationMembersHandler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodGet, "/organizations/org-1/members", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&GetAllOrganizationMembersHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "repo error", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetAllByOrganizationID", mock.Anything, "org-1", 1, 10).Return(([]orgtypes.OrganizationMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetAllByOrganizationID", mock.Anything, "org-1", 1, 10).Return([]orgtypes.OrganizationMember{{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + members := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationMember](t, reqCtx) + require.Len(t, members, 1) + require.Equal(t, "mem-1", members[0].ID) + require.Equal(t, "org-1", members[0].OrganizationID) + require.Equal(t, "user-2", members[0].UserID) + require.Equal(t, "member", members[0].Role) + }, + }, + }) +} + +func TestGetOrganizationMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodGet, "/organizations/org-1/members/mem-1", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&GetOrganizationMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + memberID: "mem-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "member not found", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "repo error", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return((*orgtypes.OrganizationMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + member := internaltests.DecodeResponseJSON[orgtypes.OrganizationMember](t, reqCtx) + require.Equal(t, "mem-1", member.ID) + require.Equal(t, "org-1", member.OrganizationID) + require.Equal(t, "user-2", member.UserID) + require.Equal(t, "member", member.Role) + }, + }, + }) +} + +func TestUpdateOrganizationMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodPatch, "/organizations/org-1/members/mem-1", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&UpdateOrganizationMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid request body", + userID: new("user-1"), + body: []byte("{"), + organizationID: "org-1", + memberID: "mem-1", + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "bad request when role is empty", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: " "}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "bad request", + }, + { + name: "organization not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "member not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "repo error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgMemberRepo.On("Update", mock.Anything, mock.MatchedBy(func(member *orgtypes.OrganizationMember) bool { + return member != nil && member.ID == "mem-1" && member.Role == "admin" + })).Return((*orgtypes.OrganizationMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgMemberRepo.On("Update", mock.Anything, mock.MatchedBy(func(member *orgtypes.OrganizationMember) bool { + return member != nil && member.ID == "mem-1" && member.Role == "admin" + })).Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "admin"}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + member := internaltests.DecodeResponseJSON[orgtypes.OrganizationMember](t, reqCtx) + require.Equal(t, "mem-1", member.ID) + require.Equal(t, "admin", member.Role) + }, + }, + }) +} + +func TestDeleteOrganizationMemberHandler_Handler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodDelete, "/organizations/org-1/members/mem-1", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&DeleteOrganizationMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + memberID: "mem-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "member not found", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "repo error", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgMemberRepo.On("Delete", mock.Anything, "mem-1").Return(errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgMemberRepo.On("Delete", mock.Anything, "mem-1").Return(nil).Once() + }, + expectedStatus: http.StatusNoContent, + }, + }) +} diff --git a/plugins/organizations/handlers/organization_team_handlers.go b/plugins/organizations/handlers/organization_team_handlers.go new file mode 100644 index 00000000..334908f2 --- /dev/null +++ b/plugins/organizations/handlers/organization_team_handlers.go @@ -0,0 +1,136 @@ +package handlers + +import ( + "net/http" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type CreateOrganizationTeamHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *CreateOrganizationTeamHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + + var payload types.CreateOrganizationTeamRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + team, err := h.UseCase.CreateTeam(ctx, userID, organizationID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, team) + } +} + +type GetAllOrganizationTeamsHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *GetAllOrganizationTeamsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teams, err := h.UseCase.GetAllTeams(ctx, userID, organizationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, teams) + } +} + +type UpdateOrganizationTeamHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *UpdateOrganizationTeamHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + + var payload types.UpdateOrganizationTeamRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + team, err := h.UseCase.UpdateTeam(ctx, userID, organizationID, teamID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, team) + } +} + +type DeleteOrganizationTeamHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *DeleteOrganizationTeamHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + if err := h.UseCase.DeleteTeam(ctx, userID, organizationID, teamID); err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusNoContent, nil) + } +} diff --git a/plugins/organizations/handlers/organization_team_handlers_test.go b/plugins/organizations/handlers/organization_team_handlers_test.go new file mode 100644 index 00000000..aea93f46 --- /dev/null +++ b/plugins/organizations/handlers/organization_team_handlers_test.go @@ -0,0 +1,597 @@ +package handlers + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type organizationTeamHandlerFixture struct { + orgRepo *orgtests.MockOrganizationRepository + teamRepo *orgtests.MockOrganizationTeamRepository + memberRepo *orgtests.MockOrganizationMemberRepository + teamMemberRepo *orgtests.MockOrganizationTeamMemberRepository +} + +type organizationTeamHandlerCase struct { + name string + userID *string + body []byte + organizationID string + teamID string + prepare func(*organizationTeamHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(*testing.T, *models.RequestContext) +} + +func newOrganizationTeamHandlerFixture() *organizationTeamHandlerFixture { + return &organizationTeamHandlerFixture{ + orgRepo: &orgtests.MockOrganizationRepository{}, + teamRepo: &orgtests.MockOrganizationTeamRepository{}, + memberRepo: &orgtests.MockOrganizationMemberRepository{}, + teamMemberRepo: &orgtests.MockOrganizationTeamMemberRepository{}, + } +} + +func (f *organizationTeamHandlerFixture) useCase() *usecases.OrganizationTeamUseCase { + service := orgservices.NewOrganizationTeamService(f.orgRepo, f.teamRepo, f.memberRepo, f.teamMemberRepo, nil, nil) + return usecases.NewOrganizationTeamUseCase(service) +} + +func (f *organizationTeamHandlerFixture) newRequest(t *testing.T, method, path string, body []byte, userID *string, organizationID, teamID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + t.Helper() + + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + if teamID != "" { + req.SetPathValue("team_id", teamID) + } + return req, w, reqCtx +} + +func TestCreateOrganizationTeamHandler(t *testing.T) { + t.Parallel() + + tests := []organizationTeamHandlerCase{ + { + name: "missing_user", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{Name: "Platform"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + organizationID: "org-1", + body: []byte("{invalid"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{Name: "Platform"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{Name: "Platform"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "bad_request", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{Name: " "}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: internalerrors.ErrBadRequest.Error(), + }, + { + name: "slug_conflict", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{ + Name: "Platform", + }), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return(&orgtypes.OrganizationTeam{ID: "team-2"}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "repository_error", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{ + Name: "Platform", + }), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return((*orgtypes.OrganizationTeam)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{ + Name: "Platform Team", + Slug: new("platform-team"), + Description: new("Core platform team"), + Metadata: json.RawMessage(`{"tier":"core"}`), + }), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform-team").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + f.teamRepo.On("Create", mock.Anything, mock.MatchedBy(func(team *orgtypes.OrganizationTeam) bool { + return team != nil && team.OrganizationID == "org-1" && team.Name == "Platform Team" && team.Slug == "platform-team" && team.Description != nil && *team.Description == "Core platform team" && string(team.Metadata) == `{"tier":"core"}` + })).Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform Team", Slug: "platform-team", Description: new("Core platform team"), Metadata: json.RawMessage(`{"tier":"core"}`)}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + team := internaltests.DecodeResponseJSON[orgtypes.OrganizationTeam](t, reqCtx) + assert.Equal(t, "team-1", team.ID) + assert.Equal(t, "org-1", team.OrganizationID) + assert.Equal(t, "Platform Team", team.Name) + assert.Equal(t, "platform-team", team.Slug) + require.NotNil(t, team.Description) + assert.Equal(t, "Core platform team", *team.Description) + assert.JSONEq(t, `{"tier":"core"}`, string(team.Metadata)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &CreateOrganizationTeamHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodPost, "/organizations/org-1/teams", tt.body, tt.userID, tt.organizationID, "") + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.teamRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) + fixture.teamMemberRepo.AssertExpectations(t) + }) + } +} + +func TestGetAllOrganizationTeamsHandler(t *testing.T) { + t.Parallel() + + tests := []organizationTeamHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "repository_error", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return([]orgtypes.OrganizationTeam{{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + teams := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationTeam](t, reqCtx) + require.Len(t, teams, 1) + assert.Equal(t, "team-1", teams[0].ID) + assert.Equal(t, "org-1", teams[0].OrganizationID) + assert.Equal(t, "Platform", teams[0].Name) + assert.Equal(t, "platform", teams[0].Slug) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &GetAllOrganizationTeamsHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodGet, "/organizations/org-1/teams", nil, tt.userID, tt.organizationID, "") + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.teamRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) + fixture.teamMemberRepo.AssertExpectations(t) + }) + } +} + +func TestUpdateOrganizationTeamHandler(t *testing.T) { + t.Parallel() + + tests := []organizationTeamHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: []byte("{invalid"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team_not_found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team_wrong_organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "bad_request", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: " "}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: internalerrors.ErrBadRequest.Error(), + }, + { + name: "slug_conflict", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp", Slug: new("platform")}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return(&orgtypes.OrganizationTeam{ID: "team-2"}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "repository_error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return((*orgtypes.OrganizationTeam)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "update_error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + f.teamRepo.On("Update", mock.Anything, mock.MatchedBy(func(team *orgtypes.OrganizationTeam) bool { + return team != nil && team.ID == "team-1" && team.OrganizationID == "org-1" && team.Name == "Platform Revamp" && team.Slug == "platform" + })).Return((*orgtypes.OrganizationTeam)(nil), errors.New("update failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "update failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{ + Name: "Platform Revamp", + Description: new("Updated platform team"), + Metadata: json.RawMessage(`{"tier":"core"}`), + }), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + f.teamRepo.On("Update", mock.Anything, mock.MatchedBy(func(team *orgtypes.OrganizationTeam) bool { + return team != nil && team.ID == "team-1" && team.OrganizationID == "org-1" && team.Name == "Platform Revamp" && team.Slug == "platform" && team.Description != nil && *team.Description == "Updated platform team" && string(team.Metadata) == `{"tier":"core"}` + })).Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform Revamp", Slug: "platform", Description: new("Updated platform team"), Metadata: json.RawMessage(`{"tier":"core"}`)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + team := internaltests.DecodeResponseJSON[orgtypes.OrganizationTeam](t, reqCtx) + assert.Equal(t, "team-1", team.ID) + assert.Equal(t, "org-1", team.OrganizationID) + assert.Equal(t, "Platform Revamp", team.Name) + assert.Equal(t, "platform", team.Slug) + require.NotNil(t, team.Description) + assert.Equal(t, "Updated platform team", *team.Description) + assert.JSONEq(t, `{"tier":"core"}`, string(team.Metadata)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &UpdateOrganizationTeamHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodPatch, "/organizations/org-1/teams/team-1", tt.body, tt.userID, tt.organizationID, tt.teamID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.teamRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) + fixture.teamMemberRepo.AssertExpectations(t) + }) + } +} + +func TestDeleteOrganizationTeamHandler(t *testing.T) { + t.Parallel() + + tests := []organizationTeamHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + teamID: "team-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team_not_found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team_wrong_organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "repository_error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("Delete", mock.Anything, "team-1").Return(errors.New("delete failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "delete failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("Delete", mock.Anything, "team-1").Return(nil).Once() + }, + expectedStatus: http.StatusNoContent, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + assert.Equal(t, "null", string(reqCtx.ResponseBody)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &DeleteOrganizationTeamHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodDelete, "/organizations/org-1/teams/team-1", nil, tt.userID, tt.organizationID, tt.teamID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.teamRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) + fixture.teamMemberRepo.AssertExpectations(t) + }) + } +} diff --git a/plugins/organizations/handlers/organization_team_member_handlers.go b/plugins/organizations/handlers/organization_team_member_handlers.go new file mode 100644 index 00000000..0e4be010 --- /dev/null +++ b/plugins/organizations/handlers/organization_team_member_handlers.go @@ -0,0 +1,134 @@ +package handlers + +import ( + "net/http" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type AddOrganizationTeamMemberHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *AddOrganizationTeamMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + + var payload types.AddOrganizationTeamMemberRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + teamMember, err := h.UseCase.AddTeamMember(ctx, userID, organizationID, teamID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, teamMember) + } +} + +type GetAllOrganizationTeamMembersHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *GetAllOrganizationTeamMembersHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + page := util.GetQueryInt(r, "page", 1) + limit := util.GetQueryInt(r, "limit", 10) + teamMembers, err := h.UseCase.GetAllTeamMembers(ctx, userID, organizationID, teamID, page, limit) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, teamMembers) + } +} + +type GetOrganizationTeamMemberHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *GetOrganizationTeamMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + memberID := r.PathValue("member_id") + teamMember, err := h.UseCase.GetTeamMember(ctx, userID, organizationID, teamID, memberID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, teamMember) + } +} + +type DeleteOrganizationTeamMemberHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *DeleteOrganizationTeamMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + memberID := r.PathValue("member_id") + if err := h.UseCase.RemoveTeamMember(ctx, userID, organizationID, teamID, memberID); err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusNoContent, nil) + } +} diff --git a/plugins/organizations/handlers/organization_team_member_handlers_test.go b/plugins/organizations/handlers/organization_team_member_handlers_test.go new file mode 100644 index 00000000..75268a0e --- /dev/null +++ b/plugins/organizations/handlers/organization_team_member_handlers_test.go @@ -0,0 +1,648 @@ +package handlers + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type organizationTeamMemberHandlerFixture struct { + orgRepo *orgtests.MockOrganizationRepository + orgMemberRepo *orgtests.MockOrganizationMemberRepository + orgTeamRepo *orgtests.MockOrganizationTeamRepository + orgTeamMemberRepo *orgtests.MockOrganizationTeamMemberRepository +} + +func newOrganizationTeamMemberHandlerFixture() *organizationTeamMemberHandlerFixture { + return &organizationTeamMemberHandlerFixture{ + orgRepo: &orgtests.MockOrganizationRepository{}, + orgTeamRepo: &orgtests.MockOrganizationTeamRepository{}, + orgMemberRepo: &orgtests.MockOrganizationMemberRepository{}, + orgTeamMemberRepo: &orgtests.MockOrganizationTeamMemberRepository{}, + } +} + +func (f *organizationTeamMemberHandlerFixture) useCase() *usecases.OrganizationTeamUseCase { + service := orgservices.NewOrganizationTeamService(f.orgRepo, f.orgTeamRepo, f.orgMemberRepo, f.orgTeamMemberRepo, nil, nil) + return usecases.NewOrganizationTeamUseCase(service) +} + +func (f *organizationTeamMemberHandlerFixture) newRequest(t *testing.T, method string, path string, body []byte, userID *string, organizationID string, teamID string, memberID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + t.Helper() + + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + if teamID != "" { + req.SetPathValue("team_id", teamID) + } + if memberID != "" { + req.SetPathValue("member_id", memberID) + } + return req, w, reqCtx +} + +type organizationTeamMemberHandlerCase struct { + name string + userID *string + body []byte + organizationID string + teamID string + memberID string + prepare func(*organizationTeamMemberHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(*testing.T, *models.RequestContext) +} + +func runOrganizationTeamMemberHandlerCases(t *testing.T, method, path string, buildHandler func(*organizationTeamMemberHandlerFixture) http.HandlerFunc, cases []organizationTeamMemberHandlerCase) { + t.Helper() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamMemberHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := buildHandler(fixture) + req, w, reqCtx := fixture.newRequest(t, method, path, tt.body, tt.userID, tt.organizationID, tt.teamID, tt.memberID) + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.orgMemberRepo.AssertExpectations(t) + fixture.orgTeamRepo.AssertExpectations(t) + fixture.orgTeamMemberRepo.AssertExpectations(t) + }) + } +} + +func TestAddOrganizationTeamMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationTeamMemberHandlerCases(t, http.MethodPost, "/organizations/org-1/teams/team-1/members", func(fixture *organizationTeamMemberHandlerFixture) http.HandlerFunc { + return (&AddOrganizationTeamMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationTeamMemberHandlerCase{ + { + name: "missing user", + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid request body", + userID: new("user-1"), + body: []byte("{"), + organizationID: "org-1", + teamID: "team-1", + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "organization not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team from another organization", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "bad request when member id is empty", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: ""}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "organization member lookup error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return((*orgtypes.OrganizationMember)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "organization member not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team member conflict", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return(&orgtypes.OrganizationMember{ID: "member-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "member from another organization", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return(&orgtypes.OrganizationMember{ID: "member-1", OrganizationID: "org-2", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "team member create error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return(&orgtypes.OrganizationMember{ID: "member-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), nil).Once() + fixture.orgTeamMemberRepo.On("Create", mock.Anything, mock.MatchedBy(func(teamMember *orgtypes.OrganizationTeamMember) bool { + return teamMember != nil && teamMember.TeamID == "team-1" && teamMember.MemberID == "member-1" + })).Return((*orgtypes.OrganizationTeamMember)(nil), errors.New("create failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "create failed", + }, + { + name: "success", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return(&orgtypes.OrganizationMember{ID: "member-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), nil).Once() + fixture.orgTeamMemberRepo.On("Create", mock.Anything, mock.MatchedBy(func(teamMember *orgtypes.OrganizationTeamMember) bool { + return teamMember != nil && teamMember.TeamID == "team-1" && teamMember.MemberID == "member-1" + })).Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + teamMember := internaltests.DecodeResponseJSON[orgtypes.OrganizationTeamMember](t, reqCtx) + require.Equal(t, "team-member-1", teamMember.ID) + require.Equal(t, "team-1", teamMember.TeamID) + require.Equal(t, "member-1", teamMember.MemberID) + }, + }, + }) +} + +func TestGetAllOrganizationTeamMembersHandler(t *testing.T) { + t.Parallel() + + runOrganizationTeamMemberHandlerCases(t, http.MethodGet, "/organizations/org-1/teams/team-1/members", func(fixture *organizationTeamMemberHandlerFixture) http.HandlerFunc { + return (&GetAllOrganizationTeamMembersHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationTeamMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + teamID: "team-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team from another organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team lookup error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "repo error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetAllByTeamID", mock.Anything, "team-1", 1, 10).Return(([]orgtypes.OrganizationTeamMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetAllByTeamID", mock.Anything, "team-1", 1, 10).Return([]orgtypes.OrganizationTeamMember{{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + teamMembers := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationTeamMember](t, reqCtx) + require.Len(t, teamMembers, 1) + require.Equal(t, "team-member-1", teamMembers[0].ID) + require.Equal(t, "team-1", teamMembers[0].TeamID) + require.Equal(t, "member-1", teamMembers[0].MemberID) + }, + }, + }) +} + +func TestGetOrganizationTeamMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationTeamMemberHandlerCases(t, http.MethodGet, "/organizations/org-1/teams/team-1/members/member-1", func(fixture *organizationTeamMemberHandlerFixture) http.HandlerFunc { + return (&GetOrganizationTeamMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationTeamMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team from another organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team member lookup error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "member not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + teamMember := internaltests.DecodeResponseJSON[orgtypes.OrganizationTeamMember](t, reqCtx) + require.Equal(t, "team-member-1", teamMember.ID) + require.Equal(t, "team-1", teamMember.TeamID) + require.Equal(t, "member-1", teamMember.MemberID) + }, + }, + }) +} + +func TestDeleteOrganizationTeamMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationTeamMemberHandlerCases(t, http.MethodDelete, "/organizations/org-1/teams/team-1/members/member-1", func(fixture *organizationTeamMemberHandlerFixture) http.HandlerFunc { + return (&DeleteOrganizationTeamMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationTeamMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team from another organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team lookup error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "team member lookup error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "member not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "delete error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + fixture.orgTeamMemberRepo.On("DeleteByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(errors.New("delete failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "delete failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + fixture.orgTeamMemberRepo.On("DeleteByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(nil).Once() + }, + expectedStatus: http.StatusNoContent, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + assert.Equal(t, "null", string(reqCtx.ResponseBody)) + }, + }, + }) +} diff --git a/plugins/organizations/hooks.go b/plugins/organizations/hooks.go new file mode 100644 index 00000000..50e9f660 --- /dev/null +++ b/plugins/organizations/hooks.go @@ -0,0 +1,203 @@ +package organizations + +import "github.com/Authula/authula/plugins/organizations/types" + +type OrganizationsHookExecutor struct { + config *types.OrganizationsDatabaseHooksConfig +} + +func NewOrganizationsHookExecutor(config *types.OrganizationsDatabaseHooksConfig) *OrganizationsHookExecutor { + return &OrganizationsHookExecutor{config: config} +} + +// Organization Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganization(organization *types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.BeforeCreate == nil { + return nil + } + return e.config.Organizations.BeforeCreate(organization) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganization(organization types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.AfterCreate == nil { + return nil + } + return e.config.Organizations.AfterCreate(organization) +} + +func (e *OrganizationsHookExecutor) BeforeUpdateOrganization(organization *types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.BeforeUpdate == nil { + return nil + } + return e.config.Organizations.BeforeUpdate(organization) +} + +func (e *OrganizationsHookExecutor) AfterUpdateOrganization(organization types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.AfterUpdate == nil { + return nil + } + return e.config.Organizations.AfterUpdate(organization) +} + +func (e *OrganizationsHookExecutor) BeforeDeleteOrganization(organization *types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.BeforeDelete == nil { + return nil + } + return e.config.Organizations.BeforeDelete(organization) +} + +func (e *OrganizationsHookExecutor) AfterDeleteOrganization(organization types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.AfterDelete == nil { + return nil + } + return e.config.Organizations.AfterDelete(organization) +} + +// Organization Member Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganizationMember(member *types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.BeforeCreate == nil { + return nil + } + return e.config.Members.BeforeCreate(member) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganizationMember(member types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.AfterCreate == nil { + return nil + } + return e.config.Members.AfterCreate(member) +} + +func (e *OrganizationsHookExecutor) BeforeUpdateOrganizationMember(member *types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.BeforeUpdate == nil { + return nil + } + return e.config.Members.BeforeUpdate(member) +} + +func (e *OrganizationsHookExecutor) AfterUpdateOrganizationMember(member types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.AfterUpdate == nil { + return nil + } + return e.config.Members.AfterUpdate(member) +} + +func (e *OrganizationsHookExecutor) BeforeDeleteOrganizationMember(member *types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.BeforeDelete == nil { + return nil + } + return e.config.Members.BeforeDelete(member) +} + +func (e *OrganizationsHookExecutor) AfterDeleteOrganizationMember(member types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.AfterDelete == nil { + return nil + } + return e.config.Members.AfterDelete(member) +} + +// Organization Invitation Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganizationInvitation(invitation *types.OrganizationInvitation) error { + if e == nil || e.config == nil || e.config.Invitations == nil || e.config.Invitations.BeforeCreate == nil { + return nil + } + return e.config.Invitations.BeforeCreate(invitation) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganizationInvitation(invitation types.OrganizationInvitation) error { + if e == nil || e.config == nil || e.config.Invitations == nil || e.config.Invitations.AfterCreate == nil { + return nil + } + return e.config.Invitations.AfterCreate(invitation) +} + +func (e *OrganizationsHookExecutor) BeforeUpdateOrganizationInvitation(invitation *types.OrganizationInvitation) error { + if e == nil || e.config == nil || e.config.Invitations == nil || e.config.Invitations.BeforeUpdate == nil { + return nil + } + return e.config.Invitations.BeforeUpdate(invitation) +} + +func (e *OrganizationsHookExecutor) AfterUpdateOrganizationInvitation(invitation types.OrganizationInvitation) error { + if e == nil || e.config == nil || e.config.Invitations == nil || e.config.Invitations.AfterUpdate == nil { + return nil + } + return e.config.Invitations.AfterUpdate(invitation) +} + +// Organization Team Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganizationTeam(team *types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.BeforeCreate == nil { + return nil + } + return e.config.Teams.BeforeCreate(team) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganizationTeam(team types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.AfterCreate == nil { + return nil + } + return e.config.Teams.AfterCreate(team) +} + +func (e *OrganizationsHookExecutor) BeforeUpdateOrganizationTeam(team *types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.BeforeUpdate == nil { + return nil + } + return e.config.Teams.BeforeUpdate(team) +} + +func (e *OrganizationsHookExecutor) AfterUpdateOrganizationTeam(team types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.AfterUpdate == nil { + return nil + } + return e.config.Teams.AfterUpdate(team) +} + +func (e *OrganizationsHookExecutor) BeforeDeleteOrganizationTeam(team *types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.BeforeDelete == nil { + return nil + } + return e.config.Teams.BeforeDelete(team) +} + +func (e *OrganizationsHookExecutor) AfterDeleteOrganizationTeam(team types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.AfterDelete == nil { + return nil + } + return e.config.Teams.AfterDelete(team) +} + +// Organization Team Member Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganizationTeamMember(member *types.OrganizationTeamMember) error { + if e == nil || e.config == nil || e.config.TeamMembers == nil || e.config.TeamMembers.BeforeCreate == nil { + return nil + } + return e.config.TeamMembers.BeforeCreate(member) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganizationTeamMember(member types.OrganizationTeamMember) error { + if e == nil || e.config == nil || e.config.TeamMembers == nil || e.config.TeamMembers.AfterCreate == nil { + return nil + } + return e.config.TeamMembers.AfterCreate(member) +} + +func (e *OrganizationsHookExecutor) BeforeDeleteOrganizationTeamMember(member *types.OrganizationTeamMember) error { + if e == nil || e.config == nil || e.config.TeamMembers == nil || e.config.TeamMembers.BeforeDelete == nil { + return nil + } + return e.config.TeamMembers.BeforeDelete(member) +} + +func (e *OrganizationsHookExecutor) AfterDeleteOrganizationTeamMember(member types.OrganizationTeamMember) error { + if e == nil || e.config == nil || e.config.TeamMembers == nil || e.config.TeamMembers.AfterDelete == nil { + return nil + } + return e.config.TeamMembers.AfterDelete(member) +} diff --git a/plugins/organizations/hooks_test.go b/plugins/organizations/hooks_test.go new file mode 100644 index 00000000..7b308410 --- /dev/null +++ b/plugins/organizations/hooks_test.go @@ -0,0 +1,149 @@ +package organizations + +import ( + "errors" + "testing" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func TestOrganizationsHookExecutor_NilHooksAreNoop(t *testing.T) { + t.Parallel() + + executor := NewOrganizationsHookExecutor(nil) + + if err := executor.BeforeCreateOrganization(&types.Organization{ID: "org-1"}); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterCreateOrganization(types.Organization{ID: "org-1"}); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.BeforeCreateOrganizationInvitation(&types.OrganizationInvitation{ID: "inv-1"}); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterCreateOrganizationTeam(types.OrganizationTeam{ID: "team-1"}); err != nil { + t.Fatalf("expected nil error, got %v", err) + } +} + +func TestOrganizationsHookExecutor_OrganizationCreateHooks(t *testing.T) { + t.Parallel() + + var beforeCalled bool + var afterCalled bool + + executor := NewOrganizationsHookExecutor(&types.OrganizationsDatabaseHooksConfig{ + Organizations: &types.OrganizationDatabaseHooksConfig{ + BeforeCreate: func(organization *types.Organization) error { + beforeCalled = true + if organization == nil { + return errors.New("organization is nil") + } + if organization.ID != "org-1" { + t.Fatalf("unexpected organization ID: %s", organization.ID) + } + return nil + }, + AfterCreate: func(organization types.Organization) error { + afterCalled = true + if organization.ID != "org-1" { + t.Fatalf("unexpected organization ID: %s", organization.ID) + } + return nil + }, + }, + }) + + organization := &types.Organization{ID: "org-1", Name: "Acme"} + if err := executor.BeforeCreateOrganization(organization); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterCreateOrganization(*organization); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + if !beforeCalled { + t.Fatal("expected BeforeCreate hook to be called") + } + if !afterCalled { + t.Fatal("expected AfterCreate hook to be called") + } +} + +func TestOrganizationsHookExecutor_OrganizationCreateHookError(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + executor := NewOrganizationsHookExecutor(&types.OrganizationsDatabaseHooksConfig{ + Organizations: &types.OrganizationDatabaseHooksConfig{ + BeforeCreate: func(organization *types.Organization) error { + return someErr + }, + }, + }) + + err := executor.BeforeCreateOrganization(&types.Organization{ID: "org-1"}) + if !errors.Is(err, someErr) { + t.Fatalf("expected someErr error, got %v", err) + } +} + +func TestOrganizationsHookExecutor_MemberUpdateDeleteHooks(t *testing.T) { + t.Parallel() + + var beforeUpdateCalled bool + var afterUpdateCalled bool + var beforeDeleteCalled bool + var afterDeleteCalled bool + + executor := NewOrganizationsHookExecutor(&types.OrganizationsDatabaseHooksConfig{ + Members: &types.OrganizationMemberDatabaseHooksConfig{ + BeforeUpdate: func(member *types.OrganizationMember) error { + beforeUpdateCalled = true + if member == nil || member.ID != "mem-1" { + t.Fatalf("unexpected member in before update hook: %+v", member) + } + return nil + }, + AfterUpdate: func(member types.OrganizationMember) error { + afterUpdateCalled = true + if member.ID != "mem-1" { + t.Fatalf("unexpected member in after update hook: %+v", member) + } + return nil + }, + BeforeDelete: func(member *types.OrganizationMember) error { + beforeDeleteCalled = true + if member == nil || member.ID != "mem-1" { + t.Fatalf("unexpected member in before delete hook: %+v", member) + } + return nil + }, + AfterDelete: func(member types.OrganizationMember) error { + afterDeleteCalled = true + if member.ID != "mem-1" { + t.Fatalf("unexpected member in after delete hook: %+v", member) + } + return nil + }, + }, + }) + + member := &types.OrganizationMember{ID: "mem-1", Role: "member"} + if err := executor.BeforeUpdateOrganizationMember(member); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterUpdateOrganizationMember(*member); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.BeforeDeleteOrganizationMember(member); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterDeleteOrganizationMember(*member); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + if !beforeUpdateCalled || !afterUpdateCalled || !beforeDeleteCalled || !afterDeleteCalled { + t.Fatal("expected member update and delete hooks to be called") + } +} diff --git a/plugins/organizations/migrations.go b/plugins/organizations/migrations.go new file mode 100644 index 00000000..c149e944 --- /dev/null +++ b/plugins/organizations/migrations.go @@ -0,0 +1,330 @@ +package organizations + +import ( + "context" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/migrations" +) + +func organizationsMigrationsForProvider(provider string) []migrations.Migration { + return migrations.ForProvider(provider, migrations.ProviderVariants{ + "sqlite": func() []migrations.Migration { return []migrations.Migration{organizationsSQLiteInitial()} }, + "postgres": func() []migrations.Migration { return []migrations.Migration{organizationsPostgresInitial()} }, + "mysql": func() []migrations.Migration { return []migrations.Migration{organizationsMySQLInitial()} }, + }) +} + +func organizationsSQLiteInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260321000000_organizations_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `PRAGMA foreign_keys = ON;`, + // ----------------------------------- + `CREATE TABLE IF NOT EXISTS organizations ( + id TEXT PRIMARY KEY, + owner_id TEXT NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL UNIQUE, + logo TEXT, + metadata TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE + );`, + `CREATE INDEX IF NOT EXISTS idx_organizations_owner_id ON organizations(owner_id);`, + // ----------------------------------- + `CREATE TABLE IF NOT EXISTS organization_members ( + id TEXT PRIMARY KEY, + organization_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role VARCHAR(255) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE (organization_id, user_id) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_members_organization_id ON organization_members(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_members_user_id ON organization_members(user_id);`, + // ----------------------------------- + `CREATE TABLE IF NOT EXISTS organization_invitations ( + id TEXT PRIMARY KEY, + email VARCHAR(255) NOT NULL, + inviter_id TEXT NOT NULL, + organization_id TEXT NOT NULL, + role VARCHAR(255) NOT NULL, + status VARCHAR(32) NOT NULL DEFAULT 'pending', + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + FOREIGN KEY (inviter_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE (organization_id, email) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_email ON organization_invitations(email);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_organization_id ON organization_invitations(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_inviter_id ON organization_invitations(inviter_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_status_expires_at ON organization_invitations(status, expires_at);`, + // ----------------------------------- + `CREATE TABLE IF NOT EXISTS organization_teams ( + id TEXT PRIMARY KEY, + organization_id TEXT NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL, + description TEXT, + metadata TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + UNIQUE (organization_id, slug) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_teams_organization_id ON organization_teams(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_teams_slug ON organization_teams(slug);`, + `CREATE TABLE IF NOT EXISTS organization_team_members ( + id TEXT PRIMARY KEY, + team_id TEXT NOT NULL, + member_id TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (team_id) REFERENCES organization_teams(id) ON DELETE CASCADE, + FOREIGN KEY (member_id) REFERENCES organization_members(id) ON DELETE CASCADE, + UNIQUE (team_id, member_id) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_team_members_team_id ON organization_team_members(team_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_team_members_member_id ON organization_team_members(member_id);`, + // ----------------------------------- + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS organization_team_members;`, + `DROP TABLE IF EXISTS organization_teams;`, + `DROP TABLE IF EXISTS organization_invitations;`, + `DROP TABLE IF EXISTS organization_members;`, + `DROP TABLE IF EXISTS organizations;`, + ) + }, + } +} + +func organizationsPostgresInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_organizations_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE OR REPLACE FUNCTION organizations_set_updated_at_fn() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql;`, + `CREATE TABLE IF NOT EXISTS organizations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + owner_id UUID NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL UNIQUE, + logo TEXT, + metadata JSONB, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organizations_owner FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE + );`, + `DROP TRIGGER IF EXISTS update_organizations_updated_at_trigger ON organizations;`, + `CREATE TRIGGER update_organizations_updated_at_trigger + BEFORE UPDATE ON organizations + FOR EACH ROW + EXECUTE FUNCTION organizations_set_updated_at_fn();`, + `CREATE INDEX IF NOT EXISTS idx_organizations_owner_id ON organizations(owner_id);`, + `CREATE TABLE IF NOT EXISTS organization_members ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + organization_id UUID NOT NULL, + user_id UUID NOT NULL, + role VARCHAR(255) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organization_members_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_members_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_members_organization_user UNIQUE (organization_id, user_id) + );`, + `DROP TRIGGER IF EXISTS update_organization_members_updated_at_trigger ON organization_members;`, + `CREATE TRIGGER update_organization_members_updated_at_trigger + BEFORE UPDATE ON organization_members + FOR EACH ROW + EXECUTE FUNCTION organizations_set_updated_at_fn();`, + `CREATE INDEX IF NOT EXISTS idx_organization_members_organization_id ON organization_members(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_members_user_id ON organization_members(user_id);`, + `CREATE TABLE IF NOT EXISTS organization_invitations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + organization_id UUID NOT NULL, + inviter_id UUID NOT NULL, + email VARCHAR(255) NOT NULL, + role VARCHAR(255) NOT NULL, + status VARCHAR(32) NOT NULL DEFAULT 'pending', + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organization_invitations_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_invitations_inviter FOREIGN KEY (inviter_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT chk_organization_invitations_status CHECK (status IN ('pending', 'accepted', 'rejected', 'revoked', 'expired')), + CONSTRAINT uq_organization_invitations_organization_email UNIQUE (organization_id, email) + );`, + `DROP TRIGGER IF EXISTS update_organization_invitations_updated_at_trigger ON organization_invitations;`, + `CREATE TRIGGER update_organization_invitations_updated_at_trigger + BEFORE UPDATE ON organization_invitations + FOR EACH ROW + EXECUTE FUNCTION organizations_set_updated_at_fn();`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_organization_id ON organization_invitations(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_inviter_id ON organization_invitations(inviter_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_email ON organization_invitations(email);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_status_expires_at ON organization_invitations(status, expires_at);`, + `CREATE TABLE IF NOT EXISTS organization_teams ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + organization_id UUID NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL, + description TEXT, + metadata JSONB, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organization_teams_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_teams_organization_slug UNIQUE (organization_id, slug) + );`, + `DROP TRIGGER IF EXISTS update_organization_teams_updated_at_trigger ON organization_teams;`, + `CREATE TRIGGER update_organization_teams_updated_at_trigger + BEFORE UPDATE ON organization_teams + FOR EACH ROW + EXECUTE FUNCTION organizations_set_updated_at_fn();`, + `CREATE INDEX IF NOT EXISTS idx_organization_teams_organization_id ON organization_teams(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_teams_slug ON organization_teams(slug);`, + `CREATE TABLE IF NOT EXISTS organization_team_members ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + team_id UUID NOT NULL, + member_id UUID NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organization_team_members_team FOREIGN KEY (team_id) REFERENCES organization_teams(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_team_members_member FOREIGN KEY (member_id) REFERENCES organization_members(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_team_members_team_member UNIQUE (team_id, member_id) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_team_members_team_id ON organization_team_members(team_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_team_members_member_id ON organization_team_members(member_id);`, + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS organization_team_members;`, + `DROP TRIGGER IF EXISTS update_organization_teams_updated_at_trigger ON organization_teams;`, + `DROP TABLE IF EXISTS organization_teams;`, + `DROP TRIGGER IF EXISTS update_organization_invitations_updated_at_trigger ON organization_invitations;`, + `DROP TABLE IF EXISTS organization_invitations;`, + `DROP TRIGGER IF EXISTS update_organization_members_updated_at_trigger ON organization_members;`, + `DROP TABLE IF EXISTS organization_members;`, + `DROP TRIGGER IF EXISTS update_organizations_updated_at_trigger ON organizations;`, + `DROP TABLE IF EXISTS organizations;`, + `DROP FUNCTION IF EXISTS organizations_set_updated_at_fn();`, + ) + }, + } +} + +func organizationsMySQLInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_organizations_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE TABLE IF NOT EXISTS organizations ( + id BINARY(16) NOT NULL PRIMARY KEY, + owner_id BINARY(16) NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL UNIQUE, + logo TEXT NULL, + metadata JSON NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + CONSTRAINT fk_organizations_owner FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_organizations_owner_id (owner_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + `CREATE TABLE IF NOT EXISTS organization_members ( + id BINARY(16) NOT NULL PRIMARY KEY, + organization_id BINARY(16) NOT NULL, + user_id BINARY(16) NOT NULL, + role VARCHAR(255) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + CONSTRAINT fk_organization_members_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_members_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_members_organization_user UNIQUE (organization_id, user_id), + INDEX idx_organization_members_organization_id (organization_id), + INDEX idx_organization_members_user_id (user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + `CREATE TABLE IF NOT EXISTS organization_invitations ( + id BINARY(16) NOT NULL PRIMARY KEY, + organization_id BINARY(16) NOT NULL, + inviter_id BINARY(16) NOT NULL, + email VARCHAR(255) NOT NULL, + role VARCHAR(255) NOT NULL, + status VARCHAR(32) NOT NULL DEFAULT 'pending', + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + CONSTRAINT fk_organization_invitations_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_invitations_inviter FOREIGN KEY (inviter_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT chk_organization_invitations_status CHECK (status IN ('pending', 'accepted', 'rejected', 'revoked', 'expired')), + CONSTRAINT uq_organization_invitations_organization_email UNIQUE (organization_id, email), + INDEX idx_organization_invitations_organization_id (organization_id), + INDEX idx_organization_invitations_inviter_id (inviter_id), + INDEX idx_organization_invitations_email (email), + INDEX idx_organization_invitations_status_expires_at (status, expires_at) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + `CREATE TABLE IF NOT EXISTS organization_teams ( + id BINARY(16) NOT NULL PRIMARY KEY, + organization_id BINARY(16) NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL, + description TEXT NULL, + metadata JSON NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + CONSTRAINT fk_organization_teams_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_teams_organization_slug UNIQUE (organization_id, slug), + INDEX idx_organization_teams_organization_id (organization_id), + INDEX idx_organization_teams_slug (slug) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + `CREATE TABLE IF NOT EXISTS organization_team_members ( + id BINARY(16) NOT NULL PRIMARY KEY, + team_id BINARY(16) NOT NULL, + member_id BINARY(16) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT fk_organization_team_members_team FOREIGN KEY (team_id) REFERENCES organization_teams(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_team_members_member FOREIGN KEY (member_id) REFERENCES organization_members(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_team_members_team_member UNIQUE (team_id, member_id), + INDEX idx_organization_team_members_team_id (team_id), + INDEX idx_organization_team_members_member_id (member_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS organization_team_members;`, + `DROP TABLE IF EXISTS organization_teams;`, + `DROP TABLE IF EXISTS organization_invitations;`, + `DROP TABLE IF EXISTS organization_members;`, + `DROP TABLE IF EXISTS organizations;`, + ) + }, + } +} diff --git a/plugins/organizations/plugin.go b/plugins/organizations/plugin.go new file mode 100644 index 00000000..cf7cd425 --- /dev/null +++ b/plugins/organizations/plugin.go @@ -0,0 +1,115 @@ +package organizations + +import ( + "fmt" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/migrations" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/services" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" + rootservices "github.com/Authula/authula/services" +) + +type OrganizationsPlugin struct { + globalConfig *models.Config + pluginConfig types.OrganizationsPluginConfig + ctx *models.PluginContext + logger models.Logger + databaseHooks *OrganizationsHookExecutor + organizationRepo repositories.OrganizationRepository + organizationService *services.OrganizationService + invitationRepo repositories.OrganizationInvitationRepository + invitationService *services.OrganizationInvitationService + memberRepo repositories.OrganizationMemberRepository + memberService *services.OrganizationMemberService + teamRepo repositories.OrganizationTeamRepository + teamMemberRepo repositories.OrganizationTeamMemberRepository + teamService *services.OrganizationTeamService + organizationUseCase *usecases.OrganizationUseCase + invitationUseCase *usecases.OrganizationInvitationUseCase + memberUseCase *usecases.OrganizationMemberUseCase + teamUseCase *usecases.OrganizationTeamUseCase + Api *API +} + +func New(config types.OrganizationsPluginConfig) *OrganizationsPlugin { + config.ApplyDefaults() + return &OrganizationsPlugin{pluginConfig: config} +} + +func (p *OrganizationsPlugin) Metadata() models.PluginMetadata { + return models.PluginMetadata{ + ID: models.PluginOrganizations.String(), + Version: "1.0.0", + Description: "Provides organization, teams, members, and invitations.", + } +} + +func (p *OrganizationsPlugin) Config() any { + return p.pluginConfig +} + +func (p *OrganizationsPlugin) Init(ctx *models.PluginContext) error { + p.ctx = ctx + p.logger = ctx.Logger + p.globalConfig = ctx.GetConfig() + + if err := util.LoadPluginConfig(p.globalConfig, p.Metadata().ID, &p.pluginConfig); err != nil { + return err + } + + p.pluginConfig.ApplyDefaults() + p.databaseHooks = NewOrganizationsHookExecutor(p.pluginConfig.DatabaseHooks) + + userService, ok := ctx.ServiceRegistry.Get(models.ServiceUser.String()).(rootservices.UserService) + if !ok { + return fmt.Errorf("user service not available in service registry") + } + + var mailerService rootservices.MailerService + if rawMailerService := ctx.ServiceRegistry.Get(models.ServiceMailer.String()); rawMailerService != nil { + if typedMailerService, ok := rawMailerService.(rootservices.MailerService); ok { + mailerService = typedMailerService + } else { + p.logger.Warn("mailer service has unexpected type, skipping organization invitation email delivery") + } + } else { + p.logger.Warn("mailer service not available, skipping organization invitation email delivery") + } + + p.organizationRepo = repositories.NewBunOrganizationRepository(ctx.DB) + p.organizationService = services.NewOrganizationService(p.organizationRepo, p.databaseHooks) + p.invitationRepo = repositories.NewBunOrganizationInvitationRepository(ctx.DB) + p.memberRepo = repositories.NewBunOrganizationMemberRepository(ctx.DB) + p.teamRepo = repositories.NewBunOrganizationTeamRepository(ctx.DB) + p.teamMemberRepo = repositories.NewBunOrganizationTeamMemberRepository(ctx.DB) + p.invitationService = services.NewOrganizationInvitationService(ctx.DB, p.globalConfig, &p.pluginConfig, p.logger, userService, p.organizationRepo, p.invitationRepo, p.memberRepo, mailerService, ctx.EventBus, p.databaseHooks, p.databaseHooks) + p.memberService = services.NewOrganizationMemberService(userService, p.organizationRepo, p.memberRepo, p.databaseHooks) + p.teamService = services.NewOrganizationTeamService(p.organizationRepo, p.teamRepo, p.memberRepo, p.teamMemberRepo, p.databaseHooks, p.databaseHooks) + p.organizationUseCase = usecases.NewOrganizationUseCase(p.organizationService) + p.invitationUseCase = usecases.NewOrganizationInvitationUseCase(p.invitationService) + p.memberUseCase = usecases.NewOrganizationMemberUseCase(p.memberService) + p.teamUseCase = usecases.NewOrganizationTeamUseCase(p.teamService) + p.Api = BuildAPI(p.organizationUseCase, p.invitationUseCase, p.memberUseCase, p.teamUseCase) + + return nil +} + +func (p *OrganizationsPlugin) Migrations(provider string) []migrations.Migration { + return organizationsMigrationsForProvider(provider) +} + +func (p *OrganizationsPlugin) DependsOn() []string { + return nil +} + +func (p *OrganizationsPlugin) Routes() []models.Route { + return Routes(p) +} + +func (p *OrganizationsPlugin) Close() error { + return nil +} diff --git a/plugins/organizations/repositories/bun_organization_invitation_repository.go b/plugins/organizations/repositories/bun_organization_invitation_repository.go new file mode 100644 index 00000000..f23b20d7 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_invitation_repository.go @@ -0,0 +1,98 @@ +package repositories + +import ( + "context" + "database/sql" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationInvitationRepository struct { + db bun.IDB +} + +func NewBunOrganizationInvitationRepository(db bun.IDB) OrganizationInvitationRepository { + return &BunOrganizationInvitationRepository{db: db} +} + +func (r *BunOrganizationInvitationRepository) Create(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(invitation).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(invitation).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return invitation, nil +} + +func (r *BunOrganizationInvitationRepository) GetByID(ctx context.Context, invitationID string) (*types.OrganizationInvitation, error) { + invitation := new(types.OrganizationInvitation) + err := r.db.NewSelect().Model(invitation).Where("id = ?", invitationID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return invitation, err +} + +func (r *BunOrganizationInvitationRepository) GetByOrganizationIDAndEmail(ctx context.Context, organizationID, email string) (*types.OrganizationInvitation, error) { + invitation := new(types.OrganizationInvitation) + err := r.db.NewSelect().Model(invitation). + Where("organization_id = ? AND email = ?", organizationID, email). + Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return invitation, err +} + +func (r *BunOrganizationInvitationRepository) GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationInvitation, error) { + invitations := make([]types.OrganizationInvitation, 0) + err := r.db.NewSelect().Model(&invitations). + Where("organization_id = ?", organizationID). + OrderExpr("created_at DESC"). + Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationInvitation{}, nil + } + return invitations, err +} + +func (r *BunOrganizationInvitationRepository) GetAllPendingByEmail(ctx context.Context, email string) ([]types.OrganizationInvitation, error) { + invites := make([]types.OrganizationInvitation, 0) + err := r.db.NewSelect().Model(&invites). + Where("email = ? AND status = ? AND expires_at > ?", email, types.OrganizationInvitationStatusPending, time.Now().UTC()). + Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationInvitation{}, nil + } + return invites, err +} + +func (r *BunOrganizationInvitationRepository) Update(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewUpdate().Model(invitation).WherePK().Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(invitation).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return invitation, nil +} + +func (r *BunOrganizationInvitationRepository) WithTx(tx bun.IDB) OrganizationInvitationRepository { + return &BunOrganizationInvitationRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_invitation_repository_test.go b/plugins/organizations/repositories/bun_organization_invitation_repository_test.go new file mode 100644 index 00000000..14de3521 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_invitation_repository_test.go @@ -0,0 +1,307 @@ +package repositories + +import ( + "context" + "database/sql" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func newOrganizationInvitationRepositoryTestDB(t *testing.T) *bun.DB { + t.Helper() + + sqlDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + t.Cleanup(func() { + _ = sqlDB.Close() + }) + + db := bun.NewDB(sqlDB, sqlitedialect.New()) + t.Cleanup(func() { + _ = db.Close() + }) + + _, err = db.ExecContext(context.Background(), ` + CREATE TABLE organization_invitations ( + id TEXT PRIMARY KEY, + email TEXT NOT NULL, + inviter_id TEXT NOT NULL, + organization_id TEXT NOT NULL, + role TEXT NOT NULL, + status TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + `) + require.NoError(t, err) + + return db +} + +func TestBunOrganizationInvitationRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + invitation *types.OrganizationInvitation + expectStatus types.OrganizationInvitationStatus + }{ + { + name: "pending", + invitation: &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }, + expectStatus: types.OrganizationInvitationStatusPending, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + + created, err := repo.Create(context.Background(), tt.invitation) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.expectStatus, created.Status) + require.Equal(t, tt.invitation.ID, created.ID) + }) + } +} + +func TestBunOrganizationInvitationRepository_GetByOrganizationIDAndEmail(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + email string + expectFound bool + }{ + {name: "found", organizationID: "org-1", email: "user@example.com", expectFound: true}, + {name: "not found", organizationID: "org-1", email: "missing@example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByOrganizationIDAndEmail(ctx, tt.organizationID, tt.email) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "inv-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationInvitationRepository_GetAllPendingByEmail(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + _, err = repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-2", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusAccepted, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + tests := []struct { + name string + email string + expectPending int + }{ + {name: "pending only", email: "user@example.com", expectPending: 1}, + {name: "missing", email: "missing@example.com", expectPending: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pending, err := repo.GetAllPendingByEmail(ctx, tt.email) + require.NoError(t, err) + require.Len(t, pending, tt.expectPending) + }) + } +} + +func TestBunOrganizationInvitationRepository_GetAllByOrganizationID(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + _, err = repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-2", + Email: "other@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + invitations, err := repo.GetAllByOrganizationID(ctx, "org-1") + require.NoError(t, err) + require.Len(t, invitations, 2) +} + +func TestBunOrganizationInvitationRepository_Update(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + created, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + created.Status = types.OrganizationInvitationStatusAccepted + updated, err := repo.Update(ctx, created) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, types.OrganizationInvitationStatusAccepted, updated.Status) +} + +func TestBunOrganizationInvitationRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + tests := []struct { + name string + invitationID string + expectFound bool + }{ + {name: "found", invitationID: "inv-1", expectFound: true}, + {name: "not found", invitationID: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.invitationID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "inv-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationInvitationRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + require.IsType(t, &BunOrganizationInvitationRepository{}, txRepo) + + created, err := txRepo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "inv-1", created.ID) +} diff --git a/plugins/organizations/repositories/bun_organization_member_repository.go b/plugins/organizations/repositories/bun_organization_member_repository.go new file mode 100644 index 00000000..124a5456 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_member_repository.go @@ -0,0 +1,91 @@ +package repositories + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationMemberRepository struct { + db bun.IDB +} + +func NewBunOrganizationMemberRepository(db bun.IDB) OrganizationMemberRepository { + return &BunOrganizationMemberRepository{db: db} +} + +func (r *BunOrganizationMemberRepository) Create(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(member).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(member).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return member, nil +} + +func (r *BunOrganizationMemberRepository) GetAllByOrganizationID(ctx context.Context, organizationID string, page int, limit int) ([]types.OrganizationMember, error) { + members := make([]types.OrganizationMember, 0) + err := r.db.NewSelect().Model(&members). + Where("organization_id = ?", organizationID). + Offset((page - 1) * limit).Limit(limit). + Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationMember{}, nil + } + return members, err +} + +func (r *BunOrganizationMemberRepository) GetByOrganizationIDAndUserID(ctx context.Context, organizationID string, userID string) (*types.OrganizationMember, error) { + member := new(types.OrganizationMember) + err := r.db.NewSelect().Model(member). + Where("organization_id = ? AND user_id = ?", organizationID, userID). + Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return member, err +} + +func (r *BunOrganizationMemberRepository) GetByID(ctx context.Context, memberID string) (*types.OrganizationMember, error) { + member := new(types.OrganizationMember) + err := r.db.NewSelect().Model(member).Where("id = ?", memberID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return member, err +} + +func (r *BunOrganizationMemberRepository) Update(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewUpdate().Model(member).WherePK().Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(member).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return member, nil +} + +func (r *BunOrganizationMemberRepository) Delete(ctx context.Context, memberID string) error { + _, err := r.db.NewDelete().Model(&types.OrganizationMember{}).Where("id = ?", memberID).Exec(ctx) + return err +} + +func (r *BunOrganizationMemberRepository) WithTx(tx bun.IDB) OrganizationMemberRepository { + return &BunOrganizationMemberRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_member_repository_test.go b/plugins/organizations/repositories/bun_organization_member_repository_test.go new file mode 100644 index 00000000..04c33b8c --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_member_repository_test.go @@ -0,0 +1,315 @@ +package repositories + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func newOrganizationMemberRepositoryTestDB(t *testing.T) *bun.DB { + t.Helper() + + sqlDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + t.Cleanup(func() { _ = sqlDB.Close() }) + + db := bun.NewDB(sqlDB, sqlitedialect.New()) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(context.Background(), ` + CREATE TABLE organization_members (id TEXT PRIMARY KEY, organization_id TEXT NOT NULL, user_id TEXT NOT NULL, role TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + `) + require.NoError(t, err) + + return db +} + +func TestBunOrganizationMemberRepository_CreateGetUpdateDelete(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + run func(*testing.T, OrganizationMemberRepository, context.Context, *types.OrganizationMember) + }{ + { + name: "get by id returns created member", + run: func(t *testing.T, orgMemberRepo OrganizationMemberRepository, ctx context.Context, created *types.OrganizationMember) { + t.Helper() + found, err := orgMemberRepo.GetByID(ctx, "mem-1") + require.NoError(t, err) + require.NotNil(t, found) + require.Equal(t, created.Role, found.Role) + }, + }, + { + name: "list by organization returns created member", + run: func(t *testing.T, orgMemberRepo OrganizationMemberRepository, ctx context.Context, created *types.OrganizationMember) { + t.Helper() + members, err := orgMemberRepo.GetAllByOrganizationID(ctx, "org-1", 1, 10) + require.NoError(t, err) + require.Len(t, members, 1) + require.Equal(t, created.ID, members[0].ID) + }, + }, + { + name: "update persists changed role", + run: func(t *testing.T, orgMemberRepo OrganizationMemberRepository, ctx context.Context, created *types.OrganizationMember) { + t.Helper() + created.Role = "admin" + updated, err := orgMemberRepo.Update(ctx, created) + require.NoError(t, err) + require.Equal(t, "admin", updated.Role) + }, + }, + { + name: "delete removes member", + run: func(t *testing.T, orgMemberRepo OrganizationMemberRepository, ctx context.Context, created *types.OrganizationMember) { + t.Helper() + require.NoError(t, orgMemberRepo.Delete(ctx, created.ID)) + found, err := orgMemberRepo.GetByID(ctx, created.ID) + require.NoError(t, err) + require.Nil(t, found) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + orgMemberRepo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + created, err := orgMemberRepo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}) + require.NoError(t, err) + require.NotNil(t, created) + + tt.run(t, orgMemberRepo, ctx, created) + }) + } +} + +func TestBunOrganizationMemberRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + member *types.OrganizationMember + }{ + {name: "member", member: &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}}, + {name: "admin", member: &types.OrganizationMember{ID: "mem-2", OrganizationID: "org-1", UserID: "user-2", Role: "admin"}}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := NewBunOrganizationMemberRepository(newOrganizationMemberRepositoryTestDB(t)) + created, err := repo.Create(context.Background(), tt.member) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.member.ID, created.ID) + require.Equal(t, tt.member.Role, created.Role) + }) + } +} + +func TestBunOrganizationMemberRepository_GetAllByOrganizationID(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + _, err = repo.Create(ctx, &types.OrganizationMember{ID: "mem-2", OrganizationID: "org-1", UserID: "user-2", Role: "admin"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + page int + limit int + expectCount int + }{ + {name: "first page", organizationID: "org-1", page: 1, limit: 10, expectCount: 2}, + {name: "empty result", organizationID: "org-2", page: 1, limit: 10, expectCount: 0}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + members, err := repo.GetAllByOrganizationID(ctx, tt.organizationID, tt.page, tt.limit) + require.NoError(t, err) + require.Len(t, members, tt.expectCount) + }) + } +} + +func TestBunOrganizationMemberRepository_GetByOrganizationIDAndUserID(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + userID string + expectFound bool + }{ + {name: "found", organizationID: "org-1", userID: "user-1", expectFound: true}, + {name: "not found", organizationID: "org-1", userID: "user-2"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByOrganizationIDAndUserID(ctx, tt.organizationID, tt.userID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "mem-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationMemberRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + + tests := []struct { + name string + memberID string + expectFound bool + }{ + {name: "found", memberID: "mem-1", expectFound: true}, + {name: "not found", memberID: "missing"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.memberID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "mem-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationMemberRepository_Update(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + created, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + + tests := []struct { + name string + run func(*testing.T, *types.OrganizationMember) + }{ + { + name: "change role", + run: func(t *testing.T, member *types.OrganizationMember) { + t.Helper() + member.Role = "admin" + updated, err := repo.Update(ctx, member) + require.NoError(t, err) + require.Equal(t, "admin", updated.Role) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.run(t, created) + }) + } +} + +func TestBunOrganizationMemberRepository_Delete(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + + tests := []struct { + name string + memberID string + }{ + {name: "delete existing", memberID: "mem-1"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.NoError(t, repo.Delete(ctx, tt.memberID)) + found, err := repo.GetByID(ctx, tt.memberID) + require.NoError(t, err) + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationMemberRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + require.IsType(t, &BunOrganizationMemberRepository{}, txRepo) + + created, err := txRepo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "mem-1", created.ID) +} diff --git a/plugins/organizations/repositories/bun_organization_repository.go b/plugins/organizations/repositories/bun_organization_repository.go new file mode 100644 index 00000000..59465835 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_repository.go @@ -0,0 +1,90 @@ +package repositories + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationRepository struct { + db bun.IDB +} + +func NewBunOrganizationRepository(db bun.IDB) OrganizationRepository { + return &BunOrganizationRepository{db: db} +} + +func (r *BunOrganizationRepository) Create(ctx context.Context, organization *types.Organization) (*types.Organization, error) { + if len(organization.Metadata) == 0 { + organization.Metadata = []byte("{}") + } + + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(organization).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(organization).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return organization, nil +} + +func (r *BunOrganizationRepository) GetByID(ctx context.Context, organizationID string) (*types.Organization, error) { + organization := new(types.Organization) + err := r.db.NewSelect().Model(organization).Where("id = ?", organizationID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return organization, err +} + +func (r *BunOrganizationRepository) GetBySlug(ctx context.Context, slug string) (*types.Organization, error) { + organization := new(types.Organization) + err := r.db.NewSelect().Model(organization).Where("slug = ?", slug).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return organization, err +} + +func (r *BunOrganizationRepository) GetAllByOwnerID(ctx context.Context, ownerID string) ([]types.Organization, error) { + organizations := make([]types.Organization, 0) + err := r.db.NewSelect().Model(&organizations).Where("owner_id = ?", ownerID).Scan(ctx) + if err == sql.ErrNoRows { + return []types.Organization{}, nil + } + return organizations, err +} + +func (r *BunOrganizationRepository) Update(ctx context.Context, organization *types.Organization) (*types.Organization, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewUpdate().Model(organization).WherePK().Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(organization).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return organization, nil +} + +func (r *BunOrganizationRepository) Delete(ctx context.Context, organizationID string) error { + _, err := r.db.NewDelete().Model(&types.Organization{}).Where("id = ?", organizationID).Exec(ctx) + return err +} + +func (r *BunOrganizationRepository) WithTx(tx bun.IDB) OrganizationRepository { + return &BunOrganizationRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_repository_test.go b/plugins/organizations/repositories/bun_organization_repository_test.go new file mode 100644 index 00000000..a4f07d7e --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_repository_test.go @@ -0,0 +1,257 @@ +package repositories + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func newOrganizationRepositoryTestDB(t *testing.T) *bun.DB { + t.Helper() + + sqlDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + t.Cleanup(func() { _ = sqlDB.Close() }) + + db := bun.NewDB(sqlDB, sqlitedialect.New()) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(context.Background(), ` + CREATE TABLE organizations (id TEXT PRIMARY KEY, owner_id TEXT NOT NULL, name TEXT NOT NULL, slug TEXT NOT NULL, logo TEXT, metadata TEXT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + `) + require.NoError(t, err) + + return db +} + +func TestBunOrganizationRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + organization *types.Organization + expectMetadata string + }{ + { + name: "defaults metadata", + organization: &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, + expectMetadata: "{}", + }, + { + name: "keeps provided metadata", + organization: &types.Organization{ID: "org-2", OwnerID: "user-1", Name: "Platform", Slug: "platform", Metadata: []byte(`{"tier":"core"}`)}, + expectMetadata: `{"tier":"core"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + + created, err := repo.Create(context.Background(), tt.organization) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.organization.ID, created.ID) + require.Equal(t, tt.expectMetadata, string(created.Metadata)) + }) + } +} + +func TestBunOrganizationRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc", Metadata: []byte(`{"tier":"core"}`)}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + expectFound bool + }{ + {name: "found", organizationID: "org-1", expectFound: true}, + {name: "not found", organizationID: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.organizationID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "org-1", found.ID) + require.Equal(t, "user-1", found.OwnerID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationRepository_GetBySlug(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc", Metadata: []byte(`{"tier":"core"}`)}) + require.NoError(t, err) + + tests := []struct { + name string + slug string + expectFound bool + }{ + {name: "found", slug: "acme-inc", expectFound: true}, + {name: "not found", slug: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetBySlug(ctx, tt.slug) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "org-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationRepository_GetAllByOwnerID(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}) + require.NoError(t, err) + _, err = repo.Create(ctx, &types.Organization{ID: "org-2", OwnerID: "user-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + ownerID string + expectCount int + }{ + {name: "found", ownerID: "user-1", expectCount: 2}, + {name: "empty", ownerID: "user-2", expectCount: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetAllByOwnerID(ctx, tt.ownerID) + require.NoError(t, err) + require.Len(t, found, tt.expectCount) + }) + } +} + +func TestBunOrganizationRepository_Update(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + created, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc", Metadata: []byte(`{"tier":"core"}`)}) + require.NoError(t, err) + + tests := []struct { + name string + run func(*testing.T, *types.Organization) + }{ + { + name: "update name and logo", + run: func(t *testing.T, organization *types.Organization) { + t.Helper() + organization.Name = "Acme Platform" + logo := new(string) + *logo = "http://example.com/logo.svg" + organization.Logo = logo + updated, err := repo.Update(ctx, organization) + require.NoError(t, err) + require.Equal(t, "Acme Platform", updated.Name) + require.NotNil(t, updated.Logo) + require.Equal(t, "http://example.com/logo.svg", *updated.Logo) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.run(t, created) + }) + } +} + +func TestBunOrganizationRepository_Delete(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + }{ + {name: "delete existing", organizationID: "org-1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.NoError(t, repo.Delete(ctx, tt.organizationID)) + found, err := repo.GetByID(ctx, tt.organizationID) + require.NoError(t, err) + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + + created, err := txRepo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "org-1", created.ID) + require.IsType(t, &BunOrganizationRepository{}, txRepo) +} diff --git a/plugins/organizations/repositories/bun_organization_team_member_repository.go b/plugins/organizations/repositories/bun_organization_team_member_repository.go new file mode 100644 index 00000000..c716c007 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_team_member_repository.go @@ -0,0 +1,75 @@ +package repositories + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationTeamMemberRepository struct { + db bun.IDB +} + +func NewBunOrganizationTeamMemberRepository(db bun.IDB) OrganizationTeamMemberRepository { + return &BunOrganizationTeamMemberRepository{db: db} +} + +func (r *BunOrganizationTeamMemberRepository) Create(ctx context.Context, teamMember *types.OrganizationTeamMember) (*types.OrganizationTeamMember, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(teamMember).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(teamMember).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return teamMember, nil +} + +func (r *BunOrganizationTeamMemberRepository) GetByID(ctx context.Context, teamMemberID string) (*types.OrganizationTeamMember, error) { + teamMember := new(types.OrganizationTeamMember) + err := r.db.NewSelect().Model(teamMember).Where("id = ?", teamMemberID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return teamMember, err +} + +func (r *BunOrganizationTeamMemberRepository) GetByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) (*types.OrganizationTeamMember, error) { + teamMember := new(types.OrganizationTeamMember) + err := r.db.NewSelect().Model(teamMember). + Where("team_id = ? AND member_id = ?", teamID, memberID). + Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return teamMember, err +} + +func (r *BunOrganizationTeamMemberRepository) GetAllByTeamID(ctx context.Context, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) { + teamMembers := make([]types.OrganizationTeamMember, 0) + err := r.db.NewSelect().Model(&teamMembers). + Where("team_id = ?", teamID). + Offset((page - 1) * limit).Limit(limit). + Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationTeamMember{}, nil + } + return teamMembers, err +} + +func (r *BunOrganizationTeamMemberRepository) DeleteByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) error { + _, err := r.db.NewDelete().Model(&types.OrganizationTeamMember{}).Where("team_id = ? AND member_id = ?", teamID, memberID).Exec(ctx) + return err +} + +func (r *BunOrganizationTeamMemberRepository) WithTx(tx bun.IDB) OrganizationTeamMemberRepository { + return &BunOrganizationTeamMemberRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_team_member_repository_test.go b/plugins/organizations/repositories/bun_organization_team_member_repository_test.go new file mode 100644 index 00000000..3f43206c --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_team_member_repository_test.go @@ -0,0 +1 @@ +package repositories diff --git a/plugins/organizations/repositories/bun_organization_team_repository.go b/plugins/organizations/repositories/bun_organization_team_repository.go new file mode 100644 index 00000000..4734f912 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_team_repository.go @@ -0,0 +1,92 @@ +package repositories + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationTeamRepository struct { + db bun.IDB +} + +func NewBunOrganizationTeamRepository(db bun.IDB) OrganizationTeamRepository { + return &BunOrganizationTeamRepository{db: db} +} + +func (r *BunOrganizationTeamRepository) Create(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) { + if len(team.Metadata) == 0 { + team.Metadata = []byte("{}") + } + + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(team).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(team).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return team, nil +} + +func (r *BunOrganizationTeamRepository) GetByID(ctx context.Context, teamID string) (*types.OrganizationTeam, error) { + team := new(types.OrganizationTeam) + err := r.db.NewSelect().Model(team).Where("id = ?", teamID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return team, err +} + +func (r *BunOrganizationTeamRepository) GetByOrganizationIDAndSlug(ctx context.Context, organizationID, slug string) (*types.OrganizationTeam, error) { + team := new(types.OrganizationTeam) + err := r.db.NewSelect().Model(team). + Where("organization_id = ? AND slug = ?", organizationID, slug). + Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return team, err +} + +func (r *BunOrganizationTeamRepository) GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationTeam, error) { + teams := make([]types.OrganizationTeam, 0) + err := r.db.NewSelect().Model(&teams).Where("organization_id = ?", organizationID).Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationTeam{}, nil + } + return teams, err +} + +func (r *BunOrganizationTeamRepository) Update(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewUpdate().Model(team).WherePK().Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(team).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return team, nil +} + +func (r *BunOrganizationTeamRepository) Delete(ctx context.Context, teamID string) error { + _, err := r.db.NewDelete().Model(&types.OrganizationTeam{}).Where("id = ?", teamID).Exec(ctx) + return err +} + +func (r *BunOrganizationTeamRepository) WithTx(tx bun.IDB) OrganizationTeamRepository { + return &BunOrganizationTeamRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_team_repository_test.go b/plugins/organizations/repositories/bun_organization_team_repository_test.go new file mode 100644 index 00000000..f9b98f6c --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_team_repository_test.go @@ -0,0 +1,438 @@ +package repositories + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func newOrganizationTeamRepositoryTestDB(t *testing.T) *bun.DB { + t.Helper() + + sqlDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + t.Cleanup(func() { _ = sqlDB.Close() }) + + db := bun.NewDB(sqlDB, sqlitedialect.New()) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(context.Background(), ` + CREATE TABLE organizations (id TEXT PRIMARY KEY, owner_id TEXT NOT NULL, name TEXT NOT NULL, slug TEXT NOT NULL, logo TEXT, metadata TEXT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + CREATE TABLE organization_members (id TEXT PRIMARY KEY, organization_id TEXT NOT NULL, user_id TEXT NOT NULL, role TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + CREATE TABLE organization_teams (id TEXT PRIMARY KEY, organization_id TEXT NOT NULL, name TEXT NOT NULL, slug TEXT NOT NULL, description TEXT, metadata TEXT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + CREATE TABLE organization_team_members (id TEXT PRIMARY KEY, team_id TEXT NOT NULL, member_id TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + `) + require.NoError(t, err) + + return db +} + +func TestBunOrganizationTeamRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + team *types.OrganizationTeam + expect string + }{ + { + name: "defaults metadata", + team: &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, + expect: "{}", + }, + { + name: "keeps metadata and description", + team: func() *types.OrganizationTeam { + description := new(string) + *description = "Core" + return &types.OrganizationTeam{ID: "team-2", OrganizationID: "org-1", Name: "Core", Slug: "core", Description: description, Metadata: []byte(`{"tier":"core"}`)} + }(), + expect: `{"tier":"core"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := NewBunOrganizationTeamRepository(newOrganizationTeamRepositoryTestDB(t)) + created, err := repo.Create(context.Background(), tt.team) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.team.ID, created.ID) + require.Equal(t, tt.expect, string(created.Metadata)) + }) + } +} + +func TestBunOrganizationTeamRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + expectFound bool + }{ + {name: "found", teamID: "team-1", expectFound: true}, + {name: "not found", teamID: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.teamID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "team-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamRepository_GetByOrganizationIDAndSlug(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + slug string + expectFound bool + }{ + {name: "found", organizationID: "org-1", slug: "platform", expectFound: true}, + {name: "not found", organizationID: "org-1", slug: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByOrganizationIDAndSlug(ctx, tt.organizationID, tt.slug) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "team-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamRepository_GetAllByOrganizationID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + _, err = repo.Create(ctx, &types.OrganizationTeam{ID: "team-2", OrganizationID: "org-1", Name: "Core", Slug: "core"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + expectCount int + }{ + {name: "found", organizationID: "org-1", expectCount: 2}, + {name: "empty", organizationID: "org-2", expectCount: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetAllByOrganizationID(ctx, tt.organizationID) + require.NoError(t, err) + require.Len(t, found, tt.expectCount) + }) + } +} + +func TestBunOrganizationTeamRepository_Update(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + created, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + run func(*testing.T, *types.OrganizationTeam) + }{ + { + name: "change name and description", + run: func(t *testing.T, team *types.OrganizationTeam) { + t.Helper() + team.Name = "Platform Team" + description := new(string) + *description = "Core platform" + team.Description = description + updated, err := repo.Update(ctx, team) + require.NoError(t, err) + require.Equal(t, "Platform Team", updated.Name) + require.NotNil(t, updated.Description) + require.Equal(t, "Core platform", *updated.Description) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.run(t, created) + }) + } +} + +func TestBunOrganizationTeamRepository_Delete(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + }{ + {name: "delete existing", teamID: "team-1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.NoError(t, repo.Delete(ctx, tt.teamID)) + found, err := repo.GetByID(ctx, tt.teamID) + require.NoError(t, err) + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + require.IsType(t, &BunOrganizationTeamRepository{}, txRepo) + + created, err := txRepo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "team-1", created.ID) +} + +func TestBunOrganizationTeamMemberRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + teamMember *types.OrganizationTeamMember + }{ + {name: "team member", teamMember: &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}}, + {name: "another team member", teamMember: &types.OrganizationTeamMember{ID: "team-member-2", TeamID: "team-1", MemberID: "member-2"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := NewBunOrganizationTeamMemberRepository(newOrganizationTeamRepositoryTestDB(t)) + created, err := repo.Create(context.Background(), tt.teamMember) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.teamMember.ID, created.ID) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + + tests := []struct { + name string + teamMemberID string + expectFound bool + }{ + {name: "found", teamMemberID: "team-member-1", expectFound: true}, + {name: "not found", teamMemberID: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.teamMemberID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "team-member-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_GetByTeamIDAndMemberID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + memberID string + expectFound bool + }{ + {name: "found", teamID: "team-1", memberID: "member-1", expectFound: true}, + {name: "not found", teamID: "team-1", memberID: "member-2"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByTeamIDAndMemberID(ctx, tt.teamID, tt.memberID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "team-member-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_GetAllByTeamID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + _, err = repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-2", TeamID: "team-1", MemberID: "member-2"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + page int + limit int + expectCount int + }{ + {name: "found", teamID: "team-1", page: 1, limit: 10, expectCount: 2}, + {name: "empty", teamID: "team-2", page: 1, limit: 10, expectCount: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetAllByTeamID(ctx, tt.teamID, tt.page, tt.limit) + require.NoError(t, err) + require.Len(t, found, tt.expectCount) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_DeleteByTeamIDAndMemberID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + memberID string + }{ + {name: "delete existing", teamID: "team-1", memberID: "member-1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.NoError(t, repo.DeleteByTeamIDAndMemberID(ctx, tt.teamID, tt.memberID)) + found, err := repo.GetByTeamIDAndMemberID(ctx, tt.teamID, tt.memberID) + require.NoError(t, err) + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + require.IsType(t, &BunOrganizationTeamMemberRepository{}, txRepo) + + created, err := txRepo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "team-member-1", created.ID) +} diff --git a/plugins/organizations/repositories/interfaces.go b/plugins/organizations/repositories/interfaces.go new file mode 100644 index 00000000..5cafa782 --- /dev/null +++ b/plugins/organizations/repositories/interfaces.go @@ -0,0 +1,58 @@ +package repositories + +import ( + "context" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type OrganizationRepository interface { + Create(ctx context.Context, organization *types.Organization) (*types.Organization, error) + GetByID(ctx context.Context, organizationID string) (*types.Organization, error) + GetBySlug(ctx context.Context, slug string) (*types.Organization, error) + GetAllByOwnerID(ctx context.Context, ownerID string) ([]types.Organization, error) + Update(ctx context.Context, organization *types.Organization) (*types.Organization, error) + Delete(ctx context.Context, organizationID string) error + WithTx(tx bun.IDB) OrganizationRepository +} + +type OrganizationInvitationRepository interface { + Create(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) + GetByID(ctx context.Context, invitationID string) (*types.OrganizationInvitation, error) + GetByOrganizationIDAndEmail(ctx context.Context, organizationID string, email string) (*types.OrganizationInvitation, error) + GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationInvitation, error) + GetAllPendingByEmail(ctx context.Context, email string) ([]types.OrganizationInvitation, error) + Update(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) + WithTx(tx bun.IDB) OrganizationInvitationRepository +} + +type OrganizationMemberRepository interface { + Create(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) + GetAllByOrganizationID(ctx context.Context, organizationID string, page int, limit int) ([]types.OrganizationMember, error) + GetByOrganizationIDAndUserID(ctx context.Context, organizationID string, userID string) (*types.OrganizationMember, error) + GetByID(ctx context.Context, memberID string) (*types.OrganizationMember, error) + Update(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) + Delete(ctx context.Context, memberID string) error + WithTx(tx bun.IDB) OrganizationMemberRepository +} + +type OrganizationTeamRepository interface { + Create(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) + GetByID(ctx context.Context, teamID string) (*types.OrganizationTeam, error) + GetByOrganizationIDAndSlug(ctx context.Context, organizationID, slug string) (*types.OrganizationTeam, error) + GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationTeam, error) + Update(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) + Delete(ctx context.Context, teamID string) error + WithTx(tx bun.IDB) OrganizationTeamRepository +} + +type OrganizationTeamMemberRepository interface { + Create(ctx context.Context, teamMember *types.OrganizationTeamMember) (*types.OrganizationTeamMember, error) + GetByID(ctx context.Context, teamMemberID string) (*types.OrganizationTeamMember, error) + GetByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) (*types.OrganizationTeamMember, error) + GetAllByTeamID(ctx context.Context, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) + DeleteByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) error + WithTx(tx bun.IDB) OrganizationTeamMemberRepository +} diff --git a/plugins/organizations/routes.go b/plugins/organizations/routes.go new file mode 100644 index 00000000..4db4515c --- /dev/null +++ b/plugins/organizations/routes.go @@ -0,0 +1,167 @@ +package organizations + +import ( + "net/http" + + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/handlers" +) + +func Routes(plugin *OrganizationsPlugin) []models.Route { + if plugin == nil || plugin.Api == nil { + return []models.Route{} + } + + createOrganizationHandler := &handlers.CreateOrganizationHandler{UseCase: plugin.organizationUseCase} + getAllOrganizationsByUserIDHandler := &handlers.GetAllOrganizationsByUserIDHandler{UseCase: plugin.organizationUseCase} + getOrganizationHandler := &handlers.GetOrganizationHandler{UseCase: plugin.organizationUseCase} + updateOrganizationHandler := &handlers.UpdateOrganizationHandler{UseCase: plugin.organizationUseCase} + deleteOrganizationHandler := &handlers.DeleteOrganizationHandler{UseCase: plugin.organizationUseCase} + createInvitationHandler := &handlers.CreateOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + getInvitationHandler := &handlers.GetOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + listInvitationsHandler := &handlers.GetAllOrganizationInvitationsHandler{UseCase: plugin.invitationUseCase} + revokeInvitationHandler := &handlers.RevokeOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + acceptInvitationHandler := &handlers.AcceptOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + rejectInvitationHandler := &handlers.RejectOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + addMemberHandler := &handlers.AddOrganizationMemberHandler{UseCase: plugin.memberUseCase} + getAllMembersHandler := &handlers.GetAllOrganizationMembersHandler{UseCase: plugin.memberUseCase} + getMemberHandler := &handlers.GetOrganizationMemberHandler{UseCase: plugin.memberUseCase} + updateMemberHandler := &handlers.UpdateOrganizationMemberHandler{UseCase: plugin.memberUseCase} + deleteMemberHandler := &handlers.DeleteOrganizationMemberHandler{UseCase: plugin.memberUseCase} + createTeamHandler := &handlers.CreateOrganizationTeamHandler{UseCase: plugin.teamUseCase} + getAllTeamsHandler := &handlers.GetAllOrganizationTeamsHandler{UseCase: plugin.teamUseCase} + updateTeamHandler := &handlers.UpdateOrganizationTeamHandler{UseCase: plugin.teamUseCase} + deleteTeamHandler := &handlers.DeleteOrganizationTeamHandler{UseCase: plugin.teamUseCase} + addTeamMemberHandler := &handlers.AddOrganizationTeamMemberHandler{UseCase: plugin.teamUseCase} + getTeamMemberHandler := &handlers.GetOrganizationTeamMemberHandler{UseCase: plugin.teamUseCase} + getAllTeamMembersHandler := &handlers.GetAllOrganizationTeamMembersHandler{UseCase: plugin.teamUseCase} + deleteTeamMemberHandler := &handlers.DeleteOrganizationTeamMemberHandler{UseCase: plugin.teamUseCase} + + return []models.Route{ + // Organizations + { + Method: http.MethodPost, + Path: "/organizations", + Handler: createOrganizationHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations", + Handler: getAllOrganizationsByUserIDHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}", + Handler: getOrganizationHandler.Handler(), + }, + { + Method: http.MethodPatch, + Path: "/organizations/{organization_id}", + Handler: updateOrganizationHandler.Handler(), + }, + { + Method: http.MethodDelete, + Path: "/organizations/{organization_id}", + Handler: deleteOrganizationHandler.Handler(), + }, + // Organization Members + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/members", + Handler: addMemberHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/members", + Handler: getAllMembersHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/members/{member_id}", + Handler: getMemberHandler.Handler(), + }, + { + Method: http.MethodPatch, + Path: "/organizations/{organization_id}/members/{member_id}", + Handler: updateMemberHandler.Handler(), + }, + { + Method: http.MethodDelete, + Path: "/organizations/{organization_id}/members/{member_id}", + Handler: deleteMemberHandler.Handler(), + }, + // Teams + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/teams", + Handler: createTeamHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/teams", + Handler: getAllTeamsHandler.Handler(), + }, + { + Method: http.MethodPatch, + Path: "/organizations/{organization_id}/teams/{team_id}", + Handler: updateTeamHandler.Handler(), + }, + { + Method: http.MethodDelete, + Path: "/organizations/{organization_id}/teams/{team_id}", + Handler: deleteTeamHandler.Handler(), + }, + // Team Members + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/teams/{team_id}/members", + Handler: addTeamMemberHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/teams/{team_id}/members/{member_id}", + Handler: getTeamMemberHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/teams/{team_id}/members", + Handler: getAllTeamMembersHandler.Handler(), + }, + { + Method: http.MethodDelete, + Path: "/organizations/{organization_id}/teams/{team_id}/members/{member_id}", + Handler: deleteTeamMemberHandler.Handler(), + }, + // Invitations + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/invitations", + Handler: createInvitationHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/invitations", + Handler: listInvitationsHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/invitations/{invitation_id}", + Handler: getInvitationHandler.Handler(), + }, + { + Method: http.MethodPatch, + Path: "/organizations/{organization_id}/invitations/{invitation_id}", + Handler: revokeInvitationHandler.Handler(), + }, + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/invitations/{invitation_id}/accept", + Handler: acceptInvitationHandler.Handler(), + }, + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/invitations/{invitation_id}/reject", + Handler: rejectInvitationHandler.Handler(), + }, + } +} diff --git a/plugins/organizations/services/organization_invitation_service.go b/plugins/organizations/services/organization_invitation_service.go new file mode 100644 index 00000000..465642c0 --- /dev/null +++ b/plugins/organizations/services/organization_invitation_service.go @@ -0,0 +1,572 @@ +package services + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "html" + "net/mail" + "net/url" + "strings" + "time" + + "github.com/uptrace/bun" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + orgconstants "github.com/Authula/authula/plugins/organizations/constants" + orgevents "github.com/Authula/authula/plugins/organizations/events" + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/types" + rootservices "github.com/Authula/authula/services" +) + +type OrganizationInvitationHookExecutor interface { + BeforeCreateOrganizationInvitation(invitation *types.OrganizationInvitation) error + AfterCreateOrganizationInvitation(invitation types.OrganizationInvitation) error + BeforeUpdateOrganizationInvitation(invitation *types.OrganizationInvitation) error + AfterUpdateOrganizationInvitation(invitation types.OrganizationInvitation) error +} + +type organizationInvitationTxRunner interface { + RunInTx(ctx context.Context, opts *sql.TxOptions, fn func(context.Context, bun.Tx) error) error +} + +type OrganizationInvitationService struct { + txRunner organizationInvitationTxRunner + globalConfig *models.Config + pluginConfig *types.OrganizationsPluginConfig + logger models.Logger + mailerService rootservices.MailerService + eventBus models.EventBus + userService rootservices.UserService + organizationRepo repositories.OrganizationRepository + orgInvitationRepo repositories.OrganizationInvitationRepository + orgMemberRepo repositories.OrganizationMemberRepository + orgInvitationHooks OrganizationInvitationHookExecutor + orgMemberHooks OrganizationMemberHookExecutor +} + +func NewOrganizationInvitationService( + txRunner organizationInvitationTxRunner, + globalConfig *models.Config, + pluginConfig *types.OrganizationsPluginConfig, + logger models.Logger, + userService rootservices.UserService, + organizationRepo repositories.OrganizationRepository, + orgInvitationRepo repositories.OrganizationInvitationRepository, + orgMemberRepo repositories.OrganizationMemberRepository, + mailerService rootservices.MailerService, + eventBus models.EventBus, + orgInvitationHooks OrganizationInvitationHookExecutor, + orgMemberHooks OrganizationMemberHookExecutor, +) *OrganizationInvitationService { + return &OrganizationInvitationService{ + globalConfig: globalConfig, + pluginConfig: pluginConfig, + logger: logger, + mailerService: mailerService, + eventBus: eventBus, + userService: userService, + organizationRepo: organizationRepo, + orgInvitationRepo: orgInvitationRepo, + orgMemberRepo: orgMemberRepo, + orgInvitationHooks: orgInvitationHooks, + orgMemberHooks: orgMemberHooks, + txRunner: txRunner, + } +} + +func (s *OrganizationInvitationService) CreateOrganizationInvitation(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationInvitationRequest) (*types.OrganizationInvitation, error) { + if actorUserID == "" || organizationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + organization, err := s.organizationRepo.GetByID(ctx, organizationID) + if err != nil { + return nil, err + } + if organization == nil { + return nil, internalerrors.ErrNotFound + } + if organization.OwnerID != actorUserID { + return nil, internalerrors.ErrForbidden + } + + email := strings.ToLower(strings.TrimSpace(request.Email)) + if email == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, internalerrors.ErrUnprocessableEntity + } + + role := strings.TrimSpace(request.Role) + if role == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + if existing, err := s.orgInvitationRepo.GetByOrganizationIDAndEmail(ctx, organizationID, email); err != nil { + return nil, err + } else if existing != nil { + if err := s.expireOrganizationInvitationIfNeeded(ctx, existing); err != nil { + return nil, err + } + if existing.Status == types.OrganizationInvitationStatusPending { + return nil, internalerrors.ErrConflict + } + } + + expiresAt := time.Now().UTC().Add(s.pluginConfig.InvitationExpiresIn) + if !expiresAt.After(time.Now().UTC()) { + return nil, internalerrors.ErrUnprocessableEntity + } + invitation := &types.OrganizationInvitation{ + ID: util.GenerateUUID(), + Email: email, + InviterID: actorUserID, + OrganizationID: organizationID, + Role: role, + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: expiresAt, + } + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.BeforeCreateOrganizationInvitation(invitation); err != nil { + return nil, err + } + } + + created, err := s.orgInvitationRepo.Create(ctx, invitation) + if err != nil { + return nil, err + } + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.AfterCreateOrganizationInvitation(*created); err != nil { + return nil, err + } + } + + s.publishOrganizationInvitationCreatedEvent(created, organization, request.RedirectURL) + s.sendOrganizationInvitationEmailAsync(ctx, created, organization, request.RedirectURL) + + return created, nil +} + +func (s *OrganizationInvitationService) publishOrganizationInvitationCreatedEvent(invitation *types.OrganizationInvitation, organization *types.Organization, redirectURL string) { + payload, err := json.Marshal(orgevents.OrganizationInvitationCreatedEvent{ + ID: util.GenerateUUID(), + InvitationID: invitation.ID, + OrganizationID: invitation.OrganizationID, + OrganizationName: organization.Name, + InviteeEmail: invitation.Email, + InviterID: invitation.InviterID, + Role: invitation.Role, + ExpiresAt: invitation.ExpiresAt, + RedirectURL: redirectURL, + }) + if err != nil { + s.logger.Error("failed to marshal organization invitation created event", "error", err) + return + } + + util.PublishEventAsync(s.eventBus, s.logger, models.Event{ + ID: util.GenerateUUID(), + Type: orgconstants.EventOrganizationsInvitationCreated, + Timestamp: time.Now().UTC(), + Payload: payload, + }) +} + +func (s *OrganizationInvitationService) sendOrganizationInvitationEmailAsync(ctx context.Context, invitation *types.OrganizationInvitation, organization *types.Organization, redirectURL string) { + if s.mailerService == nil { + s.logger.Warn("mailer service not available, skipping organization invitation email") + return + } + + go func() { + detachedCtx := context.WithoutCancel(ctx) + taskCtx, cancel := context.WithTimeout(detachedCtx, 15*time.Second) + defer cancel() + + if err := s.sendOrganizationInvitationEmail(taskCtx, invitation, organization, redirectURL); err != nil { + s.logger.Error("failed to send organization invitation email", "invitation_id", invitation.ID, "error", err) + } + }() +} + +func (s *OrganizationInvitationService) sendOrganizationInvitationEmail(ctx context.Context, invitation *types.OrganizationInvitation, organization *types.Organization, redirectURL string) error { + acceptURL := s.buildOrganizationInvitationAcceptURL(invitation, redirectURL) + appName := "Authula" + if s.globalConfig.AppName != "" { + appName = s.globalConfig.AppName + } + subject := fmt.Sprintf("You're invited to join %s on %s", organization.Name, appName) + textBody := fmt.Sprintf("You have been invited to join %s on %s as %s. Open this link to accept the invitation: %s", organization.Name, appName, invitation.Role, acceptURL) + htmlBody := fmt.Sprintf(` +
+

Hello,

+

You have been invited to join %s on %s as %s.

+

Accept invitation

+

If the button does not work, copy this link:

+

%s

+
`, + html.EscapeString(organization.Name), + html.EscapeString(appName), + html.EscapeString(invitation.Role), + html.EscapeString(acceptURL), + html.EscapeString(acceptURL), + html.EscapeString(acceptURL), + ) + + return s.mailerService.SendEmail(ctx, invitation.Email, subject, textBody, htmlBody) +} + +func (s *OrganizationInvitationService) buildOrganizationInvitationAcceptURL(invitation *types.OrganizationInvitation, redirectURL string) string { + baseURL := s.globalConfig.BaseURL + basePath := s.globalConfig.BasePath + acceptPath := fmt.Sprintf("/organizations/%s/invitations/%s/accept", url.PathEscape(invitation.OrganizationID), url.PathEscape(invitation.ID)) + + fullURL := baseURL + basePath + acceptPath + parsedURL, err := url.Parse(fullURL) + if err != nil { + return fullURL + } + + if redirectURL != "" { + query := parsedURL.Query() + query.Set("redirect_url", redirectURL) + parsedURL.RawQuery = query.Encode() + } + + return parsedURL.String() +} + +func (s *OrganizationInvitationService) GetOrganizationInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + if actorUserID == "" || organizationID == "" || invitationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + organization, err := s.organizationRepo.GetByID(ctx, organizationID) + if err != nil { + return nil, err + } + if organization == nil { + return nil, internalerrors.ErrNotFound + } + if organization.OwnerID != actorUserID { + return nil, internalerrors.ErrForbidden + } + + invitation, err := s.orgInvitationRepo.GetByID(ctx, invitationID) + if err != nil { + return nil, err + } + if invitation == nil || invitation.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + if err := s.expireOrganizationInvitationIfNeeded(ctx, invitation); err != nil { + return nil, err + } + + return invitation, nil +} + +func (s *OrganizationInvitationService) GetAllOrganizationInvitations(ctx context.Context, actorUserID string, organizationID string) ([]types.OrganizationInvitation, error) { + if actorUserID == "" || organizationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + organization, err := s.organizationRepo.GetByID(ctx, organizationID) + if err != nil { + return nil, err + } + if organization == nil { + return nil, internalerrors.ErrNotFound + } + if organization.OwnerID != actorUserID { + return nil, internalerrors.ErrForbidden + } + + invitations, err := s.orgInvitationRepo.GetAllByOrganizationID(ctx, organizationID) + if err != nil { + return nil, err + } + + for index := range invitations { + if err := s.expireOrganizationInvitationIfNeeded(ctx, &invitations[index]); err != nil { + return nil, err + } + } + + return invitations, nil +} + +func (s *OrganizationInvitationService) RevokeOrganizationInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + if actorUserID == "" || organizationID == "" || invitationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + organization, err := s.organizationRepo.GetByID(ctx, organizationID) + if err != nil { + return nil, err + } + if organization == nil { + return nil, internalerrors.ErrNotFound + } + if organization.OwnerID != actorUserID { + return nil, internalerrors.ErrForbidden + } + + invitation, err := s.orgInvitationRepo.GetByID(ctx, invitationID) + if err != nil { + return nil, err + } + if invitation == nil || invitation.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + if err := s.expireOrganizationInvitationIfNeeded(ctx, invitation); err != nil { + return nil, err + } + if invitation.Status != types.OrganizationInvitationStatusPending { + return nil, internalerrors.ErrConflict + } + + invitation.Status = types.OrganizationInvitationStatusRevoked + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.BeforeUpdateOrganizationInvitation(invitation); err != nil { + return nil, err + } + } + + updated, err := s.orgInvitationRepo.Update(ctx, invitation) + if err != nil { + return nil, err + } + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.AfterUpdateOrganizationInvitation(*updated); err != nil { + return nil, err + } + } + + return updated, nil +} + +func (s *OrganizationInvitationService) AcceptOrganizationInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + if actorUserID == "" || organizationID == "" || invitationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + user, err := s.userService.GetByID(ctx, actorUserID) + if err != nil { + return nil, err + } + if user == nil || strings.TrimSpace(user.Email) == "" { + return nil, internalerrors.ErrNotFound + } + + invitation, err := s.orgInvitationRepo.GetByID(ctx, invitationID) + if err != nil { + return nil, err + } + if invitation == nil || invitation.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + if err := s.expireOrganizationInvitationIfNeeded(ctx, invitation); err != nil { + return nil, err + } + if invitation.Status != types.OrganizationInvitationStatusPending { + return nil, internalerrors.ErrConflict + } + if !strings.EqualFold(invitation.Email, user.Email) { + return nil, internalerrors.ErrForbidden + } + + accepted, err := s.acceptOrganizationInvitations(ctx, actorUserID, []types.OrganizationInvitation{*invitation}) + if err != nil { + return nil, err + } + if len(accepted) == 0 { + return nil, internalerrors.ErrConflict + } + + return &accepted[0], nil +} + +func (s *OrganizationInvitationService) RejectOrganizationInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + if actorUserID == "" || organizationID == "" || invitationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + user, err := s.userService.GetByID(ctx, actorUserID) + if err != nil { + return nil, err + } + if user == nil || strings.TrimSpace(user.Email) == "" { + return nil, internalerrors.ErrNotFound + } + + invitation, err := s.orgInvitationRepo.GetByID(ctx, invitationID) + if err != nil { + return nil, err + } + if invitation == nil || invitation.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + if err := s.expireOrganizationInvitationIfNeeded(ctx, invitation); err != nil { + return nil, err + } + if invitation.Status != types.OrganizationInvitationStatusPending { + return nil, internalerrors.ErrConflict + } + if !strings.EqualFold(invitation.Email, user.Email) { + return nil, internalerrors.ErrForbidden + } + + invitation.Status = types.OrganizationInvitationStatusRejected + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.BeforeUpdateOrganizationInvitation(invitation); err != nil { + return nil, err + } + } + + updated, err := s.orgInvitationRepo.Update(ctx, invitation) + if err != nil { + return nil, err + } + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.AfterUpdateOrganizationInvitation(*updated); err != nil { + return nil, err + } + } + + return updated, nil +} + +func (s *OrganizationInvitationService) AcceptPendingOrganizationInvitationsForEmail(ctx context.Context, userID string, email string) ([]types.OrganizationInvitation, error) { + email = strings.ToLower(email) + if userID == "" || email == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + if _, err := mail.ParseAddress(email); err != nil { + return nil, internalerrors.ErrBadRequest + } + + pendingInvitations, err := s.orgInvitationRepo.GetAllPendingByEmail(ctx, email) + if err != nil { + return nil, err + } + if len(pendingInvitations) == 0 { + return []types.OrganizationInvitation{}, nil + } + + return s.acceptOrganizationInvitations(ctx, userID, pendingInvitations) +} + +func (s *OrganizationInvitationService) acceptOrganizationInvitations(ctx context.Context, userID string, invitations []types.OrganizationInvitation) ([]types.OrganizationInvitation, error) { + accepted := make([]types.OrganizationInvitation, 0, len(invitations)) + err := s.txRunner.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + invitationRepo := s.orgInvitationRepo.WithTx(tx) + memberRepo := s.orgMemberRepo.WithTx(tx) + + for _, pendingInvitation := range invitations { + invitation := pendingInvitation + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.BeforeUpdateOrganizationInvitation(&invitation); err != nil { + return err + } + } + + existingMember, err := memberRepo.GetByOrganizationIDAndUserID(ctx, invitation.OrganizationID, userID) + if err != nil { + return err + } + if existingMember == nil { + member := &types.OrganizationMember{ + ID: util.GenerateUUID(), + OrganizationID: invitation.OrganizationID, + UserID: userID, + Role: invitation.Role, + } + + if s.orgMemberHooks != nil { + if err := s.orgMemberHooks.BeforeCreateOrganizationMember(member); err != nil { + return err + } + } + + createdMember, err := memberRepo.Create(ctx, member) + if err != nil { + return err + } + + if s.orgMemberHooks != nil { + if err := s.orgMemberHooks.AfterCreateOrganizationMember(*createdMember); err != nil { + return err + } + } + } + + invitation.Status = types.OrganizationInvitationStatusAccepted + updatedInvitation, err := invitationRepo.Update(ctx, &invitation) + if err != nil { + return err + } + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.AfterUpdateOrganizationInvitation(*updatedInvitation); err != nil { + return err + } + } + + accepted = append(accepted, *updatedInvitation) + } + + return nil + }) + if err != nil { + return nil, err + } + + return accepted, nil +} + +func (s *OrganizationInvitationService) expireOrganizationInvitationIfNeeded(ctx context.Context, invitation *types.OrganizationInvitation) error { + if invitation.Status != types.OrganizationInvitationStatusPending { + return nil + } + if invitation.ExpiresAt.After(time.Now().UTC()) { + return nil + } + + invitation.Status = types.OrganizationInvitationStatusExpired + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.BeforeUpdateOrganizationInvitation(invitation); err != nil { + return err + } + } + + updated, err := s.orgInvitationRepo.Update(ctx, invitation) + if err != nil { + return err + } + *invitation = *updated + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.AfterUpdateOrganizationInvitation(*updated); err != nil { + return err + } + } + + return nil +} diff --git a/plugins/organizations/services/organization_invitation_service_test.go b/plugins/organizations/services/organization_invitation_service_test.go new file mode 100644 index 00000000..a786bd91 --- /dev/null +++ b/plugins/organizations/services/organization_invitation_service_test.go @@ -0,0 +1,703 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgconstants "github.com/Authula/authula/plugins/organizations/constants" + orgevents "github.com/Authula/authula/plugins/organizations/events" + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/types" + rootservices "github.com/Authula/authula/services" +) + +type testInvitationLogger struct { + mu sync.Mutex + warnings []string + errors []string +} + +func (l *testInvitationLogger) Debug(msg string, args ...any) {} +func (l *testInvitationLogger) Info(msg string, args ...any) {} +func (l *testInvitationLogger) Warn(msg string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.warnings = append(l.warnings, msg) +} +func (l *testInvitationLogger) Error(msg string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.errors = append(l.errors, msg) +} + +func newTestOrganizationInvitationService( + txRunner organizationInvitationTxRunner, + pluginConfig *types.OrganizationsPluginConfig, + userService rootservices.UserService, + orgRepo repositories.OrganizationRepository, + invRepo repositories.OrganizationInvitationRepository, + memberRepo repositories.OrganizationMemberRepository, + invHooks OrganizationInvitationHookExecutor, + memberHooks OrganizationMemberHookExecutor, +) *OrganizationInvitationService { + return NewOrganizationInvitationService( + txRunner, + &models.Config{BaseURL: "https://example.com", BasePath: "/auth"}, + pluginConfig, + &testInvitationLogger{}, + userService, + orgRepo, + invRepo, + memberRepo, + nil, + nil, + invHooks, + memberHooks, + ) +} + +type invitationEmailCall struct { + to string + subject string + text string + html string +} + +type capturingMailer struct { + called chan invitationEmailCall + err error +} + +func (m *capturingMailer) SendEmail(ctx context.Context, to string, subject string, text string, html string) error { + if m.called != nil { + m.called <- invitationEmailCall{to: to, subject: subject, text: text, html: html} + } + return m.err +} + +type capturingEventBus struct { + called chan models.Event + err error +} + +func (b *capturingEventBus) Publish(ctx context.Context, event models.Event) error { + if b.called != nil { + b.called <- event + } + return b.err +} + +func (b *capturingEventBus) Close() error { return nil } +func (b *capturingEventBus) Subscribe(topic string, handler models.EventHandler) (models.SubscriptionID, error) { + return 0, nil +} +func (b *capturingEventBus) Unsubscribe(topic string, id models.SubscriptionID) {} + +func TestOrganizationInvitationService_CreateOrganizationInvitation(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + afterErr := errors.New("after error") + + tests := []struct { + name string + actorUserID string + organizationID string + request types.CreateOrganizationInvitationRequest + invitationExpiresIn time.Duration + setup func(*mockOrganizationRepository, *mockOrganizationInvitationRepository, *testOrganizationInvitationHooks) + expectErr error + expectCalled bool + }{ + { + name: "unauthorized", + actorUserID: "", + organizationID: "org-1", + request: types.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}, + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "bad request empty email", + actorUserID: "user-1", + organizationID: "org-1", + request: types.CreateOrganizationInvitationRequest{Email: " ", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository, hooks *testOrganizationInvitationHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectErr: internalerrors.ErrUnprocessableEntity, + }, + { + name: "forbidden for non owner", + actorUserID: "user-1", + organizationID: "org-1", + request: types.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository, hooks *testOrganizationInvitationHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + invitationExpiresIn: 36 * time.Hour, + request: types.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository, hooks *testOrganizationInvitationHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + expectedExpiresAt := time.Now().UTC().Add(36 * time.Hour) + invRepo.On("Create", mock.Anything, mock.MatchedBy(func(inv *types.OrganizationInvitation) bool { + return inv != nil && inv.OrganizationID == "org-1" && inv.InviterID == "user-1" && inv.Email == "user@example.com" && inv.Role == "member" && inv.Status == types.OrganizationInvitationStatusPending && inv.ExpiresAt.After(expectedExpiresAt.Add(-2*time.Second)) && inv.ExpiresAt.Before(expectedExpiresAt.Add(2*time.Second)) + })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", InviterID: "user-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: expectedExpiresAt}, nil).Once() + hooks.Before = func(invitation *types.OrganizationInvitation) error { + require.Equal(t, "user@example.com", invitation.Email) + return nil + } + hooks.After = func(invitation types.OrganizationInvitation) error { + require.Equal(t, "inv-1", invitation.ID) + return nil + } + }, + expectCalled: true, + }, + { + name: "before hook blocks", + actorUserID: "user-1", + organizationID: "org-1", + request: types.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository, hooks *testOrganizationInvitationHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + hooks.Before = func(invitation *types.OrganizationInvitation) error { return someErr } + }, + expectErr: someErr, + }, + { + name: "after hook error", + actorUserID: "user-1", + organizationID: "org-1", + request: types.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository, hooks *testOrganizationInvitationHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + invRepo.On("Create", mock.Anything, mock.AnythingOfType("*types.OrganizationInvitation")).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", InviterID: "user-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + hooks.After = func(invitation types.OrganizationInvitation) error { return afterErr } + }, + expectErr: afterErr, + expectCalled: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pluginConfig := &types.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + } + if tt.invitationExpiresIn != 0 { + pluginConfig.InvitationExpiresIn = tt.invitationExpiresIn + } + orgRepo := &mockOrganizationRepository{} + orgInvitationRepo := &mockOrganizationInvitationRepository{} + orgInvitationHooks := &testOrganizationInvitationHooks{} + if tt.setup != nil { + tt.setup(orgRepo, orgInvitationRepo, orgInvitationHooks) + } + + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, orgInvitationRepo, nil, orgInvitationHooks, nil) + inv, err := svc.CreateOrganizationInvitation(context.Background(), tt.actorUserID, tt.organizationID, tt.request) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + if tt.expectCalled { + require.True(t, orgInvitationRepo.AssertExpectations(t)) + require.True(t, orgRepo.AssertExpectations(t)) + } + return + } + require.NoError(t, err) + require.NotNil(t, inv) + require.WithinDuration(t, time.Now().UTC().Add(tt.invitationExpiresIn), inv.ExpiresAt, 2*time.Second) + require.Equal(t, tt.expectCalled, orgInvitationRepo.AssertExpectations(t)) + require.Equal(t, tt.expectCalled, orgRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationInvitationService_GetOrganizationInvitation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + organizationID string + invitationID string + setup func(*mockOrganizationRepository, *mockOrganizationInvitationRepository) + expectErr error + expectID string + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + invitationID: "inv-1", + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + invRepo.On("GetByID", mock.Anything, "inv-1").Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectID: "inv-1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pluginConfig := &types.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + } + orgRepo := &mockOrganizationRepository{} + invRepo := &mockOrganizationInvitationRepository{} + if tt.setup != nil { + tt.setup(orgRepo, invRepo) + } + + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, invRepo, nil, nil, nil) + invitation, err := svc.GetOrganizationInvitation(context.Background(), tt.actorUserID, tt.organizationID, tt.invitationID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, invitation) + require.Equal(t, tt.expectID, invitation.ID) + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, invRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationInvitationService_GetAllOrganizationInvitations(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + organizationID string + setup func(*mockOrganizationRepository, *mockOrganizationInvitationRepository) + expectErr error + expectLen int + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + invRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return([]types.OrganizationInvitation{{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}}, nil).Once() + }, + expectLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pluginConfig := &types.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + } + orgRepo := &mockOrganizationRepository{} + invRepo := &mockOrganizationInvitationRepository{} + if tt.setup != nil { + tt.setup(orgRepo, invRepo) + } + + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, invRepo, nil, nil, nil) + invitations, err := svc.GetAllOrganizationInvitations(context.Background(), tt.actorUserID, tt.organizationID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.Len(t, invitations, tt.expectLen) + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, invRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationInvitationService_RevokeOrganizationInvitation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + organizationID string + invitationID string + setup func(*mockOrganizationRepository, *mockOrganizationInvitationRepository, *testOrganizationInvitationHooks) + expectErr error + expectStatus types.OrganizationInvitationStatus + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + invitationID: "inv-1", + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository, hooks *testOrganizationInvitationHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + invRepo.On("GetByID", mock.Anything, "inv-1").Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *types.OrganizationInvitation) bool { + return invitation != nil && invitation.Status == types.OrganizationInvitationStatusRevoked + })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Status: types.OrganizationInvitationStatusRevoked}, nil).Once() + hooks.BeforeUpdate = func(invitation *types.OrganizationInvitation) error { + require.Equal(t, types.OrganizationInvitationStatusRevoked, invitation.Status) + return nil + } + hooks.AfterUpdate = func(invitation types.OrganizationInvitation) error { + require.Equal(t, types.OrganizationInvitationStatusRevoked, invitation.Status) + return nil + } + }, + expectStatus: types.OrganizationInvitationStatusRevoked, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pluginConfig := &types.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + } + orgRepo := &mockOrganizationRepository{} + invRepo := &mockOrganizationInvitationRepository{} + hooks := &testOrganizationInvitationHooks{} + if tt.setup != nil { + tt.setup(orgRepo, invRepo, hooks) + } + + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, invRepo, nil, hooks, nil) + invitation, err := svc.RevokeOrganizationInvitation(context.Background(), tt.actorUserID, tt.organizationID, tt.invitationID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, invitation) + require.Equal(t, tt.expectStatus, invitation.Status) + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, invRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationInvitationService_AcceptPendingOrganizationInvitationsForEmail(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userID string + email string + setup func(*mockOrganizationRepository, *mockOrganizationInvitationRepository, *mockOrganizationMemberRepository, *testOrganizationInvitationHooks, *mockOrganizationMemberHooks) + expectErr error + expectLen int + }{ + { + name: "bad request invalid email", + userID: "user-2", + email: "not-an-email", + expectErr: internalerrors.ErrBadRequest, + }, + { + name: "success", + userID: "user-2", + email: "USER@EXAMPLE.COM", + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository, memberRepo *mockOrganizationMemberRepository, hooks *testOrganizationInvitationHooks, memberHooks *mockOrganizationMemberHooks) { + invRepo.On("GetAllPendingByEmail", mock.Anything, "user@example.com").Return([]types.OrganizationInvitation{{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}}, nil).Once() + memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(nil, nil).Once() + memberRepo.On("Create", mock.Anything, mock.MatchedBy(func(member *types.OrganizationMember) bool { + return member != nil && member.OrganizationID == "org-1" && member.UserID == "user-2" && member.Role == "member" + })).Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *types.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == types.OrganizationInvitationStatusAccepted + })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusAccepted}, nil).Once() + hooks.BeforeUpdate = func(invitation *types.OrganizationInvitation) error { + require.Equal(t, types.OrganizationInvitationStatusPending, invitation.Status) + return nil + } + hooks.AfterUpdate = func(invitation types.OrganizationInvitation) error { + require.Equal(t, types.OrganizationInvitationStatusAccepted, invitation.Status) + return nil + } + memberHooks.Before = func(member *types.OrganizationMember) error { + require.Equal(t, "org-1", member.OrganizationID) + require.Equal(t, "user-2", member.UserID) + return nil + } + memberHooks.After = func(member types.OrganizationMember) error { + require.Equal(t, "mem-1", member.ID) + return nil + } + }, + expectLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pluginConfig := &types.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + } + orgRepo := &mockOrganizationRepository{} + invRepo := &mockOrganizationInvitationRepository{} + memberRepo := &mockOrganizationMemberRepository{} + hooks := &testOrganizationInvitationHooks{} + memberHooks := &mockOrganizationMemberHooks{} + userSvc := &internaltests.MockUserService{} + if tt.setup != nil { + tt.setup(orgRepo, invRepo, memberRepo, hooks, memberHooks) + } + + txRunner := &mockOrganizationInvitationTxRunner{} + svc := newTestOrganizationInvitationService(txRunner, pluginConfig, userSvc, orgRepo, invRepo, memberRepo, hooks, memberHooks) + accepted, err := svc.AcceptPendingOrganizationInvitationsForEmail(context.Background(), tt.userID, tt.email) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorContains(t, err, tt.expectErr.Error()) + return + } + require.NoError(t, err) + require.Len(t, accepted, tt.expectLen) + }) + } +} + +func TestOrganizationInvitationService_AcceptOrganizationInvitation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + organization string + invitationID string + setup func(*internaltests.MockUserService, *mockOrganizationInvitationRepository, *mockOrganizationMemberRepository, *testOrganizationInvitationHooks, *mockOrganizationMemberHooks) + expectErr error + expectStatus types.OrganizationInvitationStatus + }{ + { + name: "success", + actorUserID: "user-2", + organization: "org-1", + invitationID: "inv-1", + setup: func(userSvc *internaltests.MockUserService, invRepo *mockOrganizationInvitationRepository, memberRepo *mockOrganizationMemberRepository, hooks *testOrganizationInvitationHooks, memberHooks *mockOrganizationMemberHooks) { + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user@example.com"}, nil).Once() + invRepo.On("GetByID", mock.Anything, "inv-1").Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(nil, nil).Once() + memberRepo.On("Create", mock.Anything, mock.AnythingOfType("*types.OrganizationMember")).Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *types.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == types.OrganizationInvitationStatusAccepted + })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusAccepted}, nil).Once() + hooks.BeforeUpdate = func(invitation *types.OrganizationInvitation) error { return nil } + hooks.AfterUpdate = func(invitation types.OrganizationInvitation) error { return nil } + memberHooks.Before = func(member *types.OrganizationMember) error { return nil } + memberHooks.After = func(member types.OrganizationMember) error { return nil } + }, + expectStatus: types.OrganizationInvitationStatusAccepted, + }, + { + name: "forbidden when emails differ", + actorUserID: "user-2", + organization: "org-1", + invitationID: "inv-1", + setup: func(userSvc *internaltests.MockUserService, invRepo *mockOrganizationInvitationRepository, memberRepo *mockOrganizationMemberRepository, hooks *testOrganizationInvitationHooks, memberHooks *mockOrganizationMemberHooks) { + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "other@example.com"}, nil).Once() + invRepo.On("GetByID", mock.Anything, "inv-1").Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pluginConfig := &types.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + } + orgRepo := &mockOrganizationRepository{} + invRepo := &mockOrganizationInvitationRepository{} + memberRepo := &mockOrganizationMemberRepository{} + hooks := &testOrganizationInvitationHooks{} + memberHooks := &mockOrganizationMemberHooks{} + userSvc := &internaltests.MockUserService{} + if tt.setup != nil { + tt.setup(userSvc, invRepo, memberRepo, hooks, memberHooks) + } + + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, userSvc, orgRepo, invRepo, memberRepo, hooks, memberHooks) + invitation, err := svc.AcceptOrganizationInvitation(context.Background(), tt.actorUserID, tt.organization, tt.invitationID) + if tt.expectErr != nil { + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, invitation) + require.Equal(t, tt.expectStatus, invitation.Status) + }) + } +} + +func TestOrganizationInvitationService_RejectOrganizationInvitation(t *testing.T) { + t.Parallel() + + pluginConfig := &types.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + } + userSvc := &internaltests.MockUserService{} + invRepo := &mockOrganizationInvitationRepository{} + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user@example.com"}, nil).Once() + invRepo.On("GetByID", mock.Anything, "inv-1").Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *types.OrganizationInvitation) bool { + return invitation != nil && invitation.Status == types.OrganizationInvitationStatusRejected + })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Status: types.OrganizationInvitationStatusRejected}, nil).Once() + + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, userSvc, &mockOrganizationRepository{}, invRepo, &mockOrganizationMemberRepository{}, &testOrganizationInvitationHooks{}, &mockOrganizationMemberHooks{}) + invitation, err := svc.RejectOrganizationInvitation(context.Background(), "user-2", "org-1", "inv-1") + require.NoError(t, err) + require.NotNil(t, invitation) + require.Equal(t, types.OrganizationInvitationStatusRejected, invitation.Status) +} + +func TestOrganizationInvitationService_GetOrganizationInvitation_ExpiresPendingInvitation(t *testing.T) { + t.Parallel() + + pluginConfig := &types.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + } + orgRepo := &mockOrganizationRepository{} + invRepo := &mockOrganizationInvitationRepository{} + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + invRepo.On("GetByID", mock.Anything, "inv-1").Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(-time.Hour)}, nil).Once() + invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *types.OrganizationInvitation) bool { + return invitation != nil && invitation.Status == types.OrganizationInvitationStatusExpired + })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Status: types.OrganizationInvitationStatusExpired}, nil).Once() + + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, invRepo, nil, &testOrganizationInvitationHooks{}, nil) + invitation, err := svc.GetOrganizationInvitation(context.Background(), "user-1", "org-1", "inv-1") + require.NoError(t, err) + require.NotNil(t, invitation) + require.Equal(t, types.OrganizationInvitationStatusExpired, invitation.Status) +} + +func TestOrganizationInvitationService_CreateOrganizationInvitation_SendsEmailAndPublishesEvent(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + invRepo := &mockOrganizationInvitationRepository{} + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme"}, nil).Once() + invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + invRepo.On("Create", mock.Anything, mock.MatchedBy(func(inv *types.OrganizationInvitation) bool { + return inv != nil && inv.OrganizationID == "org-1" && inv.InviterID == "user-1" && inv.Email == "user@example.com" && inv.Role == "member" && inv.Status == types.OrganizationInvitationStatusPending + })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", InviterID: "user-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(36 * time.Hour)}, nil).Once() + + mailerCalls := make(chan invitationEmailCall, 1) + eventCalls := make(chan models.Event, 1) + mailer := &capturingMailer{called: mailerCalls} + eventBus := &capturingEventBus{called: eventCalls} + logger := &testInvitationLogger{} + + svc := NewOrganizationInvitationService( + &mockOrganizationInvitationTxRunner{}, + &models.Config{BaseURL: "https://example.com", BasePath: "/auth"}, + &types.OrganizationsPluginConfig{Enabled: true, InvitationExpiresIn: 36 * time.Hour}, + logger, + &internaltests.MockUserService{}, + orgRepo, + invRepo, + nil, + mailer, + eventBus, + nil, + nil, + ) + + invitation, err := svc.CreateOrganizationInvitation(context.Background(), "user-1", "org-1", types.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member", RedirectURL: "https://app.example.com/welcome"}) + require.NoError(t, err) + require.NotNil(t, invitation) + + require.Eventually(t, func() bool { return len(mailerCalls) > 0 }, time.Second, 10*time.Millisecond) + require.Eventually(t, func() bool { return len(eventCalls) > 0 }, time.Second, 10*time.Millisecond) + + mailCall := <-mailerCalls + require.Equal(t, "user@example.com", mailCall.to) + require.Contains(t, mailCall.text, "https://example.com/auth/organizations/org-1/invitations/inv-1/accept?redirect_url=https%3A%2F%2Fapp.example.com%2Fwelcome") + require.Contains(t, mailCall.html, "Accept invitation") + + event := <-eventCalls + require.Equal(t, orgconstants.EventOrganizationsInvitationCreated, event.Type) + + var payload orgevents.OrganizationInvitationCreatedEvent + require.NoError(t, json.Unmarshal(event.Payload, &payload)) + require.Equal(t, "inv-1", payload.InvitationID) + require.Equal(t, "org-1", payload.OrganizationID) + require.Equal(t, "Acme", payload.OrganizationName) + require.Equal(t, "user@example.com", payload.InviteeEmail) + require.Equal(t, "https://app.example.com/welcome", payload.RedirectURL) + require.Empty(t, logger.warnings) +} + +func TestOrganizationInvitationService_CreateOrganizationInvitation_SkipsMissingMailer(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + invRepo := &mockOrganizationInvitationRepository{} + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme"}, nil).Once() + invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + invRepo.On("Create", mock.Anything, mock.AnythingOfType("*types.OrganizationInvitation")).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", InviterID: "user-1", Email: "user@example.com", Role: "member", Status: types.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(36 * time.Hour)}, nil).Once() + + logger := &testInvitationLogger{} + svc := NewOrganizationInvitationService( + &mockOrganizationInvitationTxRunner{}, + &models.Config{BaseURL: "https://example.com", BasePath: "/auth"}, + &types.OrganizationsPluginConfig{Enabled: true, InvitationExpiresIn: 36 * time.Hour}, + logger, + &internaltests.MockUserService{}, + orgRepo, + invRepo, + nil, + nil, + nil, + nil, + nil, + ) + + _, err := svc.CreateOrganizationInvitation(context.Background(), "user-1", "org-1", types.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}) + require.NoError(t, err) + require.Eventually(t, func() bool { return len(logger.warnings) > 0 }, time.Second, 10*time.Millisecond) + require.Contains(t, logger.warnings[0], "mailer service not available") +} diff --git a/plugins/organizations/services/organization_member_service.go b/plugins/organizations/services/organization_member_service.go new file mode 100644 index 00000000..fcc42423 --- /dev/null +++ b/plugins/organizations/services/organization_member_service.go @@ -0,0 +1,209 @@ +package services + +import ( + "context" + "strings" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/types" + rootservices "github.com/Authula/authula/services" +) + +type OrganizationMemberHookExecutor interface { + BeforeCreateOrganizationMember(member *types.OrganizationMember) error + AfterCreateOrganizationMember(member types.OrganizationMember) error + BeforeUpdateOrganizationMember(member *types.OrganizationMember) error + AfterUpdateOrganizationMember(member types.OrganizationMember) error + BeforeDeleteOrganizationMember(member *types.OrganizationMember) error + AfterDeleteOrganizationMember(member types.OrganizationMember) error +} + +type OrganizationMemberService struct { + userService rootservices.UserService + orgRepo repositories.OrganizationRepository + orgMemberRepo repositories.OrganizationMemberRepository + hooks OrganizationMemberHookExecutor +} + +func NewOrganizationMemberService(userService rootservices.UserService, orgRepo repositories.OrganizationRepository, orgMemberRepo repositories.OrganizationMemberRepository, hooks OrganizationMemberHookExecutor) *OrganizationMemberService { + return &OrganizationMemberService{userService: userService, orgRepo: orgRepo, orgMemberRepo: orgMemberRepo, hooks: hooks} +} + +func (s *OrganizationMemberService) AddMember(ctx context.Context, actorUserID string, organizationID string, request types.AddOrganizationMemberRequest) (*types.OrganizationMember, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + userID := strings.TrimSpace(request.UserID) + if userID == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + role := strings.TrimSpace(request.Role) + if role == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + user, err := s.userService.GetByID(ctx, userID) + if err != nil { + return nil, err + } + if user == nil { + return nil, internalerrors.ErrNotFound + } + + if existing, err := s.orgMemberRepo.GetByOrganizationIDAndUserID(ctx, organizationID, userID); err != nil { + return nil, err + } else if existing != nil { + return nil, internalerrors.ErrConflict + } + + member := &types.OrganizationMember{ + ID: util.GenerateUUID(), + OrganizationID: organizationID, + UserID: userID, + Role: role, + } + + if s.hooks != nil { + if err := s.hooks.BeforeCreateOrganizationMember(member); err != nil { + return nil, err + } + } + + created, err := s.orgMemberRepo.Create(ctx, member) + if err != nil { + return nil, err + } + + if s.hooks != nil { + if err := s.hooks.AfterCreateOrganizationMember(*created); err != nil { + return nil, err + } + } + + return created, nil +} + +func (s *OrganizationMemberService) GetAllMembers(ctx context.Context, actorUserID string, organizationID string, page int, limit int) ([]types.OrganizationMember, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + return s.orgMemberRepo.GetAllByOrganizationID(ctx, organizationID, page, limit) +} + +func (s *OrganizationMemberService) GetMember(ctx context.Context, actorUserID string, organizationID string, memberID string) (*types.OrganizationMember, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + if memberID == "" { + return nil, internalerrors.ErrUnauthorized + } + + member, err := s.orgMemberRepo.GetByID(ctx, strings.TrimSpace(memberID)) + if err != nil { + return nil, err + } + if member == nil || member.OrganizationID != strings.TrimSpace(organizationID) { + return nil, internalerrors.ErrNotFound + } + + return member, nil +} + +func (s *OrganizationMemberService) UpdateMember(ctx context.Context, actorUserID string, organizationID string, memberID string, request types.UpdateOrganizationMemberRequest) (*types.OrganizationMember, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + member, err := s.orgMemberRepo.GetByID(ctx, strings.TrimSpace(memberID)) + if err != nil { + return nil, err + } + if member == nil || member.OrganizationID != strings.TrimSpace(organizationID) { + return nil, internalerrors.ErrNotFound + } + + role := strings.TrimSpace(request.Role) + if role == "" { + return nil, internalerrors.ErrBadRequest + } + + member.Role = role + + if s.hooks != nil { + if err := s.hooks.BeforeUpdateOrganizationMember(member); err != nil { + return nil, err + } + } + + updated, err := s.orgMemberRepo.Update(ctx, member) + if err != nil { + return nil, err + } + + if s.hooks != nil { + if err := s.hooks.AfterUpdateOrganizationMember(*updated); err != nil { + return nil, err + } + } + + return updated, nil +} + +func (s *OrganizationMemberService) RemoveMember(ctx context.Context, actorUserID string, organizationID string, memberID string) error { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return err + } + + member, err := s.orgMemberRepo.GetByID(ctx, strings.TrimSpace(memberID)) + if err != nil { + return err + } + if member == nil || member.OrganizationID != strings.TrimSpace(organizationID) { + return internalerrors.ErrNotFound + } + + if s.hooks != nil { + if err := s.hooks.BeforeDeleteOrganizationMember(member); err != nil { + return err + } + } + + if err := s.orgMemberRepo.Delete(ctx, member.ID); err != nil { + return err + } + + if s.hooks != nil { + if err := s.hooks.AfterDeleteOrganizationMember(*member); err != nil { + return err + } + } + + return nil +} + +func (s *OrganizationMemberService) authorizeOrganizationOwner(ctx context.Context, actorUserID string, organizationID string) (*types.Organization, error) { + actorUserID = strings.TrimSpace(actorUserID) + organizationID = strings.TrimSpace(organizationID) + if actorUserID == "" || organizationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + organization, err := s.orgRepo.GetByID(ctx, organizationID) + if err != nil { + return nil, err + } + if organization == nil { + return nil, internalerrors.ErrNotFound + } + if strings.TrimSpace(organization.OwnerID) != actorUserID { + return nil, internalerrors.ErrForbidden + } + + return organization, nil +} diff --git a/plugins/organizations/services/organization_member_service_test.go b/plugins/organizations/services/organization_member_service_test.go new file mode 100644 index 00000000..c1cc1a7f --- /dev/null +++ b/plugins/organizations/services/organization_member_service_test.go @@ -0,0 +1,800 @@ +package services + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" +) + +func TestOrganizationMemberService_AddMember(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + afterErr := errors.New("after error") + repoErr := errors.New("repository error") + + tests := []struct { + name string + actorUserID string + organizationID string + request types.AddOrganizationMemberRequest + setup func(*mockOrganizationRepository, *mockOrganizationMemberRepository, *internaltests.MockUserService, *mockOrganizationMemberHooks) + expectErr error + }{ + { + name: "unauthorized", + actorUserID: "", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "organization not found", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "forbidden for non owner", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "bad request empty user id", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: " ", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectErr: internalerrors.ErrUnprocessableEntity, + }, + { + name: "bad request empty role", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: " "}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectErr: internalerrors.ErrUnprocessableEntity, + }, + { + name: "user lookup error", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + userSvc.On("GetByID", mock.Anything, "user-2").Return(nil, repoErr).Once() + }, + expectErr: repoErr, + }, + { + name: "user not found", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + userSvc.On("GetByID", mock.Anything, "user-2").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "existing member conflict", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user2@example.com"}, nil).Once() + memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectErr: internalerrors.ErrConflict, + }, + { + name: "lookup existing member error", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user2@example.com"}, nil).Once() + memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(nil, repoErr).Once() + }, + expectErr: repoErr, + }, + { + name: "before hook error", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user2@example.com"}, nil).Once() + memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(nil, nil).Once() + hooks.Before = func(member *types.OrganizationMember) error { return someErr } + }, + expectErr: someErr, + }, + { + name: "create error", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user2@example.com"}, nil).Once() + memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(nil, nil).Once() + memberRepo.On("Create", mock.Anything, mock.MatchedBy(func(member *types.OrganizationMember) bool { + return member != nil && member.OrganizationID == "org-1" && member.UserID == "user-2" && member.Role == "member" + })).Return(nil, repoErr).Once() + }, + expectErr: repoErr, + }, + { + name: "after hook error", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user2@example.com"}, nil).Once() + memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(nil, nil).Once() + memberRepo.On("Create", mock.Anything, mock.MatchedBy(func(member *types.OrganizationMember) bool { + return member != nil && member.OrganizationID == "org-1" && member.UserID == "user-2" && member.Role == "member" + })).Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + hooks.After = func(member types.OrganizationMember) error { return afterErr } + }, + expectErr: afterErr, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user2@example.com"}, nil).Once() + memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(nil, nil).Once() + memberRepo.On("Create", mock.Anything, mock.MatchedBy(func(member *types.OrganizationMember) bool { + return member != nil && member.OrganizationID == "org-1" && member.UserID == "user-2" && member.Role == "member" + })).Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + hooks.Before = func(member *types.OrganizationMember) error { return nil } + hooks.After = func(member types.OrganizationMember) error { return nil } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + memberRepo := &mockOrganizationMemberRepository{} + userSvc := &internaltests.MockUserService{} + hooks := &mockOrganizationMemberHooks{} + if tt.setup != nil { + tt.setup(orgRepo, memberRepo, userSvc, hooks) + } + + svc := NewOrganizationMemberService(userSvc, orgRepo, memberRepo, hooks) + member, err := svc.AddMember(context.Background(), tt.actorUserID, tt.organizationID, tt.request) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + require.True(t, userSvc.AssertExpectations(t)) + } + return + } + require.NoError(t, err) + require.NotNil(t, member) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + require.True(t, userSvc.AssertExpectations(t)) + } + }) + } +} + +func TestOrganizationMemberService_GetAllMembers(t *testing.T) { + t.Parallel() + + repoErr := errors.New("repository error") + + tests := []struct { + name string + actorUserID string + organizationID string + setup func(*mockOrganizationRepository, *mockOrganizationMemberRepository) + expectErr error + expectLen int + }{ + { + name: "unauthorized", + actorUserID: "", + organizationID: "org-1", + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "organization not found", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "forbidden", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "repository error", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetAllByOrganizationID", mock.Anything, "org-1", 1, 10).Return(nil, repoErr).Once() + }, + expectErr: repoErr, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetAllByOrganizationID", mock.Anything, "org-1", 1, 10).Return([]types.OrganizationMember{{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}}, nil).Once() + }, + expectLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + memberRepo := &mockOrganizationMemberRepository{} + if tt.setup != nil { + tt.setup(orgRepo, memberRepo) + } + + svc := NewOrganizationMemberService(&internaltests.MockUserService{}, orgRepo, memberRepo, nil) + members, err := svc.GetAllMembers(context.Background(), tt.actorUserID, tt.organizationID, 1, 10) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + } + return + } + require.NoError(t, err) + require.Len(t, members, tt.expectLen) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + } + }) + } +} + +func TestOrganizationMemberService_GetMember(t *testing.T) { + t.Parallel() + + repoErr := errors.New("repository error") + + tests := []struct { + name string + actorUserID string + organizationID string + memberID string + setup func(*mockOrganizationRepository, *mockOrganizationMemberRepository) + expectErr error + expectMemberID string + }{ + { + name: "unauthorized", + actorUserID: "", + organizationID: "org-1", + memberID: "mem-1", + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "member id empty", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "member id whitespace", + actorUserID: "user-1", + organizationID: "org-1", + memberID: " ", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "organization not found", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "forbidden", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "repository error", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(nil, repoErr).Once() + }, + expectErr: repoErr, + }, + { + name: "not found when member is missing", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "not found when member belongs to another organization", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-2"}, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectMemberID: "mem-1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + memberRepo := &mockOrganizationMemberRepository{} + if tt.setup != nil { + tt.setup(orgRepo, memberRepo) + } + + svc := NewOrganizationMemberService(&internaltests.MockUserService{}, orgRepo, memberRepo, nil) + member, err := svc.GetMember(context.Background(), tt.actorUserID, tt.organizationID, tt.memberID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + } + return + } + require.NoError(t, err) + require.NotNil(t, member) + require.Equal(t, tt.expectMemberID, member.ID) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + } + }) + } +} + +func TestOrganizationMemberService_UpdateMember(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + updateErr := errors.New("update error") + afterErr := errors.New("after error") + repoErr := errors.New("repository error") + + tests := []struct { + name string + actorUserID string + organizationID string + memberID string + request types.UpdateOrganizationMemberRequest + setup func(*mockOrganizationRepository, *mockOrganizationMemberRepository, *mockOrganizationMemberHooks) + expectErr error + expectRole string + }{ + { + name: "unauthorized", + actorUserID: "", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "organization not found", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "forbidden", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "repository error fetching member", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(nil, repoErr).Once() + }, + expectErr: repoErr, + }, + { + name: "not found when member is missing", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "not found when member belongs to another organization", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-2", Role: "member"}, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "bad request empty role", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: " "}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", Role: "member"}, nil).Once() + }, + expectErr: internalerrors.ErrBadRequest, + }, + { + name: "before update hook error", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + hooks.BeforeUpdate = func(member *types.OrganizationMember) error { return someErr } + }, + expectErr: someErr, + }, + { + name: "update error", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + memberRepo.On("Update", mock.Anything, mock.MatchedBy(func(member *types.OrganizationMember) bool { + return member != nil && member.ID == "mem-1" && member.Role == "admin" + })).Return(nil, updateErr).Once() + }, + expectErr: updateErr, + }, + { + name: "after update hook error", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + memberRepo.On("Update", mock.Anything, mock.MatchedBy(func(member *types.OrganizationMember) bool { + return member != nil && member.ID == "mem-1" && member.Role == "admin" + })).Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "admin"}, nil).Once() + hooks.AfterUpdate = func(member types.OrganizationMember) error { return afterErr } + }, + expectErr: afterErr, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "admin"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + memberRepo.On("Update", mock.Anything, mock.MatchedBy(func(member *types.OrganizationMember) bool { + return member != nil && member.ID == "mem-1" && member.Role == "admin" + })).Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "admin"}, nil).Once() + hooks.BeforeUpdate = func(member *types.OrganizationMember) error { return nil } + hooks.AfterUpdate = func(member types.OrganizationMember) error { return nil } + }, + expectRole: "admin", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + memberRepo := &mockOrganizationMemberRepository{} + hooks := &mockOrganizationMemberHooks{} + if tt.setup != nil { + tt.setup(orgRepo, memberRepo, hooks) + } + + svc := NewOrganizationMemberService(&internaltests.MockUserService{}, orgRepo, memberRepo, hooks) + member, err := svc.UpdateMember(context.Background(), tt.actorUserID, tt.organizationID, tt.memberID, tt.request) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + } + return + } + require.NoError(t, err) + require.NotNil(t, member) + require.Equal(t, tt.expectRole, member.Role) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + } + }) + } +} + +func TestOrganizationMemberService_RemoveMember(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + deleteErr := errors.New("delete error") + repoErr := errors.New("repository error") + + tests := []struct { + name string + actorUserID string + organizationID string + memberID string + setup func(*mockOrganizationRepository, *mockOrganizationMemberRepository, *mockOrganizationMemberHooks) + expectErr error + }{ + { + name: "unauthorized", + actorUserID: "", + organizationID: "org-1", + memberID: "mem-1", + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "organization not found", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "forbidden", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "repository error fetching member", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(nil, repoErr).Once() + }, + expectErr: repoErr, + }, + { + name: "not found when member is missing", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "not found when member belongs to another organization", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-2"}, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "before delete hook error", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + hooks.BeforeDelete = func(member *types.OrganizationMember) error { return someErr } + }, + expectErr: someErr, + }, + { + name: "delete error", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + memberRepo.On("Delete", mock.Anything, "mem-1").Return(deleteErr).Once() + }, + expectErr: deleteErr, + }, + { + name: "after delete hook error", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + memberRepo.On("Delete", mock.Anything, "mem-1").Return(nil).Once() + hooks.AfterDelete = func(member types.OrganizationMember) error { return deleteErr } + }, + expectErr: deleteErr, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + memberRepo.On("Delete", mock.Anything, "mem-1").Return(nil).Once() + hooks.BeforeDelete = func(member *types.OrganizationMember) error { return nil } + hooks.AfterDelete = func(member types.OrganizationMember) error { return nil } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + memberRepo := &mockOrganizationMemberRepository{} + hooks := &mockOrganizationMemberHooks{} + if tt.setup != nil { + tt.setup(orgRepo, memberRepo, hooks) + } + + svc := NewOrganizationMemberService(&internaltests.MockUserService{}, orgRepo, memberRepo, hooks) + err := svc.RemoveMember(context.Background(), tt.actorUserID, tt.organizationID, tt.memberID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + } + return + } + require.NoError(t, err) + if tt.setup != nil { + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, memberRepo.AssertExpectations(t)) + } + }) + } +} diff --git a/plugins/organizations/services/organization_service.go b/plugins/organizations/services/organization_service.go new file mode 100644 index 00000000..1d058e15 --- /dev/null +++ b/plugins/organizations/services/organization_service.go @@ -0,0 +1,222 @@ +package services + +import ( + "context" + "strings" + "unicode" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/types" +) + +type OrganizationService struct { + repo repositories.OrganizationRepository + hooks OrganizationHookExecutor +} + +type OrganizationHookExecutor interface { + BeforeCreateOrganization(organization *types.Organization) error + AfterCreateOrganization(organization types.Organization) error + BeforeUpdateOrganization(organization *types.Organization) error + AfterUpdateOrganization(organization types.Organization) error + BeforeDeleteOrganization(organization *types.Organization) error + AfterDeleteOrganization(organization types.Organization) error +} + +func NewOrganizationService(repo repositories.OrganizationRepository, hooks OrganizationHookExecutor) *OrganizationService { + return &OrganizationService{repo: repo, hooks: hooks} +} + +func (s *OrganizationService) CreateOrganization(ctx context.Context, actorUserID string, request types.CreateOrganizationRequest) (*types.Organization, error) { + actorUserID = strings.TrimSpace(actorUserID) + if actorUserID == "" { + return nil, internalerrors.ErrUnauthorized + } + + name := strings.TrimSpace(request.Name) + if name == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + slug := "" + if request.Slug != nil { + slug = strings.TrimSpace(*request.Slug) + } + if slug == "" { + slug = slugify(name) + } + if slug == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + organization := &types.Organization{ + ID: util.GenerateUUID(), + OwnerID: actorUserID, + Name: name, + Slug: slug, + Logo: request.Logo, + Metadata: request.Metadata, + } + if len(organization.Metadata) == 0 { + organization.Metadata = []byte("{}") + } + + if s.hooks != nil { + if err := s.hooks.BeforeCreateOrganization(organization); err != nil { + return nil, err + } + } + + created, err := s.repo.Create(ctx, organization) + if err != nil { + return nil, err + } + + if s.hooks != nil { + if err := s.hooks.AfterCreateOrganization(*created); err != nil { + return nil, err + } + } + + return created, nil +} + +func (s *OrganizationService) GetAllOrganizationsByUserID(ctx context.Context, actorUserID string) ([]types.Organization, error) { + actorUserID = strings.TrimSpace(actorUserID) + if actorUserID == "" { + return nil, internalerrors.ErrUnauthorized + } + + return s.repo.GetAllByOwnerID(ctx, actorUserID) +} + +func (s *OrganizationService) GetOrganization(ctx context.Context, actorUserID string, organizationID string) (*types.Organization, error) { + organization, err := s.authorizeOwner(ctx, actorUserID, organizationID) + if err != nil { + return nil, err + } + + return organization, nil +} + +func (s *OrganizationService) UpdateOrganization(ctx context.Context, actorUserID string, organizationID string, request types.UpdateOrganizationRequest) (*types.Organization, error) { + organization, err := s.authorizeOwner(ctx, actorUserID, organizationID) + if err != nil { + return nil, err + } + + name := strings.TrimSpace(request.Name) + if name == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + slug := organization.Slug + if request.Slug != nil { + slug = strings.TrimSpace(*request.Slug) + } + if slug == "" { + slug = slugify(name) + } + if slug == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + organization.Name = name + organization.Slug = slug + if request.Logo != nil { + organization.Logo = request.Logo + } + if request.Metadata != nil { + organization.Metadata = request.Metadata + } + if len(organization.Metadata) == 0 { + organization.Metadata = []byte("{}") + } + + if s.hooks != nil { + if err := s.hooks.BeforeUpdateOrganization(organization); err != nil { + return nil, err + } + } + + updated, err := s.repo.Update(ctx, organization) + if err != nil { + return nil, err + } + + if s.hooks != nil { + if err := s.hooks.AfterUpdateOrganization(*updated); err != nil { + return nil, err + } + } + + return updated, nil +} + +func (s *OrganizationService) DeleteOrganization(ctx context.Context, actorUserID string, organizationID string) error { + organization, err := s.authorizeOwner(ctx, actorUserID, organizationID) + if err != nil { + return err + } + + if s.hooks != nil { + if err := s.hooks.BeforeDeleteOrganization(organization); err != nil { + return err + } + } + + if err := s.repo.Delete(ctx, organizationID); err != nil { + return err + } + + if s.hooks != nil { + if err := s.hooks.AfterDeleteOrganization(*organization); err != nil { + return err + } + } + + return nil +} + +func (s *OrganizationService) authorizeOwner(ctx context.Context, actorUserID string, organizationID string) (*types.Organization, error) { + actorUserID = strings.TrimSpace(actorUserID) + organizationID = strings.TrimSpace(organizationID) + if actorUserID == "" || organizationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + organization, err := s.repo.GetByID(ctx, organizationID) + if err != nil { + return nil, err + } + if organization == nil { + return nil, internalerrors.ErrNotFound + } + if strings.TrimSpace(organization.OwnerID) != actorUserID { + return nil, internalerrors.ErrForbidden + } + + return organization, nil +} + +func slugify(input string) string { + input = strings.ToLower(strings.TrimSpace(input)) + var builder strings.Builder + lastDash := false + + for _, r := range input { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + builder.WriteRune(r) + lastDash = false + continue + } + if !lastDash { + builder.WriteByte('-') + lastDash = true + } + } + + return strings.Trim(builder.String(), "-") +} diff --git a/plugins/organizations/services/organization_service_test.go b/plugins/organizations/services/organization_service_test.go new file mode 100644 index 00000000..8532ea0d --- /dev/null +++ b/plugins/organizations/services/organization_service_test.go @@ -0,0 +1,374 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + "github.com/Authula/authula/plugins/organizations/types" +) + +type mockOrganizationRepository = orgtests.MockOrganizationRepository +type testOrganizationHooks = orgtests.TestOrganizationHooks +type mockOrganizationMemberRepository = orgtests.MockOrganizationMemberRepository +type mockOrganizationMemberHooks = orgtests.TestOrganizationMemberHooks +type mockOrganizationInvitationRepository = orgtests.MockOrganizationInvitationRepository +type mockOrganizationInvitationTxRunner = orgtests.MockOrganizationInvitationTxRunner +type testOrganizationInvitationHooks = orgtests.TestOrganizationInvitationHooks +type mockOrganizationTeamRepository = orgtests.MockOrganizationTeamRepository +type mockOrganizationTeamMemberRepository = orgtests.MockOrganizationTeamMemberRepository +type testOrganizationTeamHooks = orgtests.TestOrganizationTeamHooks +type testOrganizationTeamMemberHooks = orgtests.TestOrganizationTeamMemberHooks + +func TestOrganizationService_CreateOrganization(t *testing.T) { + t.Parallel() + + someError := errors.New("some error") + afterErr := errors.New("after error") + + tests := []struct { + name string + actorUserID string + request types.CreateOrganizationRequest + setup func(*mockOrganizationRepository, *testOrganizationHooks) + expectErr error + expectCalled bool + expectReturned string + }{ + { + name: "unauthorized", + actorUserID: " ", + request: types.CreateOrganizationRequest{Name: "Acme"}, + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "bad request", + actorUserID: "user-1", + request: types.CreateOrganizationRequest{Name: " "}, + expectErr: internalerrors.ErrUnprocessableEntity, + }, + { + name: "success", + actorUserID: "user-1", + request: types.CreateOrganizationRequest{Name: "Acme Inc", Metadata: json.RawMessage(`{"tier":"pro"}`)}, + setup: func(repo *mockOrganizationRepository, hooks *testOrganizationHooks) { + repo.On("Create", mock.Anything, mock.MatchedBy(func(org *types.Organization) bool { + return org != nil && org.OwnerID == "user-1" && org.Name == "Acme Inc" && org.Slug == "acme-inc" + })).Return(&types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + hooks.BeforeCreate = func(organization *types.Organization) error { + require.NotNil(t, organization) + require.Equal(t, "user-1", organization.OwnerID) + require.Equal(t, "acme-inc", organization.Slug) + return nil + } + hooks.AfterCreate = func(organization types.Organization) error { + require.Equal(t, "org-1", organization.ID) + return nil + } + }, + expectCalled: true, + expectReturned: "org-1", + }, + { + name: "before hook blocks", + actorUserID: "user-1", + request: types.CreateOrganizationRequest{Name: "Acme"}, + setup: func(repo *mockOrganizationRepository, hooks *testOrganizationHooks) { + hooks.BeforeCreate = func(organization *types.Organization) error { return someError } + }, + expectErr: someError, + }, + { + name: "after hook error", + actorUserID: "user-1", + request: types.CreateOrganizationRequest{Name: "Acme"}, + setup: func(repo *mockOrganizationRepository, hooks *testOrganizationHooks) { + repo.On("Create", mock.Anything, mock.AnythingOfType("*types.Organization")).Return(&types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme", Slug: "acme"}, nil).Once() + hooks.AfterCreate = func(organization types.Organization) error { return afterErr } + }, + expectErr: afterErr, + expectCalled: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := &mockOrganizationRepository{} + hooks := &testOrganizationHooks{} + if tt.setup != nil { + tt.setup(repo, hooks) + } + + svc := NewOrganizationService(repo, hooks) + org, err := svc.CreateOrganization(context.Background(), tt.actorUserID, tt.request) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + if tt.expectCalled { + require.True(t, repo.AssertExpectations(t)) + } + return + } + + require.NoError(t, err) + require.NotNil(t, org) + if tt.expectReturned != "" { + require.Equal(t, tt.expectReturned, org.ID) + } + if tt.expectCalled { + repo.AssertExpectations(t) + } + }) + } +} + +func TestOrganizationService_GetAllOrganizations(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + setup func(*mockOrganizationRepository) + expectErr error + expectLen int + }{ + { + name: "unauthorized", + actorUserID: " ", + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "success", + actorUserID: "user-1", + setup: func(repo *mockOrganizationRepository) { + repo.On("GetAllByOwnerID", mock.Anything, "user-1").Return([]types.Organization{{ID: "org-1", OwnerID: "user-1", Name: "Acme"}}, nil).Once() + }, + expectLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := &mockOrganizationRepository{} + if tt.setup != nil { + tt.setup(repo) + } + + svc := NewOrganizationService(repo, &testOrganizationHooks{}) + organizations, err := svc.GetAllOrganizationsByUserID(context.Background(), tt.actorUserID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.Len(t, organizations, tt.expectLen) + }) + } +} + +func TestOrganizationService_GetOrganization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + organizationID string + setup func(*mockOrganizationRepository) + expectErr error + }{ + { + name: "forbidden", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(repo *mockOrganizationRepository) { + repo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(repo *mockOrganizationRepository) { + repo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := &mockOrganizationRepository{} + if tt.setup != nil { + tt.setup(repo) + } + + svc := NewOrganizationService(repo, &testOrganizationHooks{}) + org, err := svc.GetOrganization(context.Background(), tt.actorUserID, tt.organizationID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, org) + }) + } +} + +func TestOrganizationService_UpdateOrganization(t *testing.T) { + t.Parallel() + + someError := errors.New("some error") + + tests := []struct { + name string + actorUserID string + organizationID string + request types.UpdateOrganizationRequest + setup func(*mockOrganizationRepository, *testOrganizationHooks) + expectErr error + }{ + { + name: "unauthorized if not user ID provided", + actorUserID: "", + organizationID: "", + request: types.UpdateOrganizationRequest{Name: "Acme Platform"}, + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "unauthorized if no organization ID provided", + actorUserID: "user-1", + organizationID: "", + request: types.UpdateOrganizationRequest{Name: "Acme Platform"}, + expectErr: internalerrors.ErrUnauthorized, + }, + { + name: "forbidden", + actorUserID: "user-1", + organizationID: "org-1", + request: types.UpdateOrganizationRequest{Name: "Acme Platform"}, + setup: func(repo *mockOrganizationRepository, hooks *testOrganizationHooks) { + repo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + request: types.UpdateOrganizationRequest{Name: "Acme Platform"}, + setup: func(repo *mockOrganizationRepository, hooks *testOrganizationHooks) { + repo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme", Slug: "acme"}, nil).Once() + repo.On("Update", mock.Anything, mock.MatchedBy(func(org *types.Organization) bool { + return org != nil && org.ID == "org-1" && org.Name == "Acme Platform" && org.Slug == "acme" + })).Return(&types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Platform", Slug: "acme"}, nil).Once() + hooks.BeforeUpdate = func(organization *types.Organization) error { return nil } + hooks.AfterUpdate = func(organization types.Organization) error { return nil } + }, + }, + { + name: "before hook blocks", + actorUserID: "user-1", + organizationID: "org-1", + request: types.UpdateOrganizationRequest{Name: "Acme Platform"}, + setup: func(repo *mockOrganizationRepository, hooks *testOrganizationHooks) { + repo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme", Slug: "acme"}, nil).Once() + hooks.BeforeUpdate = func(organization *types.Organization) error { return someError } + }, + expectErr: someError, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := &mockOrganizationRepository{} + hooks := &testOrganizationHooks{} + if tt.setup != nil { + tt.setup(repo, hooks) + } + + svc := NewOrganizationService(repo, hooks) + org, err := svc.UpdateOrganization(context.Background(), tt.actorUserID, tt.organizationID, tt.request) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, org) + }) + } +} + +func TestOrganizationService_DeleteOrganization(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + + tests := []struct { + name string + actorUserID string + organizationID string + setup func(*mockOrganizationRepository, *testOrganizationHooks) + expectErr error + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(repo *mockOrganizationRepository, hooks *testOrganizationHooks) { + repo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + repo.On("Delete", mock.Anything, "org-1").Return(nil).Once() + hooks.BeforeDelete = func(organization *types.Organization) error { return nil } + hooks.AfterDelete = func(organization types.Organization) error { return nil } + }, + }, + { + name: "before hook blocks", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(repo *mockOrganizationRepository, hooks *testOrganizationHooks) { + repo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + hooks.BeforeDelete = func(organization *types.Organization) error { return someErr } + }, + expectErr: someErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := &mockOrganizationRepository{} + hooks := &testOrganizationHooks{} + if tt.setup != nil { + tt.setup(repo, hooks) + } + + svc := NewOrganizationService(repo, hooks) + err := svc.DeleteOrganization(context.Background(), tt.actorUserID, tt.organizationID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + }) + } +} diff --git a/plugins/organizations/services/organization_team_service.go b/plugins/organizations/services/organization_team_service.go new file mode 100644 index 00000000..4c1c0198 --- /dev/null +++ b/plugins/organizations/services/organization_team_service.go @@ -0,0 +1,361 @@ +package services + +import ( + "context" + "strings" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/types" +) + +type OrganizationTeamHookExecutor interface { + BeforeCreateOrganizationTeam(team *types.OrganizationTeam) error + AfterCreateOrganizationTeam(team types.OrganizationTeam) error + BeforeUpdateOrganizationTeam(team *types.OrganizationTeam) error + AfterUpdateOrganizationTeam(team types.OrganizationTeam) error + BeforeDeleteOrganizationTeam(team *types.OrganizationTeam) error + AfterDeleteOrganizationTeam(team types.OrganizationTeam) error +} + +type OrganizationTeamMemberHookExecutor interface { + BeforeCreateOrganizationTeamMember(teamMember *types.OrganizationTeamMember) error + AfterCreateOrganizationTeamMember(teamMember types.OrganizationTeamMember) error + BeforeDeleteOrganizationTeamMember(teamMember *types.OrganizationTeamMember) error + AfterDeleteOrganizationTeamMember(teamMember types.OrganizationTeamMember) error +} + +type OrganizationTeamService struct { + orgRepo repositories.OrganizationRepository + orgTeamRepo repositories.OrganizationTeamRepository + orgMemberRepo repositories.OrganizationMemberRepository + orgTeamMemberRepo repositories.OrganizationTeamMemberRepository + orgTeamHooks OrganizationTeamHookExecutor + teamMemberHooks OrganizationTeamMemberHookExecutor +} + +func NewOrganizationTeamService(organizationRepo repositories.OrganizationRepository, teamRepo repositories.OrganizationTeamRepository, organizationMemberRepo repositories.OrganizationMemberRepository, teamMemberRepo repositories.OrganizationTeamMemberRepository, hooks OrganizationTeamHookExecutor, teamMemberHooks OrganizationTeamMemberHookExecutor) *OrganizationTeamService { + return &OrganizationTeamService{orgRepo: organizationRepo, orgTeamRepo: teamRepo, orgMemberRepo: organizationMemberRepo, orgTeamMemberRepo: teamMemberRepo, orgTeamHooks: hooks, teamMemberHooks: teamMemberHooks} +} + +func (s *OrganizationTeamService) CreateTeam(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationTeamRequest) (*types.OrganizationTeam, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + name := strings.TrimSpace(request.Name) + if name == "" { + return nil, internalerrors.ErrBadRequest + } + + slug := "" + if request.Slug != nil { + slug = strings.TrimSpace(*request.Slug) + } + if slug == "" { + slug = slugify(name) + } + if slug == "" { + return nil, internalerrors.ErrBadRequest + } + + if existing, err := s.orgTeamRepo.GetByOrganizationIDAndSlug(ctx, organizationID, slug); err != nil { + return nil, err + } else if existing != nil { + return nil, internalerrors.ErrConflict + } + + team := &types.OrganizationTeam{ + ID: util.GenerateUUID(), + OrganizationID: organizationID, + Name: name, + Slug: slug, + Description: request.Description, + Metadata: request.Metadata, + } + if len(team.Metadata) == 0 { + team.Metadata = []byte("{}") + } + + if s.orgTeamHooks != nil { + if err := s.orgTeamHooks.BeforeCreateOrganizationTeam(team); err != nil { + return nil, err + } + } + + created, err := s.orgTeamRepo.Create(ctx, team) + if err != nil { + return nil, err + } + + if s.orgTeamHooks != nil { + if err := s.orgTeamHooks.AfterCreateOrganizationTeam(*created); err != nil { + return nil, err + } + } + + return created, nil +} + +func (s *OrganizationTeamService) GetAllTeams(ctx context.Context, actorUserID string, organizationID string) ([]types.OrganizationTeam, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + return s.orgTeamRepo.GetAllByOrganizationID(ctx, organizationID) +} + +func (s *OrganizationTeamService) UpdateTeam(ctx context.Context, actorUserID string, organizationID string, teamID string, request types.UpdateOrganizationTeamRequest) (*types.OrganizationTeam, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + team, err := s.orgTeamRepo.GetByID(ctx, teamID) + if err != nil { + return nil, err + } + if team == nil || team.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + + name := strings.TrimSpace(request.Name) + if name == "" { + return nil, internalerrors.ErrBadRequest + } + + slug := team.Slug + if request.Slug != nil { + slug = strings.TrimSpace(*request.Slug) + } + if slug == "" { + slug = slugify(name) + } + if slug == "" { + return nil, internalerrors.ErrBadRequest + } + + if existing, err := s.orgTeamRepo.GetByOrganizationIDAndSlug(ctx, organizationID, slug); err != nil { + return nil, err + } else if existing != nil && existing.ID != teamID { + return nil, internalerrors.ErrConflict + } + + team.Name = name + team.Slug = slug + team.Description = request.Description + team.Metadata = request.Metadata + if len(team.Metadata) == 0 { + team.Metadata = []byte("{}") + } + + if s.orgTeamHooks != nil { + if err := s.orgTeamHooks.BeforeUpdateOrganizationTeam(team); err != nil { + return nil, err + } + } + + updated, err := s.orgTeamRepo.Update(ctx, team) + if err != nil { + return nil, err + } + + if s.orgTeamHooks != nil { + if err := s.orgTeamHooks.AfterUpdateOrganizationTeam(*updated); err != nil { + return nil, err + } + } + + return updated, nil +} + +func (s *OrganizationTeamService) DeleteTeam(ctx context.Context, actorUserID string, organizationID string, teamID string) error { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return err + } + + team, err := s.orgTeamRepo.GetByID(ctx, teamID) + if err != nil { + return err + } + if team == nil || team.OrganizationID != organizationID { + return internalerrors.ErrNotFound + } + + if s.orgTeamHooks != nil { + if err := s.orgTeamHooks.BeforeDeleteOrganizationTeam(team); err != nil { + return err + } + } + + if err := s.orgTeamRepo.Delete(ctx, teamID); err != nil { + return err + } + + if s.orgTeamHooks != nil { + if err := s.orgTeamHooks.AfterDeleteOrganizationTeam(*team); err != nil { + return err + } + } + + return nil +} + +func (s *OrganizationTeamService) AddTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, request types.AddOrganizationTeamMemberRequest) (*types.OrganizationTeamMember, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + team, err := s.orgTeamRepo.GetByID(ctx, teamID) + if err != nil { + return nil, err + } + if team == nil || team.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + + orgMemberID := strings.TrimSpace(request.MemberID) + if orgMemberID == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + orgMember, err := s.orgMemberRepo.GetByID(ctx, orgMemberID) + if err != nil { + return nil, err + } + if orgMember == nil || orgMember.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + + if existing, err := s.orgTeamMemberRepo.GetByTeamIDAndMemberID(ctx, teamID, orgMemberID); err != nil { + return nil, err + } else if existing != nil { + return nil, internalerrors.ErrConflict + } + + teamMember := &types.OrganizationTeamMember{ + ID: util.GenerateUUID(), + TeamID: teamID, + MemberID: orgMemberID, + } + + if s.teamMemberHooks != nil { + if err := s.teamMemberHooks.BeforeCreateOrganizationTeamMember(teamMember); err != nil { + return nil, err + } + } + + created, err := s.orgTeamMemberRepo.Create(ctx, teamMember) + if err != nil { + return nil, err + } + + if s.teamMemberHooks != nil { + if err := s.teamMemberHooks.AfterCreateOrganizationTeamMember(*created); err != nil { + return nil, err + } + } + + return created, nil +} + +func (s *OrganizationTeamService) GetAllTeamMembers(ctx context.Context, actorUserID string, organizationID string, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + team, err := s.orgTeamRepo.GetByID(ctx, teamID) + if err != nil { + return nil, err + } + if team == nil || team.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + + return s.orgTeamMemberRepo.GetAllByTeamID(ctx, teamID, page, limit) +} + +func (s *OrganizationTeamService) GetTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, memberID string) (*types.OrganizationTeamMember, error) { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return nil, err + } + + team, err := s.orgTeamRepo.GetByID(ctx, teamID) + if err != nil { + return nil, err + } + if team == nil || team.OrganizationID != organizationID { + return nil, internalerrors.ErrNotFound + } + + teamMember, err := s.orgTeamMemberRepo.GetByTeamIDAndMemberID(ctx, teamID, strings.TrimSpace(memberID)) + if err != nil { + return nil, err + } + if teamMember == nil { + return nil, internalerrors.ErrNotFound + } + + return teamMember, nil +} + +func (s *OrganizationTeamService) RemoveTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, memberID string) error { + if _, err := s.authorizeOrganizationOwner(ctx, actorUserID, organizationID); err != nil { + return err + } + + team, err := s.orgTeamRepo.GetByID(ctx, teamID) + if err != nil { + return err + } + if team == nil || team.OrganizationID != organizationID { + return internalerrors.ErrNotFound + } + + teamMember, err := s.orgTeamMemberRepo.GetByTeamIDAndMemberID(ctx, teamID, strings.TrimSpace(memberID)) + if err != nil { + return err + } + if teamMember == nil { + return internalerrors.ErrNotFound + } + + if s.teamMemberHooks != nil { + if err := s.teamMemberHooks.BeforeDeleteOrganizationTeamMember(teamMember); err != nil { + return err + } + } + + if err := s.orgTeamMemberRepo.DeleteByTeamIDAndMemberID(ctx, teamID, strings.TrimSpace(memberID)); err != nil { + return err + } + + if s.teamMemberHooks != nil { + if err := s.teamMemberHooks.AfterDeleteOrganizationTeamMember(*teamMember); err != nil { + return err + } + } + + return nil +} + +func (s *OrganizationTeamService) authorizeOrganizationOwner(ctx context.Context, actorUserID string, organizationID string) (*types.Organization, error) { + actorUserID = strings.TrimSpace(actorUserID) + organizationID = strings.TrimSpace(organizationID) + if actorUserID == "" || organizationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + organization, err := s.orgRepo.GetByID(ctx, organizationID) + if err != nil { + return nil, err + } + if organization == nil { + return nil, internalerrors.ErrNotFound + } + if strings.TrimSpace(organization.OwnerID) != actorUserID { + return nil, internalerrors.ErrForbidden + } + + return organization, nil +} diff --git a/plugins/organizations/services/organization_team_service_test.go b/plugins/organizations/services/organization_team_service_test.go new file mode 100644 index 00000000..2966f25e --- /dev/null +++ b/plugins/organizations/services/organization_team_service_test.go @@ -0,0 +1,548 @@ +package services + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/plugins/organizations/types" +) + +func TestOrganizationTeamService_CreateTeam(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + + tests := []struct { + name string + actorUserID string + organizationID string + request types.CreateOrganizationTeamRequest + setup func(*mockOrganizationRepository, *mockOrganizationTeamRepository, *testOrganizationTeamHooks) + expectErr error + expectCalled bool + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + request: types.CreateOrganizationTeamRequest{Name: "Acme Platform"}, + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, hooks *testOrganizationTeamHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "acme-platform").Return(nil, nil).Once() + teamRepo.On("Create", mock.Anything, mock.MatchedBy(func(team *types.OrganizationTeam) bool { + return team != nil && team.OrganizationID == "org-1" && team.Name == "Acme Platform" && team.Slug == "acme-platform" + })).Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Acme Platform", Slug: "acme-platform"}, nil).Once() + hooks.BeforeCreate = func(team *types.OrganizationTeam) error { return nil } + hooks.AfterCreate = func(team types.OrganizationTeam) error { return nil } + }, + expectCalled: true, + }, + { + name: "before hook blocks", + actorUserID: "user-1", + organizationID: "org-1", + request: types.CreateOrganizationTeamRequest{Name: "Acme Platform"}, + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, hooks *testOrganizationTeamHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "acme-platform").Return(nil, nil).Once() + hooks.BeforeCreate = func(team *types.OrganizationTeam) error { return someErr } + }, + expectErr: someErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + teamRepo := &mockOrganizationTeamRepository{} + memberRepo := &mockOrganizationMemberRepository{} + teamMemberRepo := &mockOrganizationTeamMemberRepository{} + hooks := &testOrganizationTeamHooks{} + if tt.setup != nil { + tt.setup(orgRepo, teamRepo, hooks) + } + + svc := NewOrganizationTeamService(orgRepo, teamRepo, memberRepo, teamMemberRepo, hooks, &testOrganizationTeamMemberHooks{}) + team, err := svc.CreateTeam(context.Background(), tt.actorUserID, tt.organizationID, tt.request) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, team) + require.Equal(t, tt.expectCalled, orgRepo.AssertExpectations(t)) + require.Equal(t, tt.expectCalled, teamRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationTeamService_AddTeamMember(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + + tests := []struct { + name string + actorUserID string + organizationID string + teamID string + request types.AddOrganizationTeamMemberRequest + setup func(*mockOrganizationRepository, *mockOrganizationTeamRepository, *mockOrganizationMemberRepository, *mockOrganizationTeamMemberRepository, *testOrganizationTeamMemberHooks) + expectErr error + expectCalled bool + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + request: types.AddOrganizationTeamMemberRequest{MemberID: "member-1"}, + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, memberRepo *mockOrganizationMemberRepository, teamMemberRepo *mockOrganizationTeamMemberRepository, hooks *testOrganizationTeamMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "member-1").Return(&types.OrganizationMember{ID: "member-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + teamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(nil, nil).Once() + teamMemberRepo.On("Create", mock.Anything, mock.MatchedBy(func(teamMember *types.OrganizationTeamMember) bool { + return teamMember != nil && teamMember.TeamID == "team-1" && teamMember.MemberID == "member-1" + })).Return(&types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + hooks.BeforeCreate = func(teamMember *types.OrganizationTeamMember) error { return nil } + hooks.AfterCreate = func(teamMember types.OrganizationTeamMember) error { return nil } + }, + expectCalled: true, + }, + { + name: "before hook blocks", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + request: types.AddOrganizationTeamMemberRequest{MemberID: "member-1"}, + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, memberRepo *mockOrganizationMemberRepository, teamMemberRepo *mockOrganizationTeamMemberRepository, hooks *testOrganizationTeamMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "member-1").Return(&types.OrganizationMember{ID: "member-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + teamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(nil, nil).Once() + hooks.BeforeCreate = func(teamMember *types.OrganizationTeamMember) error { return someErr } + }, + expectErr: someErr, + }, + { + name: "member from another organization", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + request: types.AddOrganizationTeamMemberRequest{MemberID: "member-1"}, + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, memberRepo *mockOrganizationMemberRepository, teamMemberRepo *mockOrganizationTeamMemberRepository, hooks *testOrganizationTeamMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "member-1").Return(&types.OrganizationMember{ID: "member-1", OrganizationID: "org-2", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + teamRepo := &mockOrganizationTeamRepository{} + memberRepo := &mockOrganizationMemberRepository{} + teamMemberRepo := &mockOrganizationTeamMemberRepository{} + hooks := &testOrganizationTeamMemberHooks{} + if tt.setup != nil { + tt.setup(orgRepo, teamRepo, memberRepo, teamMemberRepo, hooks) + } + + svc := NewOrganizationTeamService(orgRepo, teamRepo, memberRepo, teamMemberRepo, &testOrganizationTeamHooks{}, hooks) + teamMember, err := svc.AddTeamMember(context.Background(), tt.actorUserID, tt.organizationID, tt.teamID, tt.request) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, teamMember) + require.Equal(t, tt.expectCalled, orgRepo.AssertExpectations(t)) + require.Equal(t, tt.expectCalled, teamRepo.AssertExpectations(t)) + require.Equal(t, tt.expectCalled, memberRepo.AssertExpectations(t)) + require.Equal(t, tt.expectCalled, teamMemberRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationTeamService_GetTeamMember(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + organizationID string + teamID string + memberID string + setup func(*mockOrganizationRepository, *mockOrganizationTeamRepository, *mockOrganizationTeamMemberRepository) + expectErr error + expectID string + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, teamMemberRepo *mockOrganizationTeamMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1"}, nil).Once() + teamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&types.OrganizationTeamMember{ID: "tm-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + }, + expectID: "tm-1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + teamRepo := &mockOrganizationTeamRepository{} + teamMemberRepo := &mockOrganizationTeamMemberRepository{} + if tt.setup != nil { + tt.setup(orgRepo, teamRepo, teamMemberRepo) + } + + svc := NewOrganizationTeamService(orgRepo, teamRepo, &mockOrganizationMemberRepository{}, teamMemberRepo, nil, nil) + teamMember, err := svc.GetTeamMember(context.Background(), tt.actorUserID, tt.organizationID, tt.teamID, tt.memberID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, teamMember) + require.Equal(t, tt.expectID, teamMember.ID) + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, teamRepo.AssertExpectations(t)) + require.True(t, teamMemberRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationTeamService_RemoveTeamMember(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + + tests := []struct { + name string + actorUserID string + organizationID string + teamID string + memberID string + setup func(*mockOrganizationRepository, *mockOrganizationTeamRepository, *mockOrganizationTeamMemberRepository, *testOrganizationTeamMemberHooks) + expectErr error + }{ + { + name: "not found", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, teamMemberRepo *mockOrganizationTeamMemberRepository, hooks *testOrganizationTeamMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1"}, nil).Once() + teamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(nil, nil).Once() + }, + expectErr: internalerrors.ErrNotFound, + }, + { + name: "before hook blocks", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, teamMemberRepo *mockOrganizationTeamMemberRepository, hooks *testOrganizationTeamMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1"}, nil).Once() + teamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&types.OrganizationTeamMember{ID: "tm-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + hooks.BeforeDelete = func(teamMember *types.OrganizationTeamMember) error { return someErr } + }, + expectErr: someErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + teamRepo := &mockOrganizationTeamRepository{} + memberRepo := &mockOrganizationMemberRepository{} + teamMemberRepo := &mockOrganizationTeamMemberRepository{} + hooks := &testOrganizationTeamMemberHooks{} + if tt.setup != nil { + tt.setup(orgRepo, teamRepo, teamMemberRepo, hooks) + } + + svc := NewOrganizationTeamService(orgRepo, teamRepo, memberRepo, teamMemberRepo, &testOrganizationTeamHooks{}, hooks) + err := svc.RemoveTeamMember(context.Background(), tt.actorUserID, tt.organizationID, tt.teamID, tt.memberID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + }) + } +} + +func TestOrganizationTeamService_GetAllTeams(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + organizationID string + setup func(*mockOrganizationRepository, *mockOrganizationTeamRepository) + expectErr error + expectLen int + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return([]types.OrganizationTeam{{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}}, nil).Once() + }, + expectLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + teamRepo := &mockOrganizationTeamRepository{} + if tt.setup != nil { + tt.setup(orgRepo, teamRepo) + } + + svc := NewOrganizationTeamService(orgRepo, teamRepo, &mockOrganizationMemberRepository{}, &mockOrganizationTeamMemberRepository{}, nil, nil) + teams, err := svc.GetAllTeams(context.Background(), tt.actorUserID, tt.organizationID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.Len(t, teams, tt.expectLen) + }) + } +} + +func TestOrganizationTeamService_UpdateTeam(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + + tests := []struct { + name string + actorUserID string + organizationID string + teamID string + request types.UpdateOrganizationTeamRequest + setup func(*mockOrganizationRepository, *mockOrganizationTeamRepository, *testOrganizationTeamHooks) + expectErr error + }{ + { + name: "forbidden", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + request: types.UpdateOrganizationTeamRequest{Name: "Platform"}, + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, hooks *testOrganizationTeamHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "owner-1"}, nil).Once() + }, + expectErr: internalerrors.ErrForbidden, + }, + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + request: types.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}, + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, hooks *testOrganizationTeamHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return(nil, nil).Once() + teamRepo.On("Update", mock.Anything, mock.MatchedBy(func(team *types.OrganizationTeam) bool { + return team != nil && team.ID == "team-1" && team.Name == "Platform Revamp" && team.Slug == "platform" + })).Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform Revamp", Slug: "platform"}, nil).Once() + hooks.BeforeUpdate = func(team *types.OrganizationTeam) error { return nil } + hooks.AfterUpdate = func(team types.OrganizationTeam) error { return nil } + }, + }, + { + name: "before hook blocks", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + request: types.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}, + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, hooks *testOrganizationTeamHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return(nil, nil).Once() + hooks.BeforeUpdate = func(team *types.OrganizationTeam) error { return someErr } + }, + expectErr: someErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + teamRepo := &mockOrganizationTeamRepository{} + hooks := &testOrganizationTeamHooks{} + if tt.setup != nil { + tt.setup(orgRepo, teamRepo, hooks) + } + + svc := NewOrganizationTeamService(orgRepo, teamRepo, &mockOrganizationMemberRepository{}, &mockOrganizationTeamMemberRepository{}, hooks, nil) + team, err := svc.UpdateTeam(context.Background(), tt.actorUserID, tt.organizationID, tt.teamID, tt.request) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.NotNil(t, team) + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, teamRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationTeamService_DeleteTeam(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + + tests := []struct { + name string + actorUserID string + organizationID string + teamID string + setup func(*mockOrganizationRepository, *mockOrganizationTeamRepository, *testOrganizationTeamHooks) + expectErr error + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, hooks *testOrganizationTeamHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + teamRepo.On("Delete", mock.Anything, "team-1").Return(nil).Once() + hooks.BeforeDelete = func(team *types.OrganizationTeam) error { return nil } + hooks.AfterDelete = func(team types.OrganizationTeam) error { return nil } + }, + }, + { + name: "before hook blocks", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, hooks *testOrganizationTeamHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + hooks.BeforeDelete = func(team *types.OrganizationTeam) error { return someErr } + }, + expectErr: someErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + teamRepo := &mockOrganizationTeamRepository{} + hooks := &testOrganizationTeamHooks{} + if tt.setup != nil { + tt.setup(orgRepo, teamRepo, hooks) + } + + svc := NewOrganizationTeamService(orgRepo, teamRepo, &mockOrganizationMemberRepository{}, &mockOrganizationTeamMemberRepository{}, hooks, nil) + err := svc.DeleteTeam(context.Background(), tt.actorUserID, tt.organizationID, tt.teamID) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.True(t, orgRepo.AssertExpectations(t)) + require.True(t, teamRepo.AssertExpectations(t)) + }) + } +} + +func TestOrganizationTeamService_GetAllTeamMembers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + actorUserID string + organizationID string + teamID string + setup func(*mockOrganizationRepository, *mockOrganizationTeamRepository, *mockOrganizationTeamMemberRepository) + expectErr error + expectLen int + }{ + { + name: "success", + actorUserID: "user-1", + organizationID: "org-1", + teamID: "team-1", + setup: func(orgRepo *mockOrganizationRepository, teamRepo *mockOrganizationTeamRepository, teamMemberRepo *mockOrganizationTeamMemberRepository) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + teamRepo.On("GetByID", mock.Anything, "team-1").Return(&types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1"}, nil).Once() + teamMemberRepo.On("GetAllByTeamID", mock.Anything, "team-1", 1, 10).Return([]types.OrganizationTeamMember{{ID: "tm-1", TeamID: "team-1", MemberID: "member-1"}}, nil).Once() + }, + expectLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + orgRepo := &mockOrganizationRepository{} + teamRepo := &mockOrganizationTeamRepository{} + teamMemberRepo := &mockOrganizationTeamMemberRepository{} + if tt.setup != nil { + tt.setup(orgRepo, teamRepo, teamMemberRepo) + } + + svc := NewOrganizationTeamService(orgRepo, teamRepo, &mockOrganizationMemberRepository{}, teamMemberRepo, nil, nil) + members, err := svc.GetAllTeamMembers(context.Background(), tt.actorUserID, tt.organizationID, tt.teamID, 1, 10) + if tt.expectErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.expectErr) + return + } + require.NoError(t, err) + require.Len(t, members, tt.expectLen) + }) + } +} diff --git a/plugins/organizations/tests/test_helpers.go b/plugins/organizations/tests/test_helpers.go new file mode 100644 index 00000000..210a85cf --- /dev/null +++ b/plugins/organizations/tests/test_helpers.go @@ -0,0 +1,503 @@ +package tests + +import ( + "context" + "database/sql" + + "github.com/stretchr/testify/mock" + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/types" +) + +type MockOrganizationRepository struct { + mock.Mock +} + +func (m *MockOrganizationRepository) Create(ctx context.Context, organization *types.Organization) (*types.Organization, error) { + args := m.Called(ctx, organization) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.Organization), args.Error(1) +} + +func (m *MockOrganizationRepository) GetByID(ctx context.Context, organizationID string) (*types.Organization, error) { + args := m.Called(ctx, organizationID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.Organization), args.Error(1) +} + +func (m *MockOrganizationRepository) GetBySlug(ctx context.Context, slug string) (*types.Organization, error) { + args := m.Called(ctx, slug) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.Organization), args.Error(1) +} + +func (m *MockOrganizationRepository) GetAllByOwnerID(ctx context.Context, ownerID string) ([]types.Organization, error) { + args := m.Called(ctx, ownerID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.Organization), args.Error(1) +} + +func (m *MockOrganizationRepository) Update(ctx context.Context, organization *types.Organization) (*types.Organization, error) { + args := m.Called(ctx, organization) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.Organization), args.Error(1) +} + +func (m *MockOrganizationRepository) Delete(ctx context.Context, organizationID string) error { + return m.Called(ctx, organizationID).Error(0) +} + +func (m *MockOrganizationRepository) WithTx(_ bun.IDB) repositories.OrganizationRepository { + return m +} + +type MockOrganizationMemberRepository struct { + mock.Mock +} + +func (m *MockOrganizationMemberRepository) Create(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) { + args := m.Called(ctx, member) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationMember), args.Error(1) +} + +func (m *MockOrganizationMemberRepository) GetByOrganizationIDAndUserID(ctx context.Context, organizationID, userID string) (*types.OrganizationMember, error) { + args := m.Called(ctx, organizationID, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationMember), args.Error(1) +} + +func (m *MockOrganizationMemberRepository) GetAllByOrganizationID(ctx context.Context, organizationID string, page int, limit int) ([]types.OrganizationMember, error) { + args := m.Called(ctx, organizationID, page, limit) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.OrganizationMember), args.Error(1) +} + +func (m *MockOrganizationMemberRepository) GetByID(ctx context.Context, memberID string) (*types.OrganizationMember, error) { + args := m.Called(ctx, memberID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationMember), args.Error(1) +} + +func (m *MockOrganizationMemberRepository) Update(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) { + args := m.Called(ctx, member) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationMember), args.Error(1) +} + +func (m *MockOrganizationMemberRepository) Delete(ctx context.Context, memberID string) error { + return m.Called(ctx, memberID).Error(0) +} + +func (m *MockOrganizationMemberRepository) WithTx(_ bun.IDB) repositories.OrganizationMemberRepository { + return m +} + +type MockOrganizationInvitationRepository struct { + mock.Mock +} + +func (m *MockOrganizationInvitationRepository) Create(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) { + args := m.Called(ctx, invitation) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationInvitation), args.Error(1) +} + +func (m *MockOrganizationInvitationRepository) GetByID(ctx context.Context, invitationID string) (*types.OrganizationInvitation, error) { + args := m.Called(ctx, invitationID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationInvitation), args.Error(1) +} + +func (m *MockOrganizationInvitationRepository) GetByOrganizationIDAndEmail(ctx context.Context, organizationID, email string) (*types.OrganizationInvitation, error) { + args := m.Called(ctx, organizationID, email) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationInvitation), args.Error(1) +} + +func (m *MockOrganizationInvitationRepository) GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationInvitation, error) { + args := m.Called(ctx, organizationID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.OrganizationInvitation), args.Error(1) +} + +func (m *MockOrganizationInvitationRepository) GetAllPendingByEmail(ctx context.Context, email string) ([]types.OrganizationInvitation, error) { + args := m.Called(ctx, email) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.OrganizationInvitation), args.Error(1) +} + +func (m *MockOrganizationInvitationRepository) Update(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) { + args := m.Called(ctx, invitation) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationInvitation), args.Error(1) +} + +func (m *MockOrganizationInvitationRepository) WithTx(_ bun.IDB) repositories.OrganizationInvitationRepository { + return m +} + +type MockOrganizationTeamRepository struct { + mock.Mock +} + +func (m *MockOrganizationTeamRepository) Create(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) { + args := m.Called(ctx, team) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationTeam), args.Error(1) +} + +func (m *MockOrganizationTeamRepository) GetByID(ctx context.Context, teamID string) (*types.OrganizationTeam, error) { + args := m.Called(ctx, teamID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationTeam), args.Error(1) +} + +func (m *MockOrganizationTeamRepository) GetByOrganizationIDAndSlug(ctx context.Context, organizationID, slug string) (*types.OrganizationTeam, error) { + args := m.Called(ctx, organizationID, slug) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationTeam), args.Error(1) +} + +func (m *MockOrganizationTeamRepository) GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationTeam, error) { + args := m.Called(ctx, organizationID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.OrganizationTeam), args.Error(1) +} + +func (m *MockOrganizationTeamRepository) Update(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) { + args := m.Called(ctx, team) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationTeam), args.Error(1) +} + +func (m *MockOrganizationTeamRepository) Delete(ctx context.Context, teamID string) error { + return m.Called(ctx, teamID).Error(0) +} + +func (m *MockOrganizationTeamRepository) WithTx(_ bun.IDB) repositories.OrganizationTeamRepository { + return m +} + +type MockOrganizationTeamMemberRepository struct { + mock.Mock +} + +func (m *MockOrganizationTeamMemberRepository) Create(ctx context.Context, teamMember *types.OrganizationTeamMember) (*types.OrganizationTeamMember, error) { + args := m.Called(ctx, teamMember) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationTeamMember), args.Error(1) +} + +func (m *MockOrganizationTeamMemberRepository) GetByID(ctx context.Context, teamMemberID string) (*types.OrganizationTeamMember, error) { + args := m.Called(ctx, teamMemberID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationTeamMember), args.Error(1) +} + +func (m *MockOrganizationTeamMemberRepository) GetByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) (*types.OrganizationTeamMember, error) { + args := m.Called(ctx, teamID, memberID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.OrganizationTeamMember), args.Error(1) +} + +func (m *MockOrganizationTeamMemberRepository) GetAllByTeamID(ctx context.Context, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) { + args := m.Called(ctx, teamID, page, limit) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]types.OrganizationTeamMember), args.Error(1) +} + +func (m *MockOrganizationTeamMemberRepository) DeleteByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) error { + return m.Called(ctx, teamID, memberID).Error(0) +} + +func (m *MockOrganizationTeamMemberRepository) WithTx(_ bun.IDB) repositories.OrganizationTeamMemberRepository { + return m +} + +type MockOrganizationInvitationTxRunner struct { + run func(func(context.Context, bun.Tx) error) error +} + +func (r *MockOrganizationInvitationTxRunner) RunInTx(ctx context.Context, _ *sql.TxOptions, fn func(context.Context, bun.Tx) error) error { + if r.run != nil { + return r.run(fn) + } + var tx bun.Tx + return fn(ctx, tx) +} + +type TestOrganizationHooks struct { + BeforeCreate func(*types.Organization) error + AfterCreate func(types.Organization) error + BeforeUpdate func(*types.Organization) error + AfterUpdate func(types.Organization) error + BeforeDelete func(*types.Organization) error + AfterDelete func(types.Organization) error +} + +func (h *TestOrganizationHooks) BeforeCreateOrganization(organization *types.Organization) error { + if h.BeforeCreate == nil { + return nil + } + return h.BeforeCreate(organization) +} + +func (h *TestOrganizationHooks) AfterCreateOrganization(organization types.Organization) error { + if h.AfterCreate == nil { + return nil + } + return h.AfterCreate(organization) +} + +func (h *TestOrganizationHooks) BeforeUpdateOrganization(organization *types.Organization) error { + if h.BeforeUpdate == nil { + return nil + } + return h.BeforeUpdate(organization) +} + +func (h *TestOrganizationHooks) AfterUpdateOrganization(organization types.Organization) error { + if h.AfterUpdate == nil { + return nil + } + return h.AfterUpdate(organization) +} + +func (h *TestOrganizationHooks) BeforeDeleteOrganization(organization *types.Organization) error { + if h.BeforeDelete == nil { + return nil + } + return h.BeforeDelete(organization) +} + +func (h *TestOrganizationHooks) AfterDeleteOrganization(organization types.Organization) error { + if h.AfterDelete == nil { + return nil + } + return h.AfterDelete(organization) +} + +type TestOrganizationMemberHooks struct { + Before func(*types.OrganizationMember) error + After func(types.OrganizationMember) error + BeforeUpdate func(*types.OrganizationMember) error + AfterUpdate func(types.OrganizationMember) error + BeforeDelete func(*types.OrganizationMember) error + AfterDelete func(types.OrganizationMember) error +} + +func (h *TestOrganizationMemberHooks) BeforeCreateOrganizationMember(member *types.OrganizationMember) error { + if h.Before == nil { + return nil + } + return h.Before(member) +} + +func (h *TestOrganizationMemberHooks) AfterCreateOrganizationMember(member types.OrganizationMember) error { + if h.After == nil { + return nil + } + return h.After(member) +} + +func (h *TestOrganizationMemberHooks) BeforeUpdateOrganizationMember(member *types.OrganizationMember) error { + if h.BeforeUpdate == nil { + return nil + } + return h.BeforeUpdate(member) +} + +func (h *TestOrganizationMemberHooks) AfterUpdateOrganizationMember(member types.OrganizationMember) error { + if h.AfterUpdate == nil { + return nil + } + return h.AfterUpdate(member) +} + +func (h *TestOrganizationMemberHooks) BeforeDeleteOrganizationMember(member *types.OrganizationMember) error { + if h.BeforeDelete == nil { + return nil + } + return h.BeforeDelete(member) +} + +func (h *TestOrganizationMemberHooks) AfterDeleteOrganizationMember(member types.OrganizationMember) error { + if h.AfterDelete == nil { + return nil + } + return h.AfterDelete(member) +} + +type TestOrganizationInvitationHooks struct { + Before func(*types.OrganizationInvitation) error + After func(types.OrganizationInvitation) error + BeforeUpdate func(*types.OrganizationInvitation) error + AfterUpdate func(types.OrganizationInvitation) error +} + +func (h *TestOrganizationInvitationHooks) BeforeCreateOrganizationInvitation(invitation *types.OrganizationInvitation) error { + if h.Before == nil { + return nil + } + return h.Before(invitation) +} + +func (h *TestOrganizationInvitationHooks) AfterCreateOrganizationInvitation(invitation types.OrganizationInvitation) error { + if h.After == nil { + return nil + } + return h.After(invitation) +} + +func (h *TestOrganizationInvitationHooks) BeforeUpdateOrganizationInvitation(invitation *types.OrganizationInvitation) error { + if h.BeforeUpdate == nil { + return nil + } + return h.BeforeUpdate(invitation) +} + +func (h *TestOrganizationInvitationHooks) AfterUpdateOrganizationInvitation(invitation types.OrganizationInvitation) error { + if h.AfterUpdate == nil { + return nil + } + return h.AfterUpdate(invitation) +} + +type TestOrganizationTeamHooks struct { + BeforeCreate func(*types.OrganizationTeam) error + AfterCreate func(types.OrganizationTeam) error + BeforeUpdate func(*types.OrganizationTeam) error + AfterUpdate func(types.OrganizationTeam) error + BeforeDelete func(*types.OrganizationTeam) error + AfterDelete func(types.OrganizationTeam) error +} + +func (h *TestOrganizationTeamHooks) BeforeCreateOrganizationTeam(team *types.OrganizationTeam) error { + if h.BeforeCreate == nil { + return nil + } + return h.BeforeCreate(team) +} + +func (h *TestOrganizationTeamHooks) AfterCreateOrganizationTeam(team types.OrganizationTeam) error { + if h.AfterCreate == nil { + return nil + } + return h.AfterCreate(team) +} + +func (h *TestOrganizationTeamHooks) BeforeUpdateOrganizationTeam(team *types.OrganizationTeam) error { + if h.BeforeUpdate == nil { + return nil + } + return h.BeforeUpdate(team) +} + +func (h *TestOrganizationTeamHooks) AfterUpdateOrganizationTeam(team types.OrganizationTeam) error { + if h.AfterUpdate == nil { + return nil + } + return h.AfterUpdate(team) +} + +func (h *TestOrganizationTeamHooks) BeforeDeleteOrganizationTeam(team *types.OrganizationTeam) error { + if h.BeforeDelete == nil { + return nil + } + return h.BeforeDelete(team) +} + +func (h *TestOrganizationTeamHooks) AfterDeleteOrganizationTeam(team types.OrganizationTeam) error { + if h.AfterDelete == nil { + return nil + } + return h.AfterDelete(team) +} + +type TestOrganizationTeamMemberHooks struct { + BeforeCreate func(*types.OrganizationTeamMember) error + AfterCreate func(types.OrganizationTeamMember) error + BeforeDelete func(*types.OrganizationTeamMember) error + AfterDelete func(types.OrganizationTeamMember) error +} + +func (h *TestOrganizationTeamMemberHooks) BeforeCreateOrganizationTeamMember(teamMember *types.OrganizationTeamMember) error { + if h.BeforeCreate == nil { + return nil + } + return h.BeforeCreate(teamMember) +} + +func (h *TestOrganizationTeamMemberHooks) AfterCreateOrganizationTeamMember(teamMember types.OrganizationTeamMember) error { + if h.AfterCreate == nil { + return nil + } + return h.AfterCreate(teamMember) +} + +func (h *TestOrganizationTeamMemberHooks) BeforeDeleteOrganizationTeamMember(teamMember *types.OrganizationTeamMember) error { + if h.BeforeDelete == nil { + return nil + } + return h.BeforeDelete(teamMember) +} + +func (h *TestOrganizationTeamMemberHooks) AfterDeleteOrganizationTeamMember(teamMember types.OrganizationTeamMember) error { + if h.AfterDelete == nil { + return nil + } + return h.AfterDelete(teamMember) +} diff --git a/plugins/organizations/types/models.go b/plugins/organizations/types/models.go new file mode 100644 index 00000000..5c134c0d --- /dev/null +++ b/plugins/organizations/types/models.go @@ -0,0 +1,93 @@ +package types + +import ( + "encoding/json" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/models" +) + +type Organization struct { + bun.BaseModel `bun:"table:organizations"` + + ID string `json:"id" bun:"column:id,pk"` + OwnerID string `json:"owner_id" bun:"column:owner_id"` + Name string `json:"name" bun:"column:name"` + Slug string `json:"slug" bun:"column:slug"` + Logo *string `json:"logo" bun:"column:logo"` + Metadata json.RawMessage `json:"metadata" bun:"column:metadata"` + CreatedAt time.Time `json:"created_at" bun:"column:created_at,default:current_timestamp"` + UpdatedAt time.Time `json:"updated_at" bun:"column:updated_at,default:current_timestamp"` + + Owner models.User `json:"-" bun:"rel:belongs-to,join:owner_id=id"` +} + +type OrganizationInvitationStatus string + +const ( + OrganizationInvitationStatusPending OrganizationInvitationStatus = "pending" + OrganizationInvitationStatusAccepted OrganizationInvitationStatus = "accepted" + OrganizationInvitationStatusRejected OrganizationInvitationStatus = "rejected" + OrganizationInvitationStatusRevoked OrganizationInvitationStatus = "revoked" + OrganizationInvitationStatusExpired OrganizationInvitationStatus = "expired" +) + +type OrganizationInvitation struct { + bun.BaseModel `bun:"table:organization_invitations"` + + ID string `json:"id" bun:"column:id,pk"` + Email string `json:"email" bun:"column:email"` + InviterID string `json:"inviter_id" bun:"column:inviter_id"` + OrganizationID string `json:"organization_id" bun:"column:organization_id"` + Role string `json:"role" bun:"column:role"` + Status OrganizationInvitationStatus `json:"status" bun:"column:status"` + ExpiresAt time.Time `json:"expires_at" bun:"column:expires_at"` + CreatedAt time.Time `json:"created_at" bun:"column:created_at,default:current_timestamp"` + UpdatedAt time.Time `json:"updated_at" bun:"column:updated_at,default:current_timestamp"` + + Organization Organization `json:"-" bun:"rel:belongs-to,join:organization_id=id"` + Inviter *models.User `json:"-" bun:"rel:belongs-to,join:inviter_id=id"` +} + +type OrganizationMember struct { + bun.BaseModel `bun:"table:organization_members"` + + ID string `json:"id" bun:"column:id,pk"` + OrganizationID string `json:"organization_id" bun:"column:organization_id"` + UserID string `json:"user_id" bun:"column:user_id"` + Role string `json:"role" bun:"column:role"` + CreatedAt time.Time `json:"created_at" bun:"column:created_at,default:current_timestamp"` + UpdatedAt time.Time `json:"updated_at" bun:"column:updated_at,default:current_timestamp"` + + Organization Organization `json:"-" bun:"rel:belongs-to,join:organization_id=id"` + User models.User `json:"-" bun:"rel:belongs-to,join:user_id=id"` +} + +type OrganizationTeam struct { + bun.BaseModel `bun:"table:organization_teams"` + + ID string `json:"id" bun:"column:id,pk"` + OrganizationID string `json:"organization_id" bun:"column:organization_id"` + Name string `json:"name" bun:"column:name"` + Slug string `json:"slug" bun:"column:slug"` + Description *string `json:"description" bun:"column:description"` + Metadata json.RawMessage `json:"metadata" bun:"column:metadata"` + CreatedAt time.Time `json:"created_at" bun:"column:created_at,default:current_timestamp"` + UpdatedAt time.Time `json:"updated_at" bun:"column:updated_at,default:current_timestamp"` + + Organization Organization `json:"-" bun:"rel:belongs-to,join:organization_id=id"` +} + +type OrganizationTeamMember struct { + bun.BaseModel `bun:"table:organization_team_members"` + + ID string `json:"id" bun:"column:id,pk"` + TeamID string `json:"team_id" bun:"column:team_id"` + MemberID string `json:"member_id" bun:"column:member_id"` + CreatedAt time.Time `json:"created_at" bun:"column:created_at,default:current_timestamp"` + + Team OrganizationTeam `json:"-" bun:"rel:belongs-to,join:team_id=id"` + Member OrganizationMember `json:"-" bun:"rel:belongs-to,join:member_id=id"` +} diff --git a/plugins/organizations/types/requests.go b/plugins/organizations/types/requests.go new file mode 100644 index 00000000..3dc9b1ca --- /dev/null +++ b/plugins/organizations/types/requests.go @@ -0,0 +1,52 @@ +package types + +import ( + "encoding/json" +) + +type CreateOrganizationRequest struct { + Name string `json:"name"` + Slug *string `json:"slug,omitempty"` + Logo *string `json:"logo,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +type UpdateOrganizationRequest struct { + Name string `json:"name"` + Slug *string `json:"slug,omitempty"` + Logo *string `json:"logo,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +type CreateOrganizationInvitationRequest struct { + Email string `json:"email"` + Role string `json:"role"` + RedirectURL string `json:"redirect_url,omitempty"` +} + +type AddOrganizationMemberRequest struct { + UserID string `json:"user_id"` + Role string `json:"role"` +} + +type UpdateOrganizationMemberRequest struct { + Role string `json:"role"` +} + +type CreateOrganizationTeamRequest struct { + Name string `json:"name"` + Slug *string `json:"slug,omitempty"` + Description *string `json:"description,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +type UpdateOrganizationTeamRequest struct { + Name string `json:"name"` + Slug *string `json:"slug,omitempty"` + Description *string `json:"description,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +type AddOrganizationTeamMemberRequest struct { + MemberID string `json:"member_id"` +} diff --git a/plugins/organizations/types/types.go b/plugins/organizations/types/types.go new file mode 100644 index 00000000..85d107ff --- /dev/null +++ b/plugins/organizations/types/types.go @@ -0,0 +1,66 @@ +package types + +import "time" + +type OrganizationsPluginConfig struct { + Enabled bool `json:"enabled" toml:"enabled"` + InvitationExpiresIn time.Duration `json:"invitation_expires_in" toml:"invitation_expires_in"` + DatabaseHooks *OrganizationsDatabaseHooksConfig `json:"-" toml:"-"` +} + +func (config *OrganizationsPluginConfig) ApplyDefaults() { + if config.InvitationExpiresIn == 0 { + config.InvitationExpiresIn = 7 * 24 * time.Hour + } +} + +type OrganizationsDatabaseHooksConfig struct { + Organizations *OrganizationDatabaseHooksConfig + Members *OrganizationMemberDatabaseHooksConfig + Invitations *OrganizationInvitationDatabaseHooksConfig + Teams *OrganizationTeamDatabaseHooksConfig + TeamMembers *OrganizationTeamMemberDatabaseHooksConfig +} + +type OrganizationDatabaseHooksConfig struct { + BeforeCreate func(organization *Organization) error + AfterCreate func(organization Organization) error + BeforeUpdate func(organization *Organization) error + AfterUpdate func(organization Organization) error + BeforeDelete func(organization *Organization) error + AfterDelete func(organization Organization) error +} + +type OrganizationMemberDatabaseHooksConfig struct { + BeforeCreate func(member *OrganizationMember) error + AfterCreate func(member OrganizationMember) error + BeforeUpdate func(member *OrganizationMember) error + AfterUpdate func(member OrganizationMember) error + BeforeDelete func(member *OrganizationMember) error + AfterDelete func(member OrganizationMember) error +} + +type OrganizationInvitationDatabaseHooksConfig struct { + BeforeCreate func(invitation *OrganizationInvitation) error + AfterCreate func(invitation OrganizationInvitation) error + BeforeUpdate func(invitation *OrganizationInvitation) error + AfterUpdate func(invitation OrganizationInvitation) error + BeforeDelete func(invitation *OrganizationInvitation) error + AfterDelete func(invitation OrganizationInvitation) error +} + +type OrganizationTeamDatabaseHooksConfig struct { + BeforeCreate func(team *OrganizationTeam) error + AfterCreate func(team OrganizationTeam) error + BeforeUpdate func(team *OrganizationTeam) error + AfterUpdate func(team OrganizationTeam) error + BeforeDelete func(team *OrganizationTeam) error + AfterDelete func(team OrganizationTeam) error +} + +type OrganizationTeamMemberDatabaseHooksConfig struct { + BeforeCreate func(member *OrganizationTeamMember) error + AfterCreate func(member OrganizationTeamMember) error + BeforeDelete func(member *OrganizationTeamMember) error + AfterDelete func(member OrganizationTeamMember) error +} diff --git a/plugins/organizations/usecases/organization_invitation_usecase.go b/plugins/organizations/usecases/organization_invitation_usecase.go new file mode 100644 index 00000000..dedbc658 --- /dev/null +++ b/plugins/organizations/usecases/organization_invitation_usecase.go @@ -0,0 +1,40 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/organizations/services" + "github.com/Authula/authula/plugins/organizations/types" +) + +type OrganizationInvitationUseCase struct { + service *services.OrganizationInvitationService +} + +func NewOrganizationInvitationUseCase(service *services.OrganizationInvitationService) *OrganizationInvitationUseCase { + return &OrganizationInvitationUseCase{service: service} +} + +func (u *OrganizationInvitationUseCase) CreateInvitation(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationInvitationRequest) (*types.OrganizationInvitation, error) { + return u.service.CreateOrganizationInvitation(ctx, actorUserID, organizationID, request) +} + +func (u *OrganizationInvitationUseCase) GetInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return u.service.GetOrganizationInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (u *OrganizationInvitationUseCase) GetAllInvitations(ctx context.Context, actorUserID string, organizationID string) ([]types.OrganizationInvitation, error) { + return u.service.GetAllOrganizationInvitations(ctx, actorUserID, organizationID) +} + +func (u *OrganizationInvitationUseCase) RevokeInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return u.service.RevokeOrganizationInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (u *OrganizationInvitationUseCase) AcceptInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return u.service.AcceptOrganizationInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (u *OrganizationInvitationUseCase) RejectInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return u.service.RejectOrganizationInvitation(ctx, actorUserID, organizationID, invitationID) +} diff --git a/plugins/organizations/usecases/organization_member_usecase.go b/plugins/organizations/usecases/organization_member_usecase.go new file mode 100644 index 00000000..e4502d41 --- /dev/null +++ b/plugins/organizations/usecases/organization_member_usecase.go @@ -0,0 +1,36 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/organizations/services" + "github.com/Authula/authula/plugins/organizations/types" +) + +type OrganizationMemberUseCase struct { + service *services.OrganizationMemberService +} + +func NewOrganizationMemberUseCase(service *services.OrganizationMemberService) *OrganizationMemberUseCase { + return &OrganizationMemberUseCase{service: service} +} + +func (u *OrganizationMemberUseCase) AddMember(ctx context.Context, actorUserID string, organizationID string, request types.AddOrganizationMemberRequest) (*types.OrganizationMember, error) { + return u.service.AddMember(ctx, actorUserID, organizationID, request) +} + +func (u *OrganizationMemberUseCase) GetAllMembers(ctx context.Context, actorUserID string, organizationID string, page int, limit int) ([]types.OrganizationMember, error) { + return u.service.GetAllMembers(ctx, actorUserID, organizationID, page, limit) +} + +func (u *OrganizationMemberUseCase) GetMember(ctx context.Context, actorUserID string, organizationID string, memberID string) (*types.OrganizationMember, error) { + return u.service.GetMember(ctx, actorUserID, organizationID, memberID) +} + +func (u *OrganizationMemberUseCase) UpdateMember(ctx context.Context, actorUserID string, organizationID string, memberID string, request types.UpdateOrganizationMemberRequest) (*types.OrganizationMember, error) { + return u.service.UpdateMember(ctx, actorUserID, organizationID, memberID, request) +} + +func (u *OrganizationMemberUseCase) RemoveMember(ctx context.Context, actorUserID string, organizationID string, memberID string) error { + return u.service.RemoveMember(ctx, actorUserID, organizationID, memberID) +} diff --git a/plugins/organizations/usecases/organization_team_usecase.go b/plugins/organizations/usecases/organization_team_usecase.go new file mode 100644 index 00000000..afe0c9bd --- /dev/null +++ b/plugins/organizations/usecases/organization_team_usecase.go @@ -0,0 +1,48 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/organizations/services" + "github.com/Authula/authula/plugins/organizations/types" +) + +type OrganizationTeamUseCase struct { + service *services.OrganizationTeamService +} + +func NewOrganizationTeamUseCase(service *services.OrganizationTeamService) *OrganizationTeamUseCase { + return &OrganizationTeamUseCase{service: service} +} + +func (u *OrganizationTeamUseCase) CreateTeam(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationTeamRequest) (*types.OrganizationTeam, error) { + return u.service.CreateTeam(ctx, actorUserID, organizationID, request) +} + +func (u *OrganizationTeamUseCase) GetAllTeams(ctx context.Context, actorUserID string, organizationID string) ([]types.OrganizationTeam, error) { + return u.service.GetAllTeams(ctx, actorUserID, organizationID) +} + +func (u *OrganizationTeamUseCase) UpdateTeam(ctx context.Context, actorUserID string, organizationID string, teamID string, request types.UpdateOrganizationTeamRequest) (*types.OrganizationTeam, error) { + return u.service.UpdateTeam(ctx, actorUserID, organizationID, teamID, request) +} + +func (u *OrganizationTeamUseCase) DeleteTeam(ctx context.Context, actorUserID string, organizationID string, teamID string) error { + return u.service.DeleteTeam(ctx, actorUserID, organizationID, teamID) +} + +func (u *OrganizationTeamUseCase) AddTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, request types.AddOrganizationTeamMemberRequest) (*types.OrganizationTeamMember, error) { + return u.service.AddTeamMember(ctx, actorUserID, organizationID, teamID, request) +} + +func (u *OrganizationTeamUseCase) GetTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, memberID string) (*types.OrganizationTeamMember, error) { + return u.service.GetTeamMember(ctx, actorUserID, organizationID, teamID, memberID) +} + +func (u *OrganizationTeamUseCase) GetAllTeamMembers(ctx context.Context, actorUserID string, organizationID string, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) { + return u.service.GetAllTeamMembers(ctx, actorUserID, organizationID, teamID, page, limit) +} + +func (u *OrganizationTeamUseCase) RemoveTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, memberID string) error { + return u.service.RemoveTeamMember(ctx, actorUserID, organizationID, teamID, memberID) +} diff --git a/plugins/organizations/usecases/organization_usecase.go b/plugins/organizations/usecases/organization_usecase.go new file mode 100644 index 00000000..8760fd67 --- /dev/null +++ b/plugins/organizations/usecases/organization_usecase.go @@ -0,0 +1,36 @@ +package usecases + +import ( + "context" + + "github.com/Authula/authula/plugins/organizations/services" + "github.com/Authula/authula/plugins/organizations/types" +) + +type OrganizationUseCase struct { + service *services.OrganizationService +} + +func NewOrganizationUseCase(service *services.OrganizationService) *OrganizationUseCase { + return &OrganizationUseCase{service: service} +} + +func (u *OrganizationUseCase) CreateOrganization(ctx context.Context, actorUserID string, request types.CreateOrganizationRequest) (*types.Organization, error) { + return u.service.CreateOrganization(ctx, actorUserID, request) +} + +func (u *OrganizationUseCase) GetAllOrganizationsByUserID(ctx context.Context, actorUserID string) ([]types.Organization, error) { + return u.service.GetAllOrganizationsByUserID(ctx, actorUserID) +} + +func (u *OrganizationUseCase) GetOrganization(ctx context.Context, actorUserID string, organizationID string) (*types.Organization, error) { + return u.service.GetOrganization(ctx, actorUserID, organizationID) +} + +func (u *OrganizationUseCase) UpdateOrganization(ctx context.Context, actorUserID string, organizationID string, request types.UpdateOrganizationRequest) (*types.Organization, error) { + return u.service.UpdateOrganization(ctx, actorUserID, organizationID, request) +} + +func (u *OrganizationUseCase) DeleteOrganization(ctx context.Context, actorUserID string, organizationID string) error { + return u.service.DeleteOrganization(ctx, actorUserID, organizationID) +} From 0419efa9ec2dbc4293a669b720330a0472f76123 Mon Sep 17 00:00:00 2001 From: Tanvir Ahmed Date: Tue, 7 Apr 2026 15:50:36 +0000 Subject: [PATCH 2/2] chore: Integrated access control role validation --- .../organization_invitation_handlers_test.go | 55 ++++++++++++---- .../organization_member_handlers_test.go | 33 +++++++++- plugins/organizations/plugin.go | 11 +++- .../organization_invitation_service.go | 63 +++++++++++-------- .../organization_invitation_service_test.go | 46 +++++++++++--- .../services/organization_member_service.go | 40 ++++++++---- .../organization_member_service_test.go | 48 ++++++++++++-- 7 files changed, 230 insertions(+), 66 deletions(-) diff --git a/plugins/organizations/handlers/organization_invitation_handlers_test.go b/plugins/organizations/handlers/organization_invitation_handlers_test.go index 1516bf06..bab76456 100644 --- a/plugins/organizations/handlers/organization_invitation_handlers_test.go +++ b/plugins/organizations/handlers/organization_invitation_handlers_test.go @@ -25,6 +25,22 @@ import ( rootservices "github.com/Authula/authula/services" ) +type invitationHandlerAccessControlServiceStub struct { + roles map[string]bool + err error +} + +func (s *invitationHandlerAccessControlServiceStub) RoleExists(ctx context.Context, roleName string) (bool, error) { + if s.err != nil { + return false, s.err + } + return s.roles[roleName], nil +} + +func newInvitationHandlerAccessControlServiceStub() *invitationHandlerAccessControlServiceStub { + return &invitationHandlerAccessControlServiceStub{roles: map[string]bool{"member": true, "admin": true}} +} + type organizationInvitationTxRunner interface { RunInTx(ctx context.Context, opts *sql.TxOptions, fn func(context.Context, bun.Tx) error) error } @@ -33,6 +49,7 @@ func newOrganizationInvitationServiceForHandlerTest( txRunner organizationInvitationTxRunner, pluginConfig *orgtypes.OrganizationsPluginConfig, userService rootservices.UserService, + accessControlService rootservices.AccessControlService, orgRepo *orgtests.MockOrganizationRepository, invRepo *orgtests.MockOrganizationInvitationRepository, memberRepo *orgtests.MockOrganizationMemberRepository, @@ -45,6 +62,7 @@ func newOrganizationInvitationServiceForHandlerTest( pluginConfig, &internaltests.MockLogger{}, userService, + accessControlService, orgRepo, invRepo, memberRepo, @@ -56,12 +74,13 @@ func newOrganizationInvitationServiceForHandlerTest( } type organizationInvitationHandlerFixture struct { - pluginConfig *orgtypes.OrganizationsPluginConfig - orgRepo *orgtests.MockOrganizationRepository - invRepo *orgtests.MockOrganizationInvitationRepository - memberRepo *orgtests.MockOrganizationMemberRepository - userSvc *internaltests.MockUserService - txRunner *orgtests.MockOrganizationInvitationTxRunner + pluginConfig *orgtypes.OrganizationsPluginConfig + orgRepo *orgtests.MockOrganizationRepository + invRepo *orgtests.MockOrganizationInvitationRepository + memberRepo *orgtests.MockOrganizationMemberRepository + userSvc *internaltests.MockUserService + accessControl *invitationHandlerAccessControlServiceStub + txRunner *orgtests.MockOrganizationInvitationTxRunner } type organizationInvitationHandlerCase struct { @@ -82,16 +101,17 @@ func newOrganizationInvitationHandlerFixture() *organizationInvitationHandlerFix Enabled: true, InvitationExpiresIn: 7 * 24 * time.Hour, }, - orgRepo: &orgtests.MockOrganizationRepository{}, - invRepo: &orgtests.MockOrganizationInvitationRepository{}, - memberRepo: &orgtests.MockOrganizationMemberRepository{}, - userSvc: &internaltests.MockUserService{}, - txRunner: &orgtests.MockOrganizationInvitationTxRunner{}, + orgRepo: &orgtests.MockOrganizationRepository{}, + invRepo: &orgtests.MockOrganizationInvitationRepository{}, + memberRepo: &orgtests.MockOrganizationMemberRepository{}, + userSvc: &internaltests.MockUserService{}, + accessControl: newInvitationHandlerAccessControlServiceStub(), + txRunner: &orgtests.MockOrganizationInvitationTxRunner{}, } } func (f *organizationInvitationHandlerFixture) useCase() *usecases.OrganizationInvitationUseCase { - service := newOrganizationInvitationServiceForHandlerTest(f.txRunner, f.pluginConfig, f.userSvc, f.orgRepo, f.invRepo, f.memberRepo, nil, nil) + service := newOrganizationInvitationServiceForHandlerTest(f.txRunner, f.pluginConfig, f.userSvc, f.accessControl, f.orgRepo, f.invRepo, f.memberRepo, nil, nil) return usecases.NewOrganizationInvitationUseCase(service) } @@ -203,6 +223,17 @@ func TestCreateOrganizationInvitationHandler(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), }, + { + name: "invalid_role", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "ghost"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, { name: "organization_not_found", userID: new("user-1"), diff --git a/plugins/organizations/handlers/organization_member_handlers_test.go b/plugins/organizations/handlers/organization_member_handlers_test.go index 8559a6d6..877b70b4 100644 --- a/plugins/organizations/handlers/organization_member_handlers_test.go +++ b/plugins/organizations/handlers/organization_member_handlers_test.go @@ -1,6 +1,7 @@ package handlers import ( + "context" "errors" "net/http" "net/http/httptest" @@ -10,6 +11,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + internalerrors "github.com/Authula/authula/internal/errors" internaltests "github.com/Authula/authula/internal/tests" "github.com/Authula/authula/models" orgservices "github.com/Authula/authula/plugins/organizations/services" @@ -18,8 +20,25 @@ import ( "github.com/Authula/authula/plugins/organizations/usecases" ) +type memberHandlerAccessControlServiceStub struct { + roles map[string]bool + err error +} + +func (s *memberHandlerAccessControlServiceStub) RoleExists(ctx context.Context, roleName string) (bool, error) { + if s.err != nil { + return false, s.err + } + return s.roles[roleName], nil +} + +func newMemberHandlerAccessControlServiceStub() *memberHandlerAccessControlServiceStub { + return &memberHandlerAccessControlServiceStub{roles: map[string]bool{"member": true, "admin": true}} +} + type organizationMemberHandlerFixture struct { userSvc *internaltests.MockUserService + accessControl *memberHandlerAccessControlServiceStub orgRepo *orgtests.MockOrganizationRepository orgMemberRepo *orgtests.MockOrganizationMemberRepository } @@ -27,13 +46,14 @@ type organizationMemberHandlerFixture struct { func newOrganizationMemberHandlerFixture() *organizationMemberHandlerFixture { return &organizationMemberHandlerFixture{ userSvc: &internaltests.MockUserService{}, + accessControl: newMemberHandlerAccessControlServiceStub(), orgRepo: &orgtests.MockOrganizationRepository{}, orgMemberRepo: &orgtests.MockOrganizationMemberRepository{}, } } func (f *organizationMemberHandlerFixture) useCase() *usecases.OrganizationMemberUseCase { - service := orgservices.NewOrganizationMemberService(f.userSvc, f.orgRepo, f.orgMemberRepo, nil) + service := orgservices.NewOrganizationMemberService(f.userSvc, f.accessControl, f.orgRepo, f.orgMemberRepo, nil) return usecases.NewOrganizationMemberUseCase(service) } @@ -133,6 +153,17 @@ func TestAddOrganizationMemberHandler(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedMessage: "unprocessable entity", }, + { + name: "invalid role", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "ghost"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, { name: "organization not found", userID: new("user-1"), diff --git a/plugins/organizations/plugin.go b/plugins/organizations/plugin.go index cf7cd425..ee54665b 100644 --- a/plugins/organizations/plugin.go +++ b/plugins/organizations/plugin.go @@ -69,6 +69,11 @@ func (p *OrganizationsPlugin) Init(ctx *models.PluginContext) error { return fmt.Errorf("user service not available in service registry") } + accessControlService, ok := ctx.ServiceRegistry.Get(models.ServiceAccessControl.String()).(rootservices.AccessControlService) + if !ok { + return fmt.Errorf("access control service not available in service registry") + } + var mailerService rootservices.MailerService if rawMailerService := ctx.ServiceRegistry.Get(models.ServiceMailer.String()); rawMailerService != nil { if typedMailerService, ok := rawMailerService.(rootservices.MailerService); ok { @@ -86,8 +91,8 @@ func (p *OrganizationsPlugin) Init(ctx *models.PluginContext) error { p.memberRepo = repositories.NewBunOrganizationMemberRepository(ctx.DB) p.teamRepo = repositories.NewBunOrganizationTeamRepository(ctx.DB) p.teamMemberRepo = repositories.NewBunOrganizationTeamMemberRepository(ctx.DB) - p.invitationService = services.NewOrganizationInvitationService(ctx.DB, p.globalConfig, &p.pluginConfig, p.logger, userService, p.organizationRepo, p.invitationRepo, p.memberRepo, mailerService, ctx.EventBus, p.databaseHooks, p.databaseHooks) - p.memberService = services.NewOrganizationMemberService(userService, p.organizationRepo, p.memberRepo, p.databaseHooks) + p.invitationService = services.NewOrganizationInvitationService(ctx.DB, p.globalConfig, &p.pluginConfig, p.logger, userService, accessControlService, p.organizationRepo, p.invitationRepo, p.memberRepo, mailerService, ctx.EventBus, p.databaseHooks, p.databaseHooks) + p.memberService = services.NewOrganizationMemberService(userService, accessControlService, p.organizationRepo, p.memberRepo, p.databaseHooks) p.teamService = services.NewOrganizationTeamService(p.organizationRepo, p.teamRepo, p.memberRepo, p.teamMemberRepo, p.databaseHooks, p.databaseHooks) p.organizationUseCase = usecases.NewOrganizationUseCase(p.organizationService) p.invitationUseCase = usecases.NewOrganizationInvitationUseCase(p.invitationService) @@ -103,7 +108,7 @@ func (p *OrganizationsPlugin) Migrations(provider string) []migrations.Migration } func (p *OrganizationsPlugin) DependsOn() []string { - return nil + return []string{models.PluginAccessControl.String()} } func (p *OrganizationsPlugin) Routes() []models.Route { diff --git a/plugins/organizations/services/organization_invitation_service.go b/plugins/organizations/services/organization_invitation_service.go index 465642c0..bfe121c4 100644 --- a/plugins/organizations/services/organization_invitation_service.go +++ b/plugins/organizations/services/organization_invitation_service.go @@ -35,18 +35,19 @@ type organizationInvitationTxRunner interface { } type OrganizationInvitationService struct { - txRunner organizationInvitationTxRunner - globalConfig *models.Config - pluginConfig *types.OrganizationsPluginConfig - logger models.Logger - mailerService rootservices.MailerService - eventBus models.EventBus - userService rootservices.UserService - organizationRepo repositories.OrganizationRepository - orgInvitationRepo repositories.OrganizationInvitationRepository - orgMemberRepo repositories.OrganizationMemberRepository - orgInvitationHooks OrganizationInvitationHookExecutor - orgMemberHooks OrganizationMemberHookExecutor + txRunner organizationInvitationTxRunner + globalConfig *models.Config + pluginConfig *types.OrganizationsPluginConfig + logger models.Logger + mailerService rootservices.MailerService + eventBus models.EventBus + userService rootservices.UserService + accessControlService rootservices.AccessControlService + organizationRepo repositories.OrganizationRepository + orgInvitationRepo repositories.OrganizationInvitationRepository + orgMemberRepo repositories.OrganizationMemberRepository + orgInvitationHooks OrganizationInvitationHookExecutor + orgMemberHooks OrganizationMemberHookExecutor } func NewOrganizationInvitationService( @@ -55,6 +56,7 @@ func NewOrganizationInvitationService( pluginConfig *types.OrganizationsPluginConfig, logger models.Logger, userService rootservices.UserService, + accessControlService rootservices.AccessControlService, organizationRepo repositories.OrganizationRepository, orgInvitationRepo repositories.OrganizationInvitationRepository, orgMemberRepo repositories.OrganizationMemberRepository, @@ -64,18 +66,19 @@ func NewOrganizationInvitationService( orgMemberHooks OrganizationMemberHookExecutor, ) *OrganizationInvitationService { return &OrganizationInvitationService{ - globalConfig: globalConfig, - pluginConfig: pluginConfig, - logger: logger, - mailerService: mailerService, - eventBus: eventBus, - userService: userService, - organizationRepo: organizationRepo, - orgInvitationRepo: orgInvitationRepo, - orgMemberRepo: orgMemberRepo, - orgInvitationHooks: orgInvitationHooks, - orgMemberHooks: orgMemberHooks, - txRunner: txRunner, + globalConfig: globalConfig, + pluginConfig: pluginConfig, + logger: logger, + mailerService: mailerService, + eventBus: eventBus, + userService: userService, + accessControlService: accessControlService, + organizationRepo: organizationRepo, + orgInvitationRepo: orgInvitationRepo, + orgMemberRepo: orgMemberRepo, + orgInvitationHooks: orgInvitationHooks, + orgMemberHooks: orgMemberHooks, + txRunner: txRunner, } } @@ -95,7 +98,7 @@ func (s *OrganizationInvitationService) CreateOrganizationInvitation(ctx context return nil, internalerrors.ErrForbidden } - email := strings.ToLower(strings.TrimSpace(request.Email)) + email := strings.ToLower(request.Email) if email == "" { return nil, internalerrors.ErrUnprocessableEntity } @@ -103,11 +106,19 @@ func (s *OrganizationInvitationService) CreateOrganizationInvitation(ctx context return nil, internalerrors.ErrUnprocessableEntity } - role := strings.TrimSpace(request.Role) + role := request.Role if role == "" { return nil, internalerrors.ErrUnprocessableEntity } + exists, err := s.accessControlService.RoleExists(ctx, role) + if err != nil { + return nil, err + } + if !exists { + return nil, fmt.Errorf("role %s doesn't exist", role) + } + if existing, err := s.orgInvitationRepo.GetByOrganizationIDAndEmail(ctx, organizationID, email); err != nil { return nil, err } else if existing != nil { diff --git a/plugins/organizations/services/organization_invitation_service_test.go b/plugins/organizations/services/organization_invitation_service_test.go index a786bd91..cfb8af71 100644 --- a/plugins/organizations/services/organization_invitation_service_test.go +++ b/plugins/organizations/services/organization_invitation_service_test.go @@ -21,6 +21,22 @@ import ( rootservices "github.com/Authula/authula/services" ) +type invitationAccessControlServiceStub struct { + roles map[string]bool + err error +} + +func (s *invitationAccessControlServiceStub) RoleExists(ctx context.Context, roleName string) (bool, error) { + if s.err != nil { + return false, s.err + } + return s.roles[roleName], nil +} + +func newInvitationAccessControlServiceStub() *invitationAccessControlServiceStub { + return &invitationAccessControlServiceStub{roles: map[string]bool{"member": true, "admin": true}} +} + type testInvitationLogger struct { mu sync.Mutex warnings []string @@ -44,6 +60,7 @@ func newTestOrganizationInvitationService( txRunner organizationInvitationTxRunner, pluginConfig *types.OrganizationsPluginConfig, userService rootservices.UserService, + accessControlService rootservices.AccessControlService, orgRepo repositories.OrganizationRepository, invRepo repositories.OrganizationInvitationRepository, memberRepo repositories.OrganizationMemberRepository, @@ -56,6 +73,7 @@ func newTestOrganizationInvitationService( pluginConfig, &testInvitationLogger{}, userService, + accessControlService, orgRepo, invRepo, memberRepo, @@ -136,6 +154,16 @@ func TestOrganizationInvitationService_CreateOrganizationInvitation(t *testing.T }, expectErr: internalerrors.ErrUnprocessableEntity, }, + { + name: "unprocessable entity invalid role", + actorUserID: "user-1", + organizationID: "org-1", + request: types.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "ghost"}, + setup: func(orgRepo *mockOrganizationRepository, invRepo *mockOrganizationInvitationRepository, hooks *testOrganizationInvitationHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectErr: internalerrors.ErrUnprocessableEntity, + }, { name: "forbidden for non owner", actorUserID: "user-1", @@ -216,7 +244,7 @@ func TestOrganizationInvitationService_CreateOrganizationInvitation(t *testing.T tt.setup(orgRepo, orgInvitationRepo, orgInvitationHooks) } - svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, orgInvitationRepo, nil, orgInvitationHooks, nil) + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, newInvitationAccessControlServiceStub(), orgRepo, orgInvitationRepo, nil, orgInvitationHooks, nil) inv, err := svc.CreateOrganizationInvitation(context.Background(), tt.actorUserID, tt.organizationID, tt.request) if tt.expectErr != nil { require.Error(t, err) @@ -275,7 +303,7 @@ func TestOrganizationInvitationService_GetOrganizationInvitation(t *testing.T) { tt.setup(orgRepo, invRepo) } - svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, invRepo, nil, nil, nil) + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, newInvitationAccessControlServiceStub(), orgRepo, invRepo, nil, nil, nil) invitation, err := svc.GetOrganizationInvitation(context.Background(), tt.actorUserID, tt.organizationID, tt.invitationID) if tt.expectErr != nil { require.Error(t, err) @@ -328,7 +356,7 @@ func TestOrganizationInvitationService_GetAllOrganizationInvitations(t *testing. tt.setup(orgRepo, invRepo) } - svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, invRepo, nil, nil, nil) + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, newInvitationAccessControlServiceStub(), orgRepo, invRepo, nil, nil, nil) invitations, err := svc.GetAllOrganizationInvitations(context.Background(), tt.actorUserID, tt.organizationID) if tt.expectErr != nil { require.Error(t, err) @@ -394,7 +422,7 @@ func TestOrganizationInvitationService_RevokeOrganizationInvitation(t *testing.T tt.setup(orgRepo, invRepo, hooks) } - svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, invRepo, nil, hooks, nil) + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, newInvitationAccessControlServiceStub(), orgRepo, invRepo, nil, hooks, nil) invitation, err := svc.RevokeOrganizationInvitation(context.Background(), tt.actorUserID, tt.organizationID, tt.invitationID) if tt.expectErr != nil { require.Error(t, err) @@ -481,7 +509,7 @@ func TestOrganizationInvitationService_AcceptPendingOrganizationInvitationsForEm } txRunner := &mockOrganizationInvitationTxRunner{} - svc := newTestOrganizationInvitationService(txRunner, pluginConfig, userSvc, orgRepo, invRepo, memberRepo, hooks, memberHooks) + svc := newTestOrganizationInvitationService(txRunner, pluginConfig, userSvc, newInvitationAccessControlServiceStub(), orgRepo, invRepo, memberRepo, hooks, memberHooks) accepted, err := svc.AcceptPendingOrganizationInvitationsForEmail(context.Background(), tt.userID, tt.email) if tt.expectErr != nil { require.Error(t, err) @@ -557,7 +585,7 @@ func TestOrganizationInvitationService_AcceptOrganizationInvitation(t *testing.T tt.setup(userSvc, invRepo, memberRepo, hooks, memberHooks) } - svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, userSvc, orgRepo, invRepo, memberRepo, hooks, memberHooks) + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, userSvc, newInvitationAccessControlServiceStub(), orgRepo, invRepo, memberRepo, hooks, memberHooks) invitation, err := svc.AcceptOrganizationInvitation(context.Background(), tt.actorUserID, tt.organization, tt.invitationID) if tt.expectErr != nil { require.ErrorIs(t, err, tt.expectErr) @@ -585,7 +613,7 @@ func TestOrganizationInvitationService_RejectOrganizationInvitation(t *testing.T return invitation != nil && invitation.Status == types.OrganizationInvitationStatusRejected })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Status: types.OrganizationInvitationStatusRejected}, nil).Once() - svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, userSvc, &mockOrganizationRepository{}, invRepo, &mockOrganizationMemberRepository{}, &testOrganizationInvitationHooks{}, &mockOrganizationMemberHooks{}) + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, userSvc, newInvitationAccessControlServiceStub(), &mockOrganizationRepository{}, invRepo, &mockOrganizationMemberRepository{}, &testOrganizationInvitationHooks{}, &mockOrganizationMemberHooks{}) invitation, err := svc.RejectOrganizationInvitation(context.Background(), "user-2", "org-1", "inv-1") require.NoError(t, err) require.NotNil(t, invitation) @@ -607,7 +635,7 @@ func TestOrganizationInvitationService_GetOrganizationInvitation_ExpiresPendingI return invitation != nil && invitation.Status == types.OrganizationInvitationStatusExpired })).Return(&types.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Status: types.OrganizationInvitationStatusExpired}, nil).Once() - svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, orgRepo, invRepo, nil, &testOrganizationInvitationHooks{}, nil) + svc := newTestOrganizationInvitationService(&mockOrganizationInvitationTxRunner{}, pluginConfig, &internaltests.MockUserService{}, newInvitationAccessControlServiceStub(), orgRepo, invRepo, nil, &testOrganizationInvitationHooks{}, nil) invitation, err := svc.GetOrganizationInvitation(context.Background(), "user-1", "org-1", "inv-1") require.NoError(t, err) require.NotNil(t, invitation) @@ -637,6 +665,7 @@ func TestOrganizationInvitationService_CreateOrganizationInvitation_SendsEmailAn &types.OrganizationsPluginConfig{Enabled: true, InvitationExpiresIn: 36 * time.Hour}, logger, &internaltests.MockUserService{}, + newInvitationAccessControlServiceStub(), orgRepo, invRepo, nil, @@ -687,6 +716,7 @@ func TestOrganizationInvitationService_CreateOrganizationInvitation_SkipsMissing &types.OrganizationsPluginConfig{Enabled: true, InvitationExpiresIn: 36 * time.Hour}, logger, &internaltests.MockUserService{}, + newInvitationAccessControlServiceStub(), orgRepo, invRepo, nil, diff --git a/plugins/organizations/services/organization_member_service.go b/plugins/organizations/services/organization_member_service.go index fcc42423..dc3aec7f 100644 --- a/plugins/organizations/services/organization_member_service.go +++ b/plugins/organizations/services/organization_member_service.go @@ -2,6 +2,7 @@ package services import ( "context" + "fmt" "strings" internalerrors "github.com/Authula/authula/internal/errors" @@ -21,14 +22,15 @@ type OrganizationMemberHookExecutor interface { } type OrganizationMemberService struct { - userService rootservices.UserService - orgRepo repositories.OrganizationRepository - orgMemberRepo repositories.OrganizationMemberRepository - hooks OrganizationMemberHookExecutor + userService rootservices.UserService + accessControlService rootservices.AccessControlService + orgRepo repositories.OrganizationRepository + orgMemberRepo repositories.OrganizationMemberRepository + hooks OrganizationMemberHookExecutor } -func NewOrganizationMemberService(userService rootservices.UserService, orgRepo repositories.OrganizationRepository, orgMemberRepo repositories.OrganizationMemberRepository, hooks OrganizationMemberHookExecutor) *OrganizationMemberService { - return &OrganizationMemberService{userService: userService, orgRepo: orgRepo, orgMemberRepo: orgMemberRepo, hooks: hooks} +func NewOrganizationMemberService(userService rootservices.UserService, accessControlService rootservices.AccessControlService, orgRepo repositories.OrganizationRepository, orgMemberRepo repositories.OrganizationMemberRepository, hooks OrganizationMemberHookExecutor) *OrganizationMemberService { + return &OrganizationMemberService{userService: userService, accessControlService: accessControlService, orgRepo: orgRepo, orgMemberRepo: orgMemberRepo, hooks: hooks} } func (s *OrganizationMemberService) AddMember(ctx context.Context, actorUserID string, organizationID string, request types.AddOrganizationMemberRequest) (*types.OrganizationMember, error) { @@ -36,16 +38,24 @@ func (s *OrganizationMemberService) AddMember(ctx context.Context, actorUserID s return nil, err } - userID := strings.TrimSpace(request.UserID) + userID := request.UserID if userID == "" { return nil, internalerrors.ErrUnprocessableEntity } - role := strings.TrimSpace(request.Role) + role := request.Role if role == "" { return nil, internalerrors.ErrUnprocessableEntity } + exists, err := s.accessControlService.RoleExists(ctx, role) + if err != nil { + return nil, err + } + if !exists { + return nil, fmt.Errorf("role %s doesn't exist", role) + } + user, err := s.userService.GetByID(ctx, userID) if err != nil { return nil, err @@ -120,19 +130,27 @@ func (s *OrganizationMemberService) UpdateMember(ctx context.Context, actorUserI return nil, err } - member, err := s.orgMemberRepo.GetByID(ctx, strings.TrimSpace(memberID)) + member, err := s.orgMemberRepo.GetByID(ctx, memberID) if err != nil { return nil, err } - if member == nil || member.OrganizationID != strings.TrimSpace(organizationID) { + if member == nil || member.OrganizationID != organizationID { return nil, internalerrors.ErrNotFound } - role := strings.TrimSpace(request.Role) + role := request.Role if role == "" { return nil, internalerrors.ErrBadRequest } + exists, err := s.accessControlService.RoleExists(ctx, role) + if err != nil { + return nil, err + } + if !exists { + return nil, fmt.Errorf("role %s doesn't exist", role) + } + member.Role = role if s.hooks != nil { diff --git a/plugins/organizations/services/organization_member_service_test.go b/plugins/organizations/services/organization_member_service_test.go index c1cc1a7f..04a71ce5 100644 --- a/plugins/organizations/services/organization_member_service_test.go +++ b/plugins/organizations/services/organization_member_service_test.go @@ -14,6 +14,22 @@ import ( "github.com/Authula/authula/plugins/organizations/types" ) +type memberAccessControlServiceStub struct { + roles map[string]bool + err error +} + +func (s *memberAccessControlServiceStub) RoleExists(ctx context.Context, roleName string) (bool, error) { + if s.err != nil { + return false, s.err + } + return s.roles[roleName], nil +} + +func newMemberAccessControlServiceStub() *memberAccessControlServiceStub { + return &memberAccessControlServiceStub{roles: map[string]bool{"member": true, "admin": true}} +} + func TestOrganizationMemberService_AddMember(t *testing.T) { t.Parallel() @@ -76,6 +92,16 @@ func TestOrganizationMemberService_AddMember(t *testing.T) { }, expectErr: internalerrors.ErrUnprocessableEntity, }, + { + name: "unprocessable entity invalid role", + actorUserID: "user-1", + organizationID: "org-1", + request: types.AddOrganizationMemberRequest{UserID: "user-2", Role: "ghost"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, userSvc *internaltests.MockUserService, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectErr: internalerrors.ErrUnprocessableEntity, + }, { name: "user lookup error", actorUserID: "user-1", @@ -196,7 +222,7 @@ func TestOrganizationMemberService_AddMember(t *testing.T) { tt.setup(orgRepo, memberRepo, userSvc, hooks) } - svc := NewOrganizationMemberService(userSvc, orgRepo, memberRepo, hooks) + svc := NewOrganizationMemberService(userSvc, newMemberAccessControlServiceStub(), orgRepo, memberRepo, hooks) member, err := svc.AddMember(context.Background(), tt.actorUserID, tt.organizationID, tt.request) if tt.expectErr != nil { require.Error(t, err) @@ -288,7 +314,7 @@ func TestOrganizationMemberService_GetAllMembers(t *testing.T) { tt.setup(orgRepo, memberRepo) } - svc := NewOrganizationMemberService(&internaltests.MockUserService{}, orgRepo, memberRepo, nil) + svc := NewOrganizationMemberService(&internaltests.MockUserService{}, newMemberAccessControlServiceStub(), orgRepo, memberRepo, nil) members, err := svc.GetAllMembers(context.Background(), tt.actorUserID, tt.organizationID, 1, 10) if tt.expectErr != nil { require.Error(t, err) @@ -427,7 +453,7 @@ func TestOrganizationMemberService_GetMember(t *testing.T) { tt.setup(orgRepo, memberRepo) } - svc := NewOrganizationMemberService(&internaltests.MockUserService{}, orgRepo, memberRepo, nil) + svc := NewOrganizationMemberService(&internaltests.MockUserService{}, newMemberAccessControlServiceStub(), orgRepo, memberRepo, nil) member, err := svc.GetMember(context.Background(), tt.actorUserID, tt.organizationID, tt.memberID) if tt.expectErr != nil { require.Error(t, err) @@ -545,6 +571,18 @@ func TestOrganizationMemberService_UpdateMember(t *testing.T) { }, expectErr: internalerrors.ErrBadRequest, }, + { + name: "unprocessable entity invalid role", + actorUserID: "user-1", + organizationID: "org-1", + memberID: "mem-1", + request: types.UpdateOrganizationMemberRequest{Role: "ghost"}, + setup: func(orgRepo *mockOrganizationRepository, memberRepo *mockOrganizationMemberRepository, hooks *mockOrganizationMemberHooks) { + orgRepo.On("GetByID", mock.Anything, "org-1").Return(&types.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + memberRepo.On("GetByID", mock.Anything, "mem-1").Return(&types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectErr: internalerrors.ErrUnprocessableEntity, + }, { name: "before update hook error", actorUserID: "user-1", @@ -619,7 +657,7 @@ func TestOrganizationMemberService_UpdateMember(t *testing.T) { tt.setup(orgRepo, memberRepo, hooks) } - svc := NewOrganizationMemberService(&internaltests.MockUserService{}, orgRepo, memberRepo, hooks) + svc := NewOrganizationMemberService(&internaltests.MockUserService{}, newMemberAccessControlServiceStub(), orgRepo, memberRepo, hooks) member, err := svc.UpdateMember(context.Background(), tt.actorUserID, tt.organizationID, tt.memberID, tt.request) if tt.expectErr != nil { require.Error(t, err) @@ -779,7 +817,7 @@ func TestOrganizationMemberService_RemoveMember(t *testing.T) { tt.setup(orgRepo, memberRepo, hooks) } - svc := NewOrganizationMemberService(&internaltests.MockUserService{}, orgRepo, memberRepo, hooks) + svc := NewOrganizationMemberService(&internaltests.MockUserService{}, newMemberAccessControlServiceStub(), orgRepo, memberRepo, hooks) err := svc.RemoveMember(context.Background(), tt.actorUserID, tt.organizationID, tt.memberID) if tt.expectErr != nil { require.Error(t, err)