-
Notifications
You must be signed in to change notification settings - Fork 19
Open
Description
The mask in jvp might not match the pytree structure of grad_y.
Unfortunately, I don't have an open-source repro at the moment, but these lines assume pytrees of a similar structure, however the mask is a leaf, but the gradients are a pytree
Line 183 in d3bf210
new_masks = broadcast_mask_to_jacobian(out_mask, grad_y) Line 248 in d3bf210
result_mask = broadcast_mask_to_jacobian(result_mask, grad_y)
Is that correct?
I believe you recently addressed a similar issue here:
5150e33#diff-21e634aa62155f577c8e87e1b851189b4791db79bdb2593cc957ca86e8cde5ccL328
Metadata
Metadata
Assignees
Labels
No labels