Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 150 additions & 39 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def is_any(t: Any) -> bool:
return t == Any or Any in get_args(t)


def extract_collection_item_types(t: Any) -> set[Any]:
"""Extracts list item types from a collection annotation, including unions containing list branches."""
if is_any(t):
return {Any}

if get_origin(t) is list:
return {arg for arg in get_args(t) if arg != NoneType}

item_types: set[Any] = set()
for arg in get_args(t):
if is_any(arg):
item_types.add(Any)
elif get_origin(arg) is list:
item_types.update(item_arg for item_arg in get_args(arg) if item_arg != NoneType)
return item_types


def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type or not to_type:
return False
Expand Down Expand Up @@ -280,7 +297,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
)


@invocation("collect", version="1.0.0")
@invocation("collect", version="1.1.0")
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""

Expand All @@ -292,7 +309,10 @@ class CollectInvocation(BaseInvocation):
input=Input.Connection,
)
collection: list[Any] = InputField(
description="The collection, will be provided on execution", default=[], ui_hidden=True
description="An optional collection to append to",
default=[],
ui_type=UIType._Collection,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
Expand Down Expand Up @@ -520,7 +540,9 @@ def _validate_edge(self, edge: Edge):

# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
if len(input_edges) > 0 and (
not isinstance(to_node, CollectInvocation) or edge.destination.field != ITEM_FIELD
):
raise InvalidEdgeError(f"Edge already exists ({edge})")

# Validate that no cycles would be created
Expand All @@ -546,8 +568,10 @@ def _validate_edge(self, edge: Edge):
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")

# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD:
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
if isinstance(to_node, CollectInvocation) and edge.destination.field in (ITEM_FIELD, COLLECTION_FIELD):
err = self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source, new_input_field=edge.destination.field
)
if err is not None:
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")

Expand Down Expand Up @@ -676,76 +700,152 @@ def _is_iterator_connection_valid(

# Collector input type must match all iterator output types
if isinstance(input_node, CollectInvocation):
collector_inputs = self._get_input_edges(input_node.id, ITEM_FIELD)
if len(collector_inputs) == 0:
return "Iterator input collector must have at least one item input edge"

# Traverse the graph to find the first collector input edge. Collectors validate that their collection
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
first_collector_input_edge = collector_inputs[0]
first_collector_input_type = get_output_field_type(
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
)
resolved_collector_type = (
first_collector_input_type
if get_origin(first_collector_input_type) is None
else get_args(first_collector_input_type)
)
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
input_root_type = self._get_collector_input_root_type(input_node.id)
if input_root_type is None:
return "Iterator input collector must have at least one item or collection input edge"
if not all((are_connection_types_compatible(input_root_type, t) for t in output_field_types)):
return "Iterator collection type must match all iterator output types"

return None

def _resolve_collector_input_types(self, node_id: str, visited: Optional[set[str]] = None) -> set[Any]:
"""Resolves possible item types for a collector's inputs, recursively following chained collectors."""
visited = visited or set()
if node_id in visited:
return set()
visited.add(node_id)

input_types: set[Any] = set()

for edge in self._get_input_edges(node_id, ITEM_FIELD):
input_field_type = get_output_field_type(self.get_node(edge.source.node_id), edge.source.field)
resolved_types = [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
input_types.update(t for t in resolved_types if t != NoneType)

for edge in self._get_input_edges(node_id, COLLECTION_FIELD):
source_node = self.get_node(edge.source.node_id)
if isinstance(source_node, CollectInvocation) and edge.source.field == COLLECTION_FIELD:
input_types.update(self._resolve_collector_input_types(source_node.id, visited.copy()))
continue

input_field_type = get_output_field_type(source_node, edge.source.field)
input_types.update(extract_collection_item_types(input_field_type))

return input_types

def _get_collector_input_root_type(self, node_id: str) -> Any | None:
input_types = self._resolve_collector_input_types(node_id)
non_any_input_types = {t for t in input_types if t != Any}
if len(non_any_input_types) == 0 and Any in input_types:
return Any
if len(non_any_input_types) == 0:
return None

type_tree = nx.DiGraph()
type_tree.add_nodes_from(non_any_input_types)
type_tree.add_edges_from([e for e in itertools.permutations(non_any_input_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
if len(root_types) != 1:
return Any
return root_types[0]

def _is_collector_connection_valid(
self,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_input_field: Optional[str] = None,
new_output: Optional[EdgeConnection] = None,
) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
item_inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
collection_inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)]
outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)]

if new_input is not None:
inputs.append(new_input)
field = new_input_field or ITEM_FIELD
if field == ITEM_FIELD:
item_inputs.append(new_input)
elif field == COLLECTION_FIELD:
collection_inputs.append(new_input)
if new_output is not None:
outputs.append(new_output)

# Get input and output fields (the fields linked to the iterator's input/output)
input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs]
if len(item_inputs) == 0 and len(collection_inputs) == 0:
return "Collector must have at least one item or collection input edge"

# Get input and output fields (the fields linked to the collector's input/output)
item_input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in item_inputs]
collection_input_field_types = [
get_output_field_type(self.get_node(e.node_id), e.field) for e in collection_inputs
]
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]

if not all((is_list_or_contains_list(t) or is_any(t) for t in collection_input_field_types)):
return "Collector collection input must be a collection"

