diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index ca8289c9f9b..f5bb63a5421 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -175,12 +175,24 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. """ data_format = backend.standardize_data_format(data_format) + + # Infer format from path if not explicitly provided + if file_format is None and isinstance(path, (str, pathlib.Path)): + file_format = pathlib.Path(path).suffix[1:].lower() + + # Normalize jpg → jpeg for Pillow compatibility + if file_format and file_format.lower() == "jpg": + file_format = "jpeg" + img = array_to_img(x, data_format=data_format, scale=scale) - if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): + + # Handle RGBA → RGB conversion for JPEG + if img.mode == "RGBA" and file_format == "jpeg": warnings.warn( - "The JPG format does not support RGBA images, converting to RGB." + "The JPEG format does not support RGBA images, converting to RGB." ) img = img.convert("RGB") + img.save(path, format=file_format, **kwargs) diff --git a/keras/src/utils/image_utils_test.py b/keras/src/utils/image_utils_test.py new file mode 100644 index 00000000000..31fb30cf83c --- /dev/null +++ b/keras/src/utils/image_utils_test.py @@ -0,0 +1,36 @@ +import os + +import numpy as np +from absl.testing import parameterized + +from keras.src import testing +from keras.src.utils import img_to_array +from keras.src.utils import load_img +from keras.src.utils import save_img + + +class SaveImgTest(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + ("rgb_explicit_format", (50, 50, 3), "rgb.jpg", "jpg", True), + ("rgba_explicit_format", (50, 50, 4), "rgba.jpg", "jpg", True), + ("rgb_inferred_format", (50, 50, 3), "rgb_inferred.jpg", None, False), + ("rgba_inferred_format", (50, 50, 4), "rgba_inferred.jpg", None, False), + ) + def test_save_jpg(self, shape, name, file_format, use_explicit_format): + tmp_dir = self.get_temp_dir() + path = os.path.join(tmp_dir, name) + + img = np.random.randint(0, 256, size=shape, dtype=np.uint8) + + # Test the actual inferred case - don't pass file_format at all + if use_explicit_format: + save_img(path, img, file_format=file_format) + else: + save_img(path, img) # Let it infer from path + + self.assertTrue(os.path.exists(path)) + + # Verify saved image is correctly converted to RGB if needed + loaded_img = load_img(path) + loaded_array = img_to_array(loaded_img) + self.assertEqual(loaded_array.shape, (50, 50, 3))