Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion agrifoodpy/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def add_node(self, node, params={}, name=None):
self.nodes.append(node)
self.params.append(params)

def run(self, from_node=0, to_node=None, timing=False):
def run(self, from_node=0, to_node=None, skip=None, timing=False):
"""Runs the pipeline

Parameters
Expand All @@ -90,6 +90,9 @@ def run(self, from_node=0, to_node=None, timing=False):
to_node : int, optional
The index of the last node to be executed. If not provided, all
nodes will be executed

skip : list of int, optional
List of node indices to skip during execution. Defaults to None.

timing : bool, optional
If True, the execution time of each node will be printed. Defaults
Expand All @@ -103,6 +106,12 @@ def run(self, from_node=0, to_node=None, timing=False):

# Execute the node functions within the specified range
for i in range(from_node, to_node):

if skip is not None and i in skip:
if timing:
print(f"Node {i + 1}: {self.names[i]}, skipped.")
continue

node = self.nodes[i]
params = self.params[i]

Expand Down
23 changes: 23 additions & 0 deletions agrifoodpy/pipeline/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,29 @@ def node2(datablock, param2):
pipeline.run(from_node=1)
assert(pipeline.datablock['result1'] == 20)

def test_run_with_skip():
pipeline = Pipeline()
def node1(datablock, param1):
datablock['result1'] = param1
return datablock

def node2(datablock, param2):
datablock['result2'] = param2
return datablock

def node3(datablock, param3):
datablock['result3'] = param3
return datablock

pipeline.add_node(node1, params={'param1': 10})
pipeline.add_node(node2, params={'param2': 20})
pipeline.add_node(node3, params={'param3': 30})

pipeline.run(skip=[1])
assert(pipeline.datablock['result1'] == 10)
assert('result2' not in pipeline.datablock)
assert(pipeline.datablock['result3'] == 30)

def test_standalone_decorator():
pipeline = Pipeline()
@standalone([], ['output1'])
Expand Down