# Validate that all inputs are derived from or match a single type
input_field_types = {
resolved_type
for input_field_type in input_field_types
for input_field_type in item_input_field_types
for resolved_type in (
[input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
)
if resolved_type != NoneType
} # Get unique types

for input_conn, input_field_type in zip(collection_inputs, collection_input_field_types, strict=False):
source_node = self.get_node(input_conn.node_id)
if isinstance(source_node, CollectInvocation) and input_conn.field == COLLECTION_FIELD:
input_field_types.update(self._resolve_collector_input_types(source_node.id))
continue
input_field_types.update(extract_collection_item_types(input_field_type))

non_any_input_field_types = {t for t in input_field_types if t != Any}
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_tree.add_nodes_from(non_any_input_field_types)
type_tree.add_edges_from(
[e for e in itertools.permutations(non_any_input_field_types, 2) if issubclass(e[1], e[0])]
)
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
if len(root_types) > 1:
return "Collector input collection items must be of a single type"

# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Get the input root type (if known)
input_root_type = root_types[0] if len(root_types) == 1 else None

# Verify that all outputs are lists
if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
return "Collector output must connect to a collection input"

# Verify that all outputs match the input type (are a base class or the same class)
if not all(
is_any(t)
or is_union_subtype(input_root_type, get_args(t)[0])
or issubclass(input_root_type, get_args(t)[0])
for t in output_field_types
):
if input_root_type is not None:
if not all(
is_any(t)
or is_union_subtype(input_root_type, get_args(t)[0])
or issubclass(input_root_type, get_args(t)[0])
for t in output_field_types
):
return "Collector outputs must connect to a collection input with a matching type"
elif any(not is_any(t) and get_args(t)[0] != Any for t in output_field_types):
return "Collector outputs must connect to a collection input with a matching type"

# If this collector outputs to another collector's collection input, validate against the downstream
# collector's resolved input type (if available).
for output in outputs:
output_node = self.get_node(output.node_id)
if not isinstance(output_node, CollectInvocation) or output.field != COLLECTION_FIELD:
continue
output_root_type = self._get_collector_input_root_type(output_node.id)
if output_root_type is None:
continue
if input_root_type is None:
if output_root_type != Any:
return "Collector outputs must connect to a collection input with a matching type"
continue
if not are_connection_types_compatible(input_root_type, output_root_type):
return "Collector outputs must connect to a collection input with a matching type"

return None

def nx_graph(self) -> nx.DiGraph:
Expand Down Expand Up @@ -1211,8 +1311,19 @@ def _prepare_inputs(self, node: BaseInvocation):
if isinstance(node, CollectInvocation):
item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD]
item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))

output_collection = [copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges]
collection_edges = [e for e in input_edges if e.destination.field == COLLECTION_FIELD]
collection_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))

output_collection = []
for edge in collection_edges:
source_value = copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
if isinstance(source_value, list):
output_collection.extend(source_value)
else:
output_collection.append(source_value)
output_collection.extend(
copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges
)
node.collection = output_collection
else:
for edge in input_edges:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,14 @@ describe(getCollectItemType.name, () => {
const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id);
expect(result).toBeNull();
});

it('should return the upstream collect item type for chained collects', () => {
const n1 = buildNode(collect);
const n2 = buildNode(collect);
const n3 = buildNode(add);
const e1 = buildEdge(n3.id, 'value', n1.id, 'item');
const e2 = buildEdge(n1.id, 'collection', n2.id, 'collection');
const result = getCollectItemType(templates, [n1, n2, n3], [e1, e2], n2.id);
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE', batch: false });
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@ import type { Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';

const toItemType = (fieldType: FieldType): FieldType | null => {
if (fieldType.name === 'CollectionField') {
return null;
}
if (fieldType.cardinality === 'COLLECTION' || fieldType.cardinality === 'SINGLE_OR_COLLECTION') {
return { ...fieldType, cardinality: 'SINGLE' };
}
return fieldType;
};

/**
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
Expand All @@ -18,21 +28,56 @@ export const getCollectItemType = (
edges: AnyEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle];
if (!fieldTemplate) {
const getCollectItemTypeInternal = (currentNodeId: string, visited: Set<string>): FieldType | null => {
if (visited.has(currentNodeId)) {
return null;
}
visited.add(currentNodeId);

const firstItemEdgeToCollect = edges.find((edge) => edge.target === currentNodeId && edge.targetHandle === 'item');
if (firstItemEdgeToCollect?.sourceHandle) {
const node = nodes.find((n) => n.id === firstItemEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldTemplate = template.outputs[firstItemEdgeToCollect.sourceHandle];
if (!fieldTemplate) {
return null;
}
return toItemType(fieldTemplate.type);
}

const firstCollectionEdgeToCollect = edges.find(
(edge) => edge.target === currentNodeId && edge.targetHandle === 'collection'
);
if (!firstCollectionEdgeToCollect?.sourceHandle) {
return null;
}
const sourceNode = nodes.find((n) => n.id === firstCollectionEdgeToCollect.source);
if (!sourceNode) {
return null;
}
if (sourceNode.data.type === 'collect' && firstCollectionEdgeToCollect.sourceHandle === 'collection') {
return getCollectItemTypeInternal(sourceNode.id, visited);
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return null;
}
const sourceFieldTemplate = sourceTemplate.outputs[firstCollectionEdgeToCollect.sourceHandle];
if (!sourceFieldTemplate) {
return null;
}
return toItemType(sourceFieldTemplate.type);
};

const itemType = getCollectItemTypeInternal(nodeId, new Set());
if (!itemType) {
return null;
}
return fieldTemplate.type;
return itemType;
};
Loading