Skip to content

Using MTP + packing + full recompute causing exception in save_for_backward #3643

@arvyanh

Description

@arvyanh

Describe the bug
in multi_token_prediction.py:
Packed_seq_param is passed in for roll action, and was then passed into _checkpointed_forward function

Image

Which caused following Issue:

/usr/local/lib/python3.12/dist-packages/Megatron-LM/megatron/core/tensor_parallel/random.py", line 480, in checkpoint
    return CheckpointFunction.apply(function, distribute_saved_activations, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: save_for_backward can only save variables, but argument 10 is of type PackedSeqParams

Steps/Code to reproduce bug

Please list minimal steps or code snippet for us to be able to reproduce the bug.

A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.

Expected behavior

PackedSeqParams should not be passed into save_for_backward()

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions