Skip to content

Commit 748e10c

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[NFC] Add name argument to pallas_call primitive.
This change introduces an optional `name` string argument to the `pallas_call` primitive and its associated lowering and transformation rules. The `name` is currently passed through but not yet used by the backend-specific lowering implementations for TPU, GPU, and Triton. A follow up CL will use it to extend the name stack on TPU PiperOrigin-RevId: 825262144
1 parent 7d256f8 commit 748e10c

File tree

7 files changed

+25
-5
lines changed

7 files changed

+25
-5
lines changed

jax/_src/pallas/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,7 @@ def core_map(
13911391
interpret: Whether to run the function in interpret mode.
13921392
debug: Whether or not to out helpful debugging information.
13931393
cost_estimate: The cost estimate of the function.
1394+
name: The (optional) name of the kernel.
13941395
metadata: Optional dictionary of information about the kernel that will be
13951396
serialized as JSON in the HLO. Can be used for debugging and analysis.
13961397
"""

jax/_src/pallas/hlo_interpreter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,9 @@ def pallas_call_hlo_interpret(
352352
cost_estimate: CostEstimate,
353353
out_avals: tuple[jax_core.AbstractValue, ...],
354354
metadata: frozen_dict.FrozenDict[str, str] | None,
355+
name: str | None,
355356
):
356-
del mesh, compiler_params, cost_estimate, out_avals, metadata
357+
del mesh, compiler_params, cost_estimate, out_avals, metadata, name
357358
debug_info = jaxpr.debug_info
358359
# If we're in interpret mode, we *scan* over the grid and eval the
359360
# discharged jaxpr.

jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1959,8 +1959,9 @@ def interpret_pallas_call(
19591959
out_avals: tuple[jax_core.AbstractValue, ...],
19601960
interpret_params: InterpretParams,
19611961
metadata: frozen_dict.FrozenDict[str, str] | None,
1962+
name: str | None,
19621963
):
1963-
del debug, cost_estimate, out_avals
1964+
del debug, cost_estimate, out_avals, name
19641965
del metadata # TODO(sharadmv): Add metadata to HLO.
19651966

19661967
if isinstance(mesh, mosaic_core.TensorCoreMesh):

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,10 @@ def pallas_call_tpu_lowering_rule(
372372
cost_estimate: pallas_core.CostEstimate | None,
373373
out_avals: tuple[jax_core.AbstractValue, ...],
374374
metadata: frozen_dict.FrozenDict[str, str] | None,
375+
name: str | None,
375376
):
376377
"""Lowers a pallas_call to a Mosaic TPU custom call."""
377-
del interpret # Unused.
378+
del interpret, name # Unused.
378379

379380
debug_info = jaxpr.debug_info
380381
if debug:

jax/_src/pallas/mosaic_gpu/pallas_call_registration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def pallas_call_lowering(
4949
cost_estimate: pallas_core.CostEstimate | None,
5050
out_avals: tuple[jax_core.AbstractValue, ...],
5151
metadata: frozen_dict.FrozenDict[str, str] | None,
52+
name: str | None,
5253
):
53-
del metadata # TODO(sharadmv): Add metadata to HLO.
54+
del metadata, name # TODO(sharadmv): Add metadata to HLO.
5455
debug_info = jaxpr.debug_info
5556
del interpret, out_avals
5657
if grid_mapping.num_dynamic_grid_bounds:

jax/_src/pallas/pallas_call.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def _pallas_call_to_lojax(
152152
out_avals: tuple[jax_core.AbstractValue, ...],
153153
backend: Backend | None,
154154
metadata: FrozenDict[str, str] | None,
155+
name: str | None,
155156
):
156157
if any(jax_core.get_aval(x).has_qdd for x in hi_args):
157158
raise NotImplementedError("pallas_call does not support QDD for inputs")
@@ -221,6 +222,7 @@ def _pallas_call_to_lojax(
221222
interpret=interpret,
222223
input_output_aliases=tuple(new_input_output_aliases),
223224
out_avals=tuple(lo_out_avals),
225+
name=name,
224226
)
225227
return pe.raise_lo_outs(out_avals, lo_outs)
226228
pallas_call_p.to_lojax = _pallas_call_to_lojax # type: ignore
@@ -241,6 +243,7 @@ def _pallas_call_jvp_rule(
241243
out_avals: tuple[jax_core.AbstractValue, ...],
242244
backend: Backend | None,
243245
metadata: FrozenDict[str, str] | None,
246+
name: str | None,
244247
):
245248
debug_info = jaxpr.debug_info
246249
if grid_mapping.num_dynamic_grid_bounds:
@@ -308,6 +311,7 @@ def _pallas_call_jvp_rule(
308311
out_avals=(*out_avals, *out_avals),
309312
backend=backend,
310313
metadata=metadata,
314+
name=name,
311315
)
312316
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
313317
return out_primals, out_tangents
@@ -457,6 +461,7 @@ def _batch_with_explicit_loop(
457461
out_avals: tuple[jax_core.AbstractValue, ...],
458462
backend: Backend | None,
459463
metadata: FrozenDict[str, str] | None,
464+
name: str | None,
460465
):
461466
"""Batch the pallas_call by calling it in loop over the batch size.
462467
@@ -526,6 +531,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
526531
out_avals=out_avals,
527532
backend=backend,
528533
metadata=metadata,
534+
name=name,
529535
)
530536
for i, batch_out_array in enumerate(batch_out):
531537
state[i] = jax.lax.dynamic_update_index_in_dim(
@@ -557,6 +563,7 @@ def _pallas_call_batching_rule(
557563
out_avals: tuple[jax_core.AbstractValue, ...],
558564
backend: Backend | None,
559565
metadata: FrozenDict[str, str] | None = None,
566+
name: str | None = None,
560567
):
561568
if mesh is not None:
562569
raise NotImplementedError(
@@ -596,6 +603,7 @@ def get_size(i, x, d):
596603
out_avals=out_avals,
597604
backend=backend,
598605
metadata=metadata,
606+
name=name,
599607
)
600608
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
601609

@@ -631,6 +639,7 @@ def get_size(i, x, d):
631639
out_avals=out_avals,
632640
backend=backend,
633641
metadata=metadata,
642+
name=name,
634643
)
635644
else:
636645
pass # No dynamic grid dimensions
@@ -667,6 +676,7 @@ def get_size(i, x, d):
667676
out_avals=out_avals,
668677
backend=backend,
669678
metadata=metadata,
679+
name=name,
670680
)
671681

