diff --git a/.agents/skills/plugin-architecture/SKILL.md b/.agents/skills/plugin-architecture/SKILL.md index fb441f66..6397366f 100644 --- a/.agents/skills/plugin-architecture/SKILL.md +++ b/.agents/skills/plugin-architecture/SKILL.md @@ -26,6 +26,7 @@ description: Build pluggable authentication features using the plugin system wit ## Pattern Plugin lifecycle: + - Define metadata (ID, name, version) - Implement Init to retrieve services, create repositories/services - Implement optional Routes, Migrations, Middleware @@ -34,8 +35,9 @@ Plugin lifecycle: ## Example -See [examples/todo_plugin.go](examples/todo_plugin.go) for: -- TodosPlugin metadata and Init pattern +See [plugins/email-password/plugin.go](../../../plugins/email-password/plugin.go) for: + +- EmailPasswordPlugin metadata and Init pattern - Service retrieval and registration - Routes definition and handler wiring - Lifecycle management (Close) @@ -52,7 +54,6 @@ See [examples/todo_plugin.go](examples/todo_plugin.go) for: ## References -- [models/plugin.go](../../../models/plugin.go) - Plugin interface definitions - [plugins/email-password/plugin.go](../../../plugins/email-password/plugin.go) - Email-password plugin - [plugins/jwt/plugin.go](../../../plugins/jwt/plugin.go) - JWT plugin - [internal/bootstrap/plugin_factory.go](../../../internal/bootstrap/plugin_factory.go) - Plugin factory diff --git a/.agents/skills/plugin-architecture/examples/todo_plugin.go b/.agents/skills/plugin-architecture/examples/todo_plugin.go deleted file mode 100644 index c998aa77..00000000 --- a/.agents/skills/plugin-architecture/examples/todo_plugin.go +++ /dev/null @@ -1,65 +0,0 @@ -package examples - -import ( - "net/http" - - "github.com/uptightgo/bun" -) - -// Type aliases for plugin interfaces (simplified) -type PluginContext struct { - Database bun.IDB -} - -type Route struct { - Method string - Path string - Handler func(w http.ResponseWriter, r *http.Request) error -} - -// TodosPlugin -type TodosPlugin struct { - todoService TodoService - createTodoUseCase *CreateTodoUseCase - markCompleteUseCase *MarkTodoCompleteUseCase -} - -func (p *TodosPlugin) Init(ctx *PluginContext) error { - // Step 1: Create repository with database - todoRepo := NewBunTodoRepository(ctx.Database) - - // Step 2: Create service - p.todoService = NewTodoService(todoRepo) - - // Step 3: Create use cases - p.createTodoUseCase = NewCreateTodoUseCase(p.todoService) - p.markCompleteUseCase = NewMarkTodoCompleteUseCase(p.todoService) - - return nil -} - -func (p *TodosPlugin) Routes() []Route { - createHandler := NewCreateTodoHandler(p.createTodoUseCase) - completeHandler := NewMarkTodoCompleteHandler(p.markCompleteUseCase) - - return []Route{ - { - Method: http.MethodPost, - Path: "/todos", - Handler: func(w http.ResponseWriter, r *http.Request) error { - return createHandler.Handle(w, r) - }, - }, - { - Method: http.MethodPut, - Path: "/todos/{id}/complete", - Handler: func(w http.ResponseWriter, r *http.Request) error { - return completeHandler.Handle(w, r) - }, - }, - } -} - -func (p *TodosPlugin) Close() error { - return nil -} diff --git a/config.example.toml b/config.example.toml index 288530a8..40a8f165 100644 --- a/config.example.toml +++ b/config.example.toml @@ -151,47 +151,3 @@ buffer_size = 100 # In standalone mode, all plugin-to-route associations are defined here via [[route_mappings]] tables. # This enables full plugin routing control without code changes. # Plugin IDs follow the format "{plugin_name}.{operation}" (e.g., "session.auth", "csrf.protect") - -# Example routes: -# [[route_mappings]] -# path = "/me" -# method = "GET" -# plugins = ["session.auth"] (SSR) or ["bearer.auth"] (SPA/mobile) - -# [[route_mappings]] -# path = "/sign-in" -# method = "POST" -# plugins = ["session.auth.optional"] - -# [[route_mappings]] -# path = "/sign-up" -# method = "POST" - -# [[route_mappings]] -# path = "/change-password" -# method = "POST" -# plugins = ["session.auth", "csrf.protect"] - -# [[route_mappings]] -# path = "/sign-out" -# method = "POST" -# plugins = ["session.auth", "csrf.protect"] - -# Access control (opt-in per route) -# [[route_mappings]] -# path = "/admin/users" -# method = "GET" -# plugins = ["session.auth", "access_control.enforce"] -# permissions = ["users.read"] - -# If using TOTP plugin, keep /totp/verify and /totp/verify-backup-code accessible -# to the pending-token flow (do not require an existing session cookie). -# [[route_mappings]] -# path = "/totp/verify" -# method = "POST" -# plugins = ["session.auth.optional"] - -# [[route_mappings]] -# path = "/totp/verify-backup-code" -# method = "POST" -# plugins = ["session.auth.optional"] diff --git a/internal/bootstrap/plugin_factory.go b/internal/bootstrap/plugin_factory.go index 6d4b39da..d1b2cd28 100644 --- a/internal/bootstrap/plugin_factory.go +++ b/internal/bootstrap/plugin_factory.go @@ -23,6 +23,8 @@ import ( magiclinkplugintypes "github.com/Authula/authula/plugins/magic-link/types" oauth2plugin "github.com/Authula/authula/plugins/oauth2" oauth2plugintypes "github.com/Authula/authula/plugins/oauth2/types" + organizationsplugin "github.com/Authula/authula/plugins/organizations" + organizationsplugintypes "github.com/Authula/authula/plugins/organizations/types" ratelimitplugin "github.com/Authula/authula/plugins/rate-limit" secondarystorageplugin "github.com/Authula/authula/plugins/secondary-storage" sessionplugin "github.com/Authula/authula/plugins/session" @@ -248,6 +250,22 @@ var pluginFactories = []PluginFactory{ return accesscontrolplugin.New(typedConfig.(accesscontrolplugintypes.AccessControlPluginConfig)) }, }, + { + ID: models.PluginOrganizations.String(), + RequiredByDefault: false, + ConfigParser: func(rawConfig any) (any, error) { + config := organizationsplugintypes.OrganizationsPluginConfig{} + if rawConfig != nil { + if err := util.ParsePluginConfig(rawConfig, &config); err != nil { + return nil, fmt.Errorf("failed to parse organizations plugin config: %w", err) + } + } + return config, nil + }, + Constructor: func(typedConfig any) models.Plugin { + return organizationsplugin.New(typedConfig.(organizationsplugintypes.OrganizationsPluginConfig)) + }, + }, { ID: models.PluginTOTP.String(), RequiredByDefault: false, diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 00000000..0b19cf39 --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,35 @@ +package errors + +import ( + "errors" + "net/http" + + "github.com/Authula/authula/models" +) + +var ( + ErrBadRequest = errors.New("bad request") + ErrUnauthorized = errors.New("unauthorized") + ErrForbidden = errors.New("forbidden") + ErrNotFound = errors.New("not found") + ErrConflict = errors.New("conflict") + ErrUnprocessableEntity = errors.New("unprocessable entity") +) + +func HandleError(err error, reqCtx *models.RequestContext) { + status := http.StatusBadRequest + switch err { + case ErrUnauthorized: + status = http.StatusUnauthorized + case ErrForbidden: + status = http.StatusForbidden + case ErrNotFound: + status = http.StatusNotFound + case ErrConflict: + status = http.StatusConflict + case ErrUnprocessableEntity: + status = http.StatusUnprocessableEntity + } + reqCtx.SetJSONResponse(status, map[string]any{"message": err.Error()}) + reqCtx.Handled = true +} diff --git a/models/plugin.go b/models/plugin.go index a1b2128e..a35be3e9 100644 --- a/models/plugin.go +++ b/models/plugin.go @@ -26,6 +26,7 @@ const ( PluginRateLimit PluginID = "ratelimit" PluginMagicLink PluginID = "magic_link" PluginTOTP PluginID = "totp" + PluginOrganizations PluginID = "organizations" ) func (id PluginID) String() string { diff --git a/plugins/admin/services/state_service_test.go b/plugins/admin/services/state_service_test.go index 5eeb8242..4ff4a61b 100644 --- a/plugins/admin/services/state_service_test.go +++ b/plugins/admin/services/state_service_test.go @@ -97,7 +97,7 @@ func TestStateService_UpsertUserState(t *testing.T) { if tc.userExists { // always prepare an Upsert expectation when user exists if tc.hasRepoErr { - usr.On("Upsert", mock.Anything, mock.Anything).Return(errors.New("boom")).Once() + usr.On("Upsert", mock.Anything, mock.Anything).Return(errors.New("some error")).Once() } else if tc.expectCall != nil { usr.On("Upsert", mock.Anything, mock.MatchedBy(tc.expectCall)).Return(nil).Once() } else { diff --git a/plugins/magic-link/handlers/verify_handler_test.go b/plugins/magic-link/handlers/verify_handler_test.go index 11db34fe..b5e7e7ec 100644 --- a/plugins/magic-link/handlers/verify_handler_test.go +++ b/plugins/magic-link/handlers/verify_handler_test.go @@ -119,14 +119,14 @@ func TestVerifyHandler_UntrustedCallbackURL(t *testing.T) { func TestVerifyHandler_UseCaseError(t *testing.T) { useCase := &mockVerifyUseCase{} - useCase.On("Verify", mock.Anything, "abc", mock.Anything, mock.Anything).Return("", errors.New("boom")).Once() + useCase.On("Verify", mock.Anything, "abc", mock.Anything, mock.Anything).Return("", errors.New("some error")).Once() handler := &VerifyHandler{UseCase: useCase} req, reqCtx, w := newCallbackRequest(t, "/magic-link/verify?token=abc") handler.Handler()(w, req) - assertErrorResponse(t, reqCtx, http.StatusBadRequest, "boom") + assertErrorResponse(t, reqCtx, http.StatusBadRequest, "some error") useCase.AssertExpectations(t) } diff --git a/plugins/organizations/api.go b/plugins/organizations/api.go new file mode 100644 index 00000000..daed0e48 --- /dev/null +++ b/plugins/organizations/api.go @@ -0,0 +1,115 @@ +package organizations + +import ( + "context" + + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type API struct { + organizationUseCase *usecases.OrganizationUseCase + invitationUseCase *usecases.OrganizationInvitationUseCase + memberUseCase *usecases.OrganizationMemberUseCase + teamUseCase *usecases.OrganizationTeamUseCase +} + +func BuildAPI(organizationUseCase *usecases.OrganizationUseCase, invitationUseCase *usecases.OrganizationInvitationUseCase, memberUseCase *usecases.OrganizationMemberUseCase, teamUseCase *usecases.OrganizationTeamUseCase) *API { + return &API{organizationUseCase: organizationUseCase, invitationUseCase: invitationUseCase, memberUseCase: memberUseCase, teamUseCase: teamUseCase} +} + +func (a *API) CreateOrganization(ctx context.Context, actorUserID string, request types.CreateOrganizationRequest) (*types.Organization, error) { + return a.organizationUseCase.CreateOrganization(ctx, actorUserID, request) +} + +func (a *API) GetAllOrganizationsByUserID(ctx context.Context, actorUserID string) ([]types.Organization, error) { + return a.organizationUseCase.GetAllOrganizationsByUserID(ctx, actorUserID) +} + +func (a *API) GetOrganization(ctx context.Context, actorUserID string, organizationID string) (*types.Organization, error) { + return a.organizationUseCase.GetOrganization(ctx, actorUserID, organizationID) +} + +func (a *API) UpdateOrganization(ctx context.Context, actorUserID string, organizationID string, request types.UpdateOrganizationRequest) (*types.Organization, error) { + return a.organizationUseCase.UpdateOrganization(ctx, actorUserID, organizationID, request) +} + +func (a *API) DeleteOrganization(ctx context.Context, actorUserID string, organizationID string) error { + return a.organizationUseCase.DeleteOrganization(ctx, actorUserID, organizationID) +} + +func (a *API) CreateInvitation(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationInvitationRequest) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.CreateInvitation(ctx, actorUserID, organizationID, request) +} + +func (a *API) GetInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.GetInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (a *API) GetAllInvitations(ctx context.Context, actorUserID string, organizationID string) ([]types.OrganizationInvitation, error) { + return a.invitationUseCase.GetAllInvitations(ctx, actorUserID, organizationID) +} + +func (a *API) RevokeInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.RevokeInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (a *API) AcceptInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.AcceptInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (a *API) RejectInvitation(ctx context.Context, actorUserID string, organizationID string, invitationID string) (*types.OrganizationInvitation, error) { + return a.invitationUseCase.RejectInvitation(ctx, actorUserID, organizationID, invitationID) +} + +func (a *API) AddMember(ctx context.Context, actorUserID string, organizationID string, request types.AddOrganizationMemberRequest) (*types.OrganizationMember, error) { + return a.memberUseCase.AddMember(ctx, actorUserID, organizationID, request) +} + +func (a *API) GetAllMembers(ctx context.Context, actorUserID string, organizationID string, page int, limit int) ([]types.OrganizationMember, error) { + return a.memberUseCase.GetAllMembers(ctx, actorUserID, organizationID, page, limit) +} + +func (a *API) GetMember(ctx context.Context, actorUserID string, organizationID string, memberID string) (*types.OrganizationMember, error) { + return a.memberUseCase.GetMember(ctx, actorUserID, organizationID, memberID) +} + +func (a *API) UpdateMember(ctx context.Context, actorUserID string, organizationID string, memberID string, request types.UpdateOrganizationMemberRequest) (*types.OrganizationMember, error) { + return a.memberUseCase.UpdateMember(ctx, actorUserID, organizationID, memberID, request) +} + +func (a *API) RemoveMember(ctx context.Context, actorUserID string, organizationID string, memberID string) error { + return a.memberUseCase.RemoveMember(ctx, actorUserID, organizationID, memberID) +} + +func (a *API) CreateTeam(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationTeamRequest) (*types.OrganizationTeam, error) { + return a.teamUseCase.CreateTeam(ctx, actorUserID, organizationID, request) +} + +func (a *API) GetAllTeams(ctx context.Context, actorUserID string, organizationID string) ([]types.OrganizationTeam, error) { + return a.teamUseCase.GetAllTeams(ctx, actorUserID, organizationID) +} + +func (a *API) UpdateTeam(ctx context.Context, actorUserID string, organizationID string, teamID string, request types.UpdateOrganizationTeamRequest) (*types.OrganizationTeam, error) { + return a.teamUseCase.UpdateTeam(ctx, actorUserID, organizationID, teamID, request) +} + +func (a *API) DeleteTeam(ctx context.Context, actorUserID string, organizationID string, teamID string) error { + return a.teamUseCase.DeleteTeam(ctx, actorUserID, organizationID, teamID) +} + +func (a *API) AddTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, request types.AddOrganizationTeamMemberRequest) (*types.OrganizationTeamMember, error) { + return a.teamUseCase.AddTeamMember(ctx, actorUserID, organizationID, teamID, request) +} + +func (a *API) GetTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, memberID string) (*types.OrganizationTeamMember, error) { + return a.teamUseCase.GetTeamMember(ctx, actorUserID, organizationID, teamID, memberID) +} + +func (a *API) GetAllTeamMembers(ctx context.Context, actorUserID string, organizationID string, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) { + return a.teamUseCase.GetAllTeamMembers(ctx, actorUserID, organizationID, teamID, page, limit) +} + +func (a *API) RemoveTeamMember(ctx context.Context, actorUserID string, organizationID string, teamID string, memberID string) error { + return a.teamUseCase.RemoveTeamMember(ctx, actorUserID, organizationID, teamID, memberID) +} diff --git a/plugins/organizations/constants/constants.go b/plugins/organizations/constants/constants.go new file mode 100644 index 00000000..1530dde7 --- /dev/null +++ b/plugins/organizations/constants/constants.go @@ -0,0 +1,5 @@ +package constants + +const ( + EventOrganizationsInvitationCreated = "organizations.invitation.created" +) diff --git a/plugins/organizations/events/events.go b/plugins/organizations/events/events.go new file mode 100644 index 00000000..1a2d562e --- /dev/null +++ b/plugins/organizations/events/events.go @@ -0,0 +1,15 @@ +package events + +import "time" + +type OrganizationInvitationCreatedEvent struct { + ID string `json:"id"` + InvitationID string `json:"invitation_id"` + OrganizationID string `json:"organization_id"` + OrganizationName string `json:"organization_name"` + InviteeEmail string `json:"invitee_email"` + InviterID string `json:"inviter_id"` + Role string `json:"role"` + ExpiresAt time.Time `json:"expires_at"` + RedirectURL string `json:"redirect_url,omitempty"` +} diff --git a/plugins/organizations/handlers/organization_handlers.go b/plugins/organizations/handlers/organization_handlers.go new file mode 100644 index 00000000..333cb1ae --- /dev/null +++ b/plugins/organizations/handlers/organization_handlers.go @@ -0,0 +1,157 @@ +package handlers + +import ( + "net/http" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type CreateOrganizationHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *CreateOrganizationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + var payload types.CreateOrganizationRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + organization, err := h.UseCase.CreateOrganization(ctx, userID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, organization) + } +} + +type GetAllOrganizationsByUserIDHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *GetAllOrganizationsByUserIDHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizations, err := h.UseCase.GetAllOrganizationsByUserID(ctx, userID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, organizations) + } +} + +type GetOrganizationHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *GetOrganizationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + organization, err := h.UseCase.GetOrganization(ctx, userID, organizationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, organization) + } +} + +type UpdateOrganizationHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *UpdateOrganizationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + var payload types.UpdateOrganizationRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + organization, err := h.UseCase.UpdateOrganization(ctx, userID, organizationID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, organization) + } +} + +type DeleteOrganizationHandler struct { + UseCase *usecases.OrganizationUseCase +} + +func (h *DeleteOrganizationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + if err := h.UseCase.DeleteOrganization(ctx, userID, organizationID); err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusNoContent, nil) + } +} diff --git a/plugins/organizations/handlers/organization_handlers_test.go b/plugins/organizations/handlers/organization_handlers_test.go new file mode 100644 index 00000000..3782096c --- /dev/null +++ b/plugins/organizations/handlers/organization_handlers_test.go @@ -0,0 +1,458 @@ +package handlers + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type organizationHandlerFixture struct { + repo *orgtests.MockOrganizationRepository +} + +type organizationHandlerCase struct { + name string + userID *string + body []byte + organizationID string + prepare func(*organizationHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(t *testing.T, reqCtx *models.RequestContext) +} + +func newOrganizationHandlerFixture() *organizationHandlerFixture { + return &organizationHandlerFixture{repo: &orgtests.MockOrganizationRepository{}} +} + +func (f *organizationHandlerFixture) useCase() *usecases.OrganizationUseCase { + service := orgservices.NewOrganizationService(f.repo, nil) + return usecases.NewOrganizationUseCase(service) +} + +func (f *organizationHandlerFixture) newRequest(t *testing.T, method, path string, body []byte, userID *string, organizationID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + return req, w, reqCtx +} + +func TestCreateOrganizationHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationRequest{Name: "Acme Inc"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + body: []byte("{invalid"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "unprocessable_entity", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationRequest{Name: " "}), + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: "unprocessable entity", + }, + { + name: "repo_error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationRequest{ + Name: "Acme Inc", + }), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("Create", mock.Anything, mock.MatchedBy(func(org *orgtypes.Organization) bool { + return org != nil && org.OwnerID == "user-1" && org.Name == "Acme Inc" && org.Slug == "acme-inc" + })).Return((*orgtypes.Organization)(nil), errors.New("create failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "create failed", + }, + { + name: "success", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationRequest{ + Name: "Acme Inc", + }), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("Create", mock.Anything, mock.MatchedBy(func(org *orgtypes.Organization) bool { + return org != nil && org.OwnerID == "user-1" && org.Name == "Acme Inc" && org.Slug == "acme-inc" + })).Return(&orgtypes.Organization{ + ID: "org-1", + OwnerID: "user-1", + Name: "Acme Inc", + Slug: "acme-inc", + }, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + org := internaltests.DecodeResponseJSON[orgtypes.Organization](t, reqCtx) + assert.Equal(t, "org-1", org.ID) + assert.Equal(t, "user-1", org.OwnerID) + assert.Equal(t, "Acme Inc", org.Name) + assert.Equal(t, "acme-inc", org.Slug) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &CreateOrganizationHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodPost, "/organizations", tt.body, tt.userID, "") + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} + +func TestGetAllOrganizationsByUserIDHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "repo_error", + userID: new("user-1"), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetAllByOwnerID", mock.Anything, "user-1").Return(([]orgtypes.Organization)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetAllByOwnerID", mock.Anything, "user-1").Return([]orgtypes.Organization{{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + organizations := internaltests.DecodeResponseJSON[[]orgtypes.Organization](t, reqCtx) + require.Len(t, organizations, 1) + assert.Equal(t, "org-1", organizations[0].ID) + assert.Equal(t, "user-1", organizations[0].OwnerID) + assert.Equal(t, "Acme Inc", organizations[0].Name) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &GetAllOrganizationsByUserIDHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodGet, "/organizations", nil, tt.userID, "") + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} + +func TestGetOrganizationHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "not_found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + org := internaltests.DecodeResponseJSON[orgtypes.Organization](t, reqCtx) + assert.Equal(t, "org-1", org.ID) + assert.Equal(t, "user-1", org.OwnerID) + assert.Equal(t, "Acme Inc", org.Name) + assert.Equal(t, "acme-inc", org.Slug) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &GetOrganizationHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodGet, "/organizations/test-id", nil, tt.userID, tt.organizationID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} + +func TestUpdateOrganizationHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{Name: "Acme Platform"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + organizationID: "org-1", + body: []byte("{invalid"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "unprocessable_entity", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{Name: " "}), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: "unprocessable entity", + }, + { + name: "not_found", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{Name: "Acme Platform"}), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{Name: "Acme Platform"}), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationRequest{ + Name: "Acme Platform", + Logo: new("http://some/url/logo.svg"), + Metadata: json.RawMessage(`{"tier":"pro"}`), + }), + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + f.repo.On("Update", mock.Anything, mock.MatchedBy(func(org *orgtypes.Organization) bool { + return org != nil && org.ID == "org-1" && org.OwnerID == "user-1" && org.Name == "Acme Platform" && org.Slug == "acme-inc" + })).Return(&orgtypes.Organization{ + ID: "org-1", + OwnerID: "user-1", + Name: "Acme Platform", + Slug: "acme-inc", + Logo: new("http://some/url/logo.svg"), + Metadata: json.RawMessage(`{"tier":"pro"}`), + }, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + org := internaltests.DecodeResponseJSON[orgtypes.Organization](t, reqCtx) + assert.Equal(t, "org-1", org.ID) + assert.Equal(t, "user-1", org.OwnerID) + assert.Equal(t, "Acme Platform", org.Name) + assert.Equal(t, "acme-inc", org.Slug) + require.NotNil(t, org.Logo) + assert.Equal(t, "http://some/url/logo.svg", *org.Logo) + assert.JSONEq(t, `{"tier":"pro"}`, string(org.Metadata)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &UpdateOrganizationHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodPatch, "/organizations/test-id", tt.body, tt.userID, tt.organizationID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} + +func TestDeleteOrganizationHandler(t *testing.T) { + t.Parallel() + + tests := []organizationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "not_found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationHandlerFixture) { + f.repo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, nil).Once() + f.repo.On("Delete", mock.Anything, "org-1").Return(nil).Once() + }, + expectedStatus: http.StatusNoContent, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + assert.Equal(t, "null", string(reqCtx.ResponseBody)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &DeleteOrganizationHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodDelete, "/organizations/test-id", nil, tt.userID, tt.organizationID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.repo.AssertExpectations(t) + }) + } +} diff --git a/plugins/organizations/handlers/organization_invitation_handlers.go b/plugins/organizations/handlers/organization_invitation_handlers.go new file mode 100644 index 00000000..b7c385ec --- /dev/null +++ b/plugins/organizations/handlers/organization_invitation_handlers.go @@ -0,0 +1,191 @@ +package handlers + +import ( + "net/http" + "strings" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type CreateOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *CreateOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + + var payload types.CreateOrganizationInvitationRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + invitation, err := h.UseCase.CreateInvitation(ctx, userID, organizationID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, invitation) + } +} + +type GetAllOrganizationInvitationsHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *GetAllOrganizationInvitationsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitations, err := h.UseCase.GetAllInvitations(ctx, userID, organizationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitations) + } +} + +type GetOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *GetOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitationID := r.PathValue("invitation_id") + invitation, err := h.UseCase.GetInvitation(ctx, userID, organizationID, invitationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitation) + } +} + +type RevokeOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *RevokeOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitationID := r.PathValue("invitation_id") + invitation, err := h.UseCase.RevokeInvitation(ctx, userID, organizationID, invitationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitation) + } +} + +type AcceptOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *AcceptOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitationID := r.PathValue("invitation_id") + invitation, err := h.UseCase.AcceptInvitation(ctx, userID, organizationID, invitationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + if redirectURL := strings.TrimSpace(r.URL.Query().Get("redirect_url")); redirectURL != "" { + reqCtx.RedirectURL = redirectURL + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitation) + } +} + +type RejectOrganizationInvitationHandler struct { + UseCase *usecases.OrganizationInvitationUseCase +} + +func (h *RejectOrganizationInvitationHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + invitationID := r.PathValue("invitation_id") + invitation, err := h.UseCase.RejectInvitation(ctx, userID, organizationID, invitationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, invitation) + } +} diff --git a/plugins/organizations/handlers/organization_invitation_handlers_test.go b/plugins/organizations/handlers/organization_invitation_handlers_test.go new file mode 100644 index 00000000..bab76456 --- /dev/null +++ b/plugins/organizations/handlers/organization_invitation_handlers_test.go @@ -0,0 +1,1050 @@ +package handlers + +import ( + "context" + "database/sql" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/uptrace/bun" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" + rootservices "github.com/Authula/authula/services" +) + +type invitationHandlerAccessControlServiceStub struct { + roles map[string]bool + err error +} + +func (s *invitationHandlerAccessControlServiceStub) RoleExists(ctx context.Context, roleName string) (bool, error) { + if s.err != nil { + return false, s.err + } + return s.roles[roleName], nil +} + +func newInvitationHandlerAccessControlServiceStub() *invitationHandlerAccessControlServiceStub { + return &invitationHandlerAccessControlServiceStub{roles: map[string]bool{"member": true, "admin": true}} +} + +type organizationInvitationTxRunner interface { + RunInTx(ctx context.Context, opts *sql.TxOptions, fn func(context.Context, bun.Tx) error) error +} + +func newOrganizationInvitationServiceForHandlerTest( + txRunner organizationInvitationTxRunner, + pluginConfig *orgtypes.OrganizationsPluginConfig, + userService rootservices.UserService, + accessControlService rootservices.AccessControlService, + orgRepo *orgtests.MockOrganizationRepository, + invRepo *orgtests.MockOrganizationInvitationRepository, + memberRepo *orgtests.MockOrganizationMemberRepository, + invHooks orgservices.OrganizationInvitationHookExecutor, + memberHooks orgservices.OrganizationMemberHookExecutor, +) *orgservices.OrganizationInvitationService { + return orgservices.NewOrganizationInvitationService( + txRunner, + &models.Config{BaseURL: "https://example.com", BasePath: "/auth"}, + pluginConfig, + &internaltests.MockLogger{}, + userService, + accessControlService, + orgRepo, + invRepo, + memberRepo, + nil, + nil, + invHooks, + memberHooks, + ) +} + +type organizationInvitationHandlerFixture struct { + pluginConfig *orgtypes.OrganizationsPluginConfig + orgRepo *orgtests.MockOrganizationRepository + invRepo *orgtests.MockOrganizationInvitationRepository + memberRepo *orgtests.MockOrganizationMemberRepository + userSvc *internaltests.MockUserService + accessControl *invitationHandlerAccessControlServiceStub + txRunner *orgtests.MockOrganizationInvitationTxRunner +} + +type organizationInvitationHandlerCase struct { + name string + userID *string + body []byte + organizationID string + invitationID string + prepare func(*organizationInvitationHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(*testing.T, *models.RequestContext) +} + +func newOrganizationInvitationHandlerFixture() *organizationInvitationHandlerFixture { + return &organizationInvitationHandlerFixture{ + pluginConfig: &orgtypes.OrganizationsPluginConfig{ + Enabled: true, + InvitationExpiresIn: 7 * 24 * time.Hour, + }, + orgRepo: &orgtests.MockOrganizationRepository{}, + invRepo: &orgtests.MockOrganizationInvitationRepository{}, + memberRepo: &orgtests.MockOrganizationMemberRepository{}, + userSvc: &internaltests.MockUserService{}, + accessControl: newInvitationHandlerAccessControlServiceStub(), + txRunner: &orgtests.MockOrganizationInvitationTxRunner{}, + } +} + +func (f *organizationInvitationHandlerFixture) useCase() *usecases.OrganizationInvitationUseCase { + service := newOrganizationInvitationServiceForHandlerTest(f.txRunner, f.pluginConfig, f.userSvc, f.accessControl, f.orgRepo, f.invRepo, f.memberRepo, nil, nil) + return usecases.NewOrganizationInvitationUseCase(service) +} + +func (f *organizationInvitationHandlerFixture) newRequest(t *testing.T, method, path string, body []byte, userID *string, organizationID, invitationID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + t.Helper() + + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + if invitationID != "" { + req.SetPathValue("invitation_id", invitationID) + } + return req, w, reqCtx +} + +func runOrganizationInvitationHandlerCases(t *testing.T, method, path string, buildHandler func(*organizationInvitationHandlerFixture) http.HandlerFunc, cases []organizationInvitationHandlerCase) { + t.Helper() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationInvitationHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := buildHandler(fixture) + req, w, reqCtx := fixture.newRequest(t, method, path, tt.body, tt.userID, tt.organizationID, tt.invitationID) + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + + if fixture.userSvc != nil { + fixture.userSvc.AssertExpectations(t) + } + if fixture.orgRepo != nil { + fixture.orgRepo.AssertExpectations(t) + } + if fixture.invRepo != nil { + fixture.invRepo.AssertExpectations(t) + } + if fixture.memberRepo != nil { + fixture.memberRepo.AssertExpectations(t) + } + + }) + } +} + +func TestCreateOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodPost, "/organizations/org-1/invitations", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&CreateOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + organizationID: "org-1", + body: []byte("{"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "empty_email", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: " ", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "invalid_email", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "not-an-email", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "empty_role", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: " "}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "invalid_role", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "ghost"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "organization_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "existing_invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "user@example.com", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "pending_conflict", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{Email: "USER@EXAMPLE.COM", Role: "member"}), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "expired_existing_invitation_then_create_success", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{ + Email: "user@example.com", + Role: "member", + }), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.pluginConfig.InvitationExpiresIn = 36 * time.Hour + expectedExpiresAt := time.Now().UTC().Add(fixture.pluginConfig.InvitationExpiresIn) + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(&orgtypes.OrganizationInvitation{ID: "inv-old", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(-time.Hour)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-old" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-old", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Now().UTC().Add(-time.Hour)}, nil).Once() + fixture.invRepo.On("Create", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.OrganizationID == "org-1" && invitation.InviterID == "user-1" && invitation.Email == "user@example.com" && invitation.Role == "member" && invitation.Status == orgtypes.OrganizationInvitationStatusPending && invitation.ExpiresAt.After(expectedExpiresAt.Add(-2*time.Second)) && invitation.ExpiresAt.Before(expectedExpiresAt.Add(2*time.Second)) + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-2", OrganizationID: "org-1", InviterID: "user-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: expectedExpiresAt}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-2", invitation.ID) + require.Equal(t, "org-1", invitation.OrganizationID) + require.Equal(t, "user-1", invitation.InviterID) + require.Equal(t, "user@example.com", invitation.Email) + require.Equal(t, "member", invitation.Role) + require.Equal(t, orgtypes.OrganizationInvitationStatusPending, invitation.Status) + require.WithinDuration(t, time.Now().UTC().Add(36*time.Hour), invitation.ExpiresAt, 2*time.Second) + }, + }, + { + name: "invalid_expiry_config", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{ + Email: "user@example.com", + Role: "member", + }), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.pluginConfig.InvitationExpiresIn = 0 + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "success_applies_configured_expiry", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationInvitationRequest{ + Email: "user@example.com", + Role: "member", + }), + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.pluginConfig.InvitationExpiresIn = 90 * time.Minute + expectedExpiresAt := time.Now().UTC().Add(fixture.pluginConfig.InvitationExpiresIn) + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByOrganizationIDAndEmail", mock.Anything, "org-1", "user@example.com").Return(nil, nil).Once() + fixture.invRepo.On("Create", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.OrganizationID == "org-1" && invitation.InviterID == "user-1" && invitation.Email == "user@example.com" && invitation.Role == "member" && invitation.Status == orgtypes.OrganizationInvitationStatusPending && invitation.ExpiresAt.After(expectedExpiresAt.Add(-2*time.Second)) && invitation.ExpiresAt.Before(expectedExpiresAt.Add(2*time.Second)) + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-2", OrganizationID: "org-1", InviterID: "user-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: expectedExpiresAt}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-2", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusPending, invitation.Status) + require.Equal(t, "org-1", invitation.OrganizationID) + require.Equal(t, "user-1", invitation.InviterID) + require.Equal(t, "user@example.com", invitation.Email) + require.WithinDuration(t, time.Now().UTC().Add(90*time.Minute), invitation.ExpiresAt, 2*time.Second) + }, + }, + }) +} + +func TestGetAllOrganizationInvitationsHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodGet, "/organizations/org-1/invitations", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&GetAllOrganizationInvitationsHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "organization_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "repo_error", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return(([]orgtypes.OrganizationInvitation)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "expired_pending_invitation_is_updated", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return([]orgtypes.OrganizationInvitation{{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitations := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationInvitation](t, reqCtx) + require.Len(t, invitations, 1) + require.Equal(t, "inv-1", invitations[0].ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusExpired, invitations[0].Status) + }, + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return([]orgtypes.OrganizationInvitation{{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Date(2026, time.January, 2, 3, 4, 5, 0, time.UTC)}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitations := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationInvitation](t, reqCtx) + require.Len(t, invitations, 1) + require.Equal(t, "inv-1", invitations[0].ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusAccepted, invitations[0].Status) + }, + }, + }) +} + +func TestGetOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodGet, "/organizations/org-1/invitations/inv-1", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&GetOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + invitationID: "inv-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "organization_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "invitation_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_from_other_organization", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-2", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "expired_pending_invitation_is_updated", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusExpired, invitation.Status) + }, + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Date(2026, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusAccepted, invitation.Status) + }, + }, + }) +} + +func TestRevokeOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodDelete, "/organizations/org-1/invitations/inv-1", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&RevokeOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + invitationID: "inv-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "organization_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "invitation_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_from_other_organization", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-2", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "already_handled_conflict", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "expired_pending_invitation_conflict", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusRevoked + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusRevoked, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusRevoked, invitation.Status) + }, + }, + }) +} + +func TestAcceptOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodPost, "/organizations/org-1/invitations/inv-1/accept", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&AcceptOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + invitationID: "inv-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "user_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return((*models.User)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "user_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return((*models.User)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "user_missing_email", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: " "}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_from_other_organization", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-2", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "expired_pending_invitation_conflict", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusExpired + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusExpired, ExpiresAt: time.Date(2020, time.January, 2, 3, 4, 5, 0, time.UTC)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "email_mismatch_forbidden", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "other@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "member_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-1").Return((*orgtypes.OrganizationMember)(nil), errors.New("member lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "member lookup failed", + }, + { + name: "success_creates_member", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + fixture.memberRepo.On("Create", mock.Anything, mock.MatchedBy(func(member *orgtypes.OrganizationMember) bool { + return member != nil && member.OrganizationID == "org-1" && member.UserID == "user-1" && member.Role == "member" + })).Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusAccepted + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusAccepted, invitation.Status) + }, + }, + { + name: "success_with_existing_member", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusAccepted + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusAccepted, invitation.Status) + }, + }, + }) +} + +func TestAcceptOrganizationInvitationHandler_RedirectURL(t *testing.T) { + t.Parallel() + + fixture := newOrganizationInvitationHandlerFixture() + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.memberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusAccepted + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusAccepted, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + + handler := (&AcceptOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + userID := "user-1" + req, w, reqCtx := internaltests.NewHandlerRequest(t, http.MethodPost, "/organizations/org-1/invitations/inv-1/accept?redirect_url=https%3A%2F%2Fapp.example.com%2Fwelcome", nil, &userID) + req.SetPathValue("organization_id", "org-1") + req.SetPathValue("invitation_id", "inv-1") + handler.ServeHTTP(w, req) + + require.Equal(t, "https://app.example.com/welcome", reqCtx.RedirectURL) + require.Equal(t, 0, reqCtx.ResponseStatus) + require.False(t, reqCtx.ResponseReady) + fixture.userSvc.AssertExpectations(t) + fixture.invRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) +} + +func TestRejectOrganizationInvitationHandler(t *testing.T) { + t.Parallel() + + runOrganizationInvitationHandlerCases(t, http.MethodPost, "/organizations/org-1/invitations/inv-1/reject", func(fixture *organizationInvitationHandlerFixture) http.HandlerFunc { + return (&RejectOrganizationInvitationHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationInvitationHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + invitationID: "inv-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "user_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return((*models.User)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "user_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return((*models.User)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "user_missing_email", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: ""}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_not_found", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_from_other_organization", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-2", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "invitation_lookup_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return((*orgtypes.OrganizationInvitation)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "already_handled_conflict", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusRejected, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "email_mismatch_forbidden", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "other@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: internalerrors.ErrForbidden.Error(), + }, + { + name: "update_error", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusRejected + })).Return((*orgtypes.OrganizationInvitation)(nil), errors.New("update failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "update failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + invitationID: "inv-1", + prepare: func(fixture *organizationInvitationHandlerFixture) { + fixture.userSvc.On("GetByID", mock.Anything, "user-1").Return(&models.User{ID: "user-1", Email: "user@example.com"}, nil).Once() + fixture.invRepo.On("GetByID", mock.Anything, "inv-1").Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusPending, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + fixture.invRepo.On("Update", mock.Anything, mock.MatchedBy(func(invitation *orgtypes.OrganizationInvitation) bool { + return invitation != nil && invitation.ID == "inv-1" && invitation.Status == orgtypes.OrganizationInvitationStatusRejected + })).Return(&orgtypes.OrganizationInvitation{ID: "inv-1", OrganizationID: "org-1", Email: "user@example.com", Role: "member", Status: orgtypes.OrganizationInvitationStatusRejected, ExpiresAt: time.Now().UTC().Add(time.Hour)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + invitation := internaltests.DecodeResponseJSON[orgtypes.OrganizationInvitation](t, reqCtx) + require.Equal(t, "inv-1", invitation.ID) + require.Equal(t, orgtypes.OrganizationInvitationStatusRejected, invitation.Status) + }, + }, + }) +} diff --git a/plugins/organizations/handlers/organization_member_handlers.go b/plugins/organizations/handlers/organization_member_handlers.go new file mode 100644 index 00000000..de5fc7d1 --- /dev/null +++ b/plugins/organizations/handlers/organization_member_handlers.go @@ -0,0 +1,166 @@ +package handlers + +import ( + "net/http" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type AddOrganizationMemberHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *AddOrganizationMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + + var payload types.AddOrganizationMemberRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + member, err := h.UseCase.AddMember(ctx, userID, organizationID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, member) + } +} + +type GetAllOrganizationMembersHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *GetAllOrganizationMembersHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + page := util.GetQueryInt(r, "page", 1) + limit := util.GetQueryInt(r, "limit", 10) + members, err := h.UseCase.GetAllMembers(ctx, userID, organizationID, page, limit) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, members) + } +} + +type GetOrganizationMemberHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *GetOrganizationMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + memberID := r.PathValue("member_id") + member, err := h.UseCase.GetMember(ctx, userID, organizationID, memberID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, member) + } +} + +type UpdateOrganizationMemberHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *UpdateOrganizationMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + memberID := r.PathValue("member_id") + + var payload types.UpdateOrganizationMemberRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + member, err := h.UseCase.UpdateMember(ctx, userID, organizationID, memberID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, member) + } +} + +type DeleteOrganizationMemberHandler struct { + UseCase *usecases.OrganizationMemberUseCase +} + +func (h *DeleteOrganizationMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + memberID := r.PathValue("member_id") + if err := h.UseCase.RemoveMember(ctx, userID, organizationID, memberID); err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusNoContent, nil) + } +} diff --git a/plugins/organizations/handlers/organization_member_handlers_test.go b/plugins/organizations/handlers/organization_member_handlers_test.go new file mode 100644 index 00000000..877b70b4 --- /dev/null +++ b/plugins/organizations/handlers/organization_member_handlers_test.go @@ -0,0 +1,581 @@ +package handlers + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type memberHandlerAccessControlServiceStub struct { + roles map[string]bool + err error +} + +func (s *memberHandlerAccessControlServiceStub) RoleExists(ctx context.Context, roleName string) (bool, error) { + if s.err != nil { + return false, s.err + } + return s.roles[roleName], nil +} + +func newMemberHandlerAccessControlServiceStub() *memberHandlerAccessControlServiceStub { + return &memberHandlerAccessControlServiceStub{roles: map[string]bool{"member": true, "admin": true}} +} + +type organizationMemberHandlerFixture struct { + userSvc *internaltests.MockUserService + accessControl *memberHandlerAccessControlServiceStub + orgRepo *orgtests.MockOrganizationRepository + orgMemberRepo *orgtests.MockOrganizationMemberRepository +} + +func newOrganizationMemberHandlerFixture() *organizationMemberHandlerFixture { + return &organizationMemberHandlerFixture{ + userSvc: &internaltests.MockUserService{}, + accessControl: newMemberHandlerAccessControlServiceStub(), + orgRepo: &orgtests.MockOrganizationRepository{}, + orgMemberRepo: &orgtests.MockOrganizationMemberRepository{}, + } +} + +func (f *organizationMemberHandlerFixture) useCase() *usecases.OrganizationMemberUseCase { + service := orgservices.NewOrganizationMemberService(f.userSvc, f.accessControl, f.orgRepo, f.orgMemberRepo, nil) + return usecases.NewOrganizationMemberUseCase(service) +} + +func (f *organizationMemberHandlerFixture) newRequest(t *testing.T, method, path string, body []byte, userID *string, organizationID, memberID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + if memberID != "" { + req.SetPathValue("member_id", memberID) + } + return req, w, reqCtx +} + +type organizationMemberHandlerCase struct { + name string + userID *string + body []byte + organizationID string + memberID string + prepare func(*organizationMemberHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(*testing.T, *models.RequestContext) +} + +func runOrganizationMemberHandlerCases(t *testing.T, method, path string, buildHandler func(*organizationMemberHandlerFixture) http.HandlerFunc, cases []organizationMemberHandlerCase) { + t.Helper() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationMemberHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := buildHandler(fixture) + req, w, reqCtx := fixture.newRequest(t, method, path, tt.body, tt.userID, tt.organizationID, tt.memberID) + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.userSvc.AssertExpectations(t) + fixture.orgRepo.AssertExpectations(t) + fixture.orgMemberRepo.AssertExpectations(t) + }) + } +} + +func TestAddOrganizationMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodPost, "/organizations/org-1/members", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&AddOrganizationMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid request body", + userID: new("user-1"), + body: []byte("{"), + organizationID: "org-1", + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "bad request when user id is empty", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: "unprocessable entity", + }, + { + name: "bad request when role is empty", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: ""}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: "unprocessable entity", + }, + { + name: "invalid role", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "ghost"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "organization not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-3", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "user not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.userSvc.On("GetByID", mock.Anything, "user-2").Return((*models.User)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "member lookup error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user@example.com"}, nil).Once() + fixture.orgMemberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return((*orgtypes.OrganizationMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "conflict", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user@example.com"}, nil).Once() + fixture.orgMemberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: "conflict", + }, + { + name: "success", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationMemberRequest{UserID: "user-2", Role: "member"}), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.userSvc.On("GetByID", mock.Anything, "user-2").Return(&models.User{ID: "user-2", Email: "user@example.com"}, nil).Once() + fixture.orgMemberRepo.On("GetByOrganizationIDAndUserID", mock.Anything, "org-1", "user-2").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + fixture.orgMemberRepo.On("Create", mock.Anything, mock.MatchedBy(func(member *orgtypes.OrganizationMember) bool { + return member != nil && member.OrganizationID == "org-1" && member.UserID == "user-2" && member.Role == "member" + })).Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + member := internaltests.DecodeResponseJSON[orgtypes.OrganizationMember](t, reqCtx) + require.Equal(t, "mem-1", member.ID) + require.Equal(t, "org-1", member.OrganizationID) + require.Equal(t, "user-2", member.UserID) + require.Equal(t, "member", member.Role) + }, + }, + }) +} + +func TestGetAllOrganizationMembersHandler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodGet, "/organizations/org-1/members", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&GetAllOrganizationMembersHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "repo error", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetAllByOrganizationID", mock.Anything, "org-1", 1, 10).Return(([]orgtypes.OrganizationMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetAllByOrganizationID", mock.Anything, "org-1", 1, 10).Return([]orgtypes.OrganizationMember{{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + members := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationMember](t, reqCtx) + require.Len(t, members, 1) + require.Equal(t, "mem-1", members[0].ID) + require.Equal(t, "org-1", members[0].OrganizationID) + require.Equal(t, "user-2", members[0].UserID) + require.Equal(t, "member", members[0].Role) + }, + }, + }) +} + +func TestGetOrganizationMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodGet, "/organizations/org-1/members/mem-1", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&GetOrganizationMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + memberID: "mem-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "member not found", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "repo error", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return((*orgtypes.OrganizationMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + member := internaltests.DecodeResponseJSON[orgtypes.OrganizationMember](t, reqCtx) + require.Equal(t, "mem-1", member.ID) + require.Equal(t, "org-1", member.OrganizationID) + require.Equal(t, "user-2", member.UserID) + require.Equal(t, "member", member.Role) + }, + }, + }) +} + +func TestUpdateOrganizationMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodPatch, "/organizations/org-1/members/mem-1", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&UpdateOrganizationMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid request body", + userID: new("user-1"), + body: []byte("{"), + organizationID: "org-1", + memberID: "mem-1", + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "bad request when role is empty", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: " "}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "bad request", + }, + { + name: "organization not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "member not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "repo error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgMemberRepo.On("Update", mock.Anything, mock.MatchedBy(func(member *orgtypes.OrganizationMember) bool { + return member != nil && member.ID == "mem-1" && member.Role == "admin" + })).Return((*orgtypes.OrganizationMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationMemberRequest{Role: "admin"}), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgMemberRepo.On("Update", mock.Anything, mock.MatchedBy(func(member *orgtypes.OrganizationMember) bool { + return member != nil && member.ID == "mem-1" && member.Role == "admin" + })).Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "admin"}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + member := internaltests.DecodeResponseJSON[orgtypes.OrganizationMember](t, reqCtx) + require.Equal(t, "mem-1", member.ID) + require.Equal(t, "admin", member.Role) + }, + }, + }) +} + +func TestDeleteOrganizationMemberHandler_Handler(t *testing.T) { + t.Parallel() + + runOrganizationMemberHandlerCases(t, http.MethodDelete, "/organizations/org-1/members/mem-1", func(fixture *organizationMemberHandlerFixture) http.HandlerFunc { + return (&DeleteOrganizationMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + memberID: "mem-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "member not found", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "repo error", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgMemberRepo.On("Delete", mock.Anything, "mem-1").Return(errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + memberID: "mem-1", + prepare: func(fixture *organizationMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "mem-1").Return(&orgtypes.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgMemberRepo.On("Delete", mock.Anything, "mem-1").Return(nil).Once() + }, + expectedStatus: http.StatusNoContent, + }, + }) +} diff --git a/plugins/organizations/handlers/organization_team_handlers.go b/plugins/organizations/handlers/organization_team_handlers.go new file mode 100644 index 00000000..334908f2 --- /dev/null +++ b/plugins/organizations/handlers/organization_team_handlers.go @@ -0,0 +1,136 @@ +package handlers + +import ( + "net/http" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type CreateOrganizationTeamHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *CreateOrganizationTeamHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + + var payload types.CreateOrganizationTeamRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + team, err := h.UseCase.CreateTeam(ctx, userID, organizationID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, team) + } +} + +type GetAllOrganizationTeamsHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *GetAllOrganizationTeamsHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teams, err := h.UseCase.GetAllTeams(ctx, userID, organizationID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, teams) + } +} + +type UpdateOrganizationTeamHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *UpdateOrganizationTeamHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + + var payload types.UpdateOrganizationTeamRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + team, err := h.UseCase.UpdateTeam(ctx, userID, organizationID, teamID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, team) + } +} + +type DeleteOrganizationTeamHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *DeleteOrganizationTeamHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + if err := h.UseCase.DeleteTeam(ctx, userID, organizationID, teamID); err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusNoContent, nil) + } +} diff --git a/plugins/organizations/handlers/organization_team_handlers_test.go b/plugins/organizations/handlers/organization_team_handlers_test.go new file mode 100644 index 00000000..aea93f46 --- /dev/null +++ b/plugins/organizations/handlers/organization_team_handlers_test.go @@ -0,0 +1,597 @@ +package handlers + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type organizationTeamHandlerFixture struct { + orgRepo *orgtests.MockOrganizationRepository + teamRepo *orgtests.MockOrganizationTeamRepository + memberRepo *orgtests.MockOrganizationMemberRepository + teamMemberRepo *orgtests.MockOrganizationTeamMemberRepository +} + +type organizationTeamHandlerCase struct { + name string + userID *string + body []byte + organizationID string + teamID string + prepare func(*organizationTeamHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(*testing.T, *models.RequestContext) +} + +func newOrganizationTeamHandlerFixture() *organizationTeamHandlerFixture { + return &organizationTeamHandlerFixture{ + orgRepo: &orgtests.MockOrganizationRepository{}, + teamRepo: &orgtests.MockOrganizationTeamRepository{}, + memberRepo: &orgtests.MockOrganizationMemberRepository{}, + teamMemberRepo: &orgtests.MockOrganizationTeamMemberRepository{}, + } +} + +func (f *organizationTeamHandlerFixture) useCase() *usecases.OrganizationTeamUseCase { + service := orgservices.NewOrganizationTeamService(f.orgRepo, f.teamRepo, f.memberRepo, f.teamMemberRepo, nil, nil) + return usecases.NewOrganizationTeamUseCase(service) +} + +func (f *organizationTeamHandlerFixture) newRequest(t *testing.T, method, path string, body []byte, userID *string, organizationID, teamID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + t.Helper() + + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + if teamID != "" { + req.SetPathValue("team_id", teamID) + } + return req, w, reqCtx +} + +func TestCreateOrganizationTeamHandler(t *testing.T) { + t.Parallel() + + tests := []organizationTeamHandlerCase{ + { + name: "missing_user", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{Name: "Platform"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + organizationID: "org-1", + body: []byte("{invalid"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{Name: "Platform"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{Name: "Platform"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "bad_request", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{Name: " "}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: internalerrors.ErrBadRequest.Error(), + }, + { + name: "slug_conflict", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{ + Name: "Platform", + }), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return(&orgtypes.OrganizationTeam{ID: "team-2"}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "repository_error", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{ + Name: "Platform", + }), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return((*orgtypes.OrganizationTeam)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + body: internaltests.MarshalToJSON(t, orgtypes.CreateOrganizationTeamRequest{ + Name: "Platform Team", + Slug: new("platform-team"), + Description: new("Core platform team"), + Metadata: json.RawMessage(`{"tier":"core"}`), + }), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform-team").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + f.teamRepo.On("Create", mock.Anything, mock.MatchedBy(func(team *orgtypes.OrganizationTeam) bool { + return team != nil && team.OrganizationID == "org-1" && team.Name == "Platform Team" && team.Slug == "platform-team" && team.Description != nil && *team.Description == "Core platform team" && string(team.Metadata) == `{"tier":"core"}` + })).Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform Team", Slug: "platform-team", Description: new("Core platform team"), Metadata: json.RawMessage(`{"tier":"core"}`)}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + team := internaltests.DecodeResponseJSON[orgtypes.OrganizationTeam](t, reqCtx) + assert.Equal(t, "team-1", team.ID) + assert.Equal(t, "org-1", team.OrganizationID) + assert.Equal(t, "Platform Team", team.Name) + assert.Equal(t, "platform-team", team.Slug) + require.NotNil(t, team.Description) + assert.Equal(t, "Core platform team", *team.Description) + assert.JSONEq(t, `{"tier":"core"}`, string(team.Metadata)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &CreateOrganizationTeamHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodPost, "/organizations/org-1/teams", tt.body, tt.userID, tt.organizationID, "") + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.teamRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) + fixture.teamMemberRepo.AssertExpectations(t) + }) + } +} + +func TestGetAllOrganizationTeamsHandler(t *testing.T) { + t.Parallel() + + tests := []organizationTeamHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "repository_error", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetAllByOrganizationID", mock.Anything, "org-1").Return([]orgtypes.OrganizationTeam{{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + teams := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationTeam](t, reqCtx) + require.Len(t, teams, 1) + assert.Equal(t, "team-1", teams[0].ID) + assert.Equal(t, "org-1", teams[0].OrganizationID) + assert.Equal(t, "Platform", teams[0].Name) + assert.Equal(t, "platform", teams[0].Slug) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &GetAllOrganizationTeamsHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodGet, "/organizations/org-1/teams", nil, tt.userID, tt.organizationID, "") + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.teamRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) + fixture.teamMemberRepo.AssertExpectations(t) + }) + } +} + +func TestUpdateOrganizationTeamHandler(t *testing.T) { + t.Parallel() + + tests := []organizationTeamHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid_json", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: []byte("{invalid"), + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team_not_found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team_wrong_organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "bad_request", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: " "}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: internalerrors.ErrBadRequest.Error(), + }, + { + name: "slug_conflict", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp", Slug: new("platform")}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return(&orgtypes.OrganizationTeam{ID: "team-2"}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "repository_error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return((*orgtypes.OrganizationTeam)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "update_error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{Name: "Platform Revamp"}), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + f.teamRepo.On("Update", mock.Anything, mock.MatchedBy(func(team *orgtypes.OrganizationTeam) bool { + return team != nil && team.ID == "team-1" && team.OrganizationID == "org-1" && team.Name == "Platform Revamp" && team.Slug == "platform" + })).Return((*orgtypes.OrganizationTeam)(nil), errors.New("update failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "update failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + body: internaltests.MarshalToJSON(t, orgtypes.UpdateOrganizationTeamRequest{ + Name: "Platform Revamp", + Description: new("Updated platform team"), + Metadata: json.RawMessage(`{"tier":"core"}`), + }), + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("GetByOrganizationIDAndSlug", mock.Anything, "org-1", "platform").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + f.teamRepo.On("Update", mock.Anything, mock.MatchedBy(func(team *orgtypes.OrganizationTeam) bool { + return team != nil && team.ID == "team-1" && team.OrganizationID == "org-1" && team.Name == "Platform Revamp" && team.Slug == "platform" && team.Description != nil && *team.Description == "Updated platform team" && string(team.Metadata) == `{"tier":"core"}` + })).Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform Revamp", Slug: "platform", Description: new("Updated platform team"), Metadata: json.RawMessage(`{"tier":"core"}`)}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + team := internaltests.DecodeResponseJSON[orgtypes.OrganizationTeam](t, reqCtx) + assert.Equal(t, "team-1", team.ID) + assert.Equal(t, "org-1", team.OrganizationID) + assert.Equal(t, "Platform Revamp", team.Name) + assert.Equal(t, "platform", team.Slug) + require.NotNil(t, team.Description) + assert.Equal(t, "Updated platform team", *team.Description) + assert.JSONEq(t, `{"tier":"core"}`, string(team.Metadata)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &UpdateOrganizationTeamHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodPatch, "/organizations/org-1/teams/team-1", tt.body, tt.userID, tt.organizationID, tt.teamID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.teamRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) + fixture.teamMemberRepo.AssertExpectations(t) + }) + } +} + +func TestDeleteOrganizationTeamHandler(t *testing.T) { + t.Parallel() + + tests := []organizationTeamHandlerCase{ + { + name: "missing_user", + organizationID: "org-1", + teamID: "team-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization_not_found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "owner-2"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team_not_found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team_wrong_organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "repository_error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("Delete", mock.Anything, "team-1").Return(errors.New("delete failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "delete failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(f *organizationTeamHandlerFixture) { + f.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + f.teamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + f.teamRepo.On("Delete", mock.Anything, "team-1").Return(nil).Once() + }, + expectedStatus: http.StatusNoContent, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + assert.Equal(t, "null", string(reqCtx.ResponseBody)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := &DeleteOrganizationTeamHandler{UseCase: fixture.useCase()} + req, w, reqCtx := fixture.newRequest(t, http.MethodDelete, "/organizations/org-1/teams/team-1", nil, tt.userID, tt.organizationID, tt.teamID) + handler.Handler().ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.teamRepo.AssertExpectations(t) + fixture.memberRepo.AssertExpectations(t) + fixture.teamMemberRepo.AssertExpectations(t) + }) + } +} diff --git a/plugins/organizations/handlers/organization_team_member_handlers.go b/plugins/organizations/handlers/organization_team_member_handlers.go new file mode 100644 index 00000000..0e4be010 --- /dev/null +++ b/plugins/organizations/handlers/organization_team_member_handlers.go @@ -0,0 +1,134 @@ +package handlers + +import ( + "net/http" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type AddOrganizationTeamMemberHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *AddOrganizationTeamMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + + var payload types.AddOrganizationTeamMemberRequest + if err := util.ParseJSON(r, &payload); err != nil { + reqCtx.SetJSONResponse(http.StatusBadRequest, map[string]any{"message": "invalid request body"}) + reqCtx.Handled = true + return + } + + teamMember, err := h.UseCase.AddTeamMember(ctx, userID, organizationID, teamID, payload) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusCreated, teamMember) + } +} + +type GetAllOrganizationTeamMembersHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *GetAllOrganizationTeamMembersHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + page := util.GetQueryInt(r, "page", 1) + limit := util.GetQueryInt(r, "limit", 10) + teamMembers, err := h.UseCase.GetAllTeamMembers(ctx, userID, organizationID, teamID, page, limit) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, teamMembers) + } +} + +type GetOrganizationTeamMemberHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *GetOrganizationTeamMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + memberID := r.PathValue("member_id") + teamMember, err := h.UseCase.GetTeamMember(ctx, userID, organizationID, teamID, memberID) + if err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusOK, teamMember) + } +} + +type DeleteOrganizationTeamMemberHandler struct { + UseCase *usecases.OrganizationTeamUseCase +} + +func (h *DeleteOrganizationTeamMemberHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + reqCtx, _ := models.GetRequestContext(ctx) + + userID, ok := models.GetUserIDFromContext(ctx) + if !ok { + reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{"message": "Unauthorized"}) + reqCtx.Handled = true + return + } + + organizationID := r.PathValue("organization_id") + teamID := r.PathValue("team_id") + memberID := r.PathValue("member_id") + if err := h.UseCase.RemoveTeamMember(ctx, userID, organizationID, teamID, memberID); err != nil { + internalerrors.HandleError(err, reqCtx) + return + } + + reqCtx.SetJSONResponse(http.StatusNoContent, nil) + } +} diff --git a/plugins/organizations/handlers/organization_team_member_handlers_test.go b/plugins/organizations/handlers/organization_team_member_handlers_test.go new file mode 100644 index 00000000..75268a0e --- /dev/null +++ b/plugins/organizations/handlers/organization_team_member_handlers_test.go @@ -0,0 +1,648 @@ +package handlers + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + internalerrors "github.com/Authula/authula/internal/errors" + internaltests "github.com/Authula/authula/internal/tests" + "github.com/Authula/authula/models" + orgservices "github.com/Authula/authula/plugins/organizations/services" + orgtests "github.com/Authula/authula/plugins/organizations/tests" + orgtypes "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" +) + +type organizationTeamMemberHandlerFixture struct { + orgRepo *orgtests.MockOrganizationRepository + orgMemberRepo *orgtests.MockOrganizationMemberRepository + orgTeamRepo *orgtests.MockOrganizationTeamRepository + orgTeamMemberRepo *orgtests.MockOrganizationTeamMemberRepository +} + +func newOrganizationTeamMemberHandlerFixture() *organizationTeamMemberHandlerFixture { + return &organizationTeamMemberHandlerFixture{ + orgRepo: &orgtests.MockOrganizationRepository{}, + orgTeamRepo: &orgtests.MockOrganizationTeamRepository{}, + orgMemberRepo: &orgtests.MockOrganizationMemberRepository{}, + orgTeamMemberRepo: &orgtests.MockOrganizationTeamMemberRepository{}, + } +} + +func (f *organizationTeamMemberHandlerFixture) useCase() *usecases.OrganizationTeamUseCase { + service := orgservices.NewOrganizationTeamService(f.orgRepo, f.orgTeamRepo, f.orgMemberRepo, f.orgTeamMemberRepo, nil, nil) + return usecases.NewOrganizationTeamUseCase(service) +} + +func (f *organizationTeamMemberHandlerFixture) newRequest(t *testing.T, method string, path string, body []byte, userID *string, organizationID string, teamID string, memberID string) (*http.Request, *httptest.ResponseRecorder, *models.RequestContext) { + t.Helper() + + req, w, reqCtx := internaltests.NewHandlerRequest(t, method, path, body, userID) + if organizationID != "" { + req.SetPathValue("organization_id", organizationID) + } + if teamID != "" { + req.SetPathValue("team_id", teamID) + } + if memberID != "" { + req.SetPathValue("member_id", memberID) + } + return req, w, reqCtx +} + +type organizationTeamMemberHandlerCase struct { + name string + userID *string + body []byte + organizationID string + teamID string + memberID string + prepare func(*organizationTeamMemberHandlerFixture) + expectedStatus int + expectedMessage string + checkResponse func(*testing.T, *models.RequestContext) +} + +func runOrganizationTeamMemberHandlerCases(t *testing.T, method, path string, buildHandler func(*organizationTeamMemberHandlerFixture) http.HandlerFunc, cases []organizationTeamMemberHandlerCase) { + t.Helper() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + fixture := newOrganizationTeamMemberHandlerFixture() + if tt.prepare != nil { + tt.prepare(fixture) + } + + handler := buildHandler(fixture) + req, w, reqCtx := fixture.newRequest(t, method, path, tt.body, tt.userID, tt.organizationID, tt.teamID, tt.memberID) + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, reqCtx.ResponseStatus) + if tt.expectedMessage != "" { + internaltests.AssertErrorMessage(t, reqCtx, tt.expectedStatus, tt.expectedMessage) + } + if tt.checkResponse != nil { + tt.checkResponse(t, reqCtx) + } + fixture.orgRepo.AssertExpectations(t) + fixture.orgMemberRepo.AssertExpectations(t) + fixture.orgTeamRepo.AssertExpectations(t) + fixture.orgTeamMemberRepo.AssertExpectations(t) + }) + } +} + +func TestAddOrganizationTeamMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationTeamMemberHandlerCases(t, http.MethodPost, "/organizations/org-1/teams/team-1/members", func(fixture *organizationTeamMemberHandlerFixture) http.HandlerFunc { + return (&AddOrganizationTeamMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationTeamMemberHandlerCase{ + { + name: "missing user", + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "invalid request body", + userID: new("user-1"), + body: []byte("{"), + organizationID: "org-1", + teamID: "team-1", + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid request body", + }, + { + name: "organization not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team from another organization", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "bad request when member id is empty", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: ""}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedMessage: internalerrors.ErrUnprocessableEntity.Error(), + }, + { + name: "organization member lookup error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return((*orgtypes.OrganizationMember)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "organization member not found", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return((*orgtypes.OrganizationMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team member conflict", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return(&orgtypes.OrganizationMember{ID: "member-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + }, + expectedStatus: http.StatusConflict, + expectedMessage: internalerrors.ErrConflict.Error(), + }, + { + name: "member from another organization", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return(&orgtypes.OrganizationMember{ID: "member-1", OrganizationID: "org-2", UserID: "user-2", Role: "member"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: internalerrors.ErrNotFound.Error(), + }, + { + name: "team member create error", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return(&orgtypes.OrganizationMember{ID: "member-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), nil).Once() + fixture.orgTeamMemberRepo.On("Create", mock.Anything, mock.MatchedBy(func(teamMember *orgtypes.OrganizationTeamMember) bool { + return teamMember != nil && teamMember.TeamID == "team-1" && teamMember.MemberID == "member-1" + })).Return((*orgtypes.OrganizationTeamMember)(nil), errors.New("create failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "create failed", + }, + { + name: "success", + userID: new("user-1"), + body: internaltests.MarshalToJSON(t, orgtypes.AddOrganizationTeamMemberRequest{MemberID: "member-1"}), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgMemberRepo.On("GetByID", mock.Anything, "member-1").Return(&orgtypes.OrganizationMember{ID: "member-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), nil).Once() + fixture.orgTeamMemberRepo.On("Create", mock.Anything, mock.MatchedBy(func(teamMember *orgtypes.OrganizationTeamMember) bool { + return teamMember != nil && teamMember.TeamID == "team-1" && teamMember.MemberID == "member-1" + })).Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + }, + expectedStatus: http.StatusCreated, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + teamMember := internaltests.DecodeResponseJSON[orgtypes.OrganizationTeamMember](t, reqCtx) + require.Equal(t, "team-member-1", teamMember.ID) + require.Equal(t, "team-1", teamMember.TeamID) + require.Equal(t, "member-1", teamMember.MemberID) + }, + }, + }) +} + +func TestGetAllOrganizationTeamMembersHandler(t *testing.T) { + t.Parallel() + + runOrganizationTeamMemberHandlerCases(t, http.MethodGet, "/organizations/org-1/teams/team-1/members", func(fixture *organizationTeamMemberHandlerFixture) http.HandlerFunc { + return (&GetAllOrganizationTeamMembersHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationTeamMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + teamID: "team-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team from another organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team lookup error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "repo error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetAllByTeamID", mock.Anything, "team-1", 1, 10).Return(([]orgtypes.OrganizationTeamMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetAllByTeamID", mock.Anything, "team-1", 1, 10).Return([]orgtypes.OrganizationTeamMember{{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + teamMembers := internaltests.DecodeResponseJSON[[]orgtypes.OrganizationTeamMember](t, reqCtx) + require.Len(t, teamMembers, 1) + require.Equal(t, "team-member-1", teamMembers[0].ID) + require.Equal(t, "team-1", teamMembers[0].TeamID) + require.Equal(t, "member-1", teamMembers[0].MemberID) + }, + }, + }) +} + +func TestGetOrganizationTeamMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationTeamMemberHandlerCases(t, http.MethodGet, "/organizations/org-1/teams/team-1/members/member-1", func(fixture *organizationTeamMemberHandlerFixture) http.HandlerFunc { + return (&GetOrganizationTeamMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationTeamMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team from another organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team member lookup error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "member not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + teamMember := internaltests.DecodeResponseJSON[orgtypes.OrganizationTeamMember](t, reqCtx) + require.Equal(t, "team-member-1", teamMember.ID) + require.Equal(t, "team-1", teamMember.TeamID) + require.Equal(t, "member-1", teamMember.MemberID) + }, + }, + }) +} + +func TestDeleteOrganizationTeamMemberHandler(t *testing.T) { + t.Parallel() + + runOrganizationTeamMemberHandlerCases(t, http.MethodDelete, "/organizations/org-1/teams/team-1/members/member-1", func(fixture *organizationTeamMemberHandlerFixture) http.HandlerFunc { + return (&DeleteOrganizationTeamMemberHandler{UseCase: fixture.useCase()}).Handler() + }, []organizationTeamMemberHandlerCase{ + { + name: "missing user", + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + expectedStatus: http.StatusUnauthorized, + expectedMessage: "Unauthorized", + }, + { + name: "organization not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return((*orgtypes.Organization)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "forbidden", + userID: new("user-2"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + }, + expectedStatus: http.StatusForbidden, + expectedMessage: "forbidden", + }, + { + name: "team not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team from another organization", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-2", Name: "Platform", Slug: "platform"}, nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "team lookup error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return((*orgtypes.OrganizationTeam)(nil), errors.New("lookup failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "lookup failed", + }, + { + name: "team member lookup error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), errors.New("some error")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "some error", + }, + { + name: "member not found", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return((*orgtypes.OrganizationTeamMember)(nil), nil).Once() + }, + expectedStatus: http.StatusNotFound, + expectedMessage: "not found", + }, + { + name: "delete error", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + fixture.orgTeamMemberRepo.On("DeleteByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(errors.New("delete failed")).Once() + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "delete failed", + }, + { + name: "success", + userID: new("user-1"), + organizationID: "org-1", + teamID: "team-1", + memberID: "member-1", + prepare: func(fixture *organizationTeamMemberHandlerFixture) { + fixture.orgRepo.On("GetByID", mock.Anything, "org-1").Return(&orgtypes.Organization{ID: "org-1", OwnerID: "user-1"}, nil).Once() + fixture.orgTeamRepo.On("GetByID", mock.Anything, "team-1").Return(&orgtypes.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, nil).Once() + fixture.orgTeamMemberRepo.On("GetByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(&orgtypes.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}, nil).Once() + fixture.orgTeamMemberRepo.On("DeleteByTeamIDAndMemberID", mock.Anything, "team-1", "member-1").Return(nil).Once() + }, + expectedStatus: http.StatusNoContent, + checkResponse: func(t *testing.T, reqCtx *models.RequestContext) { + assert.Equal(t, "null", string(reqCtx.ResponseBody)) + }, + }, + }) +} diff --git a/plugins/organizations/hooks.go b/plugins/organizations/hooks.go new file mode 100644 index 00000000..50e9f660 --- /dev/null +++ b/plugins/organizations/hooks.go @@ -0,0 +1,203 @@ +package organizations + +import "github.com/Authula/authula/plugins/organizations/types" + +type OrganizationsHookExecutor struct { + config *types.OrganizationsDatabaseHooksConfig +} + +func NewOrganizationsHookExecutor(config *types.OrganizationsDatabaseHooksConfig) *OrganizationsHookExecutor { + return &OrganizationsHookExecutor{config: config} +} + +// Organization Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganization(organization *types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.BeforeCreate == nil { + return nil + } + return e.config.Organizations.BeforeCreate(organization) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganization(organization types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.AfterCreate == nil { + return nil + } + return e.config.Organizations.AfterCreate(organization) +} + +func (e *OrganizationsHookExecutor) BeforeUpdateOrganization(organization *types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.BeforeUpdate == nil { + return nil + } + return e.config.Organizations.BeforeUpdate(organization) +} + +func (e *OrganizationsHookExecutor) AfterUpdateOrganization(organization types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.AfterUpdate == nil { + return nil + } + return e.config.Organizations.AfterUpdate(organization) +} + +func (e *OrganizationsHookExecutor) BeforeDeleteOrganization(organization *types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.BeforeDelete == nil { + return nil + } + return e.config.Organizations.BeforeDelete(organization) +} + +func (e *OrganizationsHookExecutor) AfterDeleteOrganization(organization types.Organization) error { + if e == nil || e.config == nil || e.config.Organizations == nil || e.config.Organizations.AfterDelete == nil { + return nil + } + return e.config.Organizations.AfterDelete(organization) +} + +// Organization Member Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganizationMember(member *types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.BeforeCreate == nil { + return nil + } + return e.config.Members.BeforeCreate(member) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganizationMember(member types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.AfterCreate == nil { + return nil + } + return e.config.Members.AfterCreate(member) +} + +func (e *OrganizationsHookExecutor) BeforeUpdateOrganizationMember(member *types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.BeforeUpdate == nil { + return nil + } + return e.config.Members.BeforeUpdate(member) +} + +func (e *OrganizationsHookExecutor) AfterUpdateOrganizationMember(member types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.AfterUpdate == nil { + return nil + } + return e.config.Members.AfterUpdate(member) +} + +func (e *OrganizationsHookExecutor) BeforeDeleteOrganizationMember(member *types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.BeforeDelete == nil { + return nil + } + return e.config.Members.BeforeDelete(member) +} + +func (e *OrganizationsHookExecutor) AfterDeleteOrganizationMember(member types.OrganizationMember) error { + if e == nil || e.config == nil || e.config.Members == nil || e.config.Members.AfterDelete == nil { + return nil + } + return e.config.Members.AfterDelete(member) +} + +// Organization Invitation Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganizationInvitation(invitation *types.OrganizationInvitation) error { + if e == nil || e.config == nil || e.config.Invitations == nil || e.config.Invitations.BeforeCreate == nil { + return nil + } + return e.config.Invitations.BeforeCreate(invitation) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganizationInvitation(invitation types.OrganizationInvitation) error { + if e == nil || e.config == nil || e.config.Invitations == nil || e.config.Invitations.AfterCreate == nil { + return nil + } + return e.config.Invitations.AfterCreate(invitation) +} + +func (e *OrganizationsHookExecutor) BeforeUpdateOrganizationInvitation(invitation *types.OrganizationInvitation) error { + if e == nil || e.config == nil || e.config.Invitations == nil || e.config.Invitations.BeforeUpdate == nil { + return nil + } + return e.config.Invitations.BeforeUpdate(invitation) +} + +func (e *OrganizationsHookExecutor) AfterUpdateOrganizationInvitation(invitation types.OrganizationInvitation) error { + if e == nil || e.config == nil || e.config.Invitations == nil || e.config.Invitations.AfterUpdate == nil { + return nil + } + return e.config.Invitations.AfterUpdate(invitation) +} + +// Organization Team Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganizationTeam(team *types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.BeforeCreate == nil { + return nil + } + return e.config.Teams.BeforeCreate(team) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganizationTeam(team types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.AfterCreate == nil { + return nil + } + return e.config.Teams.AfterCreate(team) +} + +func (e *OrganizationsHookExecutor) BeforeUpdateOrganizationTeam(team *types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.BeforeUpdate == nil { + return nil + } + return e.config.Teams.BeforeUpdate(team) +} + +func (e *OrganizationsHookExecutor) AfterUpdateOrganizationTeam(team types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.AfterUpdate == nil { + return nil + } + return e.config.Teams.AfterUpdate(team) +} + +func (e *OrganizationsHookExecutor) BeforeDeleteOrganizationTeam(team *types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.BeforeDelete == nil { + return nil + } + return e.config.Teams.BeforeDelete(team) +} + +func (e *OrganizationsHookExecutor) AfterDeleteOrganizationTeam(team types.OrganizationTeam) error { + if e == nil || e.config == nil || e.config.Teams == nil || e.config.Teams.AfterDelete == nil { + return nil + } + return e.config.Teams.AfterDelete(team) +} + +// Organization Team Member Hooks + +func (e *OrganizationsHookExecutor) BeforeCreateOrganizationTeamMember(member *types.OrganizationTeamMember) error { + if e == nil || e.config == nil || e.config.TeamMembers == nil || e.config.TeamMembers.BeforeCreate == nil { + return nil + } + return e.config.TeamMembers.BeforeCreate(member) +} + +func (e *OrganizationsHookExecutor) AfterCreateOrganizationTeamMember(member types.OrganizationTeamMember) error { + if e == nil || e.config == nil || e.config.TeamMembers == nil || e.config.TeamMembers.AfterCreate == nil { + return nil + } + return e.config.TeamMembers.AfterCreate(member) +} + +func (e *OrganizationsHookExecutor) BeforeDeleteOrganizationTeamMember(member *types.OrganizationTeamMember) error { + if e == nil || e.config == nil || e.config.TeamMembers == nil || e.config.TeamMembers.BeforeDelete == nil { + return nil + } + return e.config.TeamMembers.BeforeDelete(member) +} + +func (e *OrganizationsHookExecutor) AfterDeleteOrganizationTeamMember(member types.OrganizationTeamMember) error { + if e == nil || e.config == nil || e.config.TeamMembers == nil || e.config.TeamMembers.AfterDelete == nil { + return nil + } + return e.config.TeamMembers.AfterDelete(member) +} diff --git a/plugins/organizations/hooks_test.go b/plugins/organizations/hooks_test.go new file mode 100644 index 00000000..7b308410 --- /dev/null +++ b/plugins/organizations/hooks_test.go @@ -0,0 +1,149 @@ +package organizations + +import ( + "errors" + "testing" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func TestOrganizationsHookExecutor_NilHooksAreNoop(t *testing.T) { + t.Parallel() + + executor := NewOrganizationsHookExecutor(nil) + + if err := executor.BeforeCreateOrganization(&types.Organization{ID: "org-1"}); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterCreateOrganization(types.Organization{ID: "org-1"}); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.BeforeCreateOrganizationInvitation(&types.OrganizationInvitation{ID: "inv-1"}); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterCreateOrganizationTeam(types.OrganizationTeam{ID: "team-1"}); err != nil { + t.Fatalf("expected nil error, got %v", err) + } +} + +func TestOrganizationsHookExecutor_OrganizationCreateHooks(t *testing.T) { + t.Parallel() + + var beforeCalled bool + var afterCalled bool + + executor := NewOrganizationsHookExecutor(&types.OrganizationsDatabaseHooksConfig{ + Organizations: &types.OrganizationDatabaseHooksConfig{ + BeforeCreate: func(organization *types.Organization) error { + beforeCalled = true + if organization == nil { + return errors.New("organization is nil") + } + if organization.ID != "org-1" { + t.Fatalf("unexpected organization ID: %s", organization.ID) + } + return nil + }, + AfterCreate: func(organization types.Organization) error { + afterCalled = true + if organization.ID != "org-1" { + t.Fatalf("unexpected organization ID: %s", organization.ID) + } + return nil + }, + }, + }) + + organization := &types.Organization{ID: "org-1", Name: "Acme"} + if err := executor.BeforeCreateOrganization(organization); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterCreateOrganization(*organization); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + if !beforeCalled { + t.Fatal("expected BeforeCreate hook to be called") + } + if !afterCalled { + t.Fatal("expected AfterCreate hook to be called") + } +} + +func TestOrganizationsHookExecutor_OrganizationCreateHookError(t *testing.T) { + t.Parallel() + + someErr := errors.New("some error") + executor := NewOrganizationsHookExecutor(&types.OrganizationsDatabaseHooksConfig{ + Organizations: &types.OrganizationDatabaseHooksConfig{ + BeforeCreate: func(organization *types.Organization) error { + return someErr + }, + }, + }) + + err := executor.BeforeCreateOrganization(&types.Organization{ID: "org-1"}) + if !errors.Is(err, someErr) { + t.Fatalf("expected someErr error, got %v", err) + } +} + +func TestOrganizationsHookExecutor_MemberUpdateDeleteHooks(t *testing.T) { + t.Parallel() + + var beforeUpdateCalled bool + var afterUpdateCalled bool + var beforeDeleteCalled bool + var afterDeleteCalled bool + + executor := NewOrganizationsHookExecutor(&types.OrganizationsDatabaseHooksConfig{ + Members: &types.OrganizationMemberDatabaseHooksConfig{ + BeforeUpdate: func(member *types.OrganizationMember) error { + beforeUpdateCalled = true + if member == nil || member.ID != "mem-1" { + t.Fatalf("unexpected member in before update hook: %+v", member) + } + return nil + }, + AfterUpdate: func(member types.OrganizationMember) error { + afterUpdateCalled = true + if member.ID != "mem-1" { + t.Fatalf("unexpected member in after update hook: %+v", member) + } + return nil + }, + BeforeDelete: func(member *types.OrganizationMember) error { + beforeDeleteCalled = true + if member == nil || member.ID != "mem-1" { + t.Fatalf("unexpected member in before delete hook: %+v", member) + } + return nil + }, + AfterDelete: func(member types.OrganizationMember) error { + afterDeleteCalled = true + if member.ID != "mem-1" { + t.Fatalf("unexpected member in after delete hook: %+v", member) + } + return nil + }, + }, + }) + + member := &types.OrganizationMember{ID: "mem-1", Role: "member"} + if err := executor.BeforeUpdateOrganizationMember(member); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterUpdateOrganizationMember(*member); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.BeforeDeleteOrganizationMember(member); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if err := executor.AfterDeleteOrganizationMember(*member); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + if !beforeUpdateCalled || !afterUpdateCalled || !beforeDeleteCalled || !afterDeleteCalled { + t.Fatal("expected member update and delete hooks to be called") + } +} diff --git a/plugins/organizations/migrations.go b/plugins/organizations/migrations.go new file mode 100644 index 00000000..c149e944 --- /dev/null +++ b/plugins/organizations/migrations.go @@ -0,0 +1,330 @@ +package organizations + +import ( + "context" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/migrations" +) + +func organizationsMigrationsForProvider(provider string) []migrations.Migration { + return migrations.ForProvider(provider, migrations.ProviderVariants{ + "sqlite": func() []migrations.Migration { return []migrations.Migration{organizationsSQLiteInitial()} }, + "postgres": func() []migrations.Migration { return []migrations.Migration{organizationsPostgresInitial()} }, + "mysql": func() []migrations.Migration { return []migrations.Migration{organizationsMySQLInitial()} }, + }) +} + +func organizationsSQLiteInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260321000000_organizations_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `PRAGMA foreign_keys = ON;`, + // ----------------------------------- + `CREATE TABLE IF NOT EXISTS organizations ( + id TEXT PRIMARY KEY, + owner_id TEXT NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL UNIQUE, + logo TEXT, + metadata TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE + );`, + `CREATE INDEX IF NOT EXISTS idx_organizations_owner_id ON organizations(owner_id);`, + // ----------------------------------- + `CREATE TABLE IF NOT EXISTS organization_members ( + id TEXT PRIMARY KEY, + organization_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role VARCHAR(255) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE (organization_id, user_id) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_members_organization_id ON organization_members(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_members_user_id ON organization_members(user_id);`, + // ----------------------------------- + `CREATE TABLE IF NOT EXISTS organization_invitations ( + id TEXT PRIMARY KEY, + email VARCHAR(255) NOT NULL, + inviter_id TEXT NOT NULL, + organization_id TEXT NOT NULL, + role VARCHAR(255) NOT NULL, + status VARCHAR(32) NOT NULL DEFAULT 'pending', + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + FOREIGN KEY (inviter_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE (organization_id, email) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_email ON organization_invitations(email);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_organization_id ON organization_invitations(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_inviter_id ON organization_invitations(inviter_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_status_expires_at ON organization_invitations(status, expires_at);`, + // ----------------------------------- + `CREATE TABLE IF NOT EXISTS organization_teams ( + id TEXT PRIMARY KEY, + organization_id TEXT NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL, + description TEXT, + metadata TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + UNIQUE (organization_id, slug) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_teams_organization_id ON organization_teams(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_teams_slug ON organization_teams(slug);`, + `CREATE TABLE IF NOT EXISTS organization_team_members ( + id TEXT PRIMARY KEY, + team_id TEXT NOT NULL, + member_id TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (team_id) REFERENCES organization_teams(id) ON DELETE CASCADE, + FOREIGN KEY (member_id) REFERENCES organization_members(id) ON DELETE CASCADE, + UNIQUE (team_id, member_id) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_team_members_team_id ON organization_team_members(team_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_team_members_member_id ON organization_team_members(member_id);`, + // ----------------------------------- + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS organization_team_members;`, + `DROP TABLE IF EXISTS organization_teams;`, + `DROP TABLE IF EXISTS organization_invitations;`, + `DROP TABLE IF EXISTS organization_members;`, + `DROP TABLE IF EXISTS organizations;`, + ) + }, + } +} + +func organizationsPostgresInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_organizations_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE OR REPLACE FUNCTION organizations_set_updated_at_fn() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql;`, + `CREATE TABLE IF NOT EXISTS organizations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + owner_id UUID NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL UNIQUE, + logo TEXT, + metadata JSONB, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organizations_owner FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE + );`, + `DROP TRIGGER IF EXISTS update_organizations_updated_at_trigger ON organizations;`, + `CREATE TRIGGER update_organizations_updated_at_trigger + BEFORE UPDATE ON organizations + FOR EACH ROW + EXECUTE FUNCTION organizations_set_updated_at_fn();`, + `CREATE INDEX IF NOT EXISTS idx_organizations_owner_id ON organizations(owner_id);`, + `CREATE TABLE IF NOT EXISTS organization_members ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + organization_id UUID NOT NULL, + user_id UUID NOT NULL, + role VARCHAR(255) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organization_members_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_members_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_members_organization_user UNIQUE (organization_id, user_id) + );`, + `DROP TRIGGER IF EXISTS update_organization_members_updated_at_trigger ON organization_members;`, + `CREATE TRIGGER update_organization_members_updated_at_trigger + BEFORE UPDATE ON organization_members + FOR EACH ROW + EXECUTE FUNCTION organizations_set_updated_at_fn();`, + `CREATE INDEX IF NOT EXISTS idx_organization_members_organization_id ON organization_members(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_members_user_id ON organization_members(user_id);`, + `CREATE TABLE IF NOT EXISTS organization_invitations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + organization_id UUID NOT NULL, + inviter_id UUID NOT NULL, + email VARCHAR(255) NOT NULL, + role VARCHAR(255) NOT NULL, + status VARCHAR(32) NOT NULL DEFAULT 'pending', + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organization_invitations_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_invitations_inviter FOREIGN KEY (inviter_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT chk_organization_invitations_status CHECK (status IN ('pending', 'accepted', 'rejected', 'revoked', 'expired')), + CONSTRAINT uq_organization_invitations_organization_email UNIQUE (organization_id, email) + );`, + `DROP TRIGGER IF EXISTS update_organization_invitations_updated_at_trigger ON organization_invitations;`, + `CREATE TRIGGER update_organization_invitations_updated_at_trigger + BEFORE UPDATE ON organization_invitations + FOR EACH ROW + EXECUTE FUNCTION organizations_set_updated_at_fn();`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_organization_id ON organization_invitations(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_inviter_id ON organization_invitations(inviter_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_email ON organization_invitations(email);`, + `CREATE INDEX IF NOT EXISTS idx_organization_invitations_status_expires_at ON organization_invitations(status, expires_at);`, + `CREATE TABLE IF NOT EXISTS organization_teams ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + organization_id UUID NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL, + description TEXT, + metadata JSONB, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organization_teams_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_teams_organization_slug UNIQUE (organization_id, slug) + );`, + `DROP TRIGGER IF EXISTS update_organization_teams_updated_at_trigger ON organization_teams;`, + `CREATE TRIGGER update_organization_teams_updated_at_trigger + BEFORE UPDATE ON organization_teams + FOR EACH ROW + EXECUTE FUNCTION organizations_set_updated_at_fn();`, + `CREATE INDEX IF NOT EXISTS idx_organization_teams_organization_id ON organization_teams(organization_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_teams_slug ON organization_teams(slug);`, + `CREATE TABLE IF NOT EXISTS organization_team_members ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + team_id UUID NOT NULL, + member_id UUID NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_organization_team_members_team FOREIGN KEY (team_id) REFERENCES organization_teams(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_team_members_member FOREIGN KEY (member_id) REFERENCES organization_members(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_team_members_team_member UNIQUE (team_id, member_id) + );`, + `CREATE INDEX IF NOT EXISTS idx_organization_team_members_team_id ON organization_team_members(team_id);`, + `CREATE INDEX IF NOT EXISTS idx_organization_team_members_member_id ON organization_team_members(member_id);`, + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS organization_team_members;`, + `DROP TRIGGER IF EXISTS update_organization_teams_updated_at_trigger ON organization_teams;`, + `DROP TABLE IF EXISTS organization_teams;`, + `DROP TRIGGER IF EXISTS update_organization_invitations_updated_at_trigger ON organization_invitations;`, + `DROP TABLE IF EXISTS organization_invitations;`, + `DROP TRIGGER IF EXISTS update_organization_members_updated_at_trigger ON organization_members;`, + `DROP TABLE IF EXISTS organization_members;`, + `DROP TRIGGER IF EXISTS update_organizations_updated_at_trigger ON organizations;`, + `DROP TABLE IF EXISTS organizations;`, + `DROP FUNCTION IF EXISTS organizations_set_updated_at_fn();`, + ) + }, + } +} + +func organizationsMySQLInitial() migrations.Migration { + return migrations.Migration{ + Version: "20260309000000_organizations_initial", + Up: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `CREATE TABLE IF NOT EXISTS organizations ( + id BINARY(16) NOT NULL PRIMARY KEY, + owner_id BINARY(16) NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL UNIQUE, + logo TEXT NULL, + metadata JSON NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + CONSTRAINT fk_organizations_owner FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_organizations_owner_id (owner_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + `CREATE TABLE IF NOT EXISTS organization_members ( + id BINARY(16) NOT NULL PRIMARY KEY, + organization_id BINARY(16) NOT NULL, + user_id BINARY(16) NOT NULL, + role VARCHAR(255) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + CONSTRAINT fk_organization_members_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_members_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_members_organization_user UNIQUE (organization_id, user_id), + INDEX idx_organization_members_organization_id (organization_id), + INDEX idx_organization_members_user_id (user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + `CREATE TABLE IF NOT EXISTS organization_invitations ( + id BINARY(16) NOT NULL PRIMARY KEY, + organization_id BINARY(16) NOT NULL, + inviter_id BINARY(16) NOT NULL, + email VARCHAR(255) NOT NULL, + role VARCHAR(255) NOT NULL, + status VARCHAR(32) NOT NULL DEFAULT 'pending', + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + CONSTRAINT fk_organization_invitations_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_invitations_inviter FOREIGN KEY (inviter_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT chk_organization_invitations_status CHECK (status IN ('pending', 'accepted', 'rejected', 'revoked', 'expired')), + CONSTRAINT uq_organization_invitations_organization_email UNIQUE (organization_id, email), + INDEX idx_organization_invitations_organization_id (organization_id), + INDEX idx_organization_invitations_inviter_id (inviter_id), + INDEX idx_organization_invitations_email (email), + INDEX idx_organization_invitations_status_expires_at (status, expires_at) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + `CREATE TABLE IF NOT EXISTS organization_teams ( + id BINARY(16) NOT NULL PRIMARY KEY, + organization_id BINARY(16) NOT NULL, + name VARCHAR(255) NOT NULL, + slug VARCHAR(255) NOT NULL, + description TEXT NULL, + metadata JSON NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + CONSTRAINT fk_organization_teams_organization FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_teams_organization_slug UNIQUE (organization_id, slug), + INDEX idx_organization_teams_organization_id (organization_id), + INDEX idx_organization_teams_slug (slug) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + `CREATE TABLE IF NOT EXISTS organization_team_members ( + id BINARY(16) NOT NULL PRIMARY KEY, + team_id BINARY(16) NOT NULL, + member_id BINARY(16) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT fk_organization_team_members_team FOREIGN KEY (team_id) REFERENCES organization_teams(id) ON DELETE CASCADE, + CONSTRAINT fk_organization_team_members_member FOREIGN KEY (member_id) REFERENCES organization_members(id) ON DELETE CASCADE, + CONSTRAINT uq_organization_team_members_team_member UNIQUE (team_id, member_id), + INDEX idx_organization_team_members_team_id (team_id), + INDEX idx_organization_team_members_member_id (member_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;`, + ) + }, + Down: func(ctx context.Context, tx bun.Tx) error { + return migrations.ExecStatements( + ctx, + tx, + `DROP TABLE IF EXISTS organization_team_members;`, + `DROP TABLE IF EXISTS organization_teams;`, + `DROP TABLE IF EXISTS organization_invitations;`, + `DROP TABLE IF EXISTS organization_members;`, + `DROP TABLE IF EXISTS organizations;`, + ) + }, + } +} diff --git a/plugins/organizations/plugin.go b/plugins/organizations/plugin.go new file mode 100644 index 00000000..ee54665b --- /dev/null +++ b/plugins/organizations/plugin.go @@ -0,0 +1,120 @@ +package organizations + +import ( + "fmt" + + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/migrations" + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/services" + "github.com/Authula/authula/plugins/organizations/types" + "github.com/Authula/authula/plugins/organizations/usecases" + rootservices "github.com/Authula/authula/services" +) + +type OrganizationsPlugin struct { + globalConfig *models.Config + pluginConfig types.OrganizationsPluginConfig + ctx *models.PluginContext + logger models.Logger + databaseHooks *OrganizationsHookExecutor + organizationRepo repositories.OrganizationRepository + organizationService *services.OrganizationService + invitationRepo repositories.OrganizationInvitationRepository + invitationService *services.OrganizationInvitationService + memberRepo repositories.OrganizationMemberRepository + memberService *services.OrganizationMemberService + teamRepo repositories.OrganizationTeamRepository + teamMemberRepo repositories.OrganizationTeamMemberRepository + teamService *services.OrganizationTeamService + organizationUseCase *usecases.OrganizationUseCase + invitationUseCase *usecases.OrganizationInvitationUseCase + memberUseCase *usecases.OrganizationMemberUseCase + teamUseCase *usecases.OrganizationTeamUseCase + Api *API +} + +func New(config types.OrganizationsPluginConfig) *OrganizationsPlugin { + config.ApplyDefaults() + return &OrganizationsPlugin{pluginConfig: config} +} + +func (p *OrganizationsPlugin) Metadata() models.PluginMetadata { + return models.PluginMetadata{ + ID: models.PluginOrganizations.String(), + Version: "1.0.0", + Description: "Provides organization, teams, members, and invitations.", + } +} + +func (p *OrganizationsPlugin) Config() any { + return p.pluginConfig +} + +func (p *OrganizationsPlugin) Init(ctx *models.PluginContext) error { + p.ctx = ctx + p.logger = ctx.Logger + p.globalConfig = ctx.GetConfig() + + if err := util.LoadPluginConfig(p.globalConfig, p.Metadata().ID, &p.pluginConfig); err != nil { + return err + } + + p.pluginConfig.ApplyDefaults() + p.databaseHooks = NewOrganizationsHookExecutor(p.pluginConfig.DatabaseHooks) + + userService, ok := ctx.ServiceRegistry.Get(models.ServiceUser.String()).(rootservices.UserService) + if !ok { + return fmt.Errorf("user service not available in service registry") + } + + accessControlService, ok := ctx.ServiceRegistry.Get(models.ServiceAccessControl.String()).(rootservices.AccessControlService) + if !ok { + return fmt.Errorf("access control service not available in service registry") + } + + var mailerService rootservices.MailerService + if rawMailerService := ctx.ServiceRegistry.Get(models.ServiceMailer.String()); rawMailerService != nil { + if typedMailerService, ok := rawMailerService.(rootservices.MailerService); ok { + mailerService = typedMailerService + } else { + p.logger.Warn("mailer service has unexpected type, skipping organization invitation email delivery") + } + } else { + p.logger.Warn("mailer service not available, skipping organization invitation email delivery") + } + + p.organizationRepo = repositories.NewBunOrganizationRepository(ctx.DB) + p.organizationService = services.NewOrganizationService(p.organizationRepo, p.databaseHooks) + p.invitationRepo = repositories.NewBunOrganizationInvitationRepository(ctx.DB) + p.memberRepo = repositories.NewBunOrganizationMemberRepository(ctx.DB) + p.teamRepo = repositories.NewBunOrganizationTeamRepository(ctx.DB) + p.teamMemberRepo = repositories.NewBunOrganizationTeamMemberRepository(ctx.DB) + p.invitationService = services.NewOrganizationInvitationService(ctx.DB, p.globalConfig, &p.pluginConfig, p.logger, userService, accessControlService, p.organizationRepo, p.invitationRepo, p.memberRepo, mailerService, ctx.EventBus, p.databaseHooks, p.databaseHooks) + p.memberService = services.NewOrganizationMemberService(userService, accessControlService, p.organizationRepo, p.memberRepo, p.databaseHooks) + p.teamService = services.NewOrganizationTeamService(p.organizationRepo, p.teamRepo, p.memberRepo, p.teamMemberRepo, p.databaseHooks, p.databaseHooks) + p.organizationUseCase = usecases.NewOrganizationUseCase(p.organizationService) + p.invitationUseCase = usecases.NewOrganizationInvitationUseCase(p.invitationService) + p.memberUseCase = usecases.NewOrganizationMemberUseCase(p.memberService) + p.teamUseCase = usecases.NewOrganizationTeamUseCase(p.teamService) + p.Api = BuildAPI(p.organizationUseCase, p.invitationUseCase, p.memberUseCase, p.teamUseCase) + + return nil +} + +func (p *OrganizationsPlugin) Migrations(provider string) []migrations.Migration { + return organizationsMigrationsForProvider(provider) +} + +func (p *OrganizationsPlugin) DependsOn() []string { + return []string{models.PluginAccessControl.String()} +} + +func (p *OrganizationsPlugin) Routes() []models.Route { + return Routes(p) +} + +func (p *OrganizationsPlugin) Close() error { + return nil +} diff --git a/plugins/organizations/repositories/bun_organization_invitation_repository.go b/plugins/organizations/repositories/bun_organization_invitation_repository.go new file mode 100644 index 00000000..f23b20d7 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_invitation_repository.go @@ -0,0 +1,98 @@ +package repositories + +import ( + "context" + "database/sql" + "time" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationInvitationRepository struct { + db bun.IDB +} + +func NewBunOrganizationInvitationRepository(db bun.IDB) OrganizationInvitationRepository { + return &BunOrganizationInvitationRepository{db: db} +} + +func (r *BunOrganizationInvitationRepository) Create(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(invitation).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(invitation).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return invitation, nil +} + +func (r *BunOrganizationInvitationRepository) GetByID(ctx context.Context, invitationID string) (*types.OrganizationInvitation, error) { + invitation := new(types.OrganizationInvitation) + err := r.db.NewSelect().Model(invitation).Where("id = ?", invitationID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return invitation, err +} + +func (r *BunOrganizationInvitationRepository) GetByOrganizationIDAndEmail(ctx context.Context, organizationID, email string) (*types.OrganizationInvitation, error) { + invitation := new(types.OrganizationInvitation) + err := r.db.NewSelect().Model(invitation). + Where("organization_id = ? AND email = ?", organizationID, email). + Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return invitation, err +} + +func (r *BunOrganizationInvitationRepository) GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationInvitation, error) { + invitations := make([]types.OrganizationInvitation, 0) + err := r.db.NewSelect().Model(&invitations). + Where("organization_id = ?", organizationID). + OrderExpr("created_at DESC"). + Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationInvitation{}, nil + } + return invitations, err +} + +func (r *BunOrganizationInvitationRepository) GetAllPendingByEmail(ctx context.Context, email string) ([]types.OrganizationInvitation, error) { + invites := make([]types.OrganizationInvitation, 0) + err := r.db.NewSelect().Model(&invites). + Where("email = ? AND status = ? AND expires_at > ?", email, types.OrganizationInvitationStatusPending, time.Now().UTC()). + Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationInvitation{}, nil + } + return invites, err +} + +func (r *BunOrganizationInvitationRepository) Update(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewUpdate().Model(invitation).WherePK().Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(invitation).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return invitation, nil +} + +func (r *BunOrganizationInvitationRepository) WithTx(tx bun.IDB) OrganizationInvitationRepository { + return &BunOrganizationInvitationRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_invitation_repository_test.go b/plugins/organizations/repositories/bun_organization_invitation_repository_test.go new file mode 100644 index 00000000..14de3521 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_invitation_repository_test.go @@ -0,0 +1,307 @@ +package repositories + +import ( + "context" + "database/sql" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func newOrganizationInvitationRepositoryTestDB(t *testing.T) *bun.DB { + t.Helper() + + sqlDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + t.Cleanup(func() { + _ = sqlDB.Close() + }) + + db := bun.NewDB(sqlDB, sqlitedialect.New()) + t.Cleanup(func() { + _ = db.Close() + }) + + _, err = db.ExecContext(context.Background(), ` + CREATE TABLE organization_invitations ( + id TEXT PRIMARY KEY, + email TEXT NOT NULL, + inviter_id TEXT NOT NULL, + organization_id TEXT NOT NULL, + role TEXT NOT NULL, + status TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + `) + require.NoError(t, err) + + return db +} + +func TestBunOrganizationInvitationRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + invitation *types.OrganizationInvitation + expectStatus types.OrganizationInvitationStatus + }{ + { + name: "pending", + invitation: &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }, + expectStatus: types.OrganizationInvitationStatusPending, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + + created, err := repo.Create(context.Background(), tt.invitation) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.expectStatus, created.Status) + require.Equal(t, tt.invitation.ID, created.ID) + }) + } +} + +func TestBunOrganizationInvitationRepository_GetByOrganizationIDAndEmail(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + email string + expectFound bool + }{ + {name: "found", organizationID: "org-1", email: "user@example.com", expectFound: true}, + {name: "not found", organizationID: "org-1", email: "missing@example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByOrganizationIDAndEmail(ctx, tt.organizationID, tt.email) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "inv-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationInvitationRepository_GetAllPendingByEmail(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + _, err = repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-2", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusAccepted, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + tests := []struct { + name string + email string + expectPending int + }{ + {name: "pending only", email: "user@example.com", expectPending: 1}, + {name: "missing", email: "missing@example.com", expectPending: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pending, err := repo.GetAllPendingByEmail(ctx, tt.email) + require.NoError(t, err) + require.Len(t, pending, tt.expectPending) + }) + } +} + +func TestBunOrganizationInvitationRepository_GetAllByOrganizationID(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + _, err = repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-2", + Email: "other@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + invitations, err := repo.GetAllByOrganizationID(ctx, "org-1") + require.NoError(t, err) + require.Len(t, invitations, 2) +} + +func TestBunOrganizationInvitationRepository_Update(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + created, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + created.Status = types.OrganizationInvitationStatusAccepted + updated, err := repo.Update(ctx, created) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, types.OrganizationInvitationStatusAccepted, updated.Status) +} + +func TestBunOrganizationInvitationRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + + tests := []struct { + name string + invitationID string + expectFound bool + }{ + {name: "found", invitationID: "inv-1", expectFound: true}, + {name: "not found", invitationID: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.invitationID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "inv-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationInvitationRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationInvitationRepositoryTestDB(t) + repo := NewBunOrganizationInvitationRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + require.IsType(t, &BunOrganizationInvitationRepository{}, txRepo) + + created, err := txRepo.Create(ctx, &types.OrganizationInvitation{ + ID: "inv-1", + Email: "user@example.com", + InviterID: "user-1", + OrganizationID: "org-1", + Role: "member", + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: time.Now().UTC().Add(time.Hour), + }) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "inv-1", created.ID) +} diff --git a/plugins/organizations/repositories/bun_organization_member_repository.go b/plugins/organizations/repositories/bun_organization_member_repository.go new file mode 100644 index 00000000..124a5456 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_member_repository.go @@ -0,0 +1,91 @@ +package repositories + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationMemberRepository struct { + db bun.IDB +} + +func NewBunOrganizationMemberRepository(db bun.IDB) OrganizationMemberRepository { + return &BunOrganizationMemberRepository{db: db} +} + +func (r *BunOrganizationMemberRepository) Create(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(member).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(member).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return member, nil +} + +func (r *BunOrganizationMemberRepository) GetAllByOrganizationID(ctx context.Context, organizationID string, page int, limit int) ([]types.OrganizationMember, error) { + members := make([]types.OrganizationMember, 0) + err := r.db.NewSelect().Model(&members). + Where("organization_id = ?", organizationID). + Offset((page - 1) * limit).Limit(limit). + Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationMember{}, nil + } + return members, err +} + +func (r *BunOrganizationMemberRepository) GetByOrganizationIDAndUserID(ctx context.Context, organizationID string, userID string) (*types.OrganizationMember, error) { + member := new(types.OrganizationMember) + err := r.db.NewSelect().Model(member). + Where("organization_id = ? AND user_id = ?", organizationID, userID). + Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return member, err +} + +func (r *BunOrganizationMemberRepository) GetByID(ctx context.Context, memberID string) (*types.OrganizationMember, error) { + member := new(types.OrganizationMember) + err := r.db.NewSelect().Model(member).Where("id = ?", memberID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return member, err +} + +func (r *BunOrganizationMemberRepository) Update(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewUpdate().Model(member).WherePK().Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(member).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return member, nil +} + +func (r *BunOrganizationMemberRepository) Delete(ctx context.Context, memberID string) error { + _, err := r.db.NewDelete().Model(&types.OrganizationMember{}).Where("id = ?", memberID).Exec(ctx) + return err +} + +func (r *BunOrganizationMemberRepository) WithTx(tx bun.IDB) OrganizationMemberRepository { + return &BunOrganizationMemberRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_member_repository_test.go b/plugins/organizations/repositories/bun_organization_member_repository_test.go new file mode 100644 index 00000000..04c33b8c --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_member_repository_test.go @@ -0,0 +1,315 @@ +package repositories + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func newOrganizationMemberRepositoryTestDB(t *testing.T) *bun.DB { + t.Helper() + + sqlDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + t.Cleanup(func() { _ = sqlDB.Close() }) + + db := bun.NewDB(sqlDB, sqlitedialect.New()) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(context.Background(), ` + CREATE TABLE organization_members (id TEXT PRIMARY KEY, organization_id TEXT NOT NULL, user_id TEXT NOT NULL, role TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + `) + require.NoError(t, err) + + return db +} + +func TestBunOrganizationMemberRepository_CreateGetUpdateDelete(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + run func(*testing.T, OrganizationMemberRepository, context.Context, *types.OrganizationMember) + }{ + { + name: "get by id returns created member", + run: func(t *testing.T, orgMemberRepo OrganizationMemberRepository, ctx context.Context, created *types.OrganizationMember) { + t.Helper() + found, err := orgMemberRepo.GetByID(ctx, "mem-1") + require.NoError(t, err) + require.NotNil(t, found) + require.Equal(t, created.Role, found.Role) + }, + }, + { + name: "list by organization returns created member", + run: func(t *testing.T, orgMemberRepo OrganizationMemberRepository, ctx context.Context, created *types.OrganizationMember) { + t.Helper() + members, err := orgMemberRepo.GetAllByOrganizationID(ctx, "org-1", 1, 10) + require.NoError(t, err) + require.Len(t, members, 1) + require.Equal(t, created.ID, members[0].ID) + }, + }, + { + name: "update persists changed role", + run: func(t *testing.T, orgMemberRepo OrganizationMemberRepository, ctx context.Context, created *types.OrganizationMember) { + t.Helper() + created.Role = "admin" + updated, err := orgMemberRepo.Update(ctx, created) + require.NoError(t, err) + require.Equal(t, "admin", updated.Role) + }, + }, + { + name: "delete removes member", + run: func(t *testing.T, orgMemberRepo OrganizationMemberRepository, ctx context.Context, created *types.OrganizationMember) { + t.Helper() + require.NoError(t, orgMemberRepo.Delete(ctx, created.ID)) + found, err := orgMemberRepo.GetByID(ctx, created.ID) + require.NoError(t, err) + require.Nil(t, found) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + orgMemberRepo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + created, err := orgMemberRepo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-2", Role: "member"}) + require.NoError(t, err) + require.NotNil(t, created) + + tt.run(t, orgMemberRepo, ctx, created) + }) + } +} + +func TestBunOrganizationMemberRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + member *types.OrganizationMember + }{ + {name: "member", member: &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}}, + {name: "admin", member: &types.OrganizationMember{ID: "mem-2", OrganizationID: "org-1", UserID: "user-2", Role: "admin"}}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := NewBunOrganizationMemberRepository(newOrganizationMemberRepositoryTestDB(t)) + created, err := repo.Create(context.Background(), tt.member) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.member.ID, created.ID) + require.Equal(t, tt.member.Role, created.Role) + }) + } +} + +func TestBunOrganizationMemberRepository_GetAllByOrganizationID(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + _, err = repo.Create(ctx, &types.OrganizationMember{ID: "mem-2", OrganizationID: "org-1", UserID: "user-2", Role: "admin"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + page int + limit int + expectCount int + }{ + {name: "first page", organizationID: "org-1", page: 1, limit: 10, expectCount: 2}, + {name: "empty result", organizationID: "org-2", page: 1, limit: 10, expectCount: 0}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + members, err := repo.GetAllByOrganizationID(ctx, tt.organizationID, tt.page, tt.limit) + require.NoError(t, err) + require.Len(t, members, tt.expectCount) + }) + } +} + +func TestBunOrganizationMemberRepository_GetByOrganizationIDAndUserID(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + userID string + expectFound bool + }{ + {name: "found", organizationID: "org-1", userID: "user-1", expectFound: true}, + {name: "not found", organizationID: "org-1", userID: "user-2"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByOrganizationIDAndUserID(ctx, tt.organizationID, tt.userID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "mem-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationMemberRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + + tests := []struct { + name string + memberID string + expectFound bool + }{ + {name: "found", memberID: "mem-1", expectFound: true}, + {name: "not found", memberID: "missing"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.memberID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "mem-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationMemberRepository_Update(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + created, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + + tests := []struct { + name string + run func(*testing.T, *types.OrganizationMember) + }{ + { + name: "change role", + run: func(t *testing.T, member *types.OrganizationMember) { + t.Helper() + member.Role = "admin" + updated, err := repo.Update(ctx, member) + require.NoError(t, err) + require.Equal(t, "admin", updated.Role) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.run(t, created) + }) + } +} + +func TestBunOrganizationMemberRepository_Delete(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + + tests := []struct { + name string + memberID string + }{ + {name: "delete existing", memberID: "mem-1"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.NoError(t, repo.Delete(ctx, tt.memberID)) + found, err := repo.GetByID(ctx, tt.memberID) + require.NoError(t, err) + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationMemberRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationMemberRepositoryTestDB(t) + repo := NewBunOrganizationMemberRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + require.IsType(t, &BunOrganizationMemberRepository{}, txRepo) + + created, err := txRepo.Create(ctx, &types.OrganizationMember{ID: "mem-1", OrganizationID: "org-1", UserID: "user-1", Role: "member"}) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "mem-1", created.ID) +} diff --git a/plugins/organizations/repositories/bun_organization_repository.go b/plugins/organizations/repositories/bun_organization_repository.go new file mode 100644 index 00000000..59465835 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_repository.go @@ -0,0 +1,90 @@ +package repositories + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationRepository struct { + db bun.IDB +} + +func NewBunOrganizationRepository(db bun.IDB) OrganizationRepository { + return &BunOrganizationRepository{db: db} +} + +func (r *BunOrganizationRepository) Create(ctx context.Context, organization *types.Organization) (*types.Organization, error) { + if len(organization.Metadata) == 0 { + organization.Metadata = []byte("{}") + } + + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(organization).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(organization).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return organization, nil +} + +func (r *BunOrganizationRepository) GetByID(ctx context.Context, organizationID string) (*types.Organization, error) { + organization := new(types.Organization) + err := r.db.NewSelect().Model(organization).Where("id = ?", organizationID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return organization, err +} + +func (r *BunOrganizationRepository) GetBySlug(ctx context.Context, slug string) (*types.Organization, error) { + organization := new(types.Organization) + err := r.db.NewSelect().Model(organization).Where("slug = ?", slug).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return organization, err +} + +func (r *BunOrganizationRepository) GetAllByOwnerID(ctx context.Context, ownerID string) ([]types.Organization, error) { + organizations := make([]types.Organization, 0) + err := r.db.NewSelect().Model(&organizations).Where("owner_id = ?", ownerID).Scan(ctx) + if err == sql.ErrNoRows { + return []types.Organization{}, nil + } + return organizations, err +} + +func (r *BunOrganizationRepository) Update(ctx context.Context, organization *types.Organization) (*types.Organization, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewUpdate().Model(organization).WherePK().Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(organization).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return organization, nil +} + +func (r *BunOrganizationRepository) Delete(ctx context.Context, organizationID string) error { + _, err := r.db.NewDelete().Model(&types.Organization{}).Where("id = ?", organizationID).Exec(ctx) + return err +} + +func (r *BunOrganizationRepository) WithTx(tx bun.IDB) OrganizationRepository { + return &BunOrganizationRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_repository_test.go b/plugins/organizations/repositories/bun_organization_repository_test.go new file mode 100644 index 00000000..a4f07d7e --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_repository_test.go @@ -0,0 +1,257 @@ +package repositories + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func newOrganizationRepositoryTestDB(t *testing.T) *bun.DB { + t.Helper() + + sqlDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + t.Cleanup(func() { _ = sqlDB.Close() }) + + db := bun.NewDB(sqlDB, sqlitedialect.New()) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(context.Background(), ` + CREATE TABLE organizations (id TEXT PRIMARY KEY, owner_id TEXT NOT NULL, name TEXT NOT NULL, slug TEXT NOT NULL, logo TEXT, metadata TEXT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + `) + require.NoError(t, err) + + return db +} + +func TestBunOrganizationRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + organization *types.Organization + expectMetadata string + }{ + { + name: "defaults metadata", + organization: &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}, + expectMetadata: "{}", + }, + { + name: "keeps provided metadata", + organization: &types.Organization{ID: "org-2", OwnerID: "user-1", Name: "Platform", Slug: "platform", Metadata: []byte(`{"tier":"core"}`)}, + expectMetadata: `{"tier":"core"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + + created, err := repo.Create(context.Background(), tt.organization) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.organization.ID, created.ID) + require.Equal(t, tt.expectMetadata, string(created.Metadata)) + }) + } +} + +func TestBunOrganizationRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc", Metadata: []byte(`{"tier":"core"}`)}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + expectFound bool + }{ + {name: "found", organizationID: "org-1", expectFound: true}, + {name: "not found", organizationID: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.organizationID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "org-1", found.ID) + require.Equal(t, "user-1", found.OwnerID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationRepository_GetBySlug(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc", Metadata: []byte(`{"tier":"core"}`)}) + require.NoError(t, err) + + tests := []struct { + name string + slug string + expectFound bool + }{ + {name: "found", slug: "acme-inc", expectFound: true}, + {name: "not found", slug: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetBySlug(ctx, tt.slug) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "org-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationRepository_GetAllByOwnerID(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}) + require.NoError(t, err) + _, err = repo.Create(ctx, &types.Organization{ID: "org-2", OwnerID: "user-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + ownerID string + expectCount int + }{ + {name: "found", ownerID: "user-1", expectCount: 2}, + {name: "empty", ownerID: "user-2", expectCount: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetAllByOwnerID(ctx, tt.ownerID) + require.NoError(t, err) + require.Len(t, found, tt.expectCount) + }) + } +} + +func TestBunOrganizationRepository_Update(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + created, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc", Metadata: []byte(`{"tier":"core"}`)}) + require.NoError(t, err) + + tests := []struct { + name string + run func(*testing.T, *types.Organization) + }{ + { + name: "update name and logo", + run: func(t *testing.T, organization *types.Organization) { + t.Helper() + organization.Name = "Acme Platform" + logo := new(string) + *logo = "http://example.com/logo.svg" + organization.Logo = logo + updated, err := repo.Update(ctx, organization) + require.NoError(t, err) + require.Equal(t, "Acme Platform", updated.Name) + require.NotNil(t, updated.Logo) + require.Equal(t, "http://example.com/logo.svg", *updated.Logo) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.run(t, created) + }) + } +} + +func TestBunOrganizationRepository_Delete(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + }{ + {name: "delete existing", organizationID: "org-1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.NoError(t, repo.Delete(ctx, tt.organizationID)) + found, err := repo.GetByID(ctx, tt.organizationID) + require.NoError(t, err) + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationRepositoryTestDB(t) + repo := NewBunOrganizationRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + + created, err := txRepo.Create(ctx, &types.Organization{ID: "org-1", OwnerID: "user-1", Name: "Acme Inc", Slug: "acme-inc"}) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "org-1", created.ID) + require.IsType(t, &BunOrganizationRepository{}, txRepo) +} diff --git a/plugins/organizations/repositories/bun_organization_team_member_repository.go b/plugins/organizations/repositories/bun_organization_team_member_repository.go new file mode 100644 index 00000000..c716c007 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_team_member_repository.go @@ -0,0 +1,75 @@ +package repositories + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationTeamMemberRepository struct { + db bun.IDB +} + +func NewBunOrganizationTeamMemberRepository(db bun.IDB) OrganizationTeamMemberRepository { + return &BunOrganizationTeamMemberRepository{db: db} +} + +func (r *BunOrganizationTeamMemberRepository) Create(ctx context.Context, teamMember *types.OrganizationTeamMember) (*types.OrganizationTeamMember, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(teamMember).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(teamMember).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return teamMember, nil +} + +func (r *BunOrganizationTeamMemberRepository) GetByID(ctx context.Context, teamMemberID string) (*types.OrganizationTeamMember, error) { + teamMember := new(types.OrganizationTeamMember) + err := r.db.NewSelect().Model(teamMember).Where("id = ?", teamMemberID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return teamMember, err +} + +func (r *BunOrganizationTeamMemberRepository) GetByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) (*types.OrganizationTeamMember, error) { + teamMember := new(types.OrganizationTeamMember) + err := r.db.NewSelect().Model(teamMember). + Where("team_id = ? AND member_id = ?", teamID, memberID). + Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return teamMember, err +} + +func (r *BunOrganizationTeamMemberRepository) GetAllByTeamID(ctx context.Context, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) { + teamMembers := make([]types.OrganizationTeamMember, 0) + err := r.db.NewSelect().Model(&teamMembers). + Where("team_id = ?", teamID). + Offset((page - 1) * limit).Limit(limit). + Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationTeamMember{}, nil + } + return teamMembers, err +} + +func (r *BunOrganizationTeamMemberRepository) DeleteByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) error { + _, err := r.db.NewDelete().Model(&types.OrganizationTeamMember{}).Where("team_id = ? AND member_id = ?", teamID, memberID).Exec(ctx) + return err +} + +func (r *BunOrganizationTeamMemberRepository) WithTx(tx bun.IDB) OrganizationTeamMemberRepository { + return &BunOrganizationTeamMemberRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_team_member_repository_test.go b/plugins/organizations/repositories/bun_organization_team_member_repository_test.go new file mode 100644 index 00000000..3f43206c --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_team_member_repository_test.go @@ -0,0 +1 @@ +package repositories diff --git a/plugins/organizations/repositories/bun_organization_team_repository.go b/plugins/organizations/repositories/bun_organization_team_repository.go new file mode 100644 index 00000000..4734f912 --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_team_repository.go @@ -0,0 +1,92 @@ +package repositories + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type BunOrganizationTeamRepository struct { + db bun.IDB +} + +func NewBunOrganizationTeamRepository(db bun.IDB) OrganizationTeamRepository { + return &BunOrganizationTeamRepository{db: db} +} + +func (r *BunOrganizationTeamRepository) Create(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) { + if len(team.Metadata) == 0 { + team.Metadata = []byte("{}") + } + + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(team).Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(team).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return team, nil +} + +func (r *BunOrganizationTeamRepository) GetByID(ctx context.Context, teamID string) (*types.OrganizationTeam, error) { + team := new(types.OrganizationTeam) + err := r.db.NewSelect().Model(team).Where("id = ?", teamID).Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return team, err +} + +func (r *BunOrganizationTeamRepository) GetByOrganizationIDAndSlug(ctx context.Context, organizationID, slug string) (*types.OrganizationTeam, error) { + team := new(types.OrganizationTeam) + err := r.db.NewSelect().Model(team). + Where("organization_id = ? AND slug = ?", organizationID, slug). + Scan(ctx) + if err == sql.ErrNoRows { + return nil, nil + } + return team, err +} + +func (r *BunOrganizationTeamRepository) GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationTeam, error) { + teams := make([]types.OrganizationTeam, 0) + err := r.db.NewSelect().Model(&teams).Where("organization_id = ?", organizationID).Scan(ctx) + if err == sql.ErrNoRows { + return []types.OrganizationTeam{}, nil + } + return teams, err +} + +func (r *BunOrganizationTeamRepository) Update(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) { + err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewUpdate().Model(team).WherePK().Exec(ctx) + if err != nil { + return err + } + + return tx.NewSelect().Model(team).WherePK().Scan(ctx) + }) + if err != nil { + return nil, err + } + + return team, nil +} + +func (r *BunOrganizationTeamRepository) Delete(ctx context.Context, teamID string) error { + _, err := r.db.NewDelete().Model(&types.OrganizationTeam{}).Where("id = ?", teamID).Exec(ctx) + return err +} + +func (r *BunOrganizationTeamRepository) WithTx(tx bun.IDB) OrganizationTeamRepository { + return &BunOrganizationTeamRepository{db: tx} +} diff --git a/plugins/organizations/repositories/bun_organization_team_repository_test.go b/plugins/organizations/repositories/bun_organization_team_repository_test.go new file mode 100644 index 00000000..f9b98f6c --- /dev/null +++ b/plugins/organizations/repositories/bun_organization_team_repository_test.go @@ -0,0 +1,438 @@ +package repositories + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + + "github.com/Authula/authula/plugins/organizations/types" +) + +func newOrganizationTeamRepositoryTestDB(t *testing.T) *bun.DB { + t.Helper() + + sqlDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + t.Cleanup(func() { _ = sqlDB.Close() }) + + db := bun.NewDB(sqlDB, sqlitedialect.New()) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(context.Background(), ` + CREATE TABLE organizations (id TEXT PRIMARY KEY, owner_id TEXT NOT NULL, name TEXT NOT NULL, slug TEXT NOT NULL, logo TEXT, metadata TEXT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + CREATE TABLE organization_members (id TEXT PRIMARY KEY, organization_id TEXT NOT NULL, user_id TEXT NOT NULL, role TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + CREATE TABLE organization_teams (id TEXT PRIMARY KEY, organization_id TEXT NOT NULL, name TEXT NOT NULL, slug TEXT NOT NULL, description TEXT, metadata TEXT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + CREATE TABLE organization_team_members (id TEXT PRIMARY KEY, team_id TEXT NOT NULL, member_id TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP); + `) + require.NoError(t, err) + + return db +} + +func TestBunOrganizationTeamRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + team *types.OrganizationTeam + expect string + }{ + { + name: "defaults metadata", + team: &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}, + expect: "{}", + }, + { + name: "keeps metadata and description", + team: func() *types.OrganizationTeam { + description := new(string) + *description = "Core" + return &types.OrganizationTeam{ID: "team-2", OrganizationID: "org-1", Name: "Core", Slug: "core", Description: description, Metadata: []byte(`{"tier":"core"}`)} + }(), + expect: `{"tier":"core"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := NewBunOrganizationTeamRepository(newOrganizationTeamRepositoryTestDB(t)) + created, err := repo.Create(context.Background(), tt.team) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.team.ID, created.ID) + require.Equal(t, tt.expect, string(created.Metadata)) + }) + } +} + +func TestBunOrganizationTeamRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + expectFound bool + }{ + {name: "found", teamID: "team-1", expectFound: true}, + {name: "not found", teamID: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.teamID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "team-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamRepository_GetByOrganizationIDAndSlug(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + slug string + expectFound bool + }{ + {name: "found", organizationID: "org-1", slug: "platform", expectFound: true}, + {name: "not found", organizationID: "org-1", slug: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByOrganizationIDAndSlug(ctx, tt.organizationID, tt.slug) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "team-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamRepository_GetAllByOrganizationID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + _, err = repo.Create(ctx, &types.OrganizationTeam{ID: "team-2", OrganizationID: "org-1", Name: "Core", Slug: "core"}) + require.NoError(t, err) + + tests := []struct { + name string + organizationID string + expectCount int + }{ + {name: "found", organizationID: "org-1", expectCount: 2}, + {name: "empty", organizationID: "org-2", expectCount: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetAllByOrganizationID(ctx, tt.organizationID) + require.NoError(t, err) + require.Len(t, found, tt.expectCount) + }) + } +} + +func TestBunOrganizationTeamRepository_Update(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + created, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + run func(*testing.T, *types.OrganizationTeam) + }{ + { + name: "change name and description", + run: func(t *testing.T, team *types.OrganizationTeam) { + t.Helper() + team.Name = "Platform Team" + description := new(string) + *description = "Core platform" + team.Description = description + updated, err := repo.Update(ctx, team) + require.NoError(t, err) + require.Equal(t, "Platform Team", updated.Name) + require.NotNil(t, updated.Description) + require.Equal(t, "Core platform", *updated.Description) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.run(t, created) + }) + } +} + +func TestBunOrganizationTeamRepository_Delete(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + }{ + {name: "delete existing", teamID: "team-1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.NoError(t, repo.Delete(ctx, tt.teamID)) + found, err := repo.GetByID(ctx, tt.teamID) + require.NoError(t, err) + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + require.IsType(t, &BunOrganizationTeamRepository{}, txRepo) + + created, err := txRepo.Create(ctx, &types.OrganizationTeam{ID: "team-1", OrganizationID: "org-1", Name: "Platform", Slug: "platform"}) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "team-1", created.ID) +} + +func TestBunOrganizationTeamMemberRepository_Create(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + teamMember *types.OrganizationTeamMember + }{ + {name: "team member", teamMember: &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}}, + {name: "another team member", teamMember: &types.OrganizationTeamMember{ID: "team-member-2", TeamID: "team-1", MemberID: "member-2"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := NewBunOrganizationTeamMemberRepository(newOrganizationTeamRepositoryTestDB(t)) + created, err := repo.Create(context.Background(), tt.teamMember) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, tt.teamMember.ID, created.ID) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_GetByID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + + tests := []struct { + name string + teamMemberID string + expectFound bool + }{ + {name: "found", teamMemberID: "team-member-1", expectFound: true}, + {name: "not found", teamMemberID: "missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByID(ctx, tt.teamMemberID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "team-member-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_GetByTeamIDAndMemberID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + memberID string + expectFound bool + }{ + {name: "found", teamID: "team-1", memberID: "member-1", expectFound: true}, + {name: "not found", teamID: "team-1", memberID: "member-2"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetByTeamIDAndMemberID(ctx, tt.teamID, tt.memberID) + require.NoError(t, err) + if tt.expectFound { + require.NotNil(t, found) + require.Equal(t, "team-member-1", found.ID) + return + } + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_GetAllByTeamID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + _, err = repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-2", TeamID: "team-1", MemberID: "member-2"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + page int + limit int + expectCount int + }{ + {name: "found", teamID: "team-1", page: 1, limit: 10, expectCount: 2}, + {name: "empty", teamID: "team-2", page: 1, limit: 10, expectCount: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + found, err := repo.GetAllByTeamID(ctx, tt.teamID, tt.page, tt.limit) + require.NoError(t, err) + require.Len(t, found, tt.expectCount) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_DeleteByTeamIDAndMemberID(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + _, err := repo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + + tests := []struct { + name string + teamID string + memberID string + }{ + {name: "delete existing", teamID: "team-1", memberID: "member-1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.NoError(t, repo.DeleteByTeamIDAndMemberID(ctx, tt.teamID, tt.memberID)) + found, err := repo.GetByTeamIDAndMemberID(ctx, tt.teamID, tt.memberID) + require.NoError(t, err) + require.Nil(t, found) + }) + } +} + +func TestBunOrganizationTeamMemberRepository_WithTx(t *testing.T) { + t.Parallel() + + db := newOrganizationTeamRepositoryTestDB(t) + repo := NewBunOrganizationTeamMemberRepository(db) + ctx := context.Background() + + txRepo := repo.WithTx(db) + require.NotNil(t, txRepo) + require.IsType(t, &BunOrganizationTeamMemberRepository{}, txRepo) + + created, err := txRepo.Create(ctx, &types.OrganizationTeamMember{ID: "team-member-1", TeamID: "team-1", MemberID: "member-1"}) + require.NoError(t, err) + require.NotNil(t, created) + require.Equal(t, "team-member-1", created.ID) +} diff --git a/plugins/organizations/repositories/interfaces.go b/plugins/organizations/repositories/interfaces.go new file mode 100644 index 00000000..5cafa782 --- /dev/null +++ b/plugins/organizations/repositories/interfaces.go @@ -0,0 +1,58 @@ +package repositories + +import ( + "context" + + "github.com/uptrace/bun" + + "github.com/Authula/authula/plugins/organizations/types" +) + +type OrganizationRepository interface { + Create(ctx context.Context, organization *types.Organization) (*types.Organization, error) + GetByID(ctx context.Context, organizationID string) (*types.Organization, error) + GetBySlug(ctx context.Context, slug string) (*types.Organization, error) + GetAllByOwnerID(ctx context.Context, ownerID string) ([]types.Organization, error) + Update(ctx context.Context, organization *types.Organization) (*types.Organization, error) + Delete(ctx context.Context, organizationID string) error + WithTx(tx bun.IDB) OrganizationRepository +} + +type OrganizationInvitationRepository interface { + Create(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) + GetByID(ctx context.Context, invitationID string) (*types.OrganizationInvitation, error) + GetByOrganizationIDAndEmail(ctx context.Context, organizationID string, email string) (*types.OrganizationInvitation, error) + GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationInvitation, error) + GetAllPendingByEmail(ctx context.Context, email string) ([]types.OrganizationInvitation, error) + Update(ctx context.Context, invitation *types.OrganizationInvitation) (*types.OrganizationInvitation, error) + WithTx(tx bun.IDB) OrganizationInvitationRepository +} + +type OrganizationMemberRepository interface { + Create(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) + GetAllByOrganizationID(ctx context.Context, organizationID string, page int, limit int) ([]types.OrganizationMember, error) + GetByOrganizationIDAndUserID(ctx context.Context, organizationID string, userID string) (*types.OrganizationMember, error) + GetByID(ctx context.Context, memberID string) (*types.OrganizationMember, error) + Update(ctx context.Context, member *types.OrganizationMember) (*types.OrganizationMember, error) + Delete(ctx context.Context, memberID string) error + WithTx(tx bun.IDB) OrganizationMemberRepository +} + +type OrganizationTeamRepository interface { + Create(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) + GetByID(ctx context.Context, teamID string) (*types.OrganizationTeam, error) + GetByOrganizationIDAndSlug(ctx context.Context, organizationID, slug string) (*types.OrganizationTeam, error) + GetAllByOrganizationID(ctx context.Context, organizationID string) ([]types.OrganizationTeam, error) + Update(ctx context.Context, team *types.OrganizationTeam) (*types.OrganizationTeam, error) + Delete(ctx context.Context, teamID string) error + WithTx(tx bun.IDB) OrganizationTeamRepository +} + +type OrganizationTeamMemberRepository interface { + Create(ctx context.Context, teamMember *types.OrganizationTeamMember) (*types.OrganizationTeamMember, error) + GetByID(ctx context.Context, teamMemberID string) (*types.OrganizationTeamMember, error) + GetByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) (*types.OrganizationTeamMember, error) + GetAllByTeamID(ctx context.Context, teamID string, page int, limit int) ([]types.OrganizationTeamMember, error) + DeleteByTeamIDAndMemberID(ctx context.Context, teamID, memberID string) error + WithTx(tx bun.IDB) OrganizationTeamMemberRepository +} diff --git a/plugins/organizations/routes.go b/plugins/organizations/routes.go new file mode 100644 index 00000000..4db4515c --- /dev/null +++ b/plugins/organizations/routes.go @@ -0,0 +1,167 @@ +package organizations + +import ( + "net/http" + + "github.com/Authula/authula/models" + "github.com/Authula/authula/plugins/organizations/handlers" +) + +func Routes(plugin *OrganizationsPlugin) []models.Route { + if plugin == nil || plugin.Api == nil { + return []models.Route{} + } + + createOrganizationHandler := &handlers.CreateOrganizationHandler{UseCase: plugin.organizationUseCase} + getAllOrganizationsByUserIDHandler := &handlers.GetAllOrganizationsByUserIDHandler{UseCase: plugin.organizationUseCase} + getOrganizationHandler := &handlers.GetOrganizationHandler{UseCase: plugin.organizationUseCase} + updateOrganizationHandler := &handlers.UpdateOrganizationHandler{UseCase: plugin.organizationUseCase} + deleteOrganizationHandler := &handlers.DeleteOrganizationHandler{UseCase: plugin.organizationUseCase} + createInvitationHandler := &handlers.CreateOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + getInvitationHandler := &handlers.GetOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + listInvitationsHandler := &handlers.GetAllOrganizationInvitationsHandler{UseCase: plugin.invitationUseCase} + revokeInvitationHandler := &handlers.RevokeOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + acceptInvitationHandler := &handlers.AcceptOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + rejectInvitationHandler := &handlers.RejectOrganizationInvitationHandler{UseCase: plugin.invitationUseCase} + addMemberHandler := &handlers.AddOrganizationMemberHandler{UseCase: plugin.memberUseCase} + getAllMembersHandler := &handlers.GetAllOrganizationMembersHandler{UseCase: plugin.memberUseCase} + getMemberHandler := &handlers.GetOrganizationMemberHandler{UseCase: plugin.memberUseCase} + updateMemberHandler := &handlers.UpdateOrganizationMemberHandler{UseCase: plugin.memberUseCase} + deleteMemberHandler := &handlers.DeleteOrganizationMemberHandler{UseCase: plugin.memberUseCase} + createTeamHandler := &handlers.CreateOrganizationTeamHandler{UseCase: plugin.teamUseCase} + getAllTeamsHandler := &handlers.GetAllOrganizationTeamsHandler{UseCase: plugin.teamUseCase} + updateTeamHandler := &handlers.UpdateOrganizationTeamHandler{UseCase: plugin.teamUseCase} + deleteTeamHandler := &handlers.DeleteOrganizationTeamHandler{UseCase: plugin.teamUseCase} + addTeamMemberHandler := &handlers.AddOrganizationTeamMemberHandler{UseCase: plugin.teamUseCase} + getTeamMemberHandler := &handlers.GetOrganizationTeamMemberHandler{UseCase: plugin.teamUseCase} + getAllTeamMembersHandler := &handlers.GetAllOrganizationTeamMembersHandler{UseCase: plugin.teamUseCase} + deleteTeamMemberHandler := &handlers.DeleteOrganizationTeamMemberHandler{UseCase: plugin.teamUseCase} + + return []models.Route{ + // Organizations + { + Method: http.MethodPost, + Path: "/organizations", + Handler: createOrganizationHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations", + Handler: getAllOrganizationsByUserIDHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}", + Handler: getOrganizationHandler.Handler(), + }, + { + Method: http.MethodPatch, + Path: "/organizations/{organization_id}", + Handler: updateOrganizationHandler.Handler(), + }, + { + Method: http.MethodDelete, + Path: "/organizations/{organization_id}", + Handler: deleteOrganizationHandler.Handler(), + }, + // Organization Members + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/members", + Handler: addMemberHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/members", + Handler: getAllMembersHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/members/{member_id}", + Handler: getMemberHandler.Handler(), + }, + { + Method: http.MethodPatch, + Path: "/organizations/{organization_id}/members/{member_id}", + Handler: updateMemberHandler.Handler(), + }, + { + Method: http.MethodDelete, + Path: "/organizations/{organization_id}/members/{member_id}", + Handler: deleteMemberHandler.Handler(), + }, + // Teams + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/teams", + Handler: createTeamHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/teams", + Handler: getAllTeamsHandler.Handler(), + }, + { + Method: http.MethodPatch, + Path: "/organizations/{organization_id}/teams/{team_id}", + Handler: updateTeamHandler.Handler(), + }, + { + Method: http.MethodDelete, + Path: "/organizations/{organization_id}/teams/{team_id}", + Handler: deleteTeamHandler.Handler(), + }, + // Team Members + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/teams/{team_id}/members", + Handler: addTeamMemberHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/teams/{team_id}/members/{member_id}", + Handler: getTeamMemberHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/teams/{team_id}/members", + Handler: getAllTeamMembersHandler.Handler(), + }, + { + Method: http.MethodDelete, + Path: "/organizations/{organization_id}/teams/{team_id}/members/{member_id}", + Handler: deleteTeamMemberHandler.Handler(), + }, + // Invitations + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/invitations", + Handler: createInvitationHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/invitations", + Handler: listInvitationsHandler.Handler(), + }, + { + Method: http.MethodGet, + Path: "/organizations/{organization_id}/invitations/{invitation_id}", + Handler: getInvitationHandler.Handler(), + }, + { + Method: http.MethodPatch, + Path: "/organizations/{organization_id}/invitations/{invitation_id}", + Handler: revokeInvitationHandler.Handler(), + }, + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/invitations/{invitation_id}/accept", + Handler: acceptInvitationHandler.Handler(), + }, + { + Method: http.MethodPost, + Path: "/organizations/{organization_id}/invitations/{invitation_id}/reject", + Handler: rejectInvitationHandler.Handler(), + }, + } +} diff --git a/plugins/organizations/services/organization_invitation_service.go b/plugins/organizations/services/organization_invitation_service.go new file mode 100644 index 00000000..bfe121c4 --- /dev/null +++ b/plugins/organizations/services/organization_invitation_service.go @@ -0,0 +1,583 @@ +package services + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "html" + "net/mail" + "net/url" + "strings" + "time" + + "github.com/uptrace/bun" + + internalerrors "github.com/Authula/authula/internal/errors" + "github.com/Authula/authula/internal/util" + "github.com/Authula/authula/models" + orgconstants "github.com/Authula/authula/plugins/organizations/constants" + orgevents "github.com/Authula/authula/plugins/organizations/events" + "github.com/Authula/authula/plugins/organizations/repositories" + "github.com/Authula/authula/plugins/organizations/types" + rootservices "github.com/Authula/authula/services" +) + +type OrganizationInvitationHookExecutor interface { + BeforeCreateOrganizationInvitation(invitation *types.OrganizationInvitation) error + AfterCreateOrganizationInvitation(invitation types.OrganizationInvitation) error + BeforeUpdateOrganizationInvitation(invitation *types.OrganizationInvitation) error + AfterUpdateOrganizationInvitation(invitation types.OrganizationInvitation) error +} + +type organizationInvitationTxRunner interface { + RunInTx(ctx context.Context, opts *sql.TxOptions, fn func(context.Context, bun.Tx) error) error +} + +type OrganizationInvitationService struct { + txRunner organizationInvitationTxRunner + globalConfig *models.Config + pluginConfig *types.OrganizationsPluginConfig + logger models.Logger + mailerService rootservices.MailerService + eventBus models.EventBus + userService rootservices.UserService + accessControlService rootservices.AccessControlService + organizationRepo repositories.OrganizationRepository + orgInvitationRepo repositories.OrganizationInvitationRepository + orgMemberRepo repositories.OrganizationMemberRepository + orgInvitationHooks OrganizationInvitationHookExecutor + orgMemberHooks OrganizationMemberHookExecutor +} + +func NewOrganizationInvitationService( + txRunner organizationInvitationTxRunner, + globalConfig *models.Config, + pluginConfig *types.OrganizationsPluginConfig, + logger models.Logger, + userService rootservices.UserService, + accessControlService rootservices.AccessControlService, + organizationRepo repositories.OrganizationRepository, + orgInvitationRepo repositories.OrganizationInvitationRepository, + orgMemberRepo repositories.OrganizationMemberRepository, + mailerService rootservices.MailerService, + eventBus models.EventBus, + orgInvitationHooks OrganizationInvitationHookExecutor, + orgMemberHooks OrganizationMemberHookExecutor, +) *OrganizationInvitationService { + return &OrganizationInvitationService{ + globalConfig: globalConfig, + pluginConfig: pluginConfig, + logger: logger, + mailerService: mailerService, + eventBus: eventBus, + userService: userService, + accessControlService: accessControlService, + organizationRepo: organizationRepo, + orgInvitationRepo: orgInvitationRepo, + orgMemberRepo: orgMemberRepo, + orgInvitationHooks: orgInvitationHooks, + orgMemberHooks: orgMemberHooks, + txRunner: txRunner, + } +} + +func (s *OrganizationInvitationService) CreateOrganizationInvitation(ctx context.Context, actorUserID string, organizationID string, request types.CreateOrganizationInvitationRequest) (*types.OrganizationInvitation, error) { + if actorUserID == "" || organizationID == "" { + return nil, internalerrors.ErrUnauthorized + } + + organization, err := s.organizationRepo.GetByID(ctx, organizationID) + if err != nil { + return nil, err + } + if organization == nil { + return nil, internalerrors.ErrNotFound + } + if organization.OwnerID != actorUserID { + return nil, internalerrors.ErrForbidden + } + + email := strings.ToLower(request.Email) + if email == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, internalerrors.ErrUnprocessableEntity + } + + role := request.Role + if role == "" { + return nil, internalerrors.ErrUnprocessableEntity + } + + exists, err := s.accessControlService.RoleExists(ctx, role) + if err != nil { + return nil, err + } + if !exists { + return nil, fmt.Errorf("role %s doesn't exist", role) + } + + if existing, err := s.orgInvitationRepo.GetByOrganizationIDAndEmail(ctx, organizationID, email); err != nil { + return nil, err + } else if existing != nil { + if err := s.expireOrganizationInvitationIfNeeded(ctx, existing); err != nil { + return nil, err + } + if existing.Status == types.OrganizationInvitationStatusPending { + return nil, internalerrors.ErrConflict + } + } + + expiresAt := time.Now().UTC().Add(s.pluginConfig.InvitationExpiresIn) + if !expiresAt.After(time.Now().UTC()) { + return nil, internalerrors.ErrUnprocessableEntity + } + invitation := &types.OrganizationInvitation{ + ID: util.GenerateUUID(), + Email: email, + InviterID: actorUserID, + OrganizationID: organizationID, + Role: role, + Status: types.OrganizationInvitationStatusPending, + ExpiresAt: expiresAt, + } + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.BeforeCreateOrganizationInvitation(invitation); err != nil { + return nil, err + } + } + + created, err := s.orgInvitationRepo.Create(ctx, invitation) + if err != nil { + return nil, err + } + + if s.orgInvitationHooks != nil { + if err := s.orgInvitationHooks.AfterCreateOrganizationInvitation(*created); err != nil { + return nil, err + } + } + + s.publishOrganizationInvitationCreatedEvent(created, organization, request.RedirectURL) + s.sendOrganizationInvitationEmailAsync(ctx, created, organization, request.RedirectURL) + + return created, nil +} + +func (s *OrganizationInvitationService) publishOrganizationInvitationCreatedEvent(invitation *types.OrganizationInvitation, organization *types.Organization, redirectURL string) { + payload, err := json.Marshal(orgevents.OrganizationInvitationCreatedEvent{ + ID: util.GenerateUUID(), + InvitationID: invitation.ID, + OrganizationID: invitation.OrganizationID, + OrganizationName: organization.Name, + InviteeEmail: invitation.Email, + InviterID: invitation.InviterID, + Role: invitation.Role, + ExpiresAt: invitation.ExpiresAt, + RedirectURL: redirectURL, + }) + if err != nil { + s.logger.Error("failed to marshal organization invitation created event", "error", err) + return + } + + util.PublishEventAsync(s.eventBus, s.logger, models.Event{ + ID: util.GenerateUUID(), + Type: orgconstants.EventOrganizationsInvitationCreated, + Timestamp: time.Now().UTC(), + Payload: payload, + }) +} + +func (s *OrganizationInvitationService) sendOrganizationInvitationEmailAsync(ctx context.Context, invitation *types.OrganizationInvitation, organization *types.Organization, redirectURL string) { + if s.mailerService == nil { + s.logger.Warn("mailer service not available, skipping organization invitation email") + return + } + + go func() { + detachedCtx := context.WithoutCancel(ctx) + taskCtx, cancel := context.WithTimeout(detachedCtx, 15*time.Second) + defer cancel() + + if err := s.sendOrganizationInvitationEmail(taskCtx, invitation, organization, redirectURL); err != nil { + s.logger.Error("failed to send organization invitation email", "invitation_id", invitation.ID, "error", err) + } + }() +} + +func (s *OrganizationInvitationService) sendOrganizationInvitationEmail(ctx context.Context, invitation *types.OrganizationInvitation, organization *types.Organization, redirectURL string) error { + acceptURL := s.buildOrganizationInvitationAcceptURL(invitation, redirectURL) + appName := "Authula" + if s.globalConfig.AppName != "" { + appName = s.globalConfig.AppName + } + subject := fmt.Sprintf("You're invited to join %s on %s", organization.Name, appName) + textBody := fmt.Sprintf("You have been invited to join %s on %s as %s. Open this link to accept the invitation: %s", organization.Name, appName, invitation.Role, acceptURL) + htmlBody := fmt.Sprintf(` +
Hello,
+You have been invited to join %s on %s as %s.
+ +If the button does not work, copy this link:
+ +