diff --git a/ninja/params/models.py b/ninja/params/models.py index bba135345..7a70e938d 100644 --- a/ninja/params/models.py +++ b/ninja/params/models.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from collections import defaultdict from typing import ( TYPE_CHECKING, Any, @@ -39,10 +38,6 @@ TModels = List[TModel] -def NestedDict() -> DictStrAny: - return defaultdict(NestedDict) - - class ParamModel(BaseModel, ABC): __ninja_param_source__ = None @@ -65,11 +60,6 @@ def resolve( return cls() data = cls._map_data_paths(data) - # Convert defaultdict to dict for pydantic 2.12+ compatibility - # In pydantic 2.12+, accessing missing keys in defaultdict creates nested - # defaultdicts which then fail validation - if isinstance(data, defaultdict): - data = dict(data) return cls.model_validate(data, context={"request": request}) @classmethod @@ -78,22 +68,20 @@ def _map_data_paths(cls, data: DictStrAny) -> DictStrAny: if not flatten_map: return data - mapped_data: DictStrAny = NestedDict() - for k in flatten_map: - if k in data: - cls._map_data_path(mapped_data, data[k], flatten_map[k]) - else: - cls._map_data_path(mapped_data, None, flatten_map[k]) - + mapped_data: DictStrAny = {} + for key, path in flatten_map.items(): + cls._map_data_path(mapped_data, data.get(key), path) return mapped_data @classmethod - def _map_data_path(cls, data: DictStrAny, value: Any, path: Tuple) -> None: - if len(path) == 1: - if value is not None: - data[path[0]] = value - else: - cls._map_data_path(data[path[0]], value, path[1:]) + def _map_data_path( + cls, data: DictStrAny, value: Any, path: Tuple[str, ...] + ) -> None: + current = data + for key in path[:-1]: + current = current.setdefault(key, {}) + if value is not None: + current[path[-1]] = value class QueryModel(ParamModel): diff --git a/tests/test_params_models.py b/tests/test_params_models.py new file mode 100644 index 000000000..d73a1488b --- /dev/null +++ b/tests/test_params_models.py @@ -0,0 +1,23 @@ +from typing import Optional + +from ninja.params.models import DictStrAny, ParamModel + + +class _NestedParamModel(ParamModel): + outer: DictStrAny + leaf: Optional[int] + + __ninja_flatten_map__ = { + "foo": ("outer", "foo"), + "bar": ("outer", "bar"), + "leaf": ("leaf",), + } + + +def test_map_data_paths_creates_parent_for_missing_nested_values(): + assert _NestedParamModel._map_data_paths({}) == {"outer": {}} + + +def test_map_data_paths_sets_values_when_present(): + data = _NestedParamModel._map_data_paths({"foo": 1, "leaf": 2}) + assert data == {"outer": {"foo": 1}, "leaf": 2} diff --git a/tests/test_query_schema.py b/tests/test_query_schema.py index 465e6779b..335ddcbda 100644 --- a/tests/test_query_schema.py +++ b/tests/test_query_schema.py @@ -1,9 +1,10 @@ from datetime import datetime from enum import IntEnum -from pydantic import Field +from pydantic import BaseModel, Field from ninja import NinjaAPI, Query, Schema +from ninja.testing.client import TestClient class Range(IntEnum): @@ -12,7 +13,7 @@ class Range(IntEnum): TWO_HUNDRED = 200 -class Filter(Schema): +class Filter(BaseModel): to_datetime: datetime = Field(alias="to") from_datetime: datetime = Field(alias="from") range: Range = Range.TWENTY @@ -28,7 +29,7 @@ class Data(Schema): @api.get("/test") def query_params_schema(request, filters: Filter = Query(...)): - return filters.dict() + return filters.model_dump() @api.get("/test-mixed") @@ -39,57 +40,80 @@ def query_params_mixed_schema( filters: Filter = Query(...), data: Data = Query(...), ): - return dict(query1=query1, query2=query2, filters=filters.dict(), data=data.dict()) - - -# def test_request(): -# client = TestClient(api) -# response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50") -# print("!", response.json()) -# assert response.json() == { -# "to_datetime": "1970-01-01T00:00:02Z", -# "from_datetime": "1970-01-01T00:00:01Z", -# "range": 20, -# } - -# response = client.get("/test?from=1&to=2&range=21") -# assert response.status_code == 422 - - -# def test_request_mixed(): -# client = TestClient(api) -# response = client.get( -# "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&int=3&float=1.6" -# ) -# print(response.json()) -# assert response.json() == { -# "data": {"a_float": 1.6, "an_int": 3}, -# "filters": { -# "from_datetime": "1970-01-01T00:00:01Z", -# "range": 20, -# "to_datetime": "1970-01-01T00:00:02Z", -# }, -# "query1": 2, -# "query2": 5, -# } - -# response = client.get( -# "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&query2=10" -# ) -# print(response.json()) -# assert response.json() == { -# "data": {"a_float": 1.5, "an_int": 0}, -# "filters": { -# "from_datetime": "1970-01-01T00:00:01Z", -# "range": 20, -# "to_datetime": "1970-01-01T00:00:02Z", -# }, -# "query1": 2, -# "query2": 10, -# } - -# response = client.get("/test-mixed?from=1&to=2") -# assert response.status_code == 422 + return dict( + query1=query1, + query2=query2, + filters=filters.model_dump(), + data=data.model_dump(), + ) + + +def test_request(): + client = TestClient(api) + response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50") + print("!", response.json()) + assert response.json() == { + "to_datetime": "1970-01-01T00:00:02Z", + "from_datetime": "1970-01-01T00:00:01Z", + "range": 20, + } + + response = client.get("/test?from=1&to=2&range=21") + assert response.status_code == 422 + + +def test_request_mixed(): + client = TestClient(api) + response = client.get( + "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&int=3&float=1.6" + ) + print(response.json()) + assert response.json() == { + "data": {"a_float": 1.6, "an_int": 3}, + "filters": { + "from_datetime": "1970-01-01T00:00:01Z", + "range": 20, + "to_datetime": "1970-01-01T00:00:02Z", + }, + "query1": 2, + "query2": 5, + } + + response = client.get( + "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&query2=10" + ) + print(response.json()) + assert response.json() == { + "data": {"a_float": 1.5, "an_int": 0}, + "filters": { + "from_datetime": "1970-01-01T00:00:01Z", + "range": 20, + "to_datetime": "1970-01-01T00:00:02Z", + }, + "query1": 2, + "query2": 10, + } + + response = client.get("/test-mixed?from=1&to=2") + assert response.status_code == 422 + + +def test_request_query_params_using_basemodel(): + class Foo(BaseModel): + start: int + optional: int = 42 + + temp_api = NinjaAPI() + + @temp_api.get("/foo") + def view(request, foo: Foo = Query(...)): + return foo.model_dump() + + client = TestClient(temp_api) + resp = client.get("/foo?start=1") + + assert resp.status_code == 200 + assert resp.json() == {"start": 1, "optional": 42} def test_schema():