File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -306,10 +306,12 @@ async def _async_serialize(
306306 and arr_inp .is_fully_addressable
307307 )
308308 # pylint: disable=protected-access
309- spec_has_metadata = {
310- "0.6.2" : lambda : serialization .ts_impl ._spec_has_metadata ,
311- "0.5.3" : lambda : serialization ._spec_has_metadata ,
312- }[jax .__version__ ]()
309+ if jax .__version__ .startswith ("0.8.0" ) or jax .__version__ == "0.6.2" :
310+ spec_has_metadata = serialization .ts_impl ._spec_has_metadata
311+ elif jax .__version__ == "0.5.3" :
312+ spec_has_metadata = serialization ._spec_has_metadata
313+ else :
314+ raise ValueError (f"Unsupported JAX version for spec_has_metadata: { jax .__version__ } " )
313315 if not spec_has_metadata (tensorstore_spec ):
314316 # pylint: disable-next=protected-access
315317 tensorstore_spec ["metadata" ] = serialization ._get_metadata (arr_inp )
You can’t perform that action at this time.
0 commit comments