Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/two_factor_authentication/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .user_2fa import Users2FA
from .user_2fa import Users2FA
20 changes: 15 additions & 5 deletions app/two_factor_authentication/models/user_2fa.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
from typing import TYPE_CHECKING
from uuid import UUID

from app.common.models.base_class import Base
from sqlalchemy import ForeignKey
from sqlalchemy.orm import (
Mapped,
mapped_column,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.common.models.base_class import Base


if TYPE_CHECKING:
from app.users.models import User


class Users2FA(Base):
__tablename__ = "users_2fa"

# fields
secret_key: Mapped[str]
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
active: Mapped[bool]

# relationships
user: Mapped["User"] = relationship(
back_populates="two_factor_authentications",
)
2 changes: 2 additions & 0 deletions app/two_factor_authentication/use_cases/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .create_new_user_2fa_use_case import CreateNewUser2FAUseCase
from .verify_user_2fa_use_case import VerifyUser2FAUseCase
2 changes: 1 addition & 1 deletion app/users/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .user import User # noqa
from .user import User
19 changes: 16 additions & 3 deletions app/users/models/user.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from sqlalchemy.orm import Mapped, mapped_column
from typing import TYPE_CHECKING

from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import String

from app.common.models.base_class import Base

from app.users.constants.user_constants import USER_EMAIL_MAX_LENGTH


if TYPE_CHECKING:
from app.two_factor_authentication.models.user_2fa import Users2FA


class User(Base):
__tablename__ = "users"

# fields
email: Mapped[str] = mapped_column(
String(USER_EMAIL_MAX_LENGTH), unique=True
String(USER_EMAIL_MAX_LENGTH),
unique=True,
)
hashed_password: Mapped[str]

# relationships
two_factor_authentications: Mapped[list["Users2FA"]] = relationship(
back_populates="user",
)
1 change: 1 addition & 0 deletions app/users/repositories/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .users_repository import UsersRepository, users_repository
31 changes: 17 additions & 14 deletions tests/auth/api/endpoints/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,48 @@
from sqlalchemy.orm import Session

from fastapi.testclient import TestClient

from tests.utils.create_user import create_user
from tests.factories import UserFactory


login_path = "api/v1/auth/login"


class TestLogin:
def test_login(self, client: TestClient, session: Session):
created_user = create_user(session)
def test_login(
self, client: TestClient, user_factory: UserFactory
) -> None:
user_password = "password"

user = user_factory.create(password=user_password)

response = client.post(
login_path,
json={
"email": created_user.email,
"password": "password",
"email": user.email,
"password": user_password,
},
)

assert response.status_code == 204

def test_login_incorrect_password(
self, client: TestClient, session: Session
):
created_user = create_user(session)
self, client: TestClient, user_factory: UserFactory
) -> None:
user = user_factory.create(password="password")

response = client.post(
login_path,
json={
"email": created_user.email,
"email": user.email,
"password": "incorrect_password",
},
)

assert response.status_code == 401

def test_login_non_existent_email(
self, client: TestClient, session: Session
):
create_user(session)
self, client: TestClient, user_factory: UserFactory
) -> None:
user_factory.create(email="test@user.com")

