From 12ba0ddb2f0ed2fc00f5a8c75028517c43ddf05d Mon Sep 17 00:00:00 2001 From: Asandei Stefan-Alexandru Date: Wed, 29 Oct 2025 15:06:11 +0200 Subject: [PATCH 1/2] fix FakeTextEncoder bug when loading --- .../diffusion_models/chroma/chroma_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index 236d9508b..aa0fc0f97 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -420,9 +420,13 @@ def get_model_has_grad(self): return self.model.final_layer.linear.weight.requires_grad def get_te_has_grad(self): - # return from a weight if it has grad - return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad - + from toolkit.unloader import FakeTextEncoder + + te = self.text_encoder[1] + if isinstance(te, FakeTextEncoder): + return False + return te.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + def save_model(self, output_path, meta, save_dtype): if not output_path.endswith(".safetensors"): output_path = output_path + ".safetensors" From ef8e5790ce90574acbdd5d865e72005940e845b0 Mon Sep 17 00:00:00 2001 From: Asandei Stefan-Alexandru Date: Wed, 29 Oct 2025 15:07:08 +0200 Subject: [PATCH 2/2] enable Chroma loading within 32gb RAM & 24gb VRAM --- .../diffusion_models/chroma/chroma_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index aa0fc0f97..83ec1a669 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -167,8 +167,13 @@ def load_model(self): chroma_params.depth = double_blocks chroma_params.depth_single_blocks = single_blocks + + # load Chroma into RAM in bfloat16, go back to fp32 afterwards + def_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) transformer = Chroma(chroma_params) - + torch.set_default_dtype(def_dtype) + # add dtype, not sure why it doesnt have it transformer.dtype = dtype # load the state dict into the model