@@ -98,21 +98,116 @@ def _set_conditional_node_edges(self):
9898 except :
9999 node .false_node_name = None
100100
101- def _execute_standard (self , initial_state : dict ) -> Tuple [dict , list ]:
102- """
103- Executes the graph by traversing nodes starting from the
104- entry point using the standard method.
101+ def _get_node_by_name (self , node_name : str ):
102+ """Returns a node instance by its name."""
103+ return next (node for node in self .nodes if node .node_name == node_name )
105104
106- Args:
107- initial_state (dict): The initial state to pass to the entry point node.
105+ def _update_source_info (self , current_node , state ):
106+ """Updates source type and source information from FetchNode."""
107+ source_type = None
108+ source = []
109+ prompt = None
110+
111+ if current_node .__class__ .__name__ == "FetchNode" :
112+ source_type = list (state .keys ())[1 ]
113+ if state .get ("user_prompt" , None ):
114+ prompt = state ["user_prompt" ] if isinstance (state ["user_prompt" ], str ) else None
115+
116+ if source_type == "local_dir" :
117+ source_type = "html_dir"
118+ elif source_type == "url" :
119+ if isinstance (state [source_type ], list ):
120+ source .extend (url for url in state [source_type ] if isinstance (url , str ))
121+ elif isinstance (state [source_type ], str ):
122+ source .append (state [source_type ])
123+
124+ return source_type , source , prompt
125+
126+ def _get_model_info (self , current_node ):
127+ """Extracts LLM and embedder model information from the node."""
128+ llm_model = None
129+ llm_model_name = None
130+ embedder_model = None
108131
109- Returns:
110- Tuple[dict, list]: A tuple containing the final state and a list of execution info.
132+ if hasattr (current_node , "llm_model" ):
133+ llm_model = current_node .llm_model
134+ if hasattr (llm_model , "model_name" ):
135+ llm_model_name = llm_model .model_name
136+ elif hasattr (llm_model , "model" ):
137+ llm_model_name = llm_model .model
138+ elif hasattr (llm_model , "model_id" ):
139+ llm_model_name = llm_model .model_id
140+
141+ if hasattr (current_node , "embedder_model" ):
142+ embedder_model = current_node .embedder_model
143+ if hasattr (embedder_model , "model_name" ):
144+ embedder_model = embedder_model .model_name
145+ elif hasattr (embedder_model , "model" ):
146+ embedder_model = embedder_model .model
147+
148+ return llm_model , llm_model_name , embedder_model
149+
150+ def _get_schema (self , current_node ):
151+ """Extracts schema information from the node configuration."""
152+ if not hasattr (current_node , "node_config" ):
153+ return None
154+
155+ if not isinstance (current_node .node_config , dict ):
156+ return None
157+
158+ schema_config = current_node .node_config .get ("schema" )
159+ if not schema_config or isinstance (schema_config , dict ):
160+ return None
161+
162+ try :
163+ return schema_config .schema ()
164+ except Exception :
165+ return None
166+
167+ def _execute_node (self , current_node , state , llm_model , llm_model_name ):
168+ """Executes a single node and returns execution information."""
169+ curr_time = time .time ()
170+
171+ with self .callback_manager .exclusive_get_callback (llm_model , llm_model_name ) as cb :
172+ result = current_node .execute (state )
173+ node_exec_time = time .time () - curr_time
174+
175+ cb_data = None
176+ if cb is not None :
177+ cb_data = {
178+ "node_name" : current_node .node_name ,
179+ "total_tokens" : cb .total_tokens ,
180+ "prompt_tokens" : cb .prompt_tokens ,
181+ "completion_tokens" : cb .completion_tokens ,
182+ "successful_requests" : cb .successful_requests ,
183+ "total_cost_USD" : cb .total_cost ,
184+ "exec_time" : node_exec_time ,
185+ }
186+
187+ return result , node_exec_time , cb_data
188+
189+ def _get_next_node (self , current_node , result ):
190+ """Determines the next node to execute based on current node type and result."""
191+ if current_node .node_type == "conditional_node" :
192+ node_names = {node .node_name for node in self .nodes }
193+ if result in node_names :
194+ return result
195+ elif result is None :
196+ return None
197+ raise ValueError (
198+ f"Conditional Node returned a node name '{ result } ' that does not exist in the graph"
199+ )
200+
201+ return self .edges .get (current_node .node_name )
202+
203+ def _execute_standard (self , initial_state : dict ) -> Tuple [dict , list ]:
204+ """
205+ Executes the graph by traversing nodes starting from the entry point using the standard method.
111206 """
112207 current_node_name = self .entry_point
113208 state = initial_state
114-
115- # variables for tracking execution info
209+
210+ # Tracking variables
116211 total_exec_time = 0.0
117212 exec_info = []
118213 cb_total = {
@@ -134,104 +229,51 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
134229 schema = None
135230
136231 while current_node_name :
137- curr_time = time .time ()
138- current_node = next (node for node in self .nodes if node .node_name == current_node_name )
139-
140- if current_node .__class__ .__name__ == "FetchNode" :
141- source_type = list (state .keys ())[1 ]
142- if state .get ("user_prompt" , None ):
143- prompt = state ["user_prompt" ] if isinstance (state ["user_prompt" ], str ) else None
144-
145- if source_type == "local_dir" :
146- source_type = "html_dir"
147- elif source_type == "url" :
148- if isinstance (state [source_type ], list ):
149- for url in state [source_type ]:
150- if isinstance (url , str ):
151- source .append (url )
152- elif isinstance (state [source_type ], str ):
153- source .append (state [source_type ])
154-
155- if hasattr (current_node , "llm_model" ) and llm_model is None :
156- llm_model = current_node .llm_model
157- if hasattr (llm_model , "model_name" ):
158- llm_model_name = llm_model .model_name
159- elif hasattr (llm_model , "model" ):
160- llm_model_name = llm_model .model
161- elif hasattr (llm_model , "model_id" ):
162- llm_model_name = llm_model .model_id
163-
164- if hasattr (current_node , "embedder_model" ) and embedder_model is None :
165- embedder_model = current_node .embedder_model
166- if hasattr (embedder_model , "model_name" ):
167- embedder_model = embedder_model .model_name
168- elif hasattr (embedder_model , "model" ):
169- embedder_model = embedder_model .model
170-
171- if hasattr (current_node , "node_config" ):
172- if isinstance (current_node .node_config ,dict ):
173- if current_node .node_config .get ("schema" , None ) and schema is None :
174- if not isinstance (current_node .node_config ["schema" ], dict ):
175- try :
176- schema = current_node .node_config ["schema" ].schema ()
177- except Exception as e :
178- schema = None
179-
180- with self .callback_manager .exclusive_get_callback (llm_model , llm_model_name ) as cb :
181- try :
182- result = current_node .execute (state )
183- except Exception as e :
184- error_node = current_node .node_name
185- graph_execution_time = time .time () - start_time
186- log_graph_execution (
187- graph_name = self .graph_name ,
188- source = source ,
189- prompt = prompt ,
190- schema = schema ,
191- llm_model = llm_model_name ,
192- embedder_model = embedder_model ,
193- source_type = source_type ,
194- execution_time = graph_execution_time ,
195- error_node = error_node ,
196- exception = str (e )
197- )
198- raise e
199- node_exec_time = time .time () - curr_time
232+ current_node = self ._get_node_by_name (current_node_name )
233+
234+ # Update source information if needed
235+ if source_type is None :
236+ source_type , source , prompt = self ._update_source_info (current_node , state )
237+
238+ # Get model information if needed
239+ if llm_model is None :
240+ llm_model , llm_model_name , embedder_model = self ._get_model_info (current_node )
241+
242+ # Get schema if needed
243+ if schema is None :
244+ schema = self ._get_schema (current_node )
245+
246+ try :
247+ result , node_exec_time , cb_data = self ._execute_node (
248+ current_node , state , llm_model , llm_model_name
249+ )
200250 total_exec_time += node_exec_time
201251
202- if cb is not None :
203- cb_data = {
204- "node_name" : current_node .node_name ,
205- "total_tokens" : cb .total_tokens ,
206- "prompt_tokens" : cb .prompt_tokens ,
207- "completion_tokens" : cb .completion_tokens ,
208- "successful_requests" : cb .successful_requests ,
209- "total_cost_USD" : cb .total_cost ,
210- "exec_time" : node_exec_time ,
211- }
212-
252+ if cb_data :
213253 exec_info .append (cb_data )
214-
215- cb_total ["total_tokens" ] += cb_data ["total_tokens" ]
216- cb_total ["prompt_tokens" ] += cb_data ["prompt_tokens" ]
217- cb_total ["completion_tokens" ] += cb_data ["completion_tokens" ]
218- cb_total ["successful_requests" ] += cb_data ["successful_requests" ]
219- cb_total ["total_cost_USD" ] += cb_data ["total_cost_USD" ]
220-
221- if current_node .node_type == "conditional_node" :
222- node_names = {node .node_name for node in self .nodes }
223- if result in node_names :
224- current_node_name = result
225- elif result is None :
226- current_node_name = None
227- else :
228- raise ValueError (f"Conditional Node returned a node name '{ result } ' that does not exist in the graph" )
229-
230- elif current_node_name in self .edges :
231- current_node_name = self .edges [current_node_name ]
232- else :
233- current_node_name = None
234-
254+ for key in cb_total :
255+ cb_total [key ] += cb_data [key ]
256+
257+ current_node_name = self ._get_next_node (current_node , result )
258+
259+ except Exception as e :
260+ error_node = current_node .node_name
261+ graph_execution_time = time .time () - start_time
262+ log_graph_execution (
263+ graph_name = self .graph_name ,
264+ source = source ,
265+ prompt = prompt ,
266+ schema = schema ,
267+ llm_model = llm_model_name ,
268+ embedder_model = embedder_model ,
269+ source_type = source_type ,
270+ execution_time = graph_execution_time ,
271+ error_node = error_node ,
272+ exception = str (e )
273+ )
274+ raise e
275+
276+ # Add total results to execution info
235277 exec_info .append ({
236278 "node_name" : "TOTAL RESULT" ,
237279 "total_tokens" : cb_total ["total_tokens" ],
@@ -242,6 +284,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
242284 "exec_time" : total_exec_time ,
243285 })
244286
287+ # Log final execution results
245288 graph_execution_time = time .time () - start_time
246289 response = state .get ("answer" , None ) if source_type == "url" else None
247290 content = state .get ("parsed_doc" , None ) if response is not None else None
@@ -300,3 +343,4 @@ def append_node(self, node):
300343 self .raw_edges .append ((last_node , node ))
301344 self .nodes .append (node )
302345 self .edges = self ._create_edges ({e for e in self .raw_edges })
346+
0 commit comments