diff --git a/agrifoodpy/pipeline/pipeline.py b/agrifoodpy/pipeline/pipeline.py index ecc3f60..b867884 100644 --- a/agrifoodpy/pipeline/pipeline.py +++ b/agrifoodpy/pipeline/pipeline.py @@ -79,7 +79,7 @@ def add_node(self, node, params={}, name=None): self.nodes.append(node) self.params.append(params) - def run(self, from_node=0, to_node=None, timing=False): + def run(self, from_node=0, to_node=None, skip=None, timing=False): """Runs the pipeline Parameters @@ -90,6 +90,9 @@ def run(self, from_node=0, to_node=None, timing=False): to_node : int, optional The index of the last node to be executed. If not provided, all nodes will be executed + + skip : list of int, optional + List of node indices to skip during execution. Defaults to None. timing : bool, optional If True, the execution time of each node will be printed. Defaults @@ -103,6 +106,12 @@ def run(self, from_node=0, to_node=None, timing=False): # Execute the node functions within the specified range for i in range(from_node, to_node): + + if skip is not None and i in skip: + if timing: + print(f"Node {i + 1}: {self.names[i]}, skipped.") + continue + node = self.nodes[i] params = self.params[i] diff --git a/agrifoodpy/pipeline/tests/test_pipeline.py b/agrifoodpy/pipeline/tests/test_pipeline.py index 1562c91..7a78f04 100644 --- a/agrifoodpy/pipeline/tests/test_pipeline.py +++ b/agrifoodpy/pipeline/tests/test_pipeline.py @@ -68,6 +68,29 @@ def node2(datablock, param2): pipeline.run(from_node=1) assert(pipeline.datablock['result1'] == 20) +def test_run_with_skip(): + pipeline = Pipeline() + def node1(datablock, param1): + datablock['result1'] = param1 + return datablock + + def node2(datablock, param2): + datablock['result2'] = param2 + return datablock + + def node3(datablock, param3): + datablock['result3'] = param3 + return datablock + + pipeline.add_node(node1, params={'param1': 10}) + pipeline.add_node(node2, params={'param2': 20}) + pipeline.add_node(node3, params={'param3': 30}) + + pipeline.run(skip=[1]) + assert(pipeline.datablock['result1'] == 10) + assert('result2' not in pipeline.datablock) + assert(pipeline.datablock['result3'] == 30) + def test_standalone_decorator(): pipeline = Pipeline() @standalone([], ['output1'])