Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions invokeai/app/invocations/logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Optional

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext


@invocation_output("if_output")
class IfInvocationOutput(BaseInvocationOutput):
value: Optional[Any] = OutputField(
default=None, description="The selected value", title="Output", ui_type=UIType.Any
)


@invocation("if", title="If", tags=["logic", "conditional"], category="logic", version="1.0.0")
class IfInvocation(BaseInvocation):
"""Selects between two optional inputs based on a boolean condition."""

condition: bool = InputField(default=False, description="The condition used to select an input", title="Condition")
true_input: Optional[Any] = InputField(
default=None,
description="Selected when the condition is true",
title="True Input",
ui_type=UIType.Any,
)
false_input: Optional[Any] = InputField(
default=None,
description="Selected when the condition is false",
title="False Input",
ui_type=UIType.Any,
)

def invoke(self, context: InvocationContext) -> IfInvocationOutput:
return IfInvocationOutput(value=self.true_input if self.condition else self.false_input)
Original file line number Diff line number Diff line change
@@ -1,10 +1,144 @@
import { deepClone } from 'common/util/deepClone';
import { set } from 'es-toolkit/compat';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { describe, expect, it } from 'vitest';

import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
import { validateConnection } from './validateConnection';

const ifTemplate: InvocationTemplate = {
title: 'If',
type: 'if',
version: '1.0.0',
tags: [],
description: 'Selects between two inputs based on a boolean condition',
outputType: 'if_output',
inputs: {
condition: {
name: 'condition',
title: 'Condition',
required: true,
description: 'The condition used to select an input',
fieldKind: 'input',
input: 'connection',
ui_hidden: false,
ui_type: 'BooleanField',
type: {
name: 'BooleanField',
cardinality: 'SINGLE',
batch: false,
},
default: false,
},
true_input: {
name: 'true_input',
title: 'True Input',
required: false,
description: 'Selected when condition is true',
fieldKind: 'input',
input: 'connection',
ui_hidden: false,
ui_type: 'AnyField',
type: {
name: 'AnyField',
cardinality: 'SINGLE',
batch: false,
},
default: undefined,
},
false_input: {
name: 'false_input',
title: 'False Input',
required: false,
description: 'Selected when condition is false',
fieldKind: 'input',
input: 'connection',
ui_hidden: false,
ui_type: 'AnyField',
type: {
name: 'AnyField',
cardinality: 'SINGLE',
batch: false,
},
default: undefined,
},
},
outputs: {
value: {
fieldKind: 'output',
name: 'value',
title: 'Output',
description: 'The selected value',
type: {
name: 'AnyField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
ui_type: 'AnyField',
},
},
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
};

const floatOutputTemplate: InvocationTemplate = {
title: 'Float Output',
type: 'float_output',
version: '1.0.0',
tags: [],
description: 'Outputs a float',
outputType: 'float_output',
inputs: {},
outputs: {
value: {
fieldKind: 'output',
name: 'value',
title: 'Value',
description: 'Float value',
type: {
name: 'FloatField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
ui_type: 'FloatField',
},
},
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
};

const integerCollectionOutputTemplate: InvocationTemplate = {
title: 'Integer Collection Output',
type: 'integer_collection_output',
version: '1.0.0',
tags: [],
description: 'Outputs an integer collection',
outputType: 'integer_collection_output',
inputs: {},
outputs: {
value: {
fieldKind: 'output',
name: 'value',
title: 'Value',
description: 'Integer collection value',
type: {
name: 'IntegerField',
cardinality: 'COLLECTION',
batch: false,
},
ui_hidden: false,
ui_type: 'IntegerField',
},
},
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
};

describe(validateConnection.name, () => {
it('should reject invalid connection to self', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
Expand Down Expand Up @@ -155,6 +289,118 @@ describe(validateConnection.name, () => {
expect(r).toEqual('nodes.fieldTypesMustMatch');
});

it('should reject mismatched types between if node branch inputs', () => {
const n1 = buildNode(add);
const n2 = buildNode(img_resize);
const n3 = buildNode(ifTemplate);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'false_input' };
const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null);
expect(r).toEqual('nodes.fieldTypesMustMatch');
});

it('should reject mismatched types between if node branch inputs regardless of branch order', () => {
const n1 = buildNode(add);
const n2 = buildNode(img_resize);
const n3 = buildNode(ifTemplate);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n3.id, 'false_input');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'true_input' };
const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null);
expect(r).toEqual('nodes.fieldTypesMustMatch');
});

it('should accept convertible types between if node branch inputs', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const n3 = buildNode(ifTemplate);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'false_input' };
const r = validateConnection(c, nodes, edges, { ...templates, if: ifTemplate }, null);
expect(r).toEqual(null);
});

it('should accept one-way-convertible types between if node branch inputs in either connection order', () => {
const n1 = buildNode(add);
const n2 = buildNode(floatOutputTemplate);
const n3 = buildNode(ifTemplate);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n3.id, 'false_input');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'true_input' };
const r = validateConnection(
c,
nodes,
edges,
{ ...templates, if: ifTemplate, float_output: floatOutputTemplate },
null
);
expect(r).toEqual(null);
});

it('should accept SINGLE and COLLECTION of the same type between if node branch inputs', () => {
const n1 = buildNode(add);
const n2 = buildNode(integerCollectionOutputTemplate);
const n3 = buildNode(ifTemplate);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n3.id, targetHandle: 'false_input' };
const r = validateConnection(
c,
nodes,
edges,
{ ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate },
null
);
expect(r).toEqual(null);
});

it('should accept if output to collection input when both if branch inputs are collections of matching type', () => {
const n1 = buildNode(integerCollectionOutputTemplate);
const n2 = buildNode(integerCollectionOutputTemplate);
const n3 = buildNode(ifTemplate);
const n4 = buildNode(templates.iterate!);
const nodes = [n1, n2, n3, n4];
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
const e2 = buildEdge(n2.id, 'value', n3.id, 'false_input');
const edges = [e1, e2];
const c = { source: n3.id, sourceHandle: 'value', target: n4.id, targetHandle: 'collection' };
const r = validateConnection(
c,
nodes,
edges,
{ ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate },
null
);
expect(r).toEqual(null);
});

it('should reject if output to collection input when if branch inputs are not both collection-compatible', () => {
const n1 = buildNode(add);
const n2 = buildNode(integerCollectionOutputTemplate);
const n3 = buildNode(ifTemplate);
const n4 = buildNode(templates.iterate!);
const nodes = [n1, n2, n3, n4];
const e1 = buildEdge(n1.id, 'value', n3.id, 'true_input');
const e2 = buildEdge(n2.id, 'value', n3.id, 'false_input');
const edges = [e1, e2];
const c = { source: n3.id, sourceHandle: 'value', target: n4.id, targetHandle: 'collection' };
const r = validateConnection(
c,
nodes,
edges,
{ ...templates, if: ifTemplate, integer_collection_output: integerCollectionOutputTemplate },
null
);
expect(r).toEqual('nodes.fieldTypesMustMatch');
});

it('should reject connections that would create cycles', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
Expand Down
Loading