question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`vectorized_map` causes `tf.function` retracing.

See original GitHub issue

Problem description

It seems like applying some layers that use BaseImageAugmentationLayer and self.auto_vectorize=True, over batched input are causing tf.function retracing:

layer = Solarization()  # or Equaliztion()
rng = tf.random.Generator.from_seed(1234)

for _ in range(50):
    dummy_input = rng.uniform(
        shape=(1, 224, 224, 3), minval=0, maxval=255
    )
    layer(dummy_input)

raises

WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7f80a2a544c0> triggered tf.function retracing. (...)
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7f80a2a544c0> triggered tf.function retracing. (...)

Benchmarks

Running simple benchmarks confirms performance degradation with tf.function and batched input:

use_tf_function = False

rng = tf.random.Generator.from_seed(1234)
layer = Solarization()
results = []

if use_tf_function:
    layer.augment_image = tf.function(layer.augment_image, jit_compile=True)

# Warmup
for _ in range(10):
    layer(rng.uniform(shape=(24, 224, 224, 3), maxval=256))

# Benchmark
for _ in range(100):
    dummy_input = rng.uniform(shape=(24, 224, 224, 3), maxval=256)
    start = time.perf_counter()
    layer(dummy_input)
    stop = time.perf_counter()
    results.append(stop-start)

print(tf.reduce_mean(results))

Case 1: auto_vectorize=True

Without tf.function 0.067 ms. With: 0.079 ms.

Case 2: auto_vectorize=False

The issue doesn’t pop up with non-batched input e. g. (224, 224, 3) or if one changes self.auto_vectorize=False in the layer.

Setting self.auto_vectorize=False will yield: Withouth tf.function: 0.017 ms With: 0.013 ms.

Case 3: override _batch_augment (if possible)

In case of vectorized operations, the fastest option is still overriding _batch_augment to return self._augment(inputs). This will yield: Without tf.function: 0.0059 ms With: 0.0016 ms

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:20 (14 by maintainers)

github_iconTop GitHub Comments

1reaction
LukeWoodcommented, Apr 3, 2022

It seems like wrapping the entire layer in tf.function (even better if with jit_compile=True) also silences the warnings and provides decent performance in eager mode:

@tf.function(jit_compile=True)
def apply(x):
    return layer(x)

0.0015ms for self.auto_vectorize=True 0.0022ms for self.auto_vectorize=False 0.0015ms for native vectorization.

This issue can be closed from my end. If no further comments appear I will close this issue starting next week. Thanks for the help!

I don’t want to close it yet because I feel we can need to figure out how to effectively communicate this recommendation to users 🤔

1reaction
LukeWoodcommented, Mar 31, 2022

Thanks for the detailed report @sebastian-sz

FYI @qlzh727

Read more comments on GitHub >

github_iconTop Results From Across the Web

executing vectorized_map on batches triggers retracing #43710
Passing a new lambda to tf.function each time will cause that to retrace, so you may want to hang onto a single lambda...
Read more >
custom keras metric caused tf.function retracing warning ...
This warning occurs when a TF function is retraced because its arguments change in shape or dtype (for Tensors) or even in value...
Read more >
Release 2.12.0 - Google Git
Custom classes used as arguments for tf.function can now specify rules regarding when retracing needs to occur by implementing the Tracing Protocol ...
Read more >
Better performance with tf.function | TensorFlow Core
Retracing, which is when your Function creates more than one trace, helps ensures that TensorFlow generates correct graphs for each set of inputs....
Read more >
How to debug "triggered tf.function retracing" warnings?
It then says that there are three common causes (creating functions in a loop, passing different shaped tensors, and passing Python objects ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found