Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions apps/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from django.utils.translation import gettext_lazy as _
from rest_framework.permissions import BasePermission

from apps.core.models import User


def _normalize_roles(roles: Iterable[str]) -> frozenset[str]:
"""Normalize role values to strings."""
Expand All @@ -21,6 +23,63 @@ def _normalize_roles(roles: Iterable[str]) -> frozenset[str]:
return frozenset(normalized)


class IsAdmin(BasePermission):
"""
Allow requests only from users with Admin role or superusers.
"""

def has_permission(self, request, view) -> bool:
user = request.user
if not user or not user.is_authenticated:
return False
if user.is_superuser:
return True
return getattr(user, "is_admin_role", False) or user.role == User.Role.ADMIN

def has_object_permission(self, request, view, obj) -> bool:
return self.has_permission(request, view)


class IsDoctor(BasePermission):
"""
Allow requests only from users with Doctor role, Admin role, or superusers.
"""

def has_permission(self, request, view) -> bool:
user = request.user
if not user or not user.is_authenticated:
return False
if user.is_superuser:
return True
if getattr(user, "is_admin_role", False) or user.role == User.Role.ADMIN:
return True
return getattr(user, "is_doctor_role", False) or user.role == User.Role.DOCTOR

def has_object_permission(self, request, view, obj) -> bool:
return self.has_permission(request, view)


class IsReceptionist(BasePermission):
"""
Allow requests from users with Receptionist, Doctor, Admin roles, or superusers.
"""

def has_permission(self, request, view) -> bool:
user = request.user
if not user or not user.is_authenticated:
return False
if user.is_superuser:
return True
if getattr(user, "is_admin_role", False) or user.role == User.Role.ADMIN:
return True
if getattr(user, "is_doctor_role", False) or user.role == User.Role.DOCTOR:
return True
return getattr(user, "is_receptionist_role", False) or user.role == User.Role.RECEPTIONIST

def has_object_permission(self, request, view, obj) -> bool:
return self.has_permission(request, view)


