Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
self.bias = None

def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x

class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
Expand Down Expand Up @@ -350,12 +352,13 @@ def __init__(


def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)

x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x

class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
Expand Down
28 changes: 15 additions & 13 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,16 @@ def force_channels_last():
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))

def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None

stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
Expand All @@ -1021,21 +1031,17 @@ def get_offload_stream(device):

if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
#Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
return ss[stream_counter]
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
Expand All @@ -1044,18 +1050,14 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None

def sync_stream(device, stream):
if stream is None:
if stream is None or current_stream(device) is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
current_stream(device).wait_stream(stream)

def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
Expand Down
121 changes: 89 additions & 32 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)


@torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
dtype = input.dtype
Expand All @@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None:
device = input.device

offload_stream = comfy.model_management.get_offload_stream(device)
if offloadable:
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None

if offload_stream is not None:
wf_context = offload_stream
else:
Expand All @@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
weight = f(weight)

comfy.model_management.sync_stream(device, offload_stream)
return weight, bias
if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias


def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))


class CastWeightBiasOp:
comfy_cast_weights = False
Expand All @@ -118,8 +143,10 @@ def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -133,8 +160,10 @@ def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -148,8 +177,10 @@ def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -172,8 +203,10 @@ def _conv_forward(self, input, weight, bias, *args, **kwargs):
return super()._conv_forward(input, weight, bias, *args, **kwargs)

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -187,8 +220,10 @@ def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -203,11 +238,14 @@ def reset_parameters(self):

def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -223,11 +261,15 @@ def reset_parameters(self):

def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -246,10 +288,12 @@ def forward_comfy_cast_weights(self, input, output_size=None):
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)

weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -268,10 +312,12 @@ def forward_comfy_cast_weights(self, input, output_size=None):
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)

weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
Expand All @@ -289,8 +335,11 @@ def forward_comfy_cast_weights(self, input, out_dtype=None):
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x


def forward(self, *args, **kwargs):
run_every_op()
Expand Down Expand Up @@ -361,7 +410,7 @@ def fp8_linear(self, input):
input_dtype = input.dtype

if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)

scale_weight = self.scale_weight
scale_input = self.scale_input
Expand All @@ -382,6 +431,8 @@ def fp8_linear(self, input):
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)

uncast_bias_weight(self, w, bias, offload_stream)

if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
Expand All @@ -404,8 +455,10 @@ def forward_comfy_cast_weights(self, input):
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))

weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
Expand Down Expand Up @@ -433,12 +486,14 @@ def forward_comfy_cast_weights(self, input):
if out is not None:
return out

weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)

if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
Expand Down Expand Up @@ -577,8 +632,10 @@ def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, input, *args, **kwargs):
run_every_op()
Expand Down