Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions ninja_extra/controllers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,31 @@ class MissingAPIControllerDecoratorException(Exception):

def get_route_functions(cls: Type) -> Iterable[RouteFunction]:
"""
Get all route functions from a controller class.
This function will recursively search for route functions in the base classes of the controller class
in order that they are defined.
Return fresh RouteFunction instances for a controller class.

Args:
cls (Type): The controller class.

Returns:
Iterable[RouteFunction]: An iterable of route functions.
Each call yields a clone of the RouteFunction template stored on the
controller method, ensuring metadata is not shared across subclasses.
"""

bases = inspect.getmro(cls)
for base_cls in reversed(bases):
if base_cls not in [ControllerBase, ABC, object]:
for method in base_cls.__dict__.values():
if hasattr(method, ROUTE_FUNCTION):
yield getattr(method, ROUTE_FUNCTION)
for _, method, template in _iter_route_templates(cls):
yield template.clone(method)


def _iter_route_templates(
cls: Type,
) -> Iterable[Tuple[str, Callable[..., Any], RouteFunction]]:
seen: set[str] = set()
for base_cls in inspect.getmro(cls):
if base_cls in (ControllerBase, ABC, object):
continue
for attr_name, method in base_cls.__dict__.items():
if attr_name in seen:
continue
route_template = getattr(method, ROUTE_FUNCTION, None)
if route_template is None:
continue
seen.add(attr_name)
yield attr_name, method, route_template


def get_all_controller_route_function(
Expand Down
3 changes: 2 additions & 1 deletion ninja_extra/controllers/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(
)

def _add_to_controller(self, func: t.Callable) -> None:
route_function = getattr(func, ROUTE_FUNCTION)
route_template = getattr(func, ROUTE_FUNCTION)
route_function = route_template.clone(func)
route_function.api_controller = self._api_controller_instance
self._api_controller_instance.add_controller_route_function(route_function)

Expand Down
40 changes: 39 additions & 1 deletion ninja_extra/controllers/route/route_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@
import warnings
from contextlib import contextmanager
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)

from django.http import HttpRequest, HttpResponse

Expand All @@ -16,6 +27,11 @@
from ninja_extra.controllers.base import APIController, ControllerBase
from ninja_extra.controllers.route import Route
from ninja_extra.operation import Operation
from ninja_extra.permissions import BasePermission

RoutePermissions = Optional[
List[Union[Type["BasePermission"], "BasePermission", Any]]
]


class RouteFunctionContext:
Expand Down Expand Up @@ -104,6 +120,28 @@ def as_view(
as_view.get_route_function = lambda: self # type:ignore
return as_view

def clone(self, view_func: Callable[..., Any]) -> "RouteFunction":
from ninja_extra.controllers.route import Route

route_params = self.route.route_params.dict()
permissions: RoutePermissions
if self.route.permissions is None:
permissions = None
else:
permissions = cast(RoutePermissions, list(self.route.permissions))

if route_params["tags"] is not None:
route_params["tags"] = list(route_params["tags"])
route_params["methods"] = list(route_params["methods"])

cloned_route = Route(
view_func,
**route_params,
permissions=permissions,
)

return type(self)(route=cloned_route)

def _process_view_function_result(self, result: Any) -> Any:
"""
This process any a returned value from view_func
Expand Down
17 changes: 14 additions & 3 deletions ninja_extra/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ def get_function_name(func_class: t.Any) -> str:

@t.no_type_check
def get_route_function(func: t.Callable) -> t.Optional["RouteFunction"]:
if hasattr(func, ROUTE_FUNCTION):
return func.__dict__[ROUTE_FUNCTION]
return None # pragma: no cover
controller_instance = getattr(func, "__self__", None)

if controller_instance is not None:
controller_class = controller_instance.__class__
api_controller = controller_class.get_api_controller()
return api_controller._controller_class_route_functions.get(func.__name__)

# Unbound function – return a clone of the template for introspection
underlying_func = getattr(func, "__func__", func)
route_template = getattr(underlying_func, ROUTE_FUNCTION, None)
if route_template is None:
return None # pragma: no cover

return route_template.clone(underlying_func)
17 changes: 16 additions & 1 deletion ninja_extra/testing/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from json import dumps as json_dumps
from typing import Any, Callable, Dict, Optional, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from unittest.mock import Mock
from urllib.parse import urlencode

from django.urls import Resolver404
from ninja import NinjaAPI, Router
from ninja.responses import NinjaJSONEncoder
from ninja.testing.client import NinjaClientBase, NinjaResponse
Expand Down Expand Up @@ -42,6 +43,20 @@ def request(
)
return self._call(func, request, kwargs) # type: ignore

def _resolve(
self, method: str, path: str, data: Dict, request_params: Any
) -> Tuple[Callable, Mock, Dict]:
url_path = path.split("?")[0].lstrip("/")
for url in self.urls:
try:
match = url.resolve(url_path)
except Resolver404:
continue
if match:
request = self._build_request(method, path, data, request_params)
return match.func, request, match.kwargs
raise Exception(f'Cannot resolve "{path}"')


class TestClient(NinjaExtraClientBase):
def _call(self, func: Callable, request: Mock, kwargs: Dict) -> "NinjaResponse":
Expand Down
59 changes: 59 additions & 0 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ def example(self):
pass


class ReportControllerBase(ControllerBase):
@http_get("")
def report(self):
return {"controller": type(self).__name__}


@api_controller("/alpha", urls_namespace="alpha")
class AlphaReportController(ReportControllerBase):
pass


@api_controller("/beta", urls_namespace="beta")
class BetaReportController(ReportControllerBase):
pass


class TestAPIController:
def test_api_controller_as_decorator(self):
controller_type = api_controller("prefix", tags="new_tag", auth=FakeAuth())(
Expand Down Expand Up @@ -321,6 +337,49 @@ async def test_controller_base_aget_object_or_none_works(self):
assert isinstance(ex, exceptions.PermissionDenied)


def test_controller_subclass_routes_remain_isolated():
api = NinjaExtraAPI()
api.register_controllers(AlphaReportController)
api.register_controllers(BetaReportController)
client = testing.TestClient(api)

alpha_response = client.get("/alpha")
beta_response = client.get("/beta")

assert alpha_response.status_code == 200
assert beta_response.status_code == 200
assert alpha_response.json() == {"controller": "AlphaReportController"}
assert beta_response.json() == {"controller": "BetaReportController"}


def test_controller_multi_level_inheritance_routes_isolated():
"""Test that route isolation works with multi-level inheritance."""
# Middle layer doesn't override the route
class MiddleReportController(ReportControllerBase):
pass

@api_controller("/gamma")
class GammaReportController(MiddleReportController):
pass

@api_controller("/delta")
class DeltaReportController(MiddleReportController):
pass

api = NinjaExtraAPI()
api.register_controllers(GammaReportController)
api.register_controllers(DeltaReportController)
client = testing.TestClient(api)

gamma_response = client.get("/gamma")
delta_response = client.get("/delta")

assert gamma_response.status_code == 200
assert delta_response.status_code == 200
assert gamma_response.json() == {"controller": "GammaReportController"}
assert delta_response.json() == {"controller": "DeltaReportController"}


def test_controller_registration_through_string():
assert DisableAutoImportController.get_api_controller().registered is False

Expand Down