Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 269 additions & 3 deletions invokeai/app/api/routers/auth.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
"""Authentication endpoints."""

import secrets
import string
from datetime import timedelta
from typing import Annotated

from fastapi import APIRouter, Body, HTTPException, status
from fastapi import APIRouter, Body, HTTPException, Path, status
from pydantic import BaseModel, Field, field_validator

from invokeai.app.api.auth_dependencies import CurrentUser
from invokeai.app.api.auth_dependencies import AdminUser, CurrentUser
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.auth.token_service import TokenData, create_access_token
from invokeai.app.services.users.users_common import UserCreateRequest, UserDTO, validate_email_with_special_domains
from invokeai.app.services.users.users_common import (
UserCreateRequest,
UserDTO,
UserUpdateRequest,
validate_email_with_special_domains,
)

auth_router = APIRouter(prefix="/v1/auth", tags=["authentication"])

Expand Down Expand Up @@ -246,3 +253,262 @@ async def setup_admin(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e

return SetupResponse(success=True, user=user)


# ---------------------------------------------------------------------------
# User management models
# ---------------------------------------------------------------------------

_PASSWORD_ALPHABET = string.ascii_letters + string.digits + string.punctuation


class AdminUserCreateRequest(BaseModel):
"""Request body for admin to create a new user."""

email: str = Field(description="User email address")
display_name: str | None = Field(default=None, description="Display name")
password: str = Field(description="User password")
is_admin: bool = Field(default=False, description="Whether user should have admin privileges")

@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email address, allowing special-use domains."""
return validate_email_with_special_domains(v)


class AdminUserUpdateRequest(BaseModel):
"""Request body for admin to update any user."""

display_name: str | None = Field(default=None, description="Display name")
password: str | None = Field(default=None, description="New password")
is_admin: bool | None = Field(default=None, description="Whether user should have admin privileges")
is_active: bool | None = Field(default=None, description="Whether user account should be active")


class UserProfileUpdateRequest(BaseModel):
"""Request body for a user to update their own profile."""

display_name: str | None = Field(default=None, description="New display name")
current_password: str | None = Field(default=None, description="Current password (required when changing password)")
new_password: str | None = Field(default=None, description="New password")


class GeneratePasswordResponse(BaseModel):
"""Response containing a generated password."""

password: str = Field(description="Generated strong password")


# ---------------------------------------------------------------------------
# User management endpoints
# ---------------------------------------------------------------------------


@auth_router.get("/generate-password", response_model=GeneratePasswordResponse)
async def generate_password(
current_user: CurrentUser,
) -> GeneratePasswordResponse:
"""Generate a strong random password.

Returns a cryptographically secure random password of 16 characters
containing uppercase, lowercase, digits, and punctuation.
"""
# Ensure the generated password always meets strength requirements:
# at least one uppercase, one lowercase, one digit, one special char.
while True:
password = "".join(secrets.choice(_PASSWORD_ALPHABET) for _ in range(16))
if (
any(c.isupper() for c in password)
and any(c.islower() for c in password)
and any(c.isdigit() for c in password)
):
return GeneratePasswordResponse(password=password)


@auth_router.get("/users", response_model=list[UserDTO])
async def list_users(
current_user: AdminUser,
) -> list[UserDTO]:
"""List all users. Requires admin privileges.

The internal 'system' user (created for backward compatibility) is excluded
from the results since it cannot be managed through this interface.

Returns:
List of all real users (system user excluded)
"""
user_service = ApiDependencies.invoker.services.users
return [u for u in user_service.list_users() if u.user_id != "system"]


@auth_router.post("/users", response_model=UserDTO, status_code=status.HTTP_201_CREATED)
async def create_user(
request: Annotated[AdminUserCreateRequest, Body(description="New user details")],
current_user: AdminUser,
) -> UserDTO:
"""Create a new user. Requires admin privileges.

Args:
request: New user details

Returns:
The created user

Raises:
HTTPException: 400 if email already exists or password is weak
"""
user_service = ApiDependencies.invoker.services.users
try:
user_data = UserCreateRequest(
email=request.email,
display_name=request.display_name,
password=request.password,
is_admin=request.is_admin,
)
return user_service.create(user_data)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e


@auth_router.get("/users/{user_id}", response_model=UserDTO)
async def get_user(
user_id: Annotated[str, Path(description="User ID")],
current_user: AdminUser,
) -> UserDTO:
"""Get a user by ID. Requires admin privileges.

Args:
user_id: The user ID

Returns:
The user

Raises:
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
user = user_service.get(user_id)
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return user


@auth_router.patch("/users/{user_id}", response_model=UserDTO)
async def update_user(
user_id: Annotated[str, Path(description="User ID")],
request: Annotated[AdminUserUpdateRequest, Body(description="User fields to update")],
current_user: AdminUser,
) -> UserDTO:
"""Update a user. Requires admin privileges.

Args:
user_id: The user ID
request: Fields to update

Returns:
The updated user

Raises:
HTTPException: 400 if password is weak
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
try:
changes = UserUpdateRequest(
display_name=request.display_name,
password=request.password,
is_admin=request.is_admin,
is_active=request.is_active,
)
return user_service.update(user_id, changes)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e


@auth_router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(
user_id: Annotated[str, Path(description="User ID")],
current_user: AdminUser,
) -> None:
"""Delete a user. Requires admin privileges.

Admins can delete any user including other admins, but cannot delete the last
remaining admin.

Args:
user_id: The user ID

Raises:
HTTPException: 400 if attempting to delete the last admin
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
user = user_service.get(user_id)
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")

# Prevent deleting the last active admin
if user.is_admin and user.is_active and user_service.count_admins() <= 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete the last administrator",
)

try:
user_service.delete(user_id)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e


@auth_router.patch("/me", response_model=UserDTO)
async def update_current_user(
request: Annotated[UserProfileUpdateRequest, Body(description="Profile fields to update")],
current_user: CurrentUser,
) -> UserDTO:
"""Update the current user's own profile.

To change the password, both ``current_password`` and ``new_password`` must
be provided. The current password is verified before the change is applied.

Args:
request: Profile fields to update
current_user: The authenticated user

Returns:
The updated user

Raises:
HTTPException: 400 if current password is incorrect or new password is weak
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users

# Verify current password when attempting a password change
if request.new_password is not None:
if not request.current_password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is required to set a new password",
)

# Re-authenticate to verify the current password
user = user_service.get(current_user.user_id)
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")

authenticated = user_service.authenticate(user.email, request.current_password)
if authenticated is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect",
)

try:
changes = UserUpdateRequest(
display_name=request.display_name,
password=request.new_password,
)
return user_service.update(current_user.user_id, changes)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
9 changes: 9 additions & 0 deletions invokeai/app/services/users/users_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,12 @@ def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
List of users
"""
pass

@abstractmethod
def count_admins(self) -> int:
"""Count active admin users.
Returns:
The number of active admin users
"""
pass
7 changes: 7 additions & 0 deletions invokeai/app/services/users/users_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,10 @@ def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
)
for row in rows
]

def count_admins(self) -> int:
"""Count active admin users."""
with self._db.transaction() as cursor:
cursor.execute("SELECT COUNT(*) FROM users WHERE is_admin = TRUE AND is_active = TRUE")
row = cursor.fetchone()
return int(row[0]) if row else 0
Loading