Skip to content

Conversation

@kaixuanliu
Copy link
Contributor

@kaixuanliu kaixuanliu commented Oct 17, 2025

When we run unit test like pytest -rA tests/pipelines/wan/test_wan_22.py::Wan22PipelineFastTests::test_save_load_float16, we found that the pipeline runs w/ all fp16 datatype, but after save and reload, some parts of text-encoder in pipe_loaded uses fp32, although we set torch_dtype to fp16 explicitly. Deep investigation found that the root cause is here: L783. Here we made an adjustment to the test case to manually add the component = component.to(torch_device).half() operation to align excatly with the behavior in pipe

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu
Copy link
Contributor Author

@a-r-r-o-w @DN6 pls help review, thx!

@regisss
Copy link
Contributor

regisss commented Oct 22, 2025

Not sure I understand the issue here. This specific T5 module is kept in fp32 on purpose, why forcing a fp16 cast in the test?

@kaixuanliu
Copy link
Contributor Author

kaixuanliu commented Oct 23, 2025

@regisss Hi, the purpose of this test case is to compare the output of pipelines using fp16 dtype(pipe) and the output of pipelines loaded from previously saved(pipe_loaded), they should be the same. However, all components of pipe is set to fp16 dtype in L1424~L1426, while for pipe_loaded, some parts are kept in fp32, which does not match exactly with the computation in pipe fwd.

Comment on lines 1449 to 1451
if hasattr(component, "half"):
# Although all components for pipe_loaded should be float16 now, some submodules still use fp32, like in https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/t5/modeling_t5.py#L783, so we need to do the conversion again manally to align with the datatype we use in pipe exactly
component = component.to(torch_device).half()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right at all. torch_dtype should be able to take care of it. I just ran it on my GPU for SD and it worked fine.

Copy link
Contributor Author

@kaixuanliu kaixuanliu Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sayakpaul , I tested on A100, and when I print pipe_loaded.text_encoder.encoder.block[0].layer[1].DenseReluDense.wo.weight.dtype in L1455 , it returns torch.float32, not torch.float16, and the max_diff in L1456 is np.float16(0.0004883). When we apply this PR to align excatly with the behavior in pipe, the max_diff is 0. I think it's better to adjust the test case to make the output comparison of pipe and pipe_loaded apple to apple. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is torch_dtype in from_pretrained() should be enough for the model to be in fp16. Setting it with half() after loading the model in the FP16 torch_dtype seems erroneous to me.

I also ran the test on an A100, and it wasn't a problem. So, I am not sure if this test fix is correct at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I printed pipe_loaded.text_encoder.encoder.block[0].layer[1].DenseReluDense.wo.weight.dtype after pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16), and it returns torch.float32, it is root caused in L783, so I manualy add .half() to pipe_loaded, although it looks a bit wierd... On A100, the tolerance value is OK, but I think from the fundamentals perspective, the output from pipelines loaded from former saved should be exactly the same, that is the max_diff should be 0, right?

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu
Copy link
Contributor Author

@sayakpaul Hi, I adjusted the test code to pass dtype to get_dummy_components, instead of add .half() to every component, do you think it's OK now?

supports_dduf = False

def get_dummy_components(self):
def get_dummy_components(self, dtype=torch.float32):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

@kaixuanliu kaixuanliu Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls refer to L246-L256 (Sorry I only found Chinese version for this explanation). Using torch.Tensor.to method will convert all weights, while using torch_dtype parameter with from_pretrained will preserve layers in _keep_in_fp32_modules. For wan models, all components of pipe will be fp16 dtype while it is not the case for pipe_loaded. Here I override test_save_load_float16 function seperately for wan models.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am honestly not sure about the changes introduced in this PR. We have gone over multiple comments and so far, I haven't been able to manually verify myself the failures this PR tries to solve.

@kaixuanliu kaixuanliu marked this pull request as draft October 30, 2025 08:05
@kaixuanliu kaixuanliu marked this pull request as ready for review October 30, 2025 08:47
pass

@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
def test_save_load_float16(self, expected_max_diff=1e-2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't know then how on my end the tests are passing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be related with the input. When I set all the seed in get_dummy_components to 1, the max_diff on A100 is np.float16(0.2366), and when set seed to 42, the output will be all nan value. After this PR, the max_diff will all be 0 for all the seed

@kaixuanliu kaixuanliu marked this pull request as draft October 30, 2025 10:10
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
# Use from_pretrained with a tiny model to ensure proper dtype handling
# This ensures _keep_in_fp32_modules and _skip_layerwise_casting_patterns are respected
transformer = WanTransformer3DModel.from_pretrained(
"Kaixuanliu/tiny-random-wan-transformer",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls replace my model space. We have to use from_pretrained here to make all the submodules' dtype correctly loaded.

qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
transformer_2 = WanTransformer3DModel.from_pretrained(
"Kaixuanliu/tiny-random-wan-transformer",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@kaixuanliu kaixuanliu marked this pull request as ready for review October 30, 2025 11:18
@kaixuanliu
Copy link
Contributor Author

CC @yao-matrix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants