diff --git a/burr/core/action.py b/burr/core/action.py index 69a7c75b..b2e7c16d 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -21,6 +21,7 @@ import copy import inspect import sys +import textwrap import types import typing from collections.abc import AsyncIterator @@ -49,6 +50,56 @@ from typing import Self from burr.core.state import State + + +def _validate_declared_reads(fn: Callable, declared_reads: list[str]) -> None: + if not declared_reads: + return + + try: + source = inspect.getsource(fn) + except OSError: + return # skip if source unavailable + + # detect actual state parameter name + sig = inspect.signature(fn) + state_param_name = None + + for name, param in sig.parameters.items(): + if param.annotation is State: + state_param_name = name + break + + if state_param_name is None: + return + + tree = ast.parse(textwrap.dedent(source)) + + declared = set(declared_reads) + violations = [] + + class Visitor(ast.NodeVisitor): + def visit_Subscript(self, node): + if ( + isinstance(node.value, ast.Name) + and node.value.id == state_param_name + and isinstance(node.slice, ast.Constant) + and isinstance(node.slice.value, str) + ): + key = node.slice.value + if key not in declared: + violations.append(key) + self.generic_visit(node) + + Visitor().visit(tree) + + if violations: + raise ValueError( + f"Action reads undeclared state keys: {violations}. " + f"Declared reads: {declared_reads}" + ) + + from burr.core.typing import ActionSchema # This is here to make accessing the pydantic actions easier @@ -628,6 +679,8 @@ def __init__( self._fn = fn self._reads = reads self._writes = writes + _validate_declared_reads(self._originating_fn, self._reads) + self._bound_params = bound_params if bound_params is not None else {} self._inputs = ( derive_inputs_from_fn(self._bound_params, self._fn) @@ -1106,9 +1159,12 @@ def __init__( :param writes: """ super(FunctionBasedStreamingAction, self).__init__() + self._originating_fn = originating_fn if originating_fn is not None else fn self._fn = fn self._reads = reads self._writes = writes + _validate_declared_reads(self._originating_fn, self._reads) + self._bound_params = bound_params if bound_params is not None else {} self._inputs = ( derive_inputs_from_fn(self._bound_params, self._fn) @@ -1118,7 +1174,7 @@ def __init__( [item for item in input_spec[1] if item not in self._bound_params], ) ) - self._originating_fn = originating_fn if originating_fn is not None else fn + self._schema = schema self._tags = tags if tags is not None else [] diff --git a/tests/core/test_action.py b/tests/core/test_action.py index fd0ed36b..83fecf3b 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -823,3 +823,58 @@ def fn(state, a): required, optional = derive_inputs_from_fn(bound_params, fn) assert required == [] assert optional == [] + + +def test_undeclared_state_read_raises_error(): + with pytest.raises(ValueError): + + @action(reads=["foo"], writes=[]) + def bad_action(state: State): + _ = state["bar"] + return {}, state + + +def test_declared_state_read_passes(): + @action(reads=["foo"], writes=[]) + def good_action(state: State): + _ = state["foo"] + return {}, state + + +def test_multiple_undeclared_reads_interleaved(): + with pytest.raises(ValueError) as exc: + + @action(reads=["foo"], writes=[]) + def bad_action(state: State): + _ = state["foo"] + _ = state["bar"] + _ = state["baz"] + return {}, state + + message = str(exc.value) + assert "bar" in message + assert "baz" in message + + +def test_pydantic_action_not_impacted(): + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic not installed") + + class MyState(BaseModel): + foo: str + + @action.pydantic( + reads=["foo"], + writes=["foo"], + state_input_type=MyState, + state_output_type=MyState, + ) + def good_action(state: MyState): + return {"foo": state.foo} + + # ensure decoration didn't raise and action is creatable + from burr.core.action import create_action + + create_action(good_action, name="test")