diff --git a/graphcore/graph.py b/graphcore/graph.py index 7ae8b5e..aae550c 100644 --- a/graphcore/graph.py +++ b/graphcore/graph.py @@ -249,7 +249,7 @@ def __call__( def _get_summarizer_pure( system_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, state_type: type[StateT], context: SummaryConfig[StateT] ) -> PureFunction: @@ -268,11 +268,13 @@ def to_return(state: StateT) -> PureFunctionGenerator: summary = msg.text resume_message = config.get_resume_prompt(state, summary) config.on_summary(state, summary, resume_message) + # Handle initial_prompt as either str or dict (with cache_control) + initial_content: list[str | dict] = [initial_prompt] return { "messages": [ RemoveMessage(id="__remove_all__"), SystemMessage(content=system_prompt), - HumanMessage(content=initial_prompt, display_tag="initial_prompt"), + HumanMessage(content=initial_content, display_tag="initial_prompt"), HumanMessage(content=resume_message, display_tag="resume"), ] } @@ -300,7 +302,7 @@ def _get_initial_pure( t: Type[I], output_state: Type[O], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, ) -> PureFunction[I, O]: def impl( state: I @@ -350,7 +352,7 @@ def impl( def get_summarizer( llm: LLM, system_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, state_type: type[StateT], context: SummaryConfig[StateT] ) -> ChatNodeFunction[StateT]: @@ -362,7 +364,7 @@ def get_summarizer( def get_async_summarizer( llm: LLM, system_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, state_type: type[StateT], context: SummaryConfig[StateT] ) -> AsyncChatNodeFunction[StateT]: @@ -376,7 +378,7 @@ def __call__( self, llm: LLM, system_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, state_type: type[StateT], context: SummaryConfig[StateT] ) -> AnyChatNodeFunction[StateT]: @@ -386,7 +388,7 @@ def initial_node( t: Type[InputState], output_state: Type[StateT], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, llm: LLM ) -> NodeFunction[InputState, StateT]: return _stitch_sync_impl( @@ -398,7 +400,7 @@ def async_initial_node( t: Type[InputState], output_state: Type[StateT], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, llm: LLM ) -> AsyncNodeFunction[InputState, StateT]: return _stitch_async_impl( @@ -412,7 +414,7 @@ def __call__( t: Type[InputState], output_state: Type[StateT], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, llm: LLM ) -> AnyNodeFunction[InputState, StateT]: ... @@ -656,7 +658,7 @@ def build_workflow( input_type: Type[InputState], tools_list: Iterable[BaseTool | SplitTool], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, output_key: str, unbound_llm: BaseChatModel, output_schema: Optional[Type[OutputT]] = None, @@ -685,7 +687,7 @@ def build_async_workflow( input_type: Type[InputState], tools_list: Iterable[BaseTool | SplitTool], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, output_key: str, unbound_llm: BaseChatModel, output_schema: Optional[Type[OutputT]] = None, @@ -713,7 +715,7 @@ def _build_workflow( input_type: Type[InputState], tools_list: Iterable[BaseTool | SplitTool], sys_prompt: str, - initial_prompt: str, + initial_prompt: str | dict, output_key: str, unbound_llm: BaseChatModel, output_schema: Optional[Type[OutputT]],