-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth.py
More file actions
91 lines (80 loc) · 3.73 KB
/
auth.py
File metadata and controls
91 lines (80 loc) · 3.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from datetime import datetime, timedelta, timezone
from typing import Optional, Annotated, Callable
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status, APIRouter
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from models import UserInDB, UserRole, TokenData, DBUser
from database import get_db
# --- Configuration and Contexts ---
SECRET_KEY = os.getenv("SECRET_KEY")
ALGORITHM = os.getenv("ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30))
pwd_context = CryptContext(schemes=["sha256_crypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
# --- Password Utilities ---
def verify_password(plain_password, hashed_password):
"""Verify a plain password against a hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
"""Hash a plaintext password."""
return pwd_context.hash(password)
# --- JWT Token Utilities ---
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""Creates a signed JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def decode_access_token(token: str) -> Optional[TokenData]:
"""Decodes and validates a JWT token."""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: int = payload.get("user_id")
role: str = payload.get("role")
if user_id is None or role is None:
raise credentials_exception
token_data = TokenData(user_id=user_id, role=UserRole(role))
except JWTError:
raise credentials_exception
except ValueError: # Handles case where role is not a valid UserRole enum member
raise credentials_exception
return token_data
# --- Authentication Dependency ---
def get_current_user(
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> DBUser:
"""Dependency to fetch and validate the current logged-in user."""
token_data = decode_access_token(token)
user = db.query(DBUser).filter(DBUser.id == token_data.user_id).first()
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return user
# --- Role-Based Access Control (RBAC) Dependency ---
class RoleChecker:
def __init__(self, allowed_roles: list[UserRole]):
self.allowed_roles = allowed_roles
def __call__(self, user: DBUser = Depends(get_current_user)):
if user.role not in self.allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"User role ({user.role}) is not authorized for this action."
)
return user # Return the user object for use in the endpoint if needed
# Define common dependencies
ALLOW_DEF_PL = Annotated[DBUser, Depends(RoleChecker([UserRole.DEFENDANT, UserRole.PLAINTIFF]))]
ALLOW_JUROR = Annotated[DBUser, Depends(RoleChecker([UserRole.JUROR]))]
ALLOW_JUDGE = Annotated[DBUser, Depends(RoleChecker([UserRole.JUDGE]))]
ALLOW_ALL = Annotated[DBUser, Depends(get_current_user)]