Skip to content

Commit 7d256f8

Browse files
yashk2810mattjj
authored andcommitted
Make MetaTy work with HiJAX. Before this HiTypes (with values) were being held in Traced without converting it to MetaTys.
This change makes MetaTys work with HiTypes. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 826153440
1 parent d2ce04b commit 7d256f8

File tree

2 files changed

+30
-33
lines changed

2 files changed

+30
-33
lines changed

jax/_src/pjit.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from collections import defaultdict
1818
from collections.abc import Callable, Sequence, Iterable
19-
import dataclasses
19+
from dataclasses import dataclass, replace
2020
from functools import partial
2121
import inspect
2222
import logging
@@ -153,7 +153,7 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
153153
except stages.DeviceAssignmentMismatchError as e:
154154
fails, = e.args
155155
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
156-
arg_types = map(partial(convert_to_lower_type, False), args_flat)
156+
arg_types = map(convert_to_metaty, args_flat)
157157
msg = stages._device_assignment_mismatch_error(
158158
fun_name, fails, arg_types, 'jit', p.arg_names)
159159
raise ValueError(msg) from None
@@ -304,8 +304,7 @@ def cache_miss(*args, **kwargs):
304304
@api_boundary
305305
def jit_trace(jit_func, *args, **kwargs) -> stages.Traced:
306306
p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
307-
arg_types = map(partial(convert_to_lower_type, p.params['jaxpr'].is_high),
308-
args_flat)
307+
arg_types = map(convert_to_metaty, args_flat)
309308
return stages.Traced(arg_types, p.params, p.in_tree, p.out_tree, p.consts)
310309

311310
@api_boundary
@@ -1253,7 +1252,7 @@ def _qdd_cache_update(fun, in_type, i, consts, aval_qdds):
12531252
if aval_qdd.has_qdd])
12541253

12551254

