Skip to content

Commit b91513d

Browse files
committed
add optional stack remember on batch to AUdioMathNode
1 parent 1393a83 commit b91513d

1 file changed

Lines changed: 13 additions & 3 deletions

File tree

more_math/AudioMathNode.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ def define_schema(cls) -> io.Schema:
4848
tooltip="How to handle mismatched image batch sizes. tile: repeat shorter inputs; error: raise error on mismatch; pad: treat missing frames as zero."
4949
),
5050
io.Int.Input(id="batching", default=0),
51+
io.Bool.Input(
52+
id="remember_stack",
53+
default=False,
54+
display_name="Remember stack across batch",
55+
tooltip=(
56+
"If enabled, stack is copied at output leading to changes being remembered during batch operations (node runs multiple times in sucession). If disabled each batch gets it's own copy of the stack."
57+
),
58+
),
5159
MrmthStack.Input(id="stack", tooltip="Access stack between nodes",optional=True)
5260
],
5361
outputs=[
@@ -57,17 +65,17 @@ def define_schema(cls) -> io.Schema:
5765
)
5866

5967
@classmethod
60-
def check_lazy_status(cls, Expression, V, F, length_mismatch="tile",batching=0,stack={}):
68+
def check_lazy_status(cls, Expression, V, F, length_mismatch="tile",batching=0,, remember_stack=False,stack={}):
6169
return checkLazyNew(Expression,V,F)
6270

6371

6472
@classmethod
65-
def execute(cls, V, F, Expression, length_mismatch="tile",batching=0,stack={}):
73+
def execute(cls, V, F, Expression, length_mismatch="tile",batching=0,, remember_stack=False, stack={}):
6674
# Identify all present audio inputs and their keys
6775
tensor_keys = [k for k, v in V.items() if v is not None and isinstance(v, dict) and "waveform" in v]
6876
if not tensor_keys:
6977
raise ValueError("At least one audio input is required.")
70-
stack = copy.deepcopy(stack) if stack is not None else {}
78+
stack = stack if remember_stack else (copy.deepcopy(stack) if stack is not None else {})
7179
waveforms = {k: V[k]["waveform"] for k in tensor_keys}
7280
sample_rates = {k + "sr": V[k].get("sample_rate", 44100) for k in tensor_keys}
7381

@@ -142,4 +150,6 @@ def execute(cls, V, F, Expression, length_mismatch="tile",batching=0,stack={}):
142150
res_list.append({"waveform": result_chunk, "sample_rate": sample_rate})
143151
return (res_list, stack)
144152
else:
153+
stack = stack if remember_stack else copy.deepcopy(stack)
154+
145155
return ([{"waveform": result, "sample_rate": sample_rate}], stack)

0 commit comments

Comments
 (0)