@@ -152,6 +152,7 @@ def _pallas_call_to_lojax(
152152 out_avals : tuple [jax_core .AbstractValue , ...],
153153 backend : Backend | None ,
154154 metadata : FrozenDict [str , str ] | None ,
155+ name : str | None ,
155156):
156157 if any (jax_core .get_aval (x ).has_qdd for x in hi_args ):
157158 raise NotImplementedError ("pallas_call does not support QDD for inputs" )
@@ -221,6 +222,7 @@ def _pallas_call_to_lojax(
221222 interpret = interpret ,
222223 input_output_aliases = tuple (new_input_output_aliases ),
223224 out_avals = tuple (lo_out_avals ),
225+ name = name ,
224226 )
225227 return pe .raise_lo_outs (out_avals , lo_outs )
226228pallas_call_p .to_lojax = _pallas_call_to_lojax # type: ignore
@@ -241,6 +243,7 @@ def _pallas_call_jvp_rule(
241243 out_avals : tuple [jax_core .AbstractValue , ...],
242244 backend : Backend | None ,
243245 metadata : FrozenDict [str , str ] | None ,
246+ name : str | None ,
244247):
245248 debug_info = jaxpr .debug_info
246249 if grid_mapping .num_dynamic_grid_bounds :
@@ -308,6 +311,7 @@ def _pallas_call_jvp_rule(
308311 out_avals = (* out_avals , * out_avals ),
309312 backend = backend ,
310313 metadata = metadata ,
314+ name = name ,
311315 )
312316 out_primals , out_tangents = split_list (out_flat , [len (out_flat ) // 2 ])
313317 return out_primals , out_tangents
@@ -457,6 +461,7 @@ def _batch_with_explicit_loop(
457461 out_avals : tuple [jax_core .AbstractValue , ...],
458462 backend : Backend | None ,
459463 metadata : FrozenDict [str , str ] | None ,
464+ name : str | None ,
460465):
461466 """Batch the pallas_call by calling it in loop over the batch size.
462467
@@ -526,6 +531,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
526531 out_avals = out_avals ,
527532 backend = backend ,
528533 metadata = metadata ,
534+ name = name ,
529535 )
530536 for i , batch_out_array in enumerate (batch_out ):
531537 state [i ] = jax .lax .dynamic_update_index_in_dim (
@@ -557,6 +563,7 @@ def _pallas_call_batching_rule(
557563 out_avals : tuple [jax_core .AbstractValue , ...],
558564 backend : Backend | None ,
559565 metadata : FrozenDict [str , str ] | None = None ,
566+ name : str | None = None ,
560567):
561568 if mesh is not None :
562569 raise NotImplementedError (
@@ -596,6 +603,7 @@ def get_size(i, x, d):
596603 out_avals = out_avals ,
597604 backend = backend ,
598605 metadata = metadata ,
606+ name = name ,
599607 )
600608 return [jnp .expand_dims (x , 0 ) for x in out ], (0 ,) * len (out )
601609
@@ -631,6 +639,7 @@ def get_size(i, x, d):
631639 out_avals = out_avals ,
632640 backend = backend ,
633641 metadata = metadata ,
642+ name = name ,
634643 )
635644 else :
636645 pass # No dynamic grid dimensions
@@ -667,6 +676,7 @@ def get_size(i, x, d):
667676 out_avals = out_avals ,
668677 backend = backend ,
669678 metadata = metadata ,
679+ name = name ,
670680 )
671681
672682 if not dims :
@@ -1048,6 +1058,7 @@ def index_rewrite_kernel(*indexer_args):
10481058 out_avals = batched_out_avals ,
10491059 backend = backend ,
10501060 metadata = metadata ,
1061+ name = name ,
10511062 )
10521063 return out , (0 ,) * len (out )
10531064
@@ -1523,6 +1534,7 @@ def _pallas_call_state_discharge_rule(
15231534 out_avals : tuple [jax_core .AbstractValue , ...],
15241535 backend : Backend | None ,
15251536 metadata : FrozenDict [str , str ] | None ,
1537+ name : str | None ,
15261538):
15271539 del avals_out
15281540 assert all (isinstance (v .aval , state .AbstractRef ) for v in jaxpr .constvars )
@@ -1629,6 +1641,7 @@ def _rewritten_body(*args):
16291641 out_avals = new_out_avals ,
16301642 backend = backend ,
16311643 metadata = metadata ,
1644+ name = name ,
16321645 )
16331646 refs_out , rest = split_list (out_flat , [num_refs ])
16341647 updated_vals_in = refs_out + [None ] * len (rest_in_avals )
@@ -1909,6 +1922,7 @@ def wrapped(*args):
19091922 cost_estimate = cost_estimate ,
19101923 backend = backend ,
19111924 metadata = FrozenDict (metadata ) if metadata is not None else None ,
1925+ name = name ,
19121926 )
19131927 out = tree_util .tree_unflatten (out_tree , out_flat )
19141928 return out
0 commit comments