From 2be06583c89858cb00d48154eee6d42e74b0a810 Mon Sep 17 00:00:00 2001 From: Shelly Grossman Date: Mon, 23 Feb 2026 22:48:14 +0200 Subject: [PATCH 1/3] Allow caching the initial prompt, change of types --- graph.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/graph.py b/graph.py index 955b45f..4177071 100644 --- a/graph.py +++ b/graph.py @@ -248,7 +248,7 @@ def __call__( def _get_summarizer_pure( system_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, state_type: type[StateT], context: SummaryConfig[StateT] ) -> PureFunction: @@ -267,11 +267,13 @@ def to_return(state: StateT) -> PureFunctionGenerator: summary = msg.text() resume_message = config.get_resume_prompt(state, summary) config.on_summary(state, summary, resume_message) + # Handle initial_prompt as either str or dict (with cache_control) + initial_content = [initial_prompt] if isinstance(initial_prompt, dict) else initial_prompt return { "messages": [ RemoveMessage(id="__remove_all__"), SystemMessage(content=system_prompt), - HumanMessage(content=initial_prompt), + HumanMessage(content=initial_content), HumanMessage(content=resume_message), ] } @@ -305,7 +307,7 @@ def _get_initial_pure( t: Type[I], output_state: Type[O], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, ) -> PureFunction[I, O]: def impl( state: I @@ -354,7 +356,7 @@ def impl( def get_summarizer( llm: LLM, system_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, state_type: type[StateT], context: SummaryConfig[StateT] ) -> ChatNodeFunction[StateT]: @@ -366,7 +368,7 @@ def get_summarizer( def get_async_summarizer( llm: LLM, system_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, state_type: type[StateT], context: SummaryConfig[StateT] ) -> AsyncChatNodeFunction[StateT]: @@ -380,7 +382,7 @@ def __call__( self, llm: LLM, system_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, state_type: type[StateT], context: SummaryConfig[StateT] ) -> AnyChatNodeFunction[StateT]: @@ -390,7 +392,7 @@ def initial_node( t: Type[InputState], output_state: Type[StateT], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, llm: LLM ) -> NodeFunction[InputState, StateT]: return _stitch_sync_impl( @@ -402,7 +404,7 @@ def async_initial_node( t: Type[InputState], output_state: Type[StateT], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, llm: LLM ) -> AsyncNodeFunction[InputState, StateT]: return _stitch_async_impl( @@ -416,7 +418,7 @@ def __call__( t: Type[InputState], output_state: Type[StateT], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, llm: LLM ) -> AnyNodeFunction[InputState, StateT]: ... @@ -655,7 +657,7 @@ def build_workflow( input_type: Type[InputState], tools_list: Iterable[BaseTool | SplitTool], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, output_key: str, unbound_llm: BaseChatModel, output_schema: Optional[Type[OutputT]] = None, @@ -685,7 +687,7 @@ def build_async_workflow( input_type: Type[InputState], tools_list: Iterable[BaseTool | SplitTool], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, output_key: str, unbound_llm: BaseChatModel, output_schema: Optional[Type[OutputT]] = None, @@ -714,7 +716,7 @@ def _build_workflow( input_type: Type[InputState], tools_list: Iterable[BaseTool | SplitTool], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, output_key: str, unbound_llm: BaseChatModel, output_schema: Optional[Type[OutputT]], From 36f79c18514d6db6192ea52ef372ae4a36ae6250 Mon Sep 17 00:00:00 2001 From: Shelly Grossman Date: Wed, 18 Mar 2026 19:31:38 +0200 Subject: [PATCH 2/3] fix mypy type error for initial_content Co-Authored-By: Claude Opus 4.6 --- graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph.py b/graph.py index 4177071..7c0bd50 100644 --- a/graph.py +++ b/graph.py @@ -268,7 +268,7 @@ def to_return(state: StateT) -> PureFunctionGenerator: resume_message = config.get_resume_prompt(state, summary) config.on_summary(state, summary, resume_message) # Handle initial_prompt as either str or dict (with cache_control) - initial_content = [initial_prompt] if isinstance(initial_prompt, dict) else initial_prompt + initial_content: str | list[str | dict] = [initial_prompt] if isinstance(initial_prompt, dict) else initial_prompt return { "messages": [ RemoveMessage(id="__remove_all__"), From a396bd15296060ecff61ba7ef3266924a2eb8eab Mon Sep 17 00:00:00 2001 From: Shelly Grossman Date: Thu, 2 Apr 2026 22:29:54 +0300 Subject: [PATCH 3/3] simplified --- graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph.py b/graph.py index 3184e22..abeb2cd 100644 --- a/graph.py +++ b/graph.py @@ -269,7 +269,7 @@ def to_return(state: StateT) -> PureFunctionGenerator: resume_message = config.get_resume_prompt(state, summary) config.on_summary(state, summary, resume_message) # Handle initial_prompt as either str or dict (with cache_control) - initial_content: str | list[str | dict] = [initial_prompt] if isinstance(initial_prompt, dict) else initial_prompt + initial_content: list[str | dict] = [initial_prompt] return { "messages": [ RemoveMessage(id="__remove_all__"),