1256-
@dataclasses.dataclass(frozen=True)
1255+
@dataclass(frozen=True)
12571256
class IgnoreKey:
12581257
val: Any
12591258
def __hash__(self):
@@ -1488,7 +1487,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
14881487
if isinstance(pjit_in_s, UnspecifiedValue):
14891488
resolved_in_shardings.append(finalize_arg_sharding(arg_s, committed))
14901489
else:
1491-
if (arg.is_nd_array and not pjit_in_s.is_fully_replicated and # type: ignore[union-attr]
1490+
if (arg.is_np_array and not pjit_in_s.is_fully_replicated and # type: ignore[union-attr]
14921491
xb.process_count() > 1):
14931492
raise ValueError(
14941493
'Passing non-trivial shardings for numpy '
@@ -1542,13 +1541,15 @@ def _resolve_and_lower(
15421541
_pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore
15431542

15441543

1545-
@dataclasses.dataclass(frozen=True)
1546-
class LowerType:
1544+
@dataclass(frozen=True)
1545+
class MetaTy:
1546+
aval: Any
15471547
sharding: Any
15481548
format: Any
15491549
committed: bool
1550-
aval: Any
1551-
is_nd_array: bool
1550+
is_np_array: bool
1551+
1552+
replace = replace # type: ignore
15521553

15531554
@property
15541555
def shape(self):
@@ -1559,15 +1560,13 @@ def ndim(self):
15591560
return self.aval.ndim
15601561

15611562

1562-
def convert_to_lower_type(is_high, arg):
1563-
if is_high:
1564-
return arg
1563+
def convert_to_metaty(arg):
15651564
arg_sharding = getattr(arg, 'sharding', None)
15661565
arg_format = getattr(arg, 'format', None)
15671566
arg_committed = getattr(arg, '_committed', True)
15681567
aval = core.shaped_abstractify(arg)
1569-
is_nd_array = isinstance(arg, np.ndarray)
1570-
return LowerType(arg_sharding, arg_format, arg_committed, aval, is_nd_array)
1568+
is_np_array = isinstance(arg, np.ndarray)
1569+
return MetaTy(aval, arg_sharding, arg_format, arg_committed, is_np_array)
15711570

15721571

15731572
def _pjit_call_impl_python(
@@ -1597,7 +1596,7 @@ def _pjit_call_impl_python(
15971596
compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
15981597
# Passing mutable PGLE profile here since it should be extracted by JAXPR to
15991598
# initialize the fdo_profile compile option.
1600-
arg_types = map(partial(convert_to_lower_type, jaxpr.is_high), args)
1599+
arg_types = map(convert_to_metaty, args)
16011600
computation = _resolve_and_lower(
16021601
arg_types, jaxpr=jaxpr, in_shardings=in_shardings,
16031602
out_shardings=out_shardings, in_layouts=in_layouts,

jax/_src/stages.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import enum
3535
from collections.abc import Sequence
3636
from dataclasses import dataclass
37-
from functools import partial
3837
import itertools as it
3938
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
4039

@@ -404,10 +403,10 @@ class Traced(Stage):
404403
traced representation with the remaining information needed to later
405404
lower, compile, and execute it.
406405
"""
407-
__slots__ = ['_args_flat', '_params', '_in_tree', 'out_tree', '_consts']
406+
__slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree', '_consts']
408407

409-
def __init__(self, args_flat, params, in_tree, out_tree, consts):
410-
self._args_flat = args_flat
408+
def __init__(self, meta_tys_flat, params, in_tree, out_tree, consts):
409+
self._meta_tys_flat = meta_tys_flat
411410
self._params = params
412411
self._in_tree = in_tree
413412
self.out_tree = out_tree
@@ -425,13 +424,12 @@ def out_avals(self):
425424

426425
def fall(self):
427426
if not self.jaxpr.is_high:
428-
return Fallen(self._args_flat, self._params, self._in_tree, self.out_tree,
429-
(self._in_tree, self.jaxpr.in_avals),
427+
return Fallen(self._meta_tys_flat, self._params, self._in_tree,
428+
self.out_tree, (self._in_tree, self.jaxpr.in_avals),
430429
(self.out_tree, self.jaxpr.out_avals),
431430
self._consts)
432431

433432
# TODO(mattjj): when pmap is deleted, merge with pjit.py BUILD rule
434-
from jax._src.pjit import convert_to_lower_type # type: ignore
435433
from jax._src.interpreters import partial_eval as pe # type:ignore
436434
hi_jaxpr = self.jaxpr
437435
_, closed_over_himutables = pe.convert_const_himutables(hi_jaxpr)
@@ -440,11 +438,11 @@ def fall(self):
440438
in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree)
441439
out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree)
442440
params = dict(lojax_expand_params(hi_jaxpr, self._params), jaxpr=lo_jaxpr)
443-
lo_args = [lo_val for aval, x in zip(hi_jaxpr.in_aval_qdds, self._args_flat)
444-
for lo_val in (aval.read_loval(x) if aval.has_qdd
445-
else aval.lower_val(x))]
446-
lo_arg_types = map(partial(convert_to_lower_type, False), lo_args)
447-
return Fallen(lo_arg_types, params, in_tree, out_tree,
441+
lo_meta_tys = [mty.replace(aval=lo_ty)
442+
for mty, aq in zip(self._meta_tys_flat, hi_jaxpr.in_aval_qdds)
443+
for lo_ty in (mty.aval.lo_ty_qdd(aq.qdd)
444+
if mty.aval.has_qdd else mty.aval.lo_ty())]
445+
return Fallen(lo_meta_tys, params, in_tree, out_tree,
448446
(self._in_tree, hi_jaxpr.final_aval_qdds),
449447
(self.out_tree, hi_jaxpr.out_avals),
450448
self._consts)
@@ -472,12 +470,12 @@ def lojax_pytree(hi_avals, tree):
472470

473471
class Fallen(Stage):
474472
"""True leader of the Decepticons."""
475-
__slots__ = ['_args_flat', '_params', '_in_tree', 'out_tree',
473+
__slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree',
476474
'_consts', '_in_types', '_out_types']
477475

478-
def __init__(self, args_flat, params, in_tree, out_tree, in_types, out_types,
476+
def __init__(self, meta_tys_flat, params, in_tree, out_tree, in_types, out_types,
479477
consts):
480-
self._args_flat = args_flat
478+
self._meta_tys_flat = meta_tys_flat
481479
self._params = params
482480
self._in_tree = in_tree
483481
self.out_tree = out_tree
@@ -503,12 +501,12 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
503501
try:
504502
from jax._src.pjit import _resolve_and_lower # type: ignore
505503
lowering = _resolve_and_lower(
506-
self._args_flat, **self._params, lowering_platforms=lowering_platforms,
504+
self._meta_tys_flat, **self._params, lowering_platforms=lowering_platforms,
507505
lowering_parameters=_private_parameters, pgle_profiler=None)
508506
except DeviceAssignmentMismatchError as e:
509507
fails, = e.args
510508
msg = _device_assignment_mismatch_error(
511-
self._params['name'], fails, self._args_flat, 'jit',
509+
self._params['name'], fails, self._meta_tys_flat, 'jit',
512510
self.jaxpr.debug_info.safe_arg_names(len(self.jaxpr.in_avals)))
513511
raise ValueError(msg) from None
514512
return Lowered(lowering, self.args_info, self.out_tree,

0 commit comments

Comments
 (0)