3434import enum
3535from collections .abc import Sequence
3636from dataclasses import dataclass
37- from functools import partial
3837import itertools as it
3938from typing import Any , NamedTuple , Protocol , Union , runtime_checkable
4039
@@ -404,10 +403,10 @@ class Traced(Stage):
404403 traced representation with the remaining information needed to later
405404 lower, compile, and execute it.
406405 """
407- __slots__ = ['_args_flat ' , '_params' , '_in_tree' , 'out_tree' , '_consts' ]
406+ __slots__ = ['_meta_tys_flat ' , '_params' , '_in_tree' , 'out_tree' , '_consts' ]
408407
409- def __init__ (self , args_flat , params , in_tree , out_tree , consts ):
410- self ._args_flat = args_flat
408+ def __init__ (self , meta_tys_flat , params , in_tree , out_tree , consts ):
409+ self ._meta_tys_flat = meta_tys_flat
411410 self ._params = params
412411 self ._in_tree = in_tree
413412 self .out_tree = out_tree
@@ -425,13 +424,12 @@ def out_avals(self):
425424
426425 def fall (self ):
427426 if not self .jaxpr .is_high :
428- return Fallen (self ._args_flat , self ._params , self ._in_tree , self . out_tree ,
429- (self ._in_tree , self .jaxpr .in_avals ),
427+ return Fallen (self ._meta_tys_flat , self ._params , self ._in_tree ,
428+ self . out_tree , (self ._in_tree , self .jaxpr .in_avals ),
430429 (self .out_tree , self .jaxpr .out_avals ),
431430 self ._consts )
432431
433432 # TODO(mattjj): when pmap is deleted, merge with pjit.py BUILD rule
434- from jax ._src .pjit import convert_to_lower_type # type: ignore
435433 from jax ._src .interpreters import partial_eval as pe # type:ignore
436434 hi_jaxpr = self .jaxpr
437435 _ , closed_over_himutables = pe .convert_const_himutables (hi_jaxpr )
@@ -440,11 +438,11 @@ def fall(self):
440438 in_tree = lojax_pytree (hi_jaxpr .in_aval_qdds , self ._in_tree )
441439 out_tree = lojax_pytree (hi_jaxpr .out_avals , self .out_tree )
442440 params = dict (lojax_expand_params (hi_jaxpr , self ._params ), jaxpr = lo_jaxpr )
443- lo_args = [lo_val for aval , x in zip ( hi_jaxpr . in_aval_qdds , self . _args_flat )
444- for lo_val in ( aval . read_loval ( x ) if aval . has_qdd
445- else aval .lower_val ( x ))]
446- lo_arg_types = map ( partial ( convert_to_lower_type , False ), lo_args )
447- return Fallen (lo_arg_types , params , in_tree , out_tree ,
441+ lo_meta_tys = [mty . replace ( aval = lo_ty )
442+ for mty , aq in zip ( self . _meta_tys_flat , hi_jaxpr . in_aval_qdds )
443+ for lo_ty in ( mty . aval .lo_ty_qdd ( aq . qdd )
444+ if mty . aval . has_qdd else mty . aval . lo_ty ())]
445+ return Fallen (lo_meta_tys , params , in_tree , out_tree ,
448446 (self ._in_tree , hi_jaxpr .final_aval_qdds ),
449447 (self .out_tree , hi_jaxpr .out_avals ),
450448 self ._consts )
@@ -472,12 +470,12 @@ def lojax_pytree(hi_avals, tree):
472470
473471class Fallen (Stage ):
474472 """True leader of the Decepticons."""
475- __slots__ = ['_args_flat ' , '_params' , '_in_tree' , 'out_tree' ,
473+ __slots__ = ['_meta_tys_flat ' , '_params' , '_in_tree' , 'out_tree' ,
476474 '_consts' , '_in_types' , '_out_types' ]
477475
478- def __init__ (self , args_flat , params , in_tree , out_tree , in_types , out_types ,
476+ def __init__ (self , meta_tys_flat , params , in_tree , out_tree , in_types , out_types ,
479477 consts ):
480- self ._args_flat = args_flat
478+ self ._meta_tys_flat = meta_tys_flat
481479 self ._params = params
482480 self ._in_tree = in_tree
483481 self .out_tree = out_tree
@@ -503,12 +501,12 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
503501 try :
504502 from jax ._src .pjit import _resolve_and_lower # type: ignore
505503 lowering = _resolve_and_lower (
506- self ._args_flat , ** self ._params , lowering_platforms = lowering_platforms ,
504+ self ._meta_tys_flat , ** self ._params , lowering_platforms = lowering_platforms ,
507505 lowering_parameters = _private_parameters , pgle_profiler = None )
508506 except DeviceAssignmentMismatchError as e :
509507 fails , = e .args
510508 msg = _device_assignment_mismatch_error (
511- self ._params ['name' ], fails , self ._args_flat , 'jit' ,
509+ self ._params ['name' ], fails , self ._meta_tys_flat , 'jit' ,
512510 self .jaxpr .debug_info .safe_arg_names (len (self .jaxpr .in_avals )))
513511 raise ValueError (msg ) from None
514512 return Lowered (lowering , self .args_info , self .out_tree ,
0 commit comments