From 85ab4c8308e5a9b64eb92dc307156eb3a364e670 Mon Sep 17 00:00:00 2001 From: Smita Ambiger Date: Thu, 12 Feb 2026 19:34:26 +0530 Subject: [PATCH 1/3] core: add AST-based linter for undeclared state reads in function-based actions --- burr/core/action.py | 58 +++++++++++++++++++++++++- tests/core/test_action_reads_linter.py | 23 ++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 tests/core/test_action_reads_linter.py diff --git a/burr/core/action.py b/burr/core/action.py index 69a7c75b6..a2d2d875b 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -49,6 +49,56 @@ from typing import Self from burr.core.state import State + +def _validate_declared_reads(fn: Callable, declared_reads: list[str]) -> None: + try: + source = inspect.getsource(fn) + except OSError: + return # skip if source unavailable + + # detect actual state parameter name + sig = inspect.signature(fn) + state_param_name = None + + for name, param in sig.parameters.items(): + if param.annotation is State: + state_param_name = name + break + + if state_param_name is None: + return + + + import textwrap + tree = ast.parse(textwrap.dedent(source)) + + declared = set(declared_reads) + violations = [] + + class Visitor(ast.NodeVisitor): + def visit_Subscript(self, node): + if ( + isinstance(node.value, ast.Name) + and node.value.id == state_param_name + + and isinstance(node.slice, ast.Constant) + and isinstance(node.slice.value, str) + ): + key = node.slice.value + if key not in declared: + violations.append(key) + self.generic_visit(node) + + Visitor().visit(tree) + + if violations: + raise ValueError( + f"Action reads undeclared state keys: {violations}. " + f"Declared reads: {declared_reads}" + ) + + + from burr.core.typing import ActionSchema # This is here to make accessing the pydantic actions easier @@ -628,6 +678,8 @@ def __init__( self._fn = fn self._reads = reads self._writes = writes + _validate_declared_reads(self._originating_fn, self._reads) + self._bound_params = bound_params if bound_params is not None else {} self._inputs = ( derive_inputs_from_fn(self._bound_params, self._fn) @@ -1106,9 +1158,13 @@ def __init__( :param writes: """ super(FunctionBasedStreamingAction, self).__init__() + self._originating_fn = originating_fn if originating_fn is not None else fn self._fn = fn self._reads = reads self._writes = writes + _validate_declared_reads(self._originating_fn, self._reads) + + self._bound_params = bound_params if bound_params is not None else {} self._inputs = ( derive_inputs_from_fn(self._bound_params, self._fn) @@ -1118,7 +1174,7 @@ def __init__( [item for item in input_spec[1] if item not in self._bound_params], ) ) - self._originating_fn = originating_fn if originating_fn is not None else fn + # self._originating_fn = originating_fn if originating_fn is not None else fn self._schema = schema self._tags = tags if tags is not None else [] diff --git a/tests/core/test_action_reads_linter.py b/tests/core/test_action_reads_linter.py new file mode 100644 index 000000000..f8874ee6c --- /dev/null +++ b/tests/core/test_action_reads_linter.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements... + + +import pytest +from burr.core.action import action +from burr.core.state import State + + +def test_undeclared_state_read_raises_error(): + with pytest.raises(ValueError): + + @action(reads=["foo"], writes=[]) + def bad_action(state: State): + x = state["bar"] + return {}, state + + +def test_declared_state_read_passes(): + @action(reads=["foo"], writes=[]) + def good_action(state: State): + x = state["foo"] + return {}, state From 957d8f705ad025c525b44cf5aeb7b31a89ac4e87 Mon Sep 17 00:00:00 2001 From: Smita Ambiger Date: Thu, 12 Feb 2026 20:27:49 +0530 Subject: [PATCH 2/3] core: cleanup duplicate originating_fn assignment --- burr/core/action.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burr/core/action.py b/burr/core/action.py index a2d2d875b..d45af0d31 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -1174,7 +1174,7 @@ def __init__( [item for item in input_spec[1] if item not in self._bound_params], ) ) - # self._originating_fn = originating_fn if originating_fn is not None else fn + self._schema = schema self._tags = tags if tags is not None else [] From d80246f0d6fd764f1bfe0ae8b1c26c682ea3b701 Mon Sep 17 00:00:00 2001 From: Smita Ambiger Date: Tue, 24 Feb 2026 12:01:20 +0530 Subject: [PATCH 3/3] fix: address review feedback and add regression tests --- burr/core/action.py | 14 +++---- tests/core/test_action.py | 55 ++++++++++++++++++++++++++ tests/core/test_action_reads_linter.py | 23 ----------- 3 files changed, 62 insertions(+), 30 deletions(-) delete mode 100644 tests/core/test_action_reads_linter.py diff --git a/burr/core/action.py b/burr/core/action.py index d45af0d31..b2e7c16d3 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -21,6 +21,7 @@ import copy import inspect import sys +import textwrap import types import typing from collections.abc import AsyncIterator @@ -50,12 +51,16 @@ from burr.core.state import State + def _validate_declared_reads(fn: Callable, declared_reads: list[str]) -> None: + if not declared_reads: + return + try: source = inspect.getsource(fn) except OSError: return # skip if source unavailable - + # detect actual state parameter name sig = inspect.signature(fn) state_param_name = None @@ -68,8 +73,6 @@ def _validate_declared_reads(fn: Callable, declared_reads: list[str]) -> None: if state_param_name is None: return - - import textwrap tree = ast.parse(textwrap.dedent(source)) declared = set(declared_reads) @@ -80,7 +83,6 @@ def visit_Subscript(self, node): if ( isinstance(node.value, ast.Name) and node.value.id == state_param_name - and isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str) ): @@ -98,7 +100,6 @@ def visit_Subscript(self, node): ) - from burr.core.typing import ActionSchema # This is here to make accessing the pydantic actions easier @@ -1164,7 +1165,6 @@ def __init__( self._writes = writes _validate_declared_reads(self._originating_fn, self._reads) - self._bound_params = bound_params if bound_params is not None else {} self._inputs = ( derive_inputs_from_fn(self._bound_params, self._fn) @@ -1174,7 +1174,7 @@ def __init__( [item for item in input_spec[1] if item not in self._bound_params], ) ) - + self._schema = schema self._tags = tags if tags is not None else [] diff --git a/tests/core/test_action.py b/tests/core/test_action.py index fd0ed36b9..83fecf3b3 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -823,3 +823,58 @@ def fn(state, a): required, optional = derive_inputs_from_fn(bound_params, fn) assert required == [] assert optional == [] + + +def test_undeclared_state_read_raises_error(): + with pytest.raises(ValueError): + + @action(reads=["foo"], writes=[]) + def bad_action(state: State): + _ = state["bar"] + return {}, state + + +def test_declared_state_read_passes(): + @action(reads=["foo"], writes=[]) + def good_action(state: State): + _ = state["foo"] + return {}, state + + +def test_multiple_undeclared_reads_interleaved(): + with pytest.raises(ValueError) as exc: + + @action(reads=["foo"], writes=[]) + def bad_action(state: State): + _ = state["foo"] + _ = state["bar"] + _ = state["baz"] + return {}, state + + message = str(exc.value) + assert "bar" in message + assert "baz" in message + + +def test_pydantic_action_not_impacted(): + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic not installed") + + class MyState(BaseModel): + foo: str + + @action.pydantic( + reads=["foo"], + writes=["foo"], + state_input_type=MyState, + state_output_type=MyState, + ) + def good_action(state: MyState): + return {"foo": state.foo} + + # ensure decoration didn't raise and action is creatable + from burr.core.action import create_action + + create_action(good_action, name="test") diff --git a/tests/core/test_action_reads_linter.py b/tests/core/test_action_reads_linter.py deleted file mode 100644 index f8874ee6c..000000000 --- a/tests/core/test_action_reads_linter.py +++ /dev/null @@ -1,23 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements... - - -import pytest -from burr.core.action import action -from burr.core.state import State - - -def test_undeclared_state_read_raises_error(): - with pytest.raises(ValueError): - - @action(reads=["foo"], writes=[]) - def bad_action(state: State): - x = state["bar"] - return {}, state - - -def test_declared_state_read_passes(): - @action(reads=["foo"], writes=[]) - def good_action(state: State): - x = state["foo"] - return {}, state