response = client.post(
login_path,
Expand Down
14 changes: 7 additions & 7 deletions tests/auth/api/endpoints/test_reset_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

from app.auth.schemas.token_schema import EmailTokenPayload
from app.auth.utils.security import create_access_token, verify_password
from app.users.models.user import User
from tests.utils.create_user import create_user

from tests.factories import UserFactory


class TestResetPasswordEndpoint:
def test_reset_password_success(
self, client: TestClient, session: Session
self, client: TestClient, user_factory: UserFactory, session: Session
) -> None:
created_user = create_user(session)
user = user_factory.create()
reset_token = create_access_token(
EmailTokenPayload(user_email=created_user.email)
EmailTokenPayload(user_email=user.email)
)
new_password = "MyNewSecurePassword123!"

Expand All @@ -25,8 +25,8 @@ def test_reset_password_success(

assert response.status_code == 204

user = session.query(User).filter(User.id == created_user.id).first()
assert user and verify_password(new_password, user.hashed_password)
session.refresh(user)
assert verify_password(new_password, user.hashed_password)

def test_reset_password_invalid_token(
self,
Expand Down
29 changes: 17 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import Generator

from fastapi.testclient import TestClient
import pytest
from sqlalchemy import RootTransaction, event
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from app.common.api.dependencies.get_session import get_session
from app.core.config import get_settings
from app.core.config import settings
from app.db.session import engine, SessionLocal
from app.main import app
from sqlalchemy.orm import Session

settings = get_settings()
from tests.factories import ClientFactory, UserFactory


TEST_DATABASE_URI = settings.SQLALCHEMY_DATABASE_URI


Expand Down Expand Up @@ -39,12 +41,15 @@ def end_savepoint(session: Session, transaction: RootTransaction) -> None:


@pytest.fixture()
def client(session: Session) -> Generator:
# Use the same session as the session fixture
def override_get_session() -> Generator:
yield session
def client_factory(session: Session) -> ClientFactory:
return ClientFactory(app, session)


app.dependency_overrides[get_session] = override_get_session
@pytest.fixture()
def client(client_factory: ClientFactory) -> TestClient:
return client_factory.create()

with TestClient(app) as client:
yield client

@pytest.fixture()
def user_factory(session: Session) -> UserFactory:
return UserFactory(session)
2 changes: 2 additions & 0 deletions tests/factories/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .client_factory import ClientFactory
from .user_factory import UserFactory
34 changes: 34 additions & 0 deletions tests/factories/client_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Generator
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from app.auth.schemas.token_schema import TokenPayload
from app.auth.utils import security
from app.common.api.dependencies.get_session import get_session
from app.users.models.user import User


class ClientFactory:
def __init__(self, app: FastAPI, session: Session):
self._app = app
self._session = session

def _get_session() -> Generator:
yield self._session

self._app.dependency_overrides[get_session] = _get_session

def create(self, user_to_log_in: User | None = None, /) -> TestClient:
client = TestClient(self._app)

if not user_to_log_in:
return client

# NOTE: set access token cookie to get around login endpoint limiter
access_token = security.create_access_token(
TokenPayload(user_id=str(user_to_log_in.id))
)
client.cookies = {"access_token": access_token}

return client
126 changes: 126 additions & 0 deletions tests/factories/user_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from datetime import datetime
from typing import Any, Iterable, Sequence
from uuid import UUID

from pydantic import BaseModel
from sqlalchemy.orm import Session

from app.auth.utils import security
from app.two_factor_authentication.use_cases import CreateNewUser2FAUseCase
from app.users.models import User


class UserFactory:
class _UserConfig(BaseModel):
id: UUID | None = None
created_at: datetime | None = None
email: str = "test@user.com"
password: str = "password"
with_2fa: bool = False

@property
def config(self) -> type[_UserConfig]:
return self.__class__._UserConfig

def __init__(self, session: Session):
self.session = session
self._create_2fa_use_case = CreateNewUser2FAUseCase(session)

def create(
self,
config: _UserConfig | None = None,
/,
**kw: Any,
) -> User:
if config is None:
config = self.config(**kw)

user = self._create_model(config)
self.session.add(user)
self.session.flush()

if config.with_2fa:
self._create_2fa(user)

return user

def create_many(
self,
amount: int,
/,
*,
email_base: str = "test@mail.com",
configurations: Sequence[_UserConfig] | None = None,
**kw: Any,
) -> list[User]:
assert amount > 1, "Amount must be greater than 1"
assert email_base.count("@") == 1, (
"`email_base` must contain exactly one `@`"
)

configurations = self._parse_configurations(
amount, email_base, configurations, kw
)

users = [self._create_model(config) for config in configurations]

self.session.add_all(users)
self.session.flush()

self._create_2fa(
user
for user, config in zip(users, configurations, strict=True)
if config.with_2fa
)

return users

def _create_model(self, config: _UserConfig, /) -> User:
user = User(
email=config.email,
hashed_password=security.get_password_hash(config.password),
)

if config.id:
user.id = config.id

if config.created_at:
user.created_at = config.created_at

return user

def _create_2fa(self, users: User | Iterable[User], /) -> None:
if isinstance(users, User):
users = [users]

for user in users:
self._create_2fa_use_case.execute(user.id)
self.session.refresh(
user, attribute_names=("two_factor_authentications",)
)
assert len(user.two_factor_authentications) == 1, (
"UserFactory failed to create user 2fa"
)

def _parse_configurations(
self,
amount: int,
email_base: str,
configurations: Sequence[_UserConfig] | None,
kw: dict[str, Any],
) -> list[_UserConfig]:
address, domain = email_base.split("@")

parsed_configurations = []
for i in range(amount):
if configurations is None or len(configurations) <= i:
config = self.config(**kw)
else:
config = configurations[i]

if "email" not in config.model_fields_set:
config.email = f"{address}{i}@{domain}"

parsed_configurations.append(config)

return parsed_configurations
Loading