Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions agrifoodpy/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def datablock_write(self, path, value):
current = current.setdefault(key, {})
current[path[-1]] = value

def add_node(self, node, params={}, name=None):
def add_node(self, node, params={}, name=None, index=None):
"""Adds a node to the pipeline, including its function and execution
parameters.

Expand All @@ -67,6 +67,9 @@ def add_node(self, node, params={}, name=None):
name : str, optional
The name of the node. If not provided, a generic name will be
assigned.
index : int, optional
Index of the enw node. If None, the new node is appended at the
end of the node list.
"""

# Copy the parameters to avoid modifying the original dictionaries
Expand All @@ -75,9 +78,46 @@ def add_node(self, node, params={}, name=None):
if name is None:
name = "Node {}".format(len(self.nodes) + 1)

self.names.append(name)
self.nodes.append(node)
self.params.append(params)
if index is None:
index = len(self.nodes)

self.names.insert(index, name)
self.nodes.insert(index, node)
self.params.insert(index, params)

def remove_node(self, node):
"""Remove a node from the pipeline by index or name.

Parameters
----------
node : int or str
Index of the node to remove, or its name.
"""
# Resolve index
if isinstance(node, int):
index = node
if index < 0 or index >= len(self.nodes):
raise IndexError(f"Node index {index} out of range.")

elif isinstance(node, str):
matches = [i for i, name in enumerate(self.names) if name == node]
if not matches:
raise ValueError(f"No node found with name '{node}'.")
if len(matches) > 1:
raise ValueError(
f"Multiple nodes found with name '{node}'. "
"Please remove by index instead."
)
index = matches[0]

else:
raise TypeError("node must be an int (index) or str (name).")

# Remove from all internal lists
del self.nodes[index]
del self.params[index]
del self.names[index]


def run(self, from_node=0, to_node=None, skip=None, timing=False):
"""Runs the pipeline
Expand Down Expand Up @@ -133,6 +173,27 @@ def run(self, from_node=0, to_node=None, skip=None, timing=False):
if timing:
print(f"Pipeline executed in {pipeline_time:.4f} seconds.")

def print_nodes(self, show_params=True):
"""Prints the list of nodes associated with a Pipeline instance.

Parameters
----------
show_params : bool, optional
If True, displays the parameters associated with each node.
"""


if not self.nodes:
print("Pipeline is empty.")
return

print("Pipeline nodes:")
for i, (name, node, params) in enumerate(zip(self.names, self.nodes, self.params)):
node_name = getattr(node, "__name__", repr(node))
print(f"[{i}] {name}: {node_name}")
if show_params and params:
for k, v in params.items():
print(f" {k} = {v}")

def standalone(input_keys, return_keys):
""" Decorator to make a pipeline node available as a standalone function
Expand Down
26 changes: 25 additions & 1 deletion agrifoodpy/pipeline/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,36 @@ def test_add_node():
def dummy_node(datablock, param1):
datablock['result'] = param1
return datablock


# Test simple node addition
pipeline.add_node(dummy_node, params={'param1': 10}, name='Test Node')
assert(len(pipeline.nodes) == 1)
assert(pipeline.names[0] == 'Test Node')
assert(pipeline.params[0] == {'param1': 10})

# Test adding node at index
pipeline.add_node(dummy_node, params={'param1': 20}, name='Test Node 2',
index=0)
assert(len(pipeline.nodes) == 2)
assert(pipeline.names[0] == 'Test Node 2')
assert(pipeline.params[0] == {'param1': 20})

# Test removing a node by index
pipeline.add_node(dummy_node, params={'param1': 30}, name='Test Node 3')
assert(len(pipeline.nodes) == 3)
assert(pipeline.names[2] == 'Test Node 3')
assert(pipeline.params[2] == {'param1': 30})

pipeline.remove_node(2)
assert(len(pipeline.nodes) == 2)
assert(pipeline.names[-1] == 'Test Node')
assert(pipeline.params[-1] == {'param1': 10})

# Test removing a node by name
pipeline.remove_node("Test Node 2")
assert(pipeline.names[-1] == 'Test Node')
assert(pipeline.params[-1] == {'param1': 10})

def test_run_pipeline():
pipeline = Pipeline()
def node1(datablock, param1):
Expand Down