-
Couldn't load subscription status.
- Fork 19.6k
Description
Inside a tf.data pipeline, is it possible to use keras.ops in a backend-agnostic way? I raised a related issue here: #20722, which focused on augmentation.
Currently, if I have a function that uses keras.ops, it works with the TensorFlow backend, but using other backends like PyTorch or JAX causes issues. One solution is to reimplement the function using pure TensorFlow operations, but I’m wondering if there’s a way to make it work with keras.ops across different backends.
NotImplementedError: in user code:
File "/tmp/ipykernel_37/4081972615.py", line 56, in train_transformation *
result = pipeline(data, meta)
File "/usr/local/lib/python3.11/dist-packages/medicai/transforms/base.py", line 53, in __call__ *
x = transform(x)
File "/tmp/ipykernel_37/3272635510.py", line 87, in __call__ *
resample_image = self.spacingd_resample(
File "/tmp/ipykernel_37/3272635510.py", line 133, in spacingd_resample *
resized_dhw = resize_volumes(image[None, ...], *spatial_shape, method="trilinear", align_corners=False)
File "/usr/local/lib/python3.11/dist-packages/medicai/utils/image.py", line 7, in trilinear_resize *
volumes = ops.cast(volumes, "float32")
File "/usr/local/lib/python3.11/dist-packages/keras/src/ops/core.py", line 803, in cast **
return backend.core.cast(x, dtype)
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/core.py", line 283, in cast
return convert_to_tensor(x, dtype)
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/core.py", line 215, in convert_to_tensor
x = np.array(x)
NotImplementedError: Cannot convert a symbolic tf.Tensor (strided_slice_34:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.case
- > def some_method : uses keras.ops
- > class KerasLayers or KerasModel : use some_method
- > tf.data API : uses some_methodIf we run the code with torch/jax backend, some_method will work on the respected backend, but will cause above error in the tf.data API. To make the code work, reimplmenting the some_method in tf will be needed.
Or, is there a way to isolate some_method from keras layers or models with tf.data API regardless of the backend. In otherword, if keras.ops is detected its in the tf.data API, it will auto set the dynamic (local) backend in tf, whenever possible.