`vectorized_map` causes `tf.function` retracing.
See original GitHub issueProblem description
It seems like applying some layers that use BaseImageAugmentationLayer
and self.auto_vectorize=True
, over batched input are causing tf.function
retracing:
layer = Solarization() # or Equaliztion()
rng = tf.random.Generator.from_seed(1234)
for _ in range(50):
dummy_input = rng.uniform(
shape=(1, 224, 224, 3), minval=0, maxval=255
)
layer(dummy_input)
raises
WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7f80a2a544c0> triggered tf.function retracing. (...)
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7f80a2a544c0> triggered tf.function retracing. (...)
Benchmarks
Running simple benchmarks confirms performance degradation with tf.function
and batched input:
use_tf_function = False
rng = tf.random.Generator.from_seed(1234)
layer = Solarization()
results = []
if use_tf_function:
layer.augment_image = tf.function(layer.augment_image, jit_compile=True)
# Warmup
for _ in range(10):
layer(rng.uniform(shape=(24, 224, 224, 3), maxval=256))
# Benchmark
for _ in range(100):
dummy_input = rng.uniform(shape=(24, 224, 224, 3), maxval=256)
start = time.perf_counter()
layer(dummy_input)
stop = time.perf_counter()
results.append(stop-start)
print(tf.reduce_mean(results))
Case 1: auto_vectorize=True
Without tf.function
0.067 ms.
With: 0.079 ms.
Case 2: auto_vectorize=False
The issue doesn’t pop up with non-batched input e. g. (224, 224, 3)
or if one changes self.auto_vectorize=False
in the layer.
Setting self.auto_vectorize=False
will yield:
Withouth tf.function
: 0.017 ms
With: 0.013 ms.
Case 3: override _batch_augment
(if possible)
In case of vectorized operations, the fastest option is still overriding _batch_augment
to return self._augment(inputs)
. This will yield:
Without tf.function
: 0.0059 ms
With: 0.0016 ms
Issue Analytics
- State:
- Created a year ago
- Comments:20 (14 by maintainers)
Top Results From Across the Web
executing vectorized_map on batches triggers retracing #43710
Passing a new lambda to tf.function each time will cause that to retrace, so you may want to hang onto a single lambda...
Read more >custom keras metric caused tf.function retracing warning ...
This warning occurs when a TF function is retraced because its arguments change in shape or dtype (for Tensors) or even in value...
Read more >Release 2.12.0 - Google Git
Custom classes used as arguments for tf.function can now specify rules regarding when retracing needs to occur by implementing the Tracing Protocol ...
Read more >Better performance with tf.function | TensorFlow Core
Retracing, which is when your Function creates more than one trace, helps ensures that TensorFlow generates correct graphs for each set of inputs....
Read more >How to debug "triggered tf.function retracing" warnings?
It then says that there are three common causes (creating functions in a loop, passing different shaped tensors, and passing Python objects ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I don’t want to close it yet because I feel we can need to figure out how to effectively communicate this recommendation to users 🤔
Thanks for the detailed report @sebastian-sz
FYI @qlzh727