@@ -70,14 +70,14 @@ def __init__(self, nodes: List[Node], edges: List[Edge], **kwargs):
7070 self .node_init_variables : Dict [str , Any ] = {}
7171
7272
73- async def dynamic_fan_in_pipes_task (self , node : Node ):
73+ def dynamic_fan_in_pipes_task (self , node : Node ):
7474 """
7575 Fan in pipes for a node.
7676 """
7777 precious_nodes = get_previous_nodes (self .nx_graph , node .processor_unique_name )
7878
7979 if len (precious_nodes ) == 0 :
80- return
80+ return asyncio . sleep ( 0 )
8181
8282 previous_input_pipes = [
8383 self .processor_pipes [previous_node .processor_unique_name ].output_pipe
@@ -87,12 +87,12 @@ async def dynamic_fan_in_pipes_task(self, node: Node):
8787 self ._fan_in_pipes (
8888 previous_input_pipes ,
8989 self .processor_pipes [node .processor_unique_name ].input_pipe ),
90- name = f"fan_in_pipes_task_ { node . processor_unique_name } _to_ { '=' .join ([previous_node .processor_unique_name for previous_node in precious_nodes ])} "
90+ name = f"fan_in_pipes_task__from_ { '=' .join ([previous_node .processor_unique_name for previous_node in precious_nodes ])} to { node . processor_unique_name } "
9191 )
9292
93- await task
93+ return task
9494
95- async def dynamic_fan_out_pipes_task (self , node : Node ):
95+ def dynamic_fan_out_pipes_task (self , node : Node ):
9696 """
9797 Fan out pipes for a node.
9898 """
@@ -112,20 +112,11 @@ async def dynamic_fan_out_pipes_task(self, node: Node):
112112 ),
113113 name = f"fan_out_pipes_task_{ node .processor_unique_name } _to_{ '=' .join ([next_node .processor_unique_name for next_node in next_nodes ])} "
114114 )
115- await task
115+ return task
116116 else :
117- await asyncio .sleep (0 )
117+ return asyncio .sleep (0 )
118118
119119
120-
121- async def add_processor_into_graph (self , node : Node ):
122- self .initialize_node (node )
123- # merge input pipes
124- self .background_tasks .append (self .dynamic_fan_in_pipes_task (node ))
125- # merge output pipes
126- self .background_tasks .append (self .dynamic_fan_out_pipes_task (node ))
127-
128-
129120 def initialize (self ):
130121
131122 for node in self .nodes :
@@ -149,24 +140,33 @@ def initialize(self):
149140
150141
151142
152- for node in self .nodes :
153- # merge input pipes
154- self .background_tasks .append (self .dynamic_fan_in_pipes_task (node ))
155-
156- # merge output pipes
157- self .background_tasks .append (self .dynamic_fan_out_pipes_task (node ))
158143
159144
160145 root_node = get_root_nodes (self .nx_graph )[0 ]
161146
162147 # Create graph's own input pipe (separate from root node's input pipe)
148+
149+ async def root_task ():
150+ await self ._fan_in_pipes (
151+ [self .input_pipe ],
152+ self .processor_pipes [root_node .processor_unique_name ].input_pipe
153+ )
154+ self .logger .warning ("Graph pipe into root input pipe task completed" )
155+
163156 graph_pipe_into_root_input_pipe_task = asyncio .create_task (
164- self ._fan_out_pipes (
165- self .input_pipe ,
166- [self .processor_pipes [root_node .processor_unique_name ].input_pipe ]
167- ),
157+ root_task (),
168158 name = "graph_pipe_into_root_input_pipe_task"
169159 )
160+ self .background_tasks .append (graph_pipe_into_root_input_pipe_task )
161+
162+
163+ for node in self .nodes :
164+ # merge input pipes
165+ self .background_tasks .append (self .dynamic_fan_in_pipes_task (node ))
166+
167+ # # merge output pipes
168+ # self.background_tasks.append(self.dynamic_fan_out_pipes_task(node))
169+
170170
171171
172172 # create the task for all leaf nodes
@@ -180,8 +180,7 @@ def initialize(self):
180180 name = "leaf_nodes_output_pipe_task"
181181 )
182182
183- self .background_tasks .insert (0 , graph_pipe_into_root_input_pipe_task )
184- self .background_tasks .insert (0 , leaf_nodes_output_pipe_task )
183+ self .background_tasks .append (leaf_nodes_output_pipe_task )
185184 self .logger .info (f"Graph initialized: { self .processor_id } , { self .background_tasks } " )
186185
187186
@@ -217,14 +216,13 @@ async def _fan_in_pipes(self, input_pipes: List[PipeInterface], output_pipe: Pip
217216 """Merge multiple input pipes into a single output pipe using asyncio.as_completed"""
218217 async def read_pipe_task (pipe ):
219218 """Read all data from a single pipe"""
220- # import ipdb; ipdb.set_trace()
221219 async for message_id , data in pipe :
222220 self .logger .info (f"Fan in pipe: { data } : from { pipe ._pipe_id } to { output_pipe ._pipe_id } " )
223221 if data is None :
224- await output_pipe .put (data )
225222 break
226223 await output_pipe .put (data )
227224
225+
228226 try :
229227 # Create tasks for reading from each input pipe
230228 tasks = [asyncio .create_task (read_pipe_task (pipe ), name = f"read_pipe_task_{ pipe ._pipe_id } " ) for pipe in input_pipes ]
@@ -258,7 +256,9 @@ async def _fan_out_pipes(self, source_pipe: PipeInterface, output_pipes: List[Pi
258256 # import ipdb; ipdb.set_trace()
259257 for pipe in output_pipes :
260258 try :
259+ self .logger .info (f"Fan out pipe: None: from { source_pipe ._pipe_id } to { pipe ._pipe_id } " )
261260 await pipe .put (None )
261+
262262 except Exception as e :
263263 self .logger .error (f"Error signaling end-of-stream to pipe { getattr (pipe , '_pipe_id' , 'unknown' )} : { e } " )
264264
@@ -318,16 +318,21 @@ async def execute(self, data: Any, session_id: Optional[str] = None, *args, **kw
318318 tasks .append (task )
319319
320320 # Wait for all processors to complete
321- await asyncio .gather (* self .background_tasks )
322321 await asyncio .gather (* tasks )
322+
323+ # Close input pipe to signal end-of-stream to background tasks
324+ await self .input_pipe .close ()
325+
326+ self .logger .warning (f"Graph background tasks: { self .background_tasks } " )
327+ await asyncio .gather (* self .background_tasks )
323328
324329
325330 except Exception as e :
326331 self .logger .error (f"Error executing graph: { e } " )
327332 self .logger .error (traceback .format_exc ())
328333 caught_exception = e
329334 finally :
330- await self . input_pipe . close ()
335+ # Input pipe is already closed above, only close output pipe here
331336 await self .output_pipe .close ()
332337 if caught_exception :
333338 raise caught_exception
@@ -337,4 +342,4 @@ async def process(self, data: Any, *args, **kwargs) -> Any:
337342 """
338343 Process data.
339344 """
340- pass
345+ pass
0 commit comments