class IsClinicalStaff(BasePermission):
"""
Allow requests from authenticated clinical staff or superusers.
Expand Down
106 changes: 82 additions & 24 deletions apps/api/tests/test_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django.urls import reverse
from rest_framework.test import APIClient, APIRequestFactory

from apps.api.permissions import RoleRequired
from apps.api.permissions import IsAdmin, IsDoctor, IsReceptionist, RoleRequired

User = get_user_model()

Expand All @@ -29,35 +29,44 @@ def setUp(self) -> None:
is_superuser=True,
)

self.provider_user = User.objects.create_user(
username="rbac-provider",
self.doctor_user = User.objects.create_user(
username="rbac-doctor",
password="testpass123",
email="provider@example.com",
role=User.Role.PROVIDER,
email="doctor@example.com",
role=User.Role.DOCTOR,
user_type="doctor",
)

self.patient_user = User.objects.create_user(
username="rbac-patient",
self.receptionist_user = User.objects.create_user(
username="rbac-receptionist",
password="testpass123",
email="patient@example.com",
role=User.Role.PATIENT,
email="receptionist@example.com",
role=User.Role.RECEPTIONIST,
user_type="receptionist",
)

def _client_for(self, user: User) -> APIClient:
client = APIClient()
client.force_authenticate(user)
return client

def test_patient_blocked_from_clinical_endpoints(self):
"""Patients should receive 403 when calling staff-only endpoints."""
client = self._client_for(self.patient_user)
url = reverse("api:patients-list")
def test_receptionist_blocked_from_admin_endpoints(self):
"""Receptionists should receive 403 when calling admin-only endpoints."""
client = self._client_for(self.receptionist_user)
url = reverse("api:api_stats")
response = client.get(url)
self.assertEqual(response.status_code, 403)

def test_provider_blocked_from_admin_only_metrics(self):
"""Providers must not access admin-only statistics endpoints."""
client = self._client_for(self.provider_user)
def test_receptionist_blocked_from_doctor_only_endpoints(self):
"""Receptionists should receive 403 when calling doctor-only endpoints."""
client = self._client_for(self.receptionist_user)
url = reverse("api:health-workers-list")
response = client.get(url)
self.assertEqual(response.status_code, 403)

def test_doctor_blocked_from_admin_only_metrics(self):
"""Doctors must not access admin-only statistics endpoints."""
client = self._client_for(self.doctor_user)
url = reverse("api:api_stats")
response = client.get(url)
self.assertEqual(response.status_code, 403)
Expand All @@ -75,27 +84,34 @@ def test_admin_can_access_admin_only_endpoints(self):
self.assertEqual(export_response.status_code, 200)
self.assertEqual(export_response.data["format"], "csv")

def test_role_required_allows_provider_role(self):
def test_role_required_allows_doctor_role(self):
"""RoleRequired should allow users whose role matches the requirement."""
permission = RoleRequired()

class DummyView:
required_roles = frozenset({User.Role.PROVIDER})
required_roles = frozenset({User.Role.DOCTOR})

request = self.factory.get("/dummy")
request.user = self.provider_user
request.user = self.doctor_user

self.assertTrue(permission.has_permission(request, DummyView()))

request.user = self.patient_user
request.user = self.receptionist_user
self.assertFalse(permission.has_permission(request, DummyView()))

request.user = self.admin_user
self.assertTrue(permission.has_permission(request, DummyView()))

def test_provider_can_access_provider_endpoints(self):
"""Providers should access endpoints allowed for providers."""
client = self._client_for(self.provider_user)
def test_doctor_can_access_doctor_endpoints(self):
"""Doctors should access endpoints allowed for doctors."""
client = self._client_for(self.doctor_user)
url = reverse("api:patients-list")
response = client.get(url)
self.assertEqual(response.status_code, 200)

def test_receptionist_can_access_receptionist_endpoints(self):
"""Receptionists should access endpoints allowed for receptionists."""
client = self._client_for(self.receptionist_user)
url = reverse("api:patients-list")
response = client.get(url)
self.assertEqual(response.status_code, 200)
Expand All @@ -104,7 +120,7 @@ def test_admin_can_access_all_endpoints(self):
"""Admins should have access to all endpoints."""
client = self._client_for(self.admin_user)

# Test provider endpoints
# Test patient endpoints (all roles)
patients_url = reverse("api:patients-list")
response = client.get(patients_url)
self.assertEqual(response.status_code, 200)
Expand All @@ -129,3 +145,45 @@ def test_health_check_public(self):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["status"], "healthy")

def test_is_admin_permission(self):
"""IsAdmin permission should only allow admin users."""
permission = IsAdmin()
request = self.factory.get("/dummy")

request.user = self.admin_user
self.assertTrue(permission.has_permission(request, None))

request.user = self.doctor_user
self.assertFalse(permission.has_permission(request, None))

request.user = self.receptionist_user
self.assertFalse(permission.has_permission(request, None))

def test_is_doctor_permission(self):
"""IsDoctor permission should allow doctor and admin users."""
permission = IsDoctor()
request = self.factory.get("/dummy")

request.user = self.admin_user
self.assertTrue(permission.has_permission(request, None))

request.user = self.doctor_user
self.assertTrue(permission.has_permission(request, None))

request.user = self.receptionist_user
self.assertFalse(permission.has_permission(request, None))

def test_is_receptionist_permission(self):
"""IsReceptionist permission should allow all roles."""
permission = IsReceptionist()
request = self.factory.get("/dummy")

request.user = self.admin_user
self.assertTrue(permission.has_permission(request, None))

request.user = self.doctor_user
self.assertTrue(permission.has_permission(request, None))

request.user = self.receptionist_user
self.assertTrue(permission.has_permission(request, None))

6 changes: 3 additions & 3 deletions apps/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PatientViewSet(AuditLogMixin, viewsets.ModelViewSet):
"""

