Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"cSpell.words": ["pydantic", "RABBITMQ"]
"cSpell.words": ["conint", "joinedload", "pydantic", "RABBITMQ"]
}
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ We use ruff as linter, to run them on pre-commit please run `pre-commit install`

- You can run tests locally by running `./tests-start.sh` (this will use your local database).
- Tests can also run in docker container with the following command: `docker-compose exec api ./tests-start.sh`.

### Seed db

Run `./scripts/seed-db.sh`
36 changes: 36 additions & 0 deletions alembic/versions/2024_03_24_1223-646ed9779614_add_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,46 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("email"),
)
op.create_table(
"providers",
sa.Column("id", sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(
["id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"patients",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("provider_id", sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(
["id"],
["users.id"],
),
sa.ForeignKeyConstraint(
["provider_id"],
["providers.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("users", sa.Column("type", sa.String(), nullable=False))
op.add_column(
"users", sa.Column("first_name", sa.String(length=30), nullable=False)
)
op.add_column(
"users", sa.Column("last_name", sa.String(length=30), nullable=False)
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("users", "last_name")
op.drop_column("users", "first_name")
op.drop_column("users", "type")

op.drop_table("patients")
op.drop_table("providers")
op.drop_table("users")
# ### end Alembic commands ###
30 changes: 27 additions & 3 deletions app/auth/api/endpoints.py → app/auth/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,47 @@
from slowapi import Limiter
from slowapi.util import get_remote_address

from app.users.repositories.providers_repository import providers_repository
from app.users.repositories.patients_repository import patients_repository


router = APIRouter()
settings = get_settings()

limiter = Limiter(key_func=get_remote_address)


@router.post("/login", status_code=status.HTTP_204_NO_CONTENT)
@router.post("/providers/login", status_code=status.HTTP_204_NO_CONTENT)
@limiter.limit(settings.AUTHENTICATION_API_RATE_LIMIT)
def login_access_token(
def login_provider_access_token(
request: Request,
session: SessionDependency,
login_data: UserLogin,
response: Response,
) -> None:
try:
AuthUserUseCase(session).execute(login_data, response)
AuthUserUseCase(session).execute(
login_data, response, providers_repository
)
except (ModelNotFoundException, InvalidCredentialsException):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
)


@router.post("/patients/login", status_code=status.HTTP_204_NO_CONTENT)
@limiter.limit(settings.AUTHENTICATION_API_RATE_LIMIT)
def login_patient_access_token(
request: Request,
session: SessionDependency,
login_data: UserLogin,
response: Response,
) -> None:
try:
AuthUserUseCase(session).execute(
login_data, response, patients_repository
)
except (ModelNotFoundException, InvalidCredentialsException):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand Down
4 changes: 2 additions & 2 deletions app/auth/api/routers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter

from app.auth.api import endpoints
from app.auth.api import auth

api_router = APIRouter()
api_router.include_router(endpoints.router, prefix="/auth", tags=["login"])
api_router.include_router(auth.router, prefix="/auth", tags=["login"])
9 changes: 7 additions & 2 deletions app/auth/use_cases/auth_user_use_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
from app.auth.schemas.auth_schema import UserLogin
from app.auth.services.auth_service import AuthService
from app.auth.utils.set_http_only_cookie import set_http_only_cookie
from app.users.repositories.users_repository import users_repository
from app.users.repositories.users_repository import UsersRepository


class AuthUserUseCase:
def __init__(self, session: Session):
self.session = session

def execute(self, login_data: UserLogin, response: Response) -> None:
def execute(
self,
login_data: UserLogin,
response: Response,
users_repository: UsersRepository,
) -> None:
patient = AuthService(self.session, users_repository).authenticate(
login_data
)
Expand Down
21 changes: 12 additions & 9 deletions app/celery/tasks/emails.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@


from app.core.config import get_settings
from app.users.repositories.users_repository import users_repository
from app.users.repositories.patients_repository import patients_repository
from app.users.schemas.patient_schema import PatientInDB
from app.users.schemas.user_schema import UserInDB
from app.users.services.users_service import UsersService
from app.users.services.patients_service import PatientsService

settings = get_settings()

Expand All @@ -19,12 +20,12 @@
def send_reminder_email() -> None:
session = SessionLocal()
try:
users = UsersService(session, users_repository).list(
patients = PatientsService(session, patients_repository).list(
ListFilter(page=1, page_size=100)
)
for user in users.data:
for patient in patients:
EmailService(ExampleEmailClient()).send_user_remind_email(
UserInDB.model_validate(user)
UserInDB.model_validate(patient)
)
finally:
session.close()
Expand All @@ -36,13 +37,15 @@ def send_reminder_email() -> None:
max_retries=settings.SEND_WELCOME_EMAIL_MAX_RETRIES,
retry_jitter=False,
)
def send_welcome_email(user_id: UUID) -> None:
def send_welcome_email(patient_id: UUID) -> None:
session = SessionLocal()
try:
user = UsersService(session, users_repository).get_by_id(user_id)
if user:
patient = PatientsService(session, patients_repository).get_by_id(
patient_id
)
if patient:
EmailService(ExampleEmailClient()).send_new_user_email(
UserInDB.model_validate(user)
PatientInDB.model_validate(patient)
)
finally:
session.close()
Empty file added app/commands/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions app/commands/seed_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from setuptools import Command

from app.db.seeders.db_seeder import DBSeeder
from app.db.seeders.users_seeder import UsersSeeder
from app.db.session import SessionLocal


class SeedUsersCommand(Command):
description = "Seed patients"
user_options = [] # type: ignore

def run(self) -> None:
"""Run the command."""
session = SessionLocal()
db_seeder = DBSeeder(session)
db_seeder.run_seeder(UsersSeeder)

session.close()

def initialize_options(self) -> None:
"""Set default values for options."""

def finalize_options(self) -> None:
"""Finalize options."""
7 changes: 7 additions & 0 deletions app/common/enums/extended_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class ExtendedEnum(Enum):
@classmethod
def list(cls) -> list[str]:
return list(map(lambda c: c.value, cls))
44 changes: 25 additions & 19 deletions app/common/repositories/base_repository.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from math import ceil
from typing import Any, Generic, Optional, Type, TypeVar
from typing import Any, Generic, List, Type, TypeVar
from uuid import UUID

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import asc, desc
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm.query import Query

from app.common.schemas.pagination_schema import ListFilter, ListResponse
from app.common.schemas.pagination_schema import ListFilter


ModelType = TypeVar("ModelType", bound=Any)
Expand All @@ -17,7 +16,9 @@


class BaseRepository(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
def __init__(
self, model: Type[ModelType], joined_loads: List[str] | None = None
):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).

Expand All @@ -27,18 +28,29 @@ def __init__(self, model: Type[ModelType]):
* `schema`: A Pydantic model (schema) class
"""
self.model = model
self.joined_loads = joined_loads

def get(self, db: Session, model_id: UUID) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == model_id).first()
def get(
self,
db: Session,
model_id: UUID,
) -> ModelType | None:
query = db.query(self.model).filter(self.model.id == model_id)

if self.joined_loads is not None:
for relation in self.joined_loads:
query = query.options(
joinedload(getattr(self.model, relation))
)

return query.first()

def list(
self, db: Session, list_options: ListFilter, query: Query | None = None
) -> ListResponse:
) -> List[ModelType]:
if not query:
query = db.query(self.model)

total = query.count()

if list_options.order_by:
column = list_options.order_by
direction = list_options.order
Expand All @@ -47,15 +59,9 @@ def list(
query = query.order_by(by(column))

query = query.offset(list_options.page_size * (list_options.page - 1))

query = query.limit(list_options.page_size)
return ListResponse(
data=query.all(),
page=list_options.page,
page_size=list_options.page_size,
total=total,
total_pages=ceil(total / list_options.page_size),
)

return query.all()

def create(self, db: Session, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
Expand All @@ -68,7 +74,7 @@ def update(
self, db: Session, db_obj: ModelType, obj_in: UpdateSchemaType
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
update_data = obj_in.dict(exclude_unset=True)
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
Expand Down
10 changes: 5 additions & 5 deletions app/common/schemas/pagination_schema.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Generic, List, Literal, TypeVar
from typing import Annotated, Generic, List, Literal, TypeVar

from pydantic import BaseModel, conint


class ListFilter(BaseModel):
page: conint(ge=1) = 1
page_size: conint(ge=1, le=100) = 10
page: Annotated[int, conint(ge=1)] = 1
page_size: Annotated[int, conint(ge=1, le=100)] = 10
name: str | None = None
order: Literal["asc", "desc"] | None = None
order_by: str | None = None
Expand All @@ -16,7 +16,7 @@ class ListFilter(BaseModel):

class ListResponse(BaseModel, Generic[T]):
data: List[T]
page_size: conint(ge=1, le=100)
page: conint(ge=1)
page_size: Annotated[int, conint(ge=1, le=100)]
page: Annotated[int, conint(ge=1)]
total: int
total_pages: int
Empty file added app/db/seeders/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions app/db/seeders/db_seeder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any

from sqlalchemy.orm import Session


class DBSeeder:
def __init__(self, db: Session):
self.db = db

def run_seeder(self, seeder: Any) -> None:
seeder.run(self.db)
44 changes: 44 additions & 0 deletions app/db/seeders/users_seeder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from sqlalchemy.orm import Session


from app.auth.utils import security
from app.users.enums.user_type_enum import UserTypeEnum
from app.users.repositories.patients_repository import patients_repository
from app.users.repositories.providers_repository import providers_repository
from app.users.schemas.patient_schema import PatientCreate
from app.users.schemas.provider_schema import ProviderCreate
from app.users.services.providers_service import ProvidersService
from app.users.services.patients_service import PatientsService

from faker import Faker

fake = Faker()


class UsersSeeder:
@staticmethod
def run(db: Session) -> None:
hashed_password = security.get_password_hash("password")
provider = ProvidersService(db, providers_repository).create(
ProviderCreate(
email="test0@provider.com",
hashed_password=hashed_password,
first_name="Provider",
last_name="Test",
type=UserTypeEnum.PROVIDER,
)
)

patients_service = PatientsService(db, patients_repository)
for _ in range(1, 10):
patients_service.create(
PatientCreate(
email=fake.email(),
hashed_password=hashed_password,
first_name=fake.first_name(),
last_name=fake.last_name(),
type=UserTypeEnum.PATIENT,
provider_id=provider.id,
)
)
db.commit()
Loading
Loading