From 313638d13ab6c0f28e8afb7d2cd7c4c67371b231 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 27 Oct 2025 18:52:07 +1000 Subject: [PATCH 1/4] mm: factor out the current stream getter Make this a reusable function. --- comfy/model_management.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3e5b977d4375..cd8326f5f40e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1013,6 +1013,16 @@ def force_channels_last(): NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) +def current_stream(device): + if device is None: + return None + if is_device_cuda(device): + return torch.cuda.current_stream() + elif is_device_xpu(device): + return torch.xpu.current_stream() + else: + return None + stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1023,10 +1033,7 @@ def get_offload_stream(device): ss = STREAMS[device] s = ss[stream_counter] stream_counter = (stream_counter + 1) % len(ss) - if is_device_cuda(device): - ss[stream_counter].wait_stream(torch.cuda.current_stream()) - elif is_device_xpu(device): - ss[stream_counter].wait_stream(torch.xpu.current_stream()) + ss[stream_counter].wait_stream(current_stream(device)) stream_counters[device] = stream_counter return s elif is_device_cuda(device): @@ -1050,12 +1057,9 @@ def get_offload_stream(device): return None def sync_stream(device, stream): - if stream is None: + if stream is None or current_stream(device) is None: return - if is_device_cuda(device): - torch.cuda.current_stream().wait_stream(stream) - elif is_device_xpu(device): - torch.xpu.current_stream().wait_stream(stream) + current_stream(device).wait_stream(stream) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: From 2d69bdfabeab1ac20141c8de6ebd9d9594980065 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 27 Oct 2025 18:27:41 +1000 Subject: [PATCH 2/4] ops: sync the offload stream with the consumption of w&b This sync is nessacary as pytorch will queue cuda async frees on the same stream as created to tensor. In the case of async offload, this will be on the offload stream. Weights and biases can go out of scope in python which then triggers the pytorch garbage collector to queue the free operation on the offload stream possible before the compute stream has used the weight. This causes a use after free on weight data leading to total corruption of some workflows. So sync the offload stream with the compute stream after the weight has been used so the free has to wait for the weight to be used. The cast_bias_weight is extended in a backwards compatible way with the new behaviour opt-in on a defaulted parameter. This handles custom node packs calling cast_bias_weight and defeatures async-offload for them (as they do not handle the race). The pattern is now: cast_bias_weight(... , offloadable=True) #This might be offloaded thing(weight, bias, ...) uncast_bias_weight(...) --- comfy/ops.py | 121 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 89 insertions(+), 32 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 93731eedf276..71ca7a2bd11e 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -70,8 +70,12 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) + @torch.compiler.disable() -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): + # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass + # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This + # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: dtype = input.dtype @@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if device is None: device = input.device - offload_stream = comfy.model_management.get_offload_stream(device) + if offloadable: + offload_stream = comfy.model_management.get_offload_stream(device) + else: + offload_stream = None + if offload_stream is not None: wf_context = offload_stream else: @@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): weight = f(weight) comfy.model_management.sync_stream(device, offload_stream) - return weight, bias + if offloadable: + return weight, bias, offload_stream + else: + #Legacy function signature + return weight, bias + + +def uncast_bias_weight(s, weight, bias, offload_stream): + if offload_stream is None: + return + if weight is not None: + device = weight.device + else: + if bias is None: + return + device = bias.device + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + class CastWeightBiasOp: comfy_cast_weights = False @@ -118,8 +143,10 @@ def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -133,8 +160,10 @@ def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -148,8 +177,10 @@ def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -172,8 +203,10 @@ def _conv_forward(self, input, weight, bias, *args, **kwargs): return super()._conv_forward(input, weight, bias, *args, **kwargs) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -187,8 +220,10 @@ def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -203,11 +238,14 @@ def reset_parameters(self): def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None bias = None - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + offload_stream = None + x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -223,11 +261,15 @@ def reset_parameters(self): def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None - return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated - # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + bias = None + offload_stream = None + x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -246,10 +288,12 @@ def forward_comfy_cast_weights(self, input, output_size=None): input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose2d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -268,10 +312,12 @@ def forward_comfy_cast_weights(self, input, output_size=None): input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose1d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -289,8 +335,11 @@ def forward_comfy_cast_weights(self, input, out_dtype=None): output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None - weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) - return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True) + x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + def forward(self, *args, **kwargs): run_every_op() @@ -361,7 +410,7 @@ def fp8_linear(self, input): input_dtype = input.dtype if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) scale_weight = self.scale_weight scale_input = self.scale_input @@ -382,6 +431,8 @@ def fp8_linear(self, input): quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -404,8 +455,10 @@ def forward_comfy_cast_weights(self, input): except Exception as e: logging.info("Exception during fp8 op: {}".format(e)) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) @@ -433,12 +486,14 @@ def forward_comfy_cast_weights(self, input): if out is not None: return out - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) if weight.numel() < input.numel(): #TODO: optimize - return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) else: - return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) + x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def convert_weight(self, weight, inplace=False, **kwargs): if inplace: @@ -577,8 +632,10 @@ def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, input, *args, **kwargs): run_every_op() From 432a5a9d821654b732196a95e99c91e9a2e13f6c Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 27 Oct 2025 18:53:19 +1000 Subject: [PATCH 3/4] controlnet: adopt new cast_bias_weight synchronization scheme This is nessacary for safe async weight offloading. --- comfy/controlnet.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f08ff4b36bd8..0b5e30f52a62 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -310,11 +310,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, self.bias = None def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) + x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, weight, bias) + x = torch.nn.functional.linear(input, weight, bias) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( @@ -350,12 +352,13 @@ def __init__( def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) + x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) - + x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options From 6ceba8c9f8e90b99b1c753fa51aa779f32f13710 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 27 Oct 2025 18:54:32 +1000 Subject: [PATCH 4/4] mm: sync the last stream in the queue, not the next Currently this peeks ahead to sync the next stream in the queue of streams with the compute stream. This doesnt allow a lot of parallelization, as then end result is you can only get one weight load ahead regardless of how many streams you have. Rotate the loop logic here to synchronize the end of the queue before returning the next stream. This allows weights to be loaded ahead of the compute streams position. --- comfy/model_management.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index cd8326f5f40e..79c0dfdb4b7e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1031,18 +1031,17 @@ def get_offload_stream(device): if device in STREAMS: ss = STREAMS[device] - s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) + #Sync the oldest stream in the queue with the current ss[stream_counter].wait_stream(current_stream(device)) + stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter - return s + return ss[stream_counter] elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): ss.append(torch.cuda.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s elif is_device_xpu(device): @@ -1051,7 +1050,6 @@ def get_offload_stream(device): ss.append(torch.xpu.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s return None