diff --git a/graph.py b/graph.py index 0946e89..23537d9 100644 --- a/graph.py +++ b/graph.py @@ -13,13 +13,15 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from typing import Optional, List, TypedDict, Annotated, Literal, TypeVar, Type, Protocol, cast, Any, Tuple, NotRequired, Iterable, Callable, Generator, Awaitable, Coroutine +from typing import Optional, List, TypedDict, Annotated, Literal, TypeVar, Type, Protocol, cast, Any, Tuple, NotRequired, Iterable, Generic, Callable, Generator, Awaitable, Coroutine from langchain_core.messages import ToolMessage, AnyMessage, SystemMessage, HumanMessage, BaseMessage, AIMessage, RemoveMessage from langchain_core.tools import InjectedToolCallId, BaseTool from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.runnables import Runnable from langgraph.graph import StateGraph, MessagesState +from langgraph.graph.state import CompiledStateGraph +from langgraph.types import Checkpointer from langgraph._internal._typing import StateLike from langgraph.types import Command from langgraph.prebuilt import ToolNode @@ -409,6 +411,196 @@ def __call__( SplitTool = tuple[dict[str, Any], BaseTool] +_BStateT = TypeVar("_BStateT", bound=MessagesState | None) +_BStateBind = TypeVar("_BStateBind", bound=MessagesState) + +_BContextT = TypeVar("_BContextT", bound=StateLike | None) +_BContextBind = TypeVar("_BContextBind", bound=StateLike) + +_BInputT = TypeVar("_BInputT", bound=FlowInput | None) +_BInputTBind = TypeVar("_BInputTBind", bound=FlowInput) + + +class TemplateLoader(Protocol): + def __call__(self, template_name: str, **kwargs: Any) -> str: ... + + +class Builder( + Generic[_BStateT, _BContextT, _BInputT] +): + def __init__(self): + self._initial_prompt : str | None = None + self._sys_prompt : str | None = None + + self._summary_config : SummaryConfig[_BStateT] | None = None + self._unbound_llm : BaseChatModel | None = None + + self._state_class: type[_BStateT] | None = None + self._input_type : type[_BInputT] | None = None + self._context_type : type[_BContextT] | None = None + self._output_key : str | None = None + self._tools : list[BaseTool | SplitTool] = [] + self._loader : TemplateLoader | None = None + + def _copy_untyped_to_(self, other: "Builder[Any, Any, Any]"): + other._initial_prompt = self._initial_prompt + other._sys_prompt = self._sys_prompt + other._unbound_llm = self._unbound_llm + other._output_key = self._output_key + other._tools.extend(self._tools) + other._loader = self._loader + + def _copy_typed_to(self, other: "Builder[_BStateT, _BContextT, _BInputT]"): + other._state_class = self._state_class + other._context_type = self._context_type + other._input_type = self._input_type + other._summary_config = self._summary_config + + def with_state(self, t: type[_BStateBind]) -> "Builder[_BStateBind, _BContextT, _BInputT]": + to_ret: "Builder[_BStateBind, _BContextT, _BInputT]" = Builder() + self._copy_untyped_to_(to_ret) + to_ret._state_class = t + to_ret._context_type = self._context_type + to_ret._input_type = self._input_type + return to_ret + + def with_context(self, t: type[_BContextBind]) -> "Builder[_BStateT, _BContextBind, _BInputT]": + to_ret: "Builder[_BStateT, _BContextBind, _BInputT]" = Builder() + self._copy_untyped_to_(to_ret) + to_ret._state_class = self._state_class + to_ret._context_type = t + to_ret._input_type = self._input_type + to_ret._summary_config = self._summary_config + return to_ret + + def with_input(self, t: type[_BInputTBind]) -> "Builder[_BStateT, _BContextT, _BInputTBind]": + to_ret: "Builder[_BStateT, _BContextT, _BInputTBind]" = Builder() + self._copy_untyped_to_(to_ret) + to_ret._state_class = self._state_class + to_ret._context_type = self._context_type + to_ret._input_type = t + to_ret._summary_config = self._summary_config + return to_ret + + def with_initial_prompt(self, prompt: str) -> "Builder[_BStateT, _BContextT, _BInputT]": + to_ret: "Builder[_BStateT, _BContextT, _BInputT]" = Builder() + self._copy_untyped_to_(to_ret) + self._copy_typed_to(to_ret) + to_ret._initial_prompt = prompt + return to_ret + + def with_initial_prompt_template(self, template: str, **kwargs) -> "Builder[_BStateT, _BContextT, _BInputT]": + if self._loader is None: + raise ValueError("No loader configured. Use with_loader first.") + return self.with_initial_prompt(self._loader(template, **kwargs)) + + def with_sys_prompt(self, prompt: str) -> "Builder[_BStateT, _BContextT, _BInputT]": + to_ret: "Builder[_BStateT, _BContextT, _BInputT]" = Builder() + self._copy_untyped_to_(to_ret) + self._copy_typed_to(to_ret) + to_ret._sys_prompt = prompt + return to_ret + + def with_sys_prompt_template(self, template: str, **kwargs) -> "Builder[_BStateT, _BContextT, _BInputT]": + if self._loader is None: + raise ValueError("No loader configured. Use with_loader first.") + return self.with_sys_prompt(self._loader(template, **kwargs)) + + def with_loader(self, loader: TemplateLoader) -> "Builder[_BStateT, _BContextT, _BInputT]": + to_ret: "Builder[_BStateT, _BContextT, _BInputT]" = Builder() + self._copy_untyped_to_(to_ret) + self._copy_typed_to(to_ret) + to_ret._loader = loader + return to_ret + + def with_llm(self, llm: BaseChatModel) -> "Builder[_BStateT, _BContextT, _BInputT]": + to_ret: "Builder[_BStateT, _BContextT, _BInputT]" = Builder() + self._copy_untyped_to_(to_ret) + self._copy_typed_to(to_ret) + to_ret._unbound_llm = llm + return to_ret + + def with_output_key(self, key: str) -> "Builder[_BStateT, _BContextT, _BInputT]": + to_ret: "Builder[_BStateT, _BContextT, _BInputT]" = Builder() + self._copy_untyped_to_(to_ret) + self._copy_typed_to(to_ret) + to_ret._output_key = key + return to_ret + + def with_summary_config(self, config: SummaryConfig[_BStateT]) -> "Builder[_BStateT, _BContextT, _BInputT]": + to_ret: "Builder[_BStateT, _BContextT, _BInputT]" = Builder() + self._copy_untyped_to_(to_ret) + self._copy_typed_to(to_ret) + to_ret._summary_config = config + return to_ret + + def with_default_summarizer(self, *, max_messages: int = 20, enabled: bool = True) -> "Builder[_BStateT, _BContextT, _BInputT]": + return self.with_summary_config(SummaryConfig(max_messages=max_messages, enabled=enabled)) + + def with_tools(self, l: Iterable[BaseTool | SplitTool]) -> "Builder[_BStateT, _BContextT, _BInputT]": + to_ret: "Builder[_BStateT, _BContextT, _BInputT]" = Builder() + self._copy_typed_to(to_ret) + self._copy_untyped_to_(to_ret) + to_ret._tools.extend(l) + return to_ret + + def _build_internal(self, r: _ResultFact, i: _InitialFact, s: _SummarizerFact) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore + if self._state_class is None: + raise ValueError("state_class is required") + if self._input_type is None: + raise ValueError("input_type is required") + if self._sys_prompt is None: + raise ValueError("sys_prompt is required") + if self._initial_prompt is None: + raise ValueError("initial_prompt is required") + if self._output_key is None: + raise ValueError("output_key is required") + if self._unbound_llm is None: + raise ValueError("unbound_llm is required") + + return _build_workflow( + state_class=self._state_class, #type: ignore + input_type=self._input_type, #type: ignore + tools_list=self._tools, + sys_prompt=self._sys_prompt, + initial_prompt=self._initial_prompt, + output_key=self._output_key, + unbound_llm=self._unbound_llm, + context_schema=self._context_type, + summary_config=self._summary_config, #type: ignore + init_fact=i, + result_fact=r, + summary_fact=s, + output_schema=None + ) + + def build(self) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore + return self._build_internal( + s=get_summarizer, + i=initial_node, + r=tool_result_generator + ) + + def build_async(self) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore + return self._build_internal( + s=get_async_summarizer, + i=async_initial_node, + r=async_tool_result_generator + ) + + def compile_async( + self, *, + checkpointer: Checkpointer = None + ) -> CompiledStateGraph[_BStateT, _BContextT, _BInputT, Any]: #type: ignore + return self.build_async()[0].compile( + checkpointer=checkpointer + ) + + def compile(self, checkpointer: Checkpointer = None) -> CompiledStateGraph[_BStateT, _BContextT, _BInputT, Any]: #type: ignore + return self.build()[0].compile( + checkpointer=checkpointer + ) + def build_workflow( state_class: Type[StateT], input_type: Type[InputState], diff --git a/summary.py b/summary.py index 4bd0b7c..bfb5182 100644 --- a/summary.py +++ b/summary.py @@ -17,7 +17,7 @@ from typing import Generic, TypeVar -StateT = TypeVar("StateT") +StateT = TypeVar("StateT", contravariant=True) logger = logging.getLogger(__name__) diff --git a/tools/results.py b/tools/results.py index 70ab37a..eed4929 100644 --- a/tools/results.py +++ b/tools/results.py @@ -34,7 +34,7 @@ def result_tool_generator( outkey: str, result_schema: type[M], doc: str, - validator: tuple[type[ST], Callable[[ST, M, str], ValidationResult]] + validator: tuple[type[ST], Callable[[ST, M, str], ValidationResult]] | Callable[[M, str], ValidationResult] | None = None ) -> BaseTool: """ Generates a tool that can be used to complete a workflow @@ -53,61 +53,12 @@ def result_tool_generator( """ ... -@overload -def result_tool_generator( - outkey: str, - result_schema: type[M], - doc: str, - validator: Callable[[M, str], ValidationResult] -) -> BaseTool: - """ - Generates a tool that can be used to complete a workflow - Args: - outkey (str): The name of the key in the state which holds the result, and whose presence signals - completion - result_schema (type[M]): A BaseModel type which is the type of the completed state. Each field of this - basemodel becomes a field in the generated tool schema, and so these fields SHOULD have string descriptions. - doc (str): The documentation to use for the generated tool - validator (Callable[[M, str], ValidationResult]): A validator which simply accepts the resultant basemodel - and the current tool call id, and return None if there is no issue, otherwise it may return a string - (which is returned as the result of the tool call WITHOUT setting outkey), or it may return an arbitrary command. - - Returns: - BaseTool: The generated result tool - """ - ... - - @overload def result_tool_generator( outkey: str, result_schema: tuple[type[R], str], doc: str, - validator: Callable[[R, str], ValidationResult] -) -> BaseTool: - """ - Generates a tool that can be used to complete a workflow - Args: - outkey (str): The name of the key in the state which holds the result, and whose presence signals - completion - result_schema (tuple[type[R], str]): A tuple of the desired result type, and a description of what the output - should be. - doc (str): The documentation to use for the generated tool - validator (Callable[[R, str], ValidationResult]): A validator which simply accepts the resultant value - and the current tool call id, and return None if there is no issue, otherwise it may return a string - (which is returned as the result of the tool call WITHOUT setting outkey), or it may return an arbitrary command. - - Returns: - BaseTool: The generated result tool - """ - ... - -@overload -def result_tool_generator( - outkey: str, - result_schema: tuple[type[R], str], - doc: str, - validator: tuple[type[ST], Callable[[ST, R, str], ValidationResult]] + validator: tuple[type[ST], Callable[[ST, R, str], ValidationResult]] | Callable[[R, str], ValidationResult] | None = None ) -> BaseTool: """ Generates a tool that can be used to complete a workflow @@ -127,26 +78,6 @@ def result_tool_generator( """ ... -@overload -def result_tool_generator( - outkey: str, - result_schema: type[BaseModel] | tuple[type, str], - doc: str, -) -> BaseTool: - """ - Generates a tool that can be used to complete a workflow - Args: - outkey (str): The name of the key in the state which holds the result, and whose presence signals - completion - result_schema (type[BaseModel] | tuple[type, str]): Either a BaseModel type (where each field becomes - a field in the generated tool schema) or a tuple of the desired result type and description. - doc (str): The documentation to use for the generated tool - - Returns: - BaseTool: The generated result tool - """ - ... - def result_tool_generator( outkey: str, result_schema: type[BaseModel] | tuple[type, str], diff --git a/tools/schemas.py b/tools/schemas.py new file mode 100644 index 0000000..9be7168 --- /dev/null +++ b/tools/schemas.py @@ -0,0 +1,71 @@ +from typing import Generic, TypeVar, Annotated, Any + +from pydantic import BaseModel + +from langchain_core.tools import InjectedToolCallId +from langgraph.prebuilt import InjectedState +from langgraph.types import Command +from langchain_core.tools import StructuredTool, BaseTool + +ST = TypeVar("ST") + +T_RES = TypeVar("T_RES", bound=str | Command) + +class WithInjectedState(BaseModel, Generic[ST]): + state: Annotated[ST, InjectedState] + +class WithInjectedId(BaseModel): + tool_call_id: Annotated[str, InjectedToolCallId] + +class WithImplementation(BaseModel, Generic[T_RES]): + def run(self) -> T_RES: + """Override this method to implement the tool logic.""" + raise NotImplementedError("Subclasses must implement run()") + + @classmethod + def as_tool( + cls, + name: str + ) -> BaseTool: + impl_method = getattr(cls, "run") + + # Simple wrapper - just accept kwargs, instantiate model, call method + def wrapper(**kwargs: Any) -> Any: + instance = cls(**kwargs) + return impl_method(instance) + + return StructuredTool.from_function( + func=wrapper, + args_schema=cls, + description=cls.__doc__, + name=name, + ) + +class WithAsyncImplementation(BaseModel, Generic[T_RES]): + async def run(self) -> T_RES: + """Override this method to implement the tool logic.""" + raise NotImplementedError("Subclasses must implement run()") + + @classmethod + def as_tool( + cls, + name: str + ) -> BaseTool: + impl_method = getattr(cls, "run") + + # Simple wrapper - just accept kwargs, instantiate model, call method + async def wrapper(**kwargs: Any) -> Any: + instance = cls(**kwargs) + d = await impl_method(instance) + return d + + return StructuredTool.from_function( + coroutine=wrapper, + args_schema=cls, + description=cls.__doc__, + name=name, + ) + + +class InjectAll(WithInjectedState[ST], WithInjectedId): + pass diff --git a/tools/vfs.py b/tools/vfs.py index e50d8e9..c78504c 100644 --- a/tools/vfs.py +++ b/tools/vfs.py @@ -27,9 +27,135 @@ from langchain_core.tools.base import BaseTool from langgraph.prebuilt import InjectedState from langgraph.types import Command +from ..graph import FlowInput from ..graph import tool_output + +def _copy_base_doc[T](cls: T) -> T: + """Decorator to copy __doc__ from the first base class.""" + for base in cls.__bases__: # type: ignore + if base.__doc__: + cls.__doc__ = base.__doc__ + break + return cls + + +# returns true if the file is okay to access +def _make_checker(patt: str | None) -> Callable[[str], bool]: + if patt is None: + return lambda f_name: True + match = re.compile(patt) + return lambda f_name: match.fullmatch(f_name) is None + +class FileRange(BaseModel): + start_line: int = Field(description="The line to start reading from; lines are numbered starting from 1.") + end_line: int = Field(description="The line to read until EXCLUSIVE.") + +def _get_file(cont: str | None, range: FileRange | None) -> str: + if cont is None: + return "File not found" + if not range: + return cont + start = range.start_line - 1 + to_ret = cont.splitlines()[start:range.end_line - 1] + return "\n".join(to_ret) + + + +def _grep_impl( + search_string: str, + matching_lines: bool, + file_contents: Iterator[tuple[str, str]], + get_content: Callable[[str], str | None], + match_in: list[str] | None +) -> str: + """ + Generic grep implementation over file contents. + + Args: + search_string: Regex pattern to search for + file_contents: Iterator of (filename, content) tuples + check_allowed: Filter function for allowed filenames + + Returns: + Newline-separated list of matching filenames, or error message + """ + comp: re.Pattern + try: + comp = re.compile(search_string, re.MULTILINE) + except Exception: + return "Illegal pattern name, check your syntax and try again." + + matches: list[str] = [] + + match_set = None if not match_in else set(match_in) + + should_search = \ + (lambda _: True) if match_set is None else \ + (lambda f: f in match_set) + + for (k, v) in file_contents: + if not should_search(k): + continue + if comp.search(v) is not None: + matches.append(k) + + if not matching_lines: + return "\n".join(matches) + + matched_lines = [] + + for match_name in matches: + cont_s = get_content(match_name) + assert cont_s is not None + cont = cont_s.splitlines() + for (lno, l) in enumerate(cont, start=1): + if comp.search(l): + matched_lines.append(f"{match_name}:{lno}:{l}") + + return "\n".join(matched_lines) + + + +class _GetFileSchemaBase(BaseModel): + """ + Read the contents of the VFS at some relative path. + + If the path doesn't exist, this function returns "File not found". + """ + path: str = Field(description="The relative path of the file on the VFS. IMPORTANT: Do NOT include a leading `./` it is implied") + range: FileRange | None = Field(description="If set, (start, end) indicates to return lines starting from line `start` (lines are 1 indexed) until `end` (exclusive). If unset, the entire file is returned.", default=None) + + +class _ListFileSchemaBase(BaseModel): + """ + Lists all file contents of the VFS, including in any subdirectories. Directory entries are *not* included. + Each file in the VFS has its own line in the output, any empty lines should be ignored. + """ + pass + + +class _GrepFileSchemaBase(BaseModel): + """ + Search for a specific string in the files on the VFS. The output depends on the + value of the `matching_lines` argument. If false, returns a list of + file names which contain the query somewhere in their contents, with one file name per line. + If true, returns a list of matching lines in files with the format: + ``` + $filename:$lineno:$line + ``` + $line is a line matching the search string in $filename at $lineno (starting at 1). + + In both output modes, empty lines should be ignored. + + In both modes, the paths to search can be restricted with `match_in`. + """ + + search_string: str = Field(description="The query string to search for provided as a python regex. Thus, you must escape any special characters (like [, |, etc.)") + matching_lines: bool = Field(description="If true, show the matching lines and the line number; if false, simply list the matching files") + match_in: list[str] | None = Field(description="If set, narrow the search to only the paths listed here.", default=None) + def merge_vfs(left: dict[str, str], right: dict[str, str]) -> dict[str, str]: new_left = left.copy() for (f_name, cont) in right.items(): @@ -40,6 +166,9 @@ def merge_vfs(left: dict[str, str], right: dict[str, str]) -> dict[str, str]: class VFSState(TypedDict): vfs: Annotated[dict[str, str], merge_vfs] +class VFSInput(FlowInput): + vfs: dict[str, str] + InputType = TypeVar("InputType", bound=VFSState) StateVar = TypeVar("StateVar", contravariant=True) @@ -170,15 +299,8 @@ class PutFileSchema(BaseModel): PutFileSchema.__doc__ = pf_doc - # returns true if the file is okay to put or get - def make_checker(patt: str | None) -> Callable[[str], bool]: - if patt is None: - return lambda f_name: True - match = re.compile(patt) - return lambda f_name: match.fullmatch(f_name) is None - - put_filter = make_checker(conf.get("forbidden_write")) - get_filter = make_checker(conf.get("forbidden_read")) + put_filter = _make_checker(conf.get("forbidden_write")) + get_filter = _make_checker(conf.get("forbidden_read")) @tool(args_schema=PutFileSchema) def put_file( @@ -195,9 +317,7 @@ def put_file( } ) - def _get_content(s: InputType, path: str) -> str | None: - if not get_filter(path): - return None + def _get_content_raw(s: InputType, path: str) -> str | None: vfs = s["vfs"] if path not in vfs: layer = conf.get("fs_layer", None) @@ -206,38 +326,31 @@ def _get_content(s: InputType, path: str) -> str | None: child = pathlib.Path(layer) / path if not child.is_file(): return None - return child.read_text() + try: + return child.read_text() + except: + return None else: return vfs[path] - - class FileRange(BaseModel): - start_line: int = Field(description="The line to start reading from; lines are numbered starting from 1.") - end_line: int = Field(description="The line to read until EXCLUSIVE.") + + def _get_content(s: InputType, path: str) -> str | None: + if not get_filter(path): + return None + return _get_content_raw(s, path) @inject(doc_extra=conf.get('get_doc_extra')) - class GetFileSchema(BaseModel): - """ - Read the contents of the VFS at some relative path. - - If the path doesn't exist, this function returns "File not found". - """ - path: str = Field(description="The relative path of the file on the VFS. IMPORTANT: Do NOT include a leading `./` it is implied") - range: FileRange | None = Field(description="If set, (start, end) indicates to return lines starting from line `start` (lines are 1 indexed) until `end` (exclusive). If unset, the entire file is returned.") + @_copy_base_doc + class GetFileSchema(_GetFileSchemaBase): + pass @tool(args_schema=GetFileSchema) def get_file( path: str, - range: FileRange | None, - state: Annotated[InputType, InjectedState] + state: Annotated[InputType, InjectedState], + range: FileRange | None = None ) -> str: cont = _get_content(state, path) - if cont is None: - return "File not found" - if not range: - return cont - start = range.start_line - 1 - to_ret = cont.splitlines()[start:range.end_line - 1] - return "\n".join(to_ret) + return _get_file(cont, range) @cache def list_underlying() -> Sequence[str]: @@ -248,48 +361,33 @@ def list_underlying() -> Sequence[str]: return [str(f.relative_to(base)) for f in base.rglob("*") if f.is_file()] @inject() - class ListFileSchema(BaseModel): - """ - Lists all file contents of the VFS, including in any subdirectories. Directory entries are *not* included. - Each file in the VFS has its own line in the output, any empty lines should be ignored. - """ + @_copy_base_doc + class ListFileSchema(_ListFileSchemaBase): pass - @tool(args_schema=ListFileSchema) - def list_files( - state: Annotated[InputType, InjectedState] - ) -> str: - to_ret = [] + def _list_files( + state: InputType + ) -> Iterator[str]: for (k, _) in state["vfs"].items(): if not get_filter(k): continue - to_ret.append(k) + yield k for f_name in list_underlying(): if not get_filter(f_name) or f_name in state["vfs"]: continue - to_ret.append(f_name) + yield f_name + + @tool(args_schema=ListFileSchema) + def list_files( + state: Annotated[InputType, InjectedState] + ) -> str: + to_ret = list(_list_files(state)) return "\n".join(to_ret) @inject() - class GrepFileSchema(BaseModel): - """ - Search for a specific string in the files on the VFS. The output depends on the - value of the `matching_lines` argument. If false, returns a list of - file names which contain the query somewhere in their contents, with one file name per line. - If true, returns a list of matching lines in files with the format: - ``` - $filename:$lineno:$line - ``` - $line is a line matching the search string in $filename at $lineno (starting at 1). - - In both output modes, empty lines should be ignored. - - In both modes, the paths to search can be restricted with `match_in`. - """ - - search_string: str = Field(description="The query string to search for provided as a python regex. Thus, you must escape any special characters (like [, |, etc.)") - matching_lines: bool = Field(description="If true, show the matching lines and the line number; if false, simply list the matching files") - match_in: list[str] | None = Field(description="If set, narrow the search to only the paths listed here.", default=None) + @_copy_base_doc + class GrepFileSchema(_GrepFileSchemaBase): + pass @tool(args_schema=GrepFileSchema) def grep_files( @@ -298,58 +396,14 @@ def grep_files( matching_lines: bool, match_in: list[str] | None = None ) -> str: - comp: re.Pattern - try: - comp = re.compile(search_string, re.MULTILINE) - except Exception: - return "Illegal pattern name, check your syntax and try again." - - match_set = set(match_in) if match_in else None - - matches: list[str] = [] - - def should_search(s: str) -> bool: - if not get_filter(s): - return False - if match_set is None: - return True - return s in match_set - - for (k, v) in state["vfs"].items(): - if not should_search(k): - continue - if comp.search(v) is not None: - matches.append(k) - - if (layer := conf.get("fs_layer", None)) is not None: - p = pathlib.Path(layer) - for f in p.rglob("*"): - if not f.is_file(): + def file_contents() -> Iterator[tuple[str, str]]: + for path in _list_files(state): + cont = _get_content_raw(state, path) + if not cont: continue - rel_name = str(f.relative_to(p)) - if not should_search(rel_name): - continue - if rel_name in state["vfs"]: - continue - if not comp.search(f.read_text()): - continue - matches.append(rel_name) - - if not matching_lines: - return "\n".join(matches) - - matched_lines = [] - - for match_name in matches: - cont_s = _get_content(state, match_name) - assert cont_s is not None - cont = cont_s.splitlines() - for (lno, l) in enumerate(cont, start=1): - if comp.search(l): - matched_lines.append(f"{match_name}:{lno}:{l}") - - return "\n".join(matched_lines) - + yield (path, cont) + + return _grep_impl(search_string, matching_lines, file_contents(), lambda p: _get_content_raw(state, p), match_in) tools: list[BaseTool] = [get_file, list_files, grep_files] if not conf["immutable"]: @@ -358,3 +412,80 @@ def should_search(s: str) -> bool: materializer = _VFSAccess[InputType](conf=conf) return (tools, materializer) + + +def fs_tools(fs_layer: str, forbidden_read: str | None = None) -> list[BaseTool]: + """ + Create stateless file system tools that operate directly on a directory. + + Unlike vfs_tools, these tools don't use langgraph state - they simply + read from the provided filesystem path. Useful for immutable file access + where no VFS overlay is needed. + + Args: + fs_layer: Path to the directory to expose + forbidden_read: Optional regex pattern for paths that cannot be read + + Returns: + List of tools: [get_file, list_files, grep_files] + """ + base_path = pathlib.Path(fs_layer) + check_allowed = _make_checker(forbidden_read) + + @cache + def list_all_files() -> Sequence[str]: + return [str(f.relative_to(base_path)) for f in base_path.rglob("*") if f.is_file()] + + @_copy_base_doc + class GetFileSchema(_GetFileSchemaBase): + pass + + @tool(args_schema=GetFileSchema) + def get_file(path: str, range: FileRange | None = None) -> str: + if not check_allowed(path): + return "File not found" + child = base_path / path + if child.is_file(): + try: + return _get_file(child.read_text(), range) + except Exception: + return "File not found" + return "File not found" + + @_copy_base_doc + class ListFileSchema(_ListFileSchemaBase): + pass + + @tool(args_schema=ListFileSchema) + def list_files() -> str: + return "\n".join(f for f in list_all_files() if check_allowed(f)) + + @_copy_base_doc + class GrepFileSchema(_GrepFileSchemaBase): + pass + + @tool(args_schema=GrepFileSchema) + def grep_files( + search_string: str, + matching_lines: bool, + match_in: list[str] | None = None + ) -> str: + def file_contents() -> Iterator[tuple[str, str]]: + for f in base_path.rglob("*"): + if not f.is_file(): + continue + rel_name = str(f.relative_to(base_path)) + if not check_allowed(rel_name): + continue + try: + yield (rel_name, f.read_text()) + except Exception: + continue + def read_file(p: str) -> str | None: + try: + return (base_path / p).read_text() + except: + return None + return _grep_impl(search_string, matching_lines, file_contents(), read_file, match_in) + + return [get_file, list_files, grep_files]