|
9 | 9 | from ..logging.logging_config import get_logger |
10 | 10 | from .node_classifier import NodeClassifier |
11 | 11 | from ..configs.model_config import ModelConfig |
| 12 | +from ..configs.comfyui_models import MULTI_MODEL_WIDGET_CONFIGS |
12 | 13 | from ..models.workflow import ( |
13 | 14 | WorkflowNodeWidgetRef, |
14 | 15 | WorkflowNode, |
@@ -88,57 +89,151 @@ def analyze_dependencies(self) -> WorkflowDependencies: |
88 | 89 | def _extract_model_node_refs(self, node_id: str, node_info: WorkflowNode) -> List["WorkflowNodeWidgetRef"]: |
89 | 90 | """Extract possible model references from a single node. |
90 | 91 |
|
| 92 | + Uses a two-pronged approach: |
| 93 | + 1. Extract from properties.models (preferred - has URLs for auto-download) |
| 94 | + 2. Fall back to widget extraction using MULTI_MODEL_WIDGET_CONFIGS |
| 95 | +
|
91 | 96 | Args: |
92 | 97 | node_id: Scoped node ID from workflow.nodes dict key (e.g., "uuid:12" for subgraph nodes) |
93 | 98 | node_info: WorkflowNode object containing node data |
94 | 99 | """ |
| 100 | + refs: list[WorkflowNodeWidgetRef] = [] |
95 | 101 |
|
96 | | - refs = [] |
| 102 | + # Strategy 1: Extract from properties.models (preferred - has URLs) |
| 103 | + property_models = node_info.properties.get('models', []) |
| 104 | + if property_models: |
| 105 | + refs.extend(self._extract_from_properties_models(node_id, node_info, property_models)) |
97 | 106 |
|
98 | | - # Handle multi-model nodes specially |
99 | | - if node_info.type == "CheckpointLoader": |
100 | | - # Index 0: checkpoint, Index 1: config |
101 | | - widgets = node_info.widgets_values or [] |
102 | | - if len(widgets) > 0 and widgets[0]: |
103 | | - refs.append(WorkflowNodeWidgetRef( |
104 | | - node_id=node_id, # Use scoped ID from dict key |
105 | | - node_type=node_info.type, |
106 | | - widget_index=0, |
107 | | - widget_value=widgets[0] |
108 | | - )) |
109 | | - if len(widgets) > 1 and widgets[1]: |
110 | | - refs.append(WorkflowNodeWidgetRef( |
111 | | - node_id=node_id, # Use scoped ID from dict key |
112 | | - node_type=node_info.type, |
113 | | - widget_index=1, |
114 | | - widget_value=widgets[1] |
115 | | - )) |
| 107 | + # Strategy 2: Multi-model nodes (explicit widget indices from config) |
| 108 | + if node_info.type in MULTI_MODEL_WIDGET_CONFIGS: |
| 109 | + widget_refs = self._extract_multi_model_widgets(node_id, node_info) |
| 110 | + refs = self._merge_model_refs(refs, widget_refs) |
116 | 111 |
|
117 | | - # Standard single-model loaders |
| 112 | + # Strategy 3: Standard single-model loaders |
118 | 113 | elif self.model_config.is_model_loader_node(node_info.type): |
119 | | - widget_idx = self.model_config.get_widget_index_for_node(node_info.type) |
120 | | - widgets = node_info.widgets_values or [] |
121 | | - if widget_idx < len(widgets) and widgets[widget_idx]: |
122 | | - refs.append(WorkflowNodeWidgetRef( |
123 | | - node_id=node_id, # Use scoped ID from dict key |
124 | | - node_type=node_info.type, |
125 | | - widget_index=widget_idx, |
126 | | - widget_value=widgets[widget_idx] |
127 | | - )) |
| 114 | + widget_refs = self._extract_single_model_widget(node_id, node_info) |
| 115 | + refs = self._merge_model_refs(refs, widget_refs) |
128 | 116 |
|
129 | | - # Pattern match all widgets for custom nodes |
| 117 | + # Strategy 4: Pattern match all widgets for custom nodes |
130 | 118 | else: |
131 | | - widgets = node_info.widgets_values or [] |
132 | | - for idx, value in enumerate(widgets): |
133 | | - if self._looks_like_model(value): |
| 119 | + widget_refs = self._extract_by_pattern(node_id, node_info) |
| 120 | + refs = self._merge_model_refs(refs, widget_refs) |
| 121 | + |
| 122 | + return refs |
| 123 | + |
| 124 | + def _extract_from_properties_models( |
| 125 | + self, |
| 126 | + node_id: str, |
| 127 | + node_info: WorkflowNode, |
| 128 | + property_models: list[dict] |
| 129 | + ) -> list[WorkflowNodeWidgetRef]: |
| 130 | + """Extract model refs from node.properties.models array. |
| 131 | +
|
| 132 | + Properties models have structure: |
| 133 | + {"name": "model.safetensors", "url": "https://...", "directory": "text_encoders"} |
| 134 | + """ |
| 135 | + refs = [] |
| 136 | + for idx, model_entry in enumerate(property_models): |
| 137 | + if not isinstance(model_entry, dict): |
| 138 | + continue |
| 139 | + name = model_entry.get('name', '') |
| 140 | + if not name: |
| 141 | + continue |
| 142 | + |
| 143 | + # Find corresponding widget index by matching name to widgets_values |
| 144 | + widget_idx = self._find_widget_index_for_name(node_info, name) |
| 145 | + |
| 146 | + refs.append(WorkflowNodeWidgetRef( |
| 147 | + node_id=node_id, |
| 148 | + node_type=node_info.type, |
| 149 | + widget_index=widget_idx if widget_idx is not None else idx, |
| 150 | + widget_value=name, |
| 151 | + property_url=model_entry.get('url'), |
| 152 | + property_directory=model_entry.get('directory') |
| 153 | + )) |
| 154 | + return refs |
| 155 | + |
| 156 | + def _find_widget_index_for_name(self, node_info: WorkflowNode, name: str) -> int | None: |
| 157 | + """Find widget index that contains the given model name.""" |
| 158 | + widgets = node_info.widgets_values or [] |
| 159 | + for idx, value in enumerate(widgets): |
| 160 | + if isinstance(value, str) and value == name: |
| 161 | + return idx |
| 162 | + return None |
| 163 | + |
| 164 | + def _extract_multi_model_widgets(self, node_id: str, node_info: WorkflowNode) -> list[WorkflowNodeWidgetRef]: |
| 165 | + """Extract models from multi-model nodes using MULTI_MODEL_WIDGET_CONFIGS. |
| 166 | +
|
| 167 | + Note: Unlike pattern matching, multi-model configs explicitly define which |
| 168 | + widgets contain models, so we trust them without extension filtering. |
| 169 | + This allows CheckpointLoader to capture both .safetensors and .yaml configs. |
| 170 | + """ |
| 171 | + refs = [] |
| 172 | + widget_indices = MULTI_MODEL_WIDGET_CONFIGS.get(node_info.type, []) |
| 173 | + widgets = node_info.widgets_values or [] |
| 174 | + |
| 175 | + for widget_idx in widget_indices: |
| 176 | + if widget_idx < len(widgets) and widgets[widget_idx]: |
| 177 | + value = widgets[widget_idx] |
| 178 | + if isinstance(value, str) and value.strip(): |
134 | 179 | refs.append(WorkflowNodeWidgetRef( |
135 | | - node_id=node_id, # Use scoped ID from dict key |
| 180 | + node_id=node_id, |
136 | 181 | node_type=node_info.type, |
137 | | - widget_index=idx, |
| 182 | + widget_index=widget_idx, |
138 | 183 | widget_value=value |
139 | 184 | )) |
| 185 | + return refs |
| 186 | + |
| 187 | + def _extract_single_model_widget(self, node_id: str, node_info: WorkflowNode) -> list[WorkflowNodeWidgetRef]: |
| 188 | + """Extract model from standard single-model loader nodes.""" |
| 189 | + refs = [] |
| 190 | + widget_idx = self.model_config.get_widget_index_for_node(node_info.type) |
| 191 | + widgets = node_info.widgets_values or [] |
| 192 | + |
| 193 | + if widget_idx < len(widgets) and widgets[widget_idx]: |
| 194 | + refs.append(WorkflowNodeWidgetRef( |
| 195 | + node_id=node_id, |
| 196 | + node_type=node_info.type, |
| 197 | + widget_index=widget_idx, |
| 198 | + widget_value=widgets[widget_idx] |
| 199 | + )) |
| 200 | + return refs |
| 201 | + |
| 202 | + def _extract_by_pattern(self, node_id: str, node_info: WorkflowNode) -> list[WorkflowNodeWidgetRef]: |
| 203 | + """Extract models by pattern matching widget values (for custom nodes).""" |
| 204 | + refs = [] |
| 205 | + widgets = node_info.widgets_values or [] |
140 | 206 |
|
| 207 | + for idx, value in enumerate(widgets): |
| 208 | + if self._looks_like_model(value): |
| 209 | + refs.append(WorkflowNodeWidgetRef( |
| 210 | + node_id=node_id, |
| 211 | + node_type=node_info.type, |
| 212 | + widget_index=idx, |
| 213 | + widget_value=value |
| 214 | + )) |
141 | 215 | return refs |
| 216 | + |
| 217 | + def _merge_model_refs( |
| 218 | + self, |
| 219 | + property_refs: list[WorkflowNodeWidgetRef], |
| 220 | + widget_refs: list[WorkflowNodeWidgetRef] |
| 221 | + ) -> list[WorkflowNodeWidgetRef]: |
| 222 | + """Merge property refs with widget refs, preserving property metadata. |
| 223 | +
|
| 224 | + Property refs take precedence when both have the same widget_value, |
| 225 | + since they may contain URL metadata for auto-download. |
| 226 | + """ |
| 227 | + # Build set of values already in property_refs |
| 228 | + property_values = {ref.widget_value for ref in property_refs} |
| 229 | + |
| 230 | + # Add widget refs that aren't already covered by property refs |
| 231 | + merged = list(property_refs) |
| 232 | + for ref in widget_refs: |
| 233 | + if ref.widget_value not in property_values: |
| 234 | + merged.append(ref) |
| 235 | + |
| 236 | + return merged |
142 | 237 |
|
143 | 238 | def _looks_like_model(self, value: Any) -> bool: |
144 | 239 | """Check if value looks like a model path""" |
|
0 commit comments