Skip to content

Commit

Permalink
Fix bug when upranking passthrough inputs to RandAugment
Browse files Browse the repository at this point in the history
- RandAugment sometimes will choose a "no augmentation" option and
  passthrough inputs unaltered.
- Preprocessing normalization routines were not making copies of inputs
  and sometimes mutating layer input directly (mutating the input
  dict to cast dtypes and uprank tensors).
- RandAugment under the passthrough option would return these inputs
  directly.

The net effect was sometimes attempting to uprank during a passthrough
call, breaking tf.map_fn
  • Loading branch information
mattdangerw committed Nov 28, 2023
1 parent 32bef65 commit c87d895
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,8 @@ def _ensure_inputs_are_compute_dtype(self, inputs):
inputs,
self.compute_dtype,
)
# Copy the input dict before we mutate it.
inputs = dict(inputs)
inputs[IMAGES] = preprocessing.ensure_tensor(
inputs[IMAGES],
self.compute_dtype,
Expand Down
10 changes: 9 additions & 1 deletion keras_cv/layers/preprocessing/rand_augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@
from keras_cv.tests.test_case import TestCase


@pytest.mark.skipif(keras_3(), reason="imcompatible with Keras 3")
class RandAugmentTest(TestCase):
def test_zero_rate_pass_through(self):
rand_augment = layers.RandAugment(
value_range=(0, 255),
rate=0.0,
)
xs = np.ones((2, 512, 512, 3))
ys = rand_augment(xs)
self.assertAllClose(ys, xs)

@parameterized.named_parameters(
("0", 0),
("20", 0.2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _augment(self, inputs):
)
result = tf.cond(
skip_augment > self.rate,
lambda: inputs,
lambda: result,
lambda: self._random_choice(result),
)
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ def _format_inputs(self, inputs):
# single image input tensor
metadata[IS_DICT] = False
inputs = {IMAGES: inputs}
else:
# Copy the input dict before we mutate it.
inputs = dict(inputs)

metadata[BATCHED] = inputs["images"].shape.rank == 4
if inputs["images"].shape.rank == 3:
Expand Down Expand Up @@ -504,6 +507,8 @@ def _ensure_inputs_are_compute_dtype(self, inputs):
inputs,
self.compute_dtype,
)
# Copy the input dict before we mutate it.
inputs = dict(inputs)
inputs[IMAGES] = preprocessing.ensure_tensor(
inputs[IMAGES],
self.compute_dtype,
Expand Down

0 comments on commit c87d895

Please sign in to comment.