Skip to content

JVP treatment of mask as a leaf vs pytree #29

@rdyro

Description

@rdyro

The mask in jvp might not match the pytree structure of grad_y.

Unfortunately, I don't have an open-source repro at the moment, but these lines assume pytrees of a similar structure, however the mask is a leaf, but the gradients are a pytree

Is that correct?

I believe you recently addressed a similar issue here:
5150e33#diff-21e634aa62155f577c8e87e1b851189b4791db79bdb2593cc957ca86e8cde5ccL328

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions