|
| 1 | +"""Test data builders — auto-generate dataclass instances from type hints.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import dataclasses |
| 6 | +import random |
| 7 | +import string |
| 8 | +import uuid |
| 9 | +from datetime import date, datetime |
| 10 | +from typing import Any, TypeVar, get_type_hints |
| 11 | + |
| 12 | +T = TypeVar("T") |
| 13 | + |
| 14 | +_counter = 0 |
| 15 | + |
| 16 | + |
| 17 | +def _next_id() -> int: |
| 18 | + global _counter # noqa: PLW0603 |
| 19 | + _counter += 1 |
| 20 | + return _counter |
| 21 | + |
| 22 | + |
| 23 | +def _random_string(length: int = 8) -> str: |
| 24 | + return "".join(random.choices(string.ascii_lowercase, k=length)) |
| 25 | + |
| 26 | + |
| 27 | +def _generate_value(type_hint: Any, field_name: str = "") -> Any: |
| 28 | + """Generate a plausible value for a given type hint.""" |
| 29 | + origin = getattr(type_hint, "__origin__", None) |
| 30 | + args = getattr(type_hint, "__args__", ()) |
| 31 | + |
| 32 | + # Handle Union types (Optional[X], X | Y, etc.) |
| 33 | + import types |
| 34 | + import typing |
| 35 | + if origin is typing.Union or isinstance(type_hint, types.UnionType): |
| 36 | + type_args = args if args else getattr(type_hint, "__args__", ()) |
| 37 | + non_none = [a for a in type_args if a is not type(None)] |
| 38 | + if non_none: |
| 39 | + return _generate_value(non_none[0], field_name) |
| 40 | + return None |
| 41 | + |
| 42 | + if type_hint is str: |
| 43 | + return f"{field_name}_{_random_string()}" if field_name else _random_string() |
| 44 | + if type_hint is int: |
| 45 | + return _next_id() |
| 46 | + if type_hint is float: |
| 47 | + return round(random.uniform(0.0, 100.0), 2) |
| 48 | + if type_hint is bool: |
| 49 | + return random.choice([True, False]) |
| 50 | + if type_hint is bytes: |
| 51 | + return _random_string().encode() |
| 52 | + if type_hint is datetime: |
| 53 | + return datetime(2024, 1, 1, 12, 0, 0) |
| 54 | + if type_hint is date: |
| 55 | + return date(2024, 1, 1) |
| 56 | + |
| 57 | + # list[X] |
| 58 | + if origin is list: |
| 59 | + inner = args[0] if args else str |
| 60 | + return [_generate_value(inner, field_name) for _ in range(2)] |
| 61 | + |
| 62 | + # dict[K, V] |
| 63 | + if origin is dict: |
| 64 | + k_type = args[0] if args else str |
| 65 | + v_type = args[1] if len(args) > 1 else str # type: ignore[misc] |
| 66 | + return {_generate_value(k_type): _generate_value(v_type) for _ in range(2)} |
| 67 | + |
| 68 | + # set[X] |
| 69 | + if origin is set: |
| 70 | + inner = args[0] if args else str |
| 71 | + return {_generate_value(inner, field_name) for _ in range(2)} |
| 72 | + |
| 73 | + # tuple[X, ...] |
| 74 | + if origin is tuple: |
| 75 | + if args: |
| 76 | + return tuple(_generate_value(a, field_name) for a in args if a is not Ellipsis) |
| 77 | + return () |
| 78 | + |
| 79 | + # UUID |
| 80 | + if type_hint is uuid.UUID: |
| 81 | + return uuid.uuid4() |
| 82 | + |
| 83 | + # Nested dataclass |
| 84 | + if dataclasses.is_dataclass(type_hint) and isinstance(type_hint, type): |
| 85 | + return _create_instance(type_hint) |
| 86 | + |
| 87 | + # Fallback |
| 88 | + return None |
| 89 | + |
| 90 | + |
| 91 | +def _create_instance(cls: type[T], overrides: dict[str, Any] | None = None) -> T: |
| 92 | + """Create an instance of *cls* with auto-generated values.""" |
| 93 | + hints = get_type_hints(cls) |
| 94 | + kwargs: dict[str, Any] = {} |
| 95 | + |
| 96 | + for field in dataclasses.fields(cls): # type: ignore[arg-type] |
| 97 | + if overrides and field.name in overrides: |
| 98 | + kwargs[field.name] = overrides[field.name] |
| 99 | + elif field.default is not dataclasses.MISSING: |
| 100 | + kwargs[field.name] = field.default |
| 101 | + elif field.default_factory is not dataclasses.MISSING: |
| 102 | + kwargs[field.name] = field.default_factory() |
| 103 | + else: |
| 104 | + type_hint = hints.get(field.name, str) |
| 105 | + kwargs[field.name] = _generate_value(type_hint, field.name) |
| 106 | + |
| 107 | + return cls(**kwargs) |
| 108 | + |
| 109 | + |
| 110 | +class Fixture: |
| 111 | + """Auto-generate test data from dataclass type hints.""" |
| 112 | + |
| 113 | + @staticmethod |
| 114 | + def create(cls: type[T], **overrides: Any) -> T: |
| 115 | + """Create a single instance of *cls*.""" |
| 116 | + if not dataclasses.is_dataclass(cls): |
| 117 | + raise TypeError(f"{cls.__name__} is not a dataclass") |
| 118 | + return _create_instance(cls, overrides or None) |
| 119 | + |
| 120 | + @staticmethod |
| 121 | + def create_many(cls: type[T], count: int = 3, **overrides: Any) -> list[T]: |
| 122 | + """Create *count* instances of *cls*.""" |
| 123 | + return [Fixture.create(cls, **overrides) for _ in range(count)] |
| 124 | + |
| 125 | + |
| 126 | +class Builder: |
| 127 | + """Fluent builder for constructing test data.""" |
| 128 | + |
| 129 | + def __init__(self, cls: type[Any]) -> None: |
| 130 | + if not dataclasses.is_dataclass(cls): |
| 131 | + raise TypeError(f"{cls.__name__} is not a dataclass") |
| 132 | + self._cls = cls |
| 133 | + self._overrides: dict[str, Any] = {} |
| 134 | + |
| 135 | + def with_field(self, name: str, value: Any) -> Builder: |
| 136 | + """Set a specific field value.""" |
| 137 | + self._overrides[name] = value |
| 138 | + return self |
| 139 | + |
| 140 | + def build(self) -> Any: |
| 141 | + """Build the instance.""" |
| 142 | + return _create_instance(self._cls, self._overrides) |
| 143 | + |
| 144 | + def build_many(self, count: int = 3) -> list[Any]: |
| 145 | + """Build *count* instances.""" |
| 146 | + return [self.build() for _ in range(count)] |
0 commit comments