diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index fd31448ea4f..3ea67a52912 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -123,6 +123,23 @@ def is_any(t: Any) -> bool: return t == Any or Any in get_args(t) +def extract_collection_item_types(t: Any) -> set[Any]: + """Extracts list item types from a collection annotation, including unions containing list branches.""" + if is_any(t): + return {Any} + + if get_origin(t) is list: + return {arg for arg in get_args(t) if arg != NoneType} + + item_types: set[Any] = set() + for arg in get_args(t): + if is_any(arg): + item_types.add(Any) + elif get_origin(arg) is list: + item_types.update(item_arg for item_arg in get_args(arg) if item_arg != NoneType) + return item_types + + def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if not from_type or not to_type: return False @@ -280,7 +297,7 @@ class CollectInvocationOutput(BaseInvocationOutput): ) -@invocation("collect", version="1.0.0") +@invocation("collect", version="1.1.0") class CollectInvocation(BaseInvocation): """Collects values into a collection""" @@ -292,7 +309,10 @@ class CollectInvocation(BaseInvocation): input=Input.Connection, ) collection: list[Any] = InputField( - description="The collection, will be provided on execution", default=[], ui_hidden=True + description="An optional collection to append to", + default=[], + ui_type=UIType._Collection, + input=Input.Connection, ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: @@ -520,7 +540,9 @@ def _validate_edge(self, edge: Edge): # Validate that an edge to this node+field doesn't already exist input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) - if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): + if len(input_edges) > 0 and ( + not isinstance(to_node, CollectInvocation) or edge.destination.field != ITEM_FIELD + ): raise InvalidEdgeError(f"Edge already exists ({edge})") # Validate that no cycles would be created @@ -546,8 +568,10 @@ def _validate_edge(self, edge: Edge): raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}") # Validate if collector input type matches output type (if this edge results in both being set) - if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD: - err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source) + if isinstance(to_node, CollectInvocation) and edge.destination.field in (ITEM_FIELD, COLLECTION_FIELD): + err = self._is_collector_connection_valid( + edge.destination.node_id, new_input=edge.source, new_input_field=edge.destination.field + ) if err is not None: raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}") @@ -676,76 +700,152 @@ def _is_iterator_connection_valid( # Collector input type must match all iterator output types if isinstance(input_node, CollectInvocation): - collector_inputs = self._get_input_edges(input_node.id, ITEM_FIELD) - if len(collector_inputs) == 0: - return "Iterator input collector must have at least one item input edge" - - # Traverse the graph to find the first collector input edge. Collectors validate that their collection - # inputs are all of the same type, so we can use the first input edge to determine the collector's type - first_collector_input_edge = collector_inputs[0] - first_collector_input_type = get_output_field_type( - self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field - ) - resolved_collector_type = ( - first_collector_input_type - if get_origin(first_collector_input_type) is None - else get_args(first_collector_input_type) - ) - if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)): + input_root_type = self._get_collector_input_root_type(input_node.id) + if input_root_type is None: + return "Iterator input collector must have at least one item or collection input edge" + if not all((are_connection_types_compatible(input_root_type, t) for t in output_field_types)): return "Iterator collection type must match all iterator output types" return None + def _resolve_collector_input_types(self, node_id: str, visited: Optional[set[str]] = None) -> set[Any]: + """Resolves possible item types for a collector's inputs, recursively following chained collectors.""" + visited = visited or set() + if node_id in visited: + return set() + visited.add(node_id) + + input_types: set[Any] = set() + + for edge in self._get_input_edges(node_id, ITEM_FIELD): + input_field_type = get_output_field_type(self.get_node(edge.source.node_id), edge.source.field) + resolved_types = [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type) + input_types.update(t for t in resolved_types if t != NoneType) + + for edge in self._get_input_edges(node_id, COLLECTION_FIELD): + source_node = self.get_node(edge.source.node_id) + if isinstance(source_node, CollectInvocation) and edge.source.field == COLLECTION_FIELD: + input_types.update(self._resolve_collector_input_types(source_node.id, visited.copy())) + continue + + input_field_type = get_output_field_type(source_node, edge.source.field) + input_types.update(extract_collection_item_types(input_field_type)) + + return input_types + + def _get_collector_input_root_type(self, node_id: str) -> Any | None: + input_types = self._resolve_collector_input_types(node_id) + non_any_input_types = {t for t in input_types if t != Any} + if len(non_any_input_types) == 0 and Any in input_types: + return Any + if len(non_any_input_types) == 0: + return None + + type_tree = nx.DiGraph() + type_tree.add_nodes_from(non_any_input_types) + type_tree.add_edges_from([e for e in itertools.permutations(non_any_input_types, 2) if issubclass(e[1], e[0])]) + type_degrees = type_tree.in_degree(type_tree.nodes) + root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore + if len(root_types) != 1: + return Any + return root_types[0] + def _is_collector_connection_valid( self, node_id: str, new_input: Optional[EdgeConnection] = None, + new_input_field: Optional[str] = None, new_output: Optional[EdgeConnection] = None, ) -> str | None: - inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)] + item_inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)] + collection_inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)] outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)] if new_input is not None: - inputs.append(new_input) + field = new_input_field or ITEM_FIELD + if field == ITEM_FIELD: + item_inputs.append(new_input) + elif field == COLLECTION_FIELD: + collection_inputs.append(new_input) if new_output is not None: outputs.append(new_output) - # Get input and output fields (the fields linked to the iterator's input/output) - input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs] + if len(item_inputs) == 0 and len(collection_inputs) == 0: + return "Collector must have at least one item or collection input edge" + + # Get input and output fields (the fields linked to the collector's input/output) + item_input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in item_inputs] + collection_input_field_types = [ + get_output_field_type(self.get_node(e.node_id), e.field) for e in collection_inputs + ] output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs] + if not all((is_list_or_contains_list(t) or is_any(t) for t in collection_input_field_types)): + return "Collector collection input must be a collection" + # Validate that all inputs are derived from or match a single type input_field_types = { resolved_type - for input_field_type in input_field_types + for input_field_type in item_input_field_types for resolved_type in ( [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type) ) if resolved_type != NoneType } # Get unique types + + for input_conn, input_field_type in zip(collection_inputs, collection_input_field_types, strict=False): + source_node = self.get_node(input_conn.node_id) + if isinstance(source_node, CollectInvocation) and input_conn.field == COLLECTION_FIELD: + input_field_types.update(self._resolve_collector_input_types(source_node.id)) + continue + input_field_types.update(extract_collection_item_types(input_field_type)) + + non_any_input_field_types = {t for t in input_field_types if t != Any} type_tree = nx.DiGraph() - type_tree.add_nodes_from(input_field_types) - type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) + type_tree.add_nodes_from(non_any_input_field_types) + type_tree.add_edges_from( + [e for e in itertools.permutations(non_any_input_field_types, 2) if issubclass(e[1], e[0])] + ) type_degrees = type_tree.in_degree(type_tree.nodes) - if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore + root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore + if len(root_types) > 1: return "Collector input collection items must be of a single type" - # Get the input root type - input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore + # Get the input root type (if known) + input_root_type = root_types[0] if len(root_types) == 1 else None # Verify that all outputs are lists if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types): return "Collector output must connect to a collection input" # Verify that all outputs match the input type (are a base class or the same class) - if not all( - is_any(t) - or is_union_subtype(input_root_type, get_args(t)[0]) - or issubclass(input_root_type, get_args(t)[0]) - for t in output_field_types - ): + if input_root_type is not None: + if not all( + is_any(t) + or is_union_subtype(input_root_type, get_args(t)[0]) + or issubclass(input_root_type, get_args(t)[0]) + for t in output_field_types + ): + return "Collector outputs must connect to a collection input with a matching type" + elif any(not is_any(t) and get_args(t)[0] != Any for t in output_field_types): return "Collector outputs must connect to a collection input with a matching type" + # If this collector outputs to another collector's collection input, validate against the downstream + # collector's resolved input type (if available). + for output in outputs: + output_node = self.get_node(output.node_id) + if not isinstance(output_node, CollectInvocation) or output.field != COLLECTION_FIELD: + continue + output_root_type = self._get_collector_input_root_type(output_node.id) + if output_root_type is None: + continue + if input_root_type is None: + if output_root_type != Any: + return "Collector outputs must connect to a collection input with a matching type" + continue + if not are_connection_types_compatible(input_root_type, output_root_type): + return "Collector outputs must connect to a collection input with a matching type" + return None def nx_graph(self) -> nx.DiGraph: @@ -1211,8 +1311,19 @@ def _prepare_inputs(self, node: BaseInvocation): if isinstance(node, CollectInvocation): item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD] item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id)) - - output_collection = [copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges] + collection_edges = [e for e in input_edges if e.destination.field == COLLECTION_FIELD] + collection_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id)) + + output_collection = [] + for edge in collection_edges: + source_value = copydeep(getattr(self.results[edge.source.node_id], edge.source.field)) + if isinstance(source_value, list): + output_collection.extend(source_value) + else: + output_collection.append(source_value) + output_collection.extend( + copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges + ) node.collection = output_collection else: for edge in input_edges: diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts index 8adc013ab97..fb4d7ee48c2 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts @@ -41,4 +41,14 @@ describe(getCollectItemType.name, () => { const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id); expect(result).toBeNull(); }); + + it('should return the upstream collect item type for chained collects', () => { + const n1 = buildNode(collect); + const n2 = buildNode(collect); + const n3 = buildNode(add); + const e1 = buildEdge(n3.id, 'value', n1.id, 'item'); + const e2 = buildEdge(n1.id, 'collection', n2.id, 'collection'); + const result = getCollectItemType(templates, [n1, n2, n3], [e1, e2], n2.id); + expect(result).toEqual({ name: 'IntegerField', cardinality: 'SINGLE', batch: false }); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts index 9fb2795ae81..35ec20220ea 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts @@ -2,6 +2,16 @@ import type { Templates } from 'features/nodes/store/types'; import type { FieldType } from 'features/nodes/types/field'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +const toItemType = (fieldType: FieldType): FieldType | null => { + if (fieldType.name === 'CollectionField') { + return null; + } + if (fieldType.cardinality === 'COLLECTION' || fieldType.cardinality === 'SINGLE_OR_COLLECTION') { + return { ...fieldType, cardinality: 'SINGLE' }; + } + return fieldType; +}; + /** * Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and * field connected to the collector's `item` input. The field type of that field is returned, else null if there is no @@ -18,21 +28,56 @@ export const getCollectItemType = ( edges: AnyEdge[], nodeId: string ): FieldType | null => { - const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); - if (!firstEdgeToCollect?.sourceHandle) { - return null; - } - const node = nodes.find((n) => n.id === firstEdgeToCollect.source); - if (!node) { - return null; - } - const template = templates[node.data.type]; - if (!template) { - return null; - } - const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle]; - if (!fieldTemplate) { + const getCollectItemTypeInternal = (currentNodeId: string, visited: Set): FieldType | null => { + if (visited.has(currentNodeId)) { + return null; + } + visited.add(currentNodeId); + + const firstItemEdgeToCollect = edges.find((edge) => edge.target === currentNodeId && edge.targetHandle === 'item'); + if (firstItemEdgeToCollect?.sourceHandle) { + const node = nodes.find((n) => n.id === firstItemEdgeToCollect.source); + if (!node) { + return null; + } + const template = templates[node.data.type]; + if (!template) { + return null; + } + const fieldTemplate = template.outputs[firstItemEdgeToCollect.sourceHandle]; + if (!fieldTemplate) { + return null; + } + return toItemType(fieldTemplate.type); + } + + const firstCollectionEdgeToCollect = edges.find( + (edge) => edge.target === currentNodeId && edge.targetHandle === 'collection' + ); + if (!firstCollectionEdgeToCollect?.sourceHandle) { + return null; + } + const sourceNode = nodes.find((n) => n.id === firstCollectionEdgeToCollect.source); + if (!sourceNode) { + return null; + } + if (sourceNode.data.type === 'collect' && firstCollectionEdgeToCollect.sourceHandle === 'collection') { + return getCollectItemTypeInternal(sourceNode.id, visited); + } + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return null; + } + const sourceFieldTemplate = sourceTemplate.outputs[firstCollectionEdgeToCollect.sourceHandle]; + if (!sourceFieldTemplate) { + return null; + } + return toItemType(sourceFieldTemplate.type); + }; + + const itemType = getCollectItemTypeInternal(nodeId, new Set()); + if (!itemType) { return null; } - return fieldTemplate.type; + return itemType; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index 74425619844..1eb445beaf7 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -133,11 +133,27 @@ export const sub: InvocationTemplate = { export const collect: InvocationTemplate = { title: 'Collect', type: 'collect', - version: '1.0.0', + version: '1.1.0', tags: [], description: 'Collects values into a collection', outputType: 'collect_output', inputs: { + collection: { + name: 'collection', + title: 'Collection', + required: false, + default: undefined, + description: 'An optional collection to append to', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + ui_type: 'CollectionField' as const, + type: { + name: 'CollectionField' as const, + cardinality: 'COLLECTION', + batch: false, + }, + }, item: { name: 'item', title: 'Collection Item', @@ -1162,13 +1178,12 @@ export const schema = { items: {}, type: 'array', title: 'Collection', - description: 'The collection, will be provided on execution', - default: [], + description: 'An optional collection to append to', field_kind: 'input', - input: 'any', - orig_default: [], + input: 'connection', orig_required: false, - ui_hidden: true, + ui_hidden: false, + ui_type: 'CollectionField', }, type: { type: 'string', @@ -1185,7 +1200,7 @@ export const schema = { node_pack: 'invokeai', description: 'Collects values into a collection', classification: 'stable', - version: '1.0.0', + version: '1.1.0', output: { $ref: '#/components/schemas/CollectInvocationOutput', }, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 4108f57c075..947d8745f02 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -122,6 +122,52 @@ describe(validateConnection.name, () => { expect(r).toEqual(null); }); + it('should accept chaining collect collection output to collect collection input', () => { + const n1 = buildNode(collect); + const n2 = buildNode(collect); + const nodes = [n1, n2]; + const c = { source: n1.id, sourceHandle: 'collection', target: n2.id, targetHandle: 'collection' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(null); + }); + + it('should reject multiple connections to collect collection input', () => { + const n1 = buildNode(collect); + const n2 = buildNode(collect); + const n3 = buildNode(collect); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'collection', n2.id, 'collection'); + const c = { source: n3.id, sourceHandle: 'collection', target: n2.id, targetHandle: 'collection' }; + const r = validateConnection(c, nodes, [e1], templates, null); + expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection'); + }); + + it('should reject mismatched item connection when collect is typed via chained collection', () => { + const n1 = buildNode(add); + const n2 = buildNode(collect); + const n3 = buildNode(collect); + const n4 = buildNode(main_model_loader); + const nodes = [n1, n2, n3, n4]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const e2 = buildEdge(n2.id, 'collection', n3.id, 'collection'); + const c = { source: n4.id, sourceHandle: 'vae', target: n3.id, targetHandle: 'item' }; + const r = validateConnection(c, nodes, [e1, e2], templates, null); + expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes'); + }); + + it('should reject chaining collection-to-collection for differently typed collects', () => { + const n1 = buildNode(add); + const n2 = buildNode(img_resize); + const n3 = buildNode(collect); + const n4 = buildNode(collect); + const nodes = [n1, n2, n3, n4]; + const e1 = buildEdge(n1.id, 'value', n3.id, 'item'); + const e2 = buildEdge(n2.id, 'image', n4.id, 'item'); + const c = { source: n3.id, sourceHandle: 'collection', target: n4.id, targetHandle: 'collection' }; + const r = validateConnection(c, nodes, [e1, e2], templates, null); + expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes'); + }); + it('should reject connections to target field that is already connected', () => { const n1 = buildNode(add); const n2 = buildNode(add); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index aaeb10edfdc..9024a16f42e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -108,6 +108,24 @@ export const validateConnection: ValidateConnectionFunc = ( } } + if ( + sourceNode.data.type === 'collect' && + c.sourceHandle === 'collection' && + targetNode.data.type === 'collect' && + c.targetHandle === 'collection' + ) { + // Chained collect nodes should preserve a single item type when both ends are already typed. + const sourceCollectItemType = getCollectItemType(templates, nodes, edges, sourceNode.id); + const targetCollectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if ( + sourceCollectItemType && + targetCollectItemType && + !areTypesEqual(sourceCollectItemType, targetCollectItemType) + ) { + return 'nodes.cannotMixAndMatchCollectionItemTypes'; + } + } + if (filteredEdges.find(getTargetEqualityPredicate(c))) { // CollectionItemField inputs can have multiple input connections if (targetFieldTemplate.type.name !== 'CollectionItemField') { diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts index c9dad573a98..45dd0f79a34 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts @@ -3,17 +3,19 @@ import { schema, templates } from 'features/nodes/store/util/testUtils'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { describe, expect, it } from 'vitest'; +const stripUndefinedDeep = (value: T): T => JSON.parse(JSON.stringify(value)) as T; + describe('parseSchema', () => { it('should parse the schema', () => { const parsed = parseSchema(schema); - expect(parsed).toEqual(templates); + expect(stripUndefinedDeep(parsed)).toEqual(stripUndefinedDeep(templates)); }); it('should omit denied nodes', () => { const parsed = parseSchema(schema, undefined, ['add']); - expect(parsed).toEqual(omit(templates, 'add')); + expect(stripUndefinedDeep(parsed)).toEqual(stripUndefinedDeep(omit(templates, 'add'))); }); it('should include only allowed nodes', () => { const parsed = parseSchema(schema, ['add']); - expect(parsed).toEqual(pick(templates, 'add')); + expect(stripUndefinedDeep(parsed)).toEqual(stripUndefinedDeep(pick(templates, 'add'))); }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 1371db1568d..57cd9943c57 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -39,9 +39,6 @@ const isReservedInputField = (nodeType: string, fieldName: string) => { if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) { return true; } - if (nodeType === 'collect' && fieldName === 'collection') { - return true; - } if (nodeType === 'iterate' && fieldName === 'index') { return true; } diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index b605413787b..ea473b1d99d 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -4776,7 +4776,7 @@ export type components = { item?: unknown | null; /** * Collection - * @description The collection, will be provided on execution + * @description An optional collection to append to * @default [] */ collection?: unknown[]; diff --git a/tests/test_node_graph.py b/tests/test_node_graph.py index 160dc96d852..4f3b262204a 100644 --- a/tests/test_node_graph.py +++ b/tests/test_node_graph.py @@ -40,6 +40,7 @@ PromptTestInvocation, PromptTestInvocationOutput, TextToImageTestInvocation, + UnionCollectionTestInvocation, get_single_output_from_session, run_session_with_mock_context, ) @@ -337,6 +338,100 @@ def test_graph_collector_invalid_with_non_list_output(): g.add_edge(e3) +def test_graph_collector_can_chain_collection_input(): + g = Graph() + n1 = PromptCollectionTestInvocation(id="1", collection=["Banana", "Sushi"]) + n2 = PromptTestInvocation(id="2", prompt="Ramen") + n3 = CollectInvocation(id="3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + g.add_edge(create_edge("1", "collection", "3", "collection")) + g.add_edge(create_edge("2", "prompt", "3", "item")) + + session = GraphExecutionState(graph=g) + run_session_with_mock_context(session) + output = get_single_output_from_session(session, n3.id) + + assert isinstance(output, CollectInvocationOutput) + assert output.collection == ["Banana", "Sushi", "Ramen"] + + +def test_graph_collector_chain_rejects_mismatched_item_type(): + g = Graph() + n1 = PromptCollectionTestInvocation(id="1", collection=["Banana", "Sushi"]) + n2 = IntegerInvocation(id="2", value=7) + n3 = CollectInvocation(id="3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + g.add_edge(create_edge("1", "collection", "3", "collection")) + with pytest.raises(InvalidEdgeError): + g.add_edge(create_edge("2", "value", "3", "item")) + + +def test_graph_iterator_accepts_collector_chained_collection_input(): + g = Graph() + n1 = PromptTestInvocation(id="1", prompt="Banana") + n2 = CollectInvocation(id="2") + n3 = CollectInvocation(id="3") + n4 = IterateInvocation(id="4") + n5 = PromptTestInvocation(id="5") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + g.add_node(n5) + + g.add_edge(create_edge("1", "prompt", "2", "item")) + g.add_edge(create_edge("2", "collection", "3", "collection")) + g.add_edge(create_edge("3", "collection", "4", "collection")) + g.add_edge(create_edge("4", "item", "5", "prompt")) + + session = GraphExecutionState(graph=g) + run_session_with_mock_context(session) + + output = get_single_output_from_session(session, n5.id) + assert isinstance(output, PromptTestInvocationOutput) + assert output.prompt == "Banana" + + +def test_graph_collector_chain_rejects_upstream_mismatch_added_late(): + g = Graph() + n1 = CollectInvocation(id="1") + n2 = CollectInvocation(id="2") + n3 = PromptTestInvocation(id="3", prompt="typed-as-string") + n4 = ColorInvocation(id="4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + # Connect chain first while n1 is still untyped. + g.add_edge(create_edge("1", "collection", "2", "collection")) + # Constrain downstream collector to strings. + g.add_edge(create_edge("3", "prompt", "2", "item")) + # Now adding an incompatible type to the upstream collector must fail. + with pytest.raises(InvalidEdgeError): + g.add_edge(create_edge("4", "color", "1", "item")) + + +def test_graph_collector_rejects_mismatched_item_with_union_collection_input(): + g = Graph() + n1 = UnionCollectionTestInvocation(id="1") + n2 = CollectInvocation(id="2") + n3 = ColorInvocation(id="3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + g.add_edge(create_edge("1", "value", "2", "collection")) + with pytest.raises(InvalidEdgeError): + g.add_edge(create_edge("3", "color", "2", "item")) + + def test_graph_connects_iterator(): g = Graph() n1 = ListPassThroughInvocation(id="1") @@ -712,6 +807,24 @@ def test_iterate_accepts_collection(): g.add_edge(e3) +def test_iterate_accepts_collection_from_any_only_collector(): + g = Graph() + n1 = AnyTypeTestInvocation(id="1") + n2 = CollectInvocation(id="2") + n3 = IterateInvocation(id="3") + n4 = AnyTypeTestInvocation(id="4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + e1 = create_edge(n1.id, "value", n2.id, "item") + e2 = create_edge(n2.id, "collection", n3.id, "collection") + e3 = create_edge(n3.id, "item", n4.id, "value") + g.add_edge(e1) + g.add_edge(e2) + g.add_edge(e3) + + def test_iterate_validates_collection_inputs_against_iterator_outputs(): g = Graph() n1 = IntegerInvocation(id="1", value=1) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 04ea5126f02..6e8d25a6034 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -107,6 +107,19 @@ def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOu return PromptCollectionTestInvocationOutput(collection=self.value) +@invocation_output("test_union_collection_output") +class UnionCollectionTestInvocationOutput(BaseInvocationOutput): + value: Union[str, list[str], None] = OutputField(default=None) + + +@invocation("test_union_collection", version="1.0.0") +class UnionCollectionTestInvocation(BaseInvocation): + value: Union[str, list[str], None] = InputField(default=None) + + def invoke(self, context: InvocationContext) -> UnionCollectionTestInvocationOutput: + return UnionCollectionTestInvocationOutput(value=self.value) + + # Importing these must happen after test invocations are defined or they won't register from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402 from invokeai.app.services.shared.graph import Edge, EdgeConnection, GraphExecutionState # noqa: E402