permission_classes = [IsAuthenticated, RoleRequired]
required_roles = frozenset({User.Role.ADMIN, User.Role.PROVIDER})
required_roles = frozenset({User.Role.ADMIN, User.Role.DOCTOR, User.Role.RECEPTIONIST})
queryset = (
Patient.objects.select_related("location", "registered_facility")
.prefetch_related("patientvisit_set")
Expand Down Expand Up @@ -132,7 +132,7 @@ class FacilityViewSet(viewsets.ReadOnlyModelViewSet):
"""

permission_classes = [IsAuthenticated, RoleRequired]
required_roles = frozenset({User.Role.ADMIN, User.Role.PROVIDER})
required_roles = frozenset({User.Role.ADMIN, User.Role.DOCTOR, User.Role.RECEPTIONIST})
queryset = HealthFacility.objects.select_related("location").all()
serializer_class = HealthFacilitySerializer
filterset_fields = ["facility_type", "location"]
Expand All @@ -146,7 +146,7 @@ class PatientVisitViewSet(AuditLogMixin, viewsets.ModelViewSet):
"""

permission_classes = [IsAuthenticated, RoleRequired]
required_roles = frozenset({User.Role.ADMIN, User.Role.PROVIDER})
required_roles = frozenset({User.Role.ADMIN, User.Role.DOCTOR, User.Role.RECEPTIONIST})
queryset = (
PatientVisit.objects.select_related("patient", "facility", "attending_provider")
.all()
Expand Down
23 changes: 23 additions & 0 deletions apps/core/migrations/0003_update_role_choices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated migration for updating role field choices to support RBAC

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('core', '0002_set_admin_role_default'),
]

operations = [
migrations.AlterField(
model_name='user',
name='role',
field=models.CharField(
choices=[('admin', 'Administrator'), ('doctor', 'Doctor'), ('receptionist', 'Receptionist')],
default='receptionist',
help_text='High-level persona used for role-based access control.',
max_length=20
),
),
]
19 changes: 10 additions & 9 deletions apps/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class User(AbstractUser):
"""
class Role(models.TextChoices):
ADMIN = 'admin', _('Administrator')
PROVIDER = 'provider', _('Healthcare Provider')
PATIENT = 'patient', _('Patient')
DOCTOR = 'doctor', _('Doctor')
RECEPTIONIST = 'receptionist', _('Receptionist')

USER_TYPE_CHOICES = [
('admin', _('Administrator')),
Expand All @@ -25,12 +25,13 @@ class Role(models.TextChoices):
('community_worker', _('Community Health Worker')),
('pharmacist', _('Pharmacist')),
('lab_technician', _('Laboratory Technician')),
('receptionist', _('Receptionist')),
]

role = models.CharField(
max_length=20,
choices=Role.choices,
default=Role.PROVIDER,
default=Role.RECEPTIONIST,
help_text=_('High-level persona used for role-based access control.')
)
user_type = models.CharField(
Expand Down Expand Up @@ -71,14 +72,14 @@ def is_admin_role(self):
return self.role == self.Role.ADMIN

@property
def is_provider_role(self):
"""Check if user has provider role."""
return self.role == self.Role.PROVIDER
def is_doctor_role(self):
"""Check if user has doctor role."""
return self.role == self.Role.DOCTOR

@property
def is_patient_role(self):
"""Check if user has patient role."""
return self.role == self.Role.PATIENT
def is_receptionist_role(self):
"""Check if user has receptionist role."""
return self.role == self.Role.RECEPTIONIST


class Location(models.Model):
Expand Down
5 changes: 3 additions & 2 deletions apps/core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ class UserSerializer(serializers.ModelSerializer):
"""
full_name = serializers.SerializerMethodField()
user_type_display = serializers.CharField(source='get_user_type_display', read_only=True)
role_display = serializers.CharField(source='get_role_display', read_only=True)

class Meta:
model = User
fields = [
'id', 'username', 'email', 'first_name', 'last_name', 'full_name',
'user_type', 'user_type_display', 'phone_number', 'date_of_birth',
'role', 'role_display', 'user_type', 'user_type_display', 'phone_number', 'date_of_birth',
'profile_picture', 'license_number', 'specialization', 'years_of_experience',
'is_active', 'date_joined', 'last_login'
]
Expand Down Expand Up @@ -52,7 +53,7 @@ class Meta:
model = User
fields = [
'username', 'email', 'first_name', 'last_name', 'password',
'password_confirm', 'user_type', 'phone_number', 'date_of_birth'
'password_confirm', 'role', 'user_type', 'phone_number', 'date_of_birth'
]

def validate(self, attrs):
Expand Down