672682
if not dims:
@@ -1048,6 +1058,7 @@ def index_rewrite_kernel(*indexer_args):
10481058
out_avals=batched_out_avals,
10491059
backend=backend,
10501060
metadata=metadata,
1061+
name=name,
10511062
)
10521063
return out, (0,) * len(out)
10531064

@@ -1523,6 +1534,7 @@ def _pallas_call_state_discharge_rule(
15231534
out_avals: tuple[jax_core.AbstractValue, ...],
15241535
backend: Backend | None,
15251536
metadata: FrozenDict[str, str] | None,
1537+
name: str | None,
15261538
):
15271539
del avals_out
15281540
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
@@ -1629,6 +1641,7 @@ def _rewritten_body(*args):
16291641
out_avals=new_out_avals,
16301642
backend=backend,
16311643
metadata=metadata,
1644+
name=name,
16321645
)
16331646
refs_out, rest = split_list(out_flat, [num_refs])
16341647
updated_vals_in = refs_out + [None] * len(rest_in_avals)
@@ -1909,6 +1922,7 @@ def wrapped(*args):
19091922
cost_estimate=cost_estimate,
19101923
backend=backend,
19111924
metadata=FrozenDict(metadata) if metadata is not None else None,
1925+
name=name,
19121926
)
19131927
out = tree_util.tree_unflatten(out_tree, out_flat)
19141928
return out

jax/_src/pallas/triton/pallas_call_registration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ def pallas_call_lowering(
5858
cost_estimate: pallas_core.CostEstimate | None,
5959
out_avals: tuple[jax_core.AbstractValue, ...],
6060
metadata: frozen_dict.FrozenDict[str, str] | None,
61+
name: str | None,
6162
):
62-
del interpret, out_avals, cost_estimate
63+
del interpret, out_avals, cost_estimate, name
6364
debug_info = jaxpr.debug_info
6465
if grid_mapping.num_dynamic_grid_bounds:
6566
raise NotImplementedError(

0 commit comments

Comments
 (0)