From c87d895399703e1e36a13f030b8ece4d63bf5444 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Mon, 27 Nov 2023 18:51:06 -0800 Subject: [PATCH] Fix bug when upranking passthrough inputs to RandAugment - RandAugment sometimes will choose a "no augmentation" option and passthrough inputs unaltered. - Preprocessing normalization routines were not making copies of inputs and sometimes mutating layer input directly (mutating the input dict to cast dtypes and uprank tensors). - RandAugment under the passthrough option would return these inputs directly. The net effect was sometimes attempting to uprank during a passthrough call, breaking tf.map_fn --- .../preprocessing/base_image_augmentation_layer.py | 2 ++ keras_cv/layers/preprocessing/rand_augment_test.py | 10 +++++++++- .../preprocessing/random_augmentation_pipeline.py | 2 +- .../vectorized_base_image_augmentation_layer.py | 5 +++++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py index 3f42a804a3..0a365891b7 100644 --- a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py @@ -571,6 +571,8 @@ def _ensure_inputs_are_compute_dtype(self, inputs): inputs, self.compute_dtype, ) + # Copy the input dict before we mutate it. + inputs = dict(inputs) inputs[IMAGES] = preprocessing.ensure_tensor( inputs[IMAGES], self.compute_dtype, diff --git a/keras_cv/layers/preprocessing/rand_augment_test.py b/keras_cv/layers/preprocessing/rand_augment_test.py index 9ef0ea9fbd..bfecec86b2 100644 --- a/keras_cv/layers/preprocessing/rand_augment_test.py +++ b/keras_cv/layers/preprocessing/rand_augment_test.py @@ -21,8 +21,16 @@ from keras_cv.tests.test_case import TestCase -@pytest.mark.skipif(keras_3(), reason="imcompatible with Keras 3") class RandAugmentTest(TestCase): + def test_zero_rate_pass_through(self): + rand_augment = layers.RandAugment( + value_range=(0, 255), + rate=0.0, + ) + xs = np.ones((2, 512, 512, 3)) + ys = rand_augment(xs) + self.assertAllClose(ys, xs) + @parameterized.named_parameters( ("0", 0), ("20", 0.2), diff --git a/keras_cv/layers/preprocessing/random_augmentation_pipeline.py b/keras_cv/layers/preprocessing/random_augmentation_pipeline.py index a525680fca..e437a90147 100644 --- a/keras_cv/layers/preprocessing/random_augmentation_pipeline.py +++ b/keras_cv/layers/preprocessing/random_augmentation_pipeline.py @@ -103,7 +103,7 @@ def _augment(self, inputs): ) result = tf.cond( skip_augment > self.rate, - lambda: inputs, + lambda: result, lambda: self._random_choice(result), ) return result diff --git a/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py index c46b5f81b2..1f420437f4 100644 --- a/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py @@ -444,6 +444,9 @@ def _format_inputs(self, inputs): # single image input tensor metadata[IS_DICT] = False inputs = {IMAGES: inputs} + else: + # Copy the input dict before we mutate it. + inputs = dict(inputs) metadata[BATCHED] = inputs["images"].shape.rank == 4 if inputs["images"].shape.rank == 3: @@ -504,6 +507,8 @@ def _ensure_inputs_are_compute_dtype(self, inputs): inputs, self.compute_dtype, ) + # Copy the input dict before we mutate it. + inputs = dict(inputs) inputs[IMAGES] = preprocessing.ensure_tensor( inputs[IMAGES], self.compute_dtype,