diff --git a/adv_control/control.py b/adv_control/control.py index 810f3d1..3711d32 100644 --- a/adv_control/control.py +++ b/adv_control/control.py @@ -523,10 +523,34 @@ def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None): def is_advanced_controlnet(input_object): return hasattr(input_object, "sub_idxs") +def adjust_positional_encoding_parameters(controlnet_data, expected_seq_len): + """ + Adjusts the positional encoding parameters in the model state dict for the expected sequence length. + This is a utility function to ensure compatibility with models saved with different configurations. + """ + pe_keys = [key for key in controlnet_data.keys() if "pos_encoder.pe" in key] + for key in pe_keys: + original_pe = controlnet_data[key] + _, seq_len, dim = original_pe.shape + if seq_len != expected_seq_len: + # Ensure expected_seq_len and dim are integers + expected_seq_len = int(expected_seq_len) + dim = int(dim) + + # Adjust the positional encoding to match the expected sequence length. + adjusted_pe = torch.zeros((1, expected_seq_len, dim)) + length_to_copy = min(seq_len, expected_seq_len) + adjusted_pe[:, :length_to_copy, :] = original_pe[:, :length_to_copy, :] + controlnet_data[key] = adjusted_pe + def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, sparse_settings=SparseSettings.default(), model=None) -> SparseCtrlAdvanced: if controlnet_data is None: controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + + # Adjust positional encoding parameters before loading parts of the model, using the expected_seq_len from sparse_settings + adjust_positional_encoding_parameters(controlnet_data, sparse_settings.expected_seq_len) + # first, separate out motion part from normal controlnet part and attempt to load that portion motion_data = {} for key in list(controlnet_data.keys()): diff --git a/adv_control/control_sparsectrl.py b/adv_control/control_sparsectrl.py index 5885ed5..41e793c 100644 --- a/adv_control/control_sparsectrl.py +++ b/adv_control/control_sparsectrl.py @@ -127,16 +127,17 @@ def __setitem__(self, *args, **kwargs): class SparseSettings: - def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False): + def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False, expected_seq_len=32): self.sparse_method = sparse_method self.use_motion = use_motion self.motion_strength = motion_strength self.motion_scale = motion_scale self.merged = merged + self.expected_seq_len = expected_seq_len # Add expected sequence length for positional encodings @classmethod def default(cls): - return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True) + return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True, expected_seq_len=32) class SparseMethod(ABC): diff --git a/adv_control/nodes_sparsectrl.py b/adv_control/nodes_sparsectrl.py index 4df32b0..1d2b2af 100644 --- a/adv_control/nodes_sparsectrl.py +++ b/adv_control/nodes_sparsectrl.py @@ -20,6 +20,7 @@ def INPUT_TYPES(s): "use_motion": ("BOOLEAN", {"default": True}, ), "motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), "motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), + "expected_seq_len": ("INT", {"default": 32.0, "min": 1.0, "step": 1}, ), }, "optional": { "sparse_method": ("SPARSE_METHOD", ), @@ -32,9 +33,9 @@ def INPUT_TYPES(s): CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl" - def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None): + def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, expected_seq_len: int = 32, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None): sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name) - sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale) + sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale, expected_seq_len=expected_seq_len) sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=sparse_settings) return (sparsectrl,)