Content Overview
In TensorFlow 2, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier and faster), but this can come at the expense of performance and deployability.
You can use tf.function
to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use SavedModel
.
This guide will help you conceptualize how tf.function
works under the hood, so you can use it effectively.
The main takeaways and recommendations are:
-
Debug in eager mode, then decorate with
@tf.function
. -
Don't rely on Python side effects like object mutation or list appends.
-
tf.function
works best with TensorFlow ops; NumPy and Python calls are converted to constants.
Setup
import tensorflow as tf
2024-08-15 02:57:28.958444: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-15 02:57:28.979712: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-15 02:57:28.986177: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Define a helper function to demonstrate the kinds of errors you might encounter:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
Basics
Usage
A tf.function
that you define (for example by applying the @tf.function
decorator) is just like a core TensorFlow operation: You can execute it eagerly; you can compute gradients; and so on.
@tf.function # The decorator converts `add` into a `PolymorphicFunction`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1723690651.607368 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.611235 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.614398 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.618234 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.629890 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.633433 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.636337 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.639748 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.643233 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.646588 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.649526 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690651.652949 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.865955 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.868101 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.870112 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.872121 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.874165 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.876153 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.878068 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.879960 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.881883 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.883841 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.885768 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.887660 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.926250 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.928321 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.930298 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.932288 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.934241 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.936253 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.938172 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.940080 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.942041 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.944593 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.946947 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1723690652.949245 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
[2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
You can use tf.function
s inside other tf.function
s.
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
[3., 3.],
[3., 3.]], dtype=float32)>
tf.function
s can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
W0000 00:00:1723690654.228267 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.285525 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.290477 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.295072 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.299820 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.304580 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.322737 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.327483 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.332646 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.337747 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.343046 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.347480 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.361780 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.370325 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.381185 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1723690654.405763 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
Eager conv: 0.011224052000216034
Function conv: 0.005400947000453016
Note how there's not much difference in performance for convolutions
Tracing
This section exposes how tf.function works under the hood, including implementation details which may change in the future. However, once you understand why and when tracing happens, it's much easier to use tf.function effectively!
What is "tracing"?
A tf.function runs your program in a TensorFlow Graph. However, a tf.Graph cannot represent all the things that you'd write in an eager TensorFlow program. For instance, Python supports polymorphism, but tf.Graph requires its inputs to have a specified data type and dimension. Or you may perform side tasks like reading command-line arguments, raising an error, or working with a more complex Python object; none of these things can run in a tf.Graph.
tf.function bridges this gap by separating your code in two stages:
- In the first stage, referred to as "tracing", tf.function creates a new tf.Graph. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are deferred: they are captured by the tf.Graph and not run.
- In the second stage, a tf.Graph which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage.
Depending on its inputs, tf.function will not always run the first stage when it is called. See "Rules of tracing" below to get a better sense of how it makes that determination. Skipping the first stage and only executing the second stage is what gives you TensorFlow's high performance.
When tf.function does decide to trace, the tracing stage is immediately followed by the second stage, so calling the tf.function both creates and runs the tf.Graph. Later you will see how you can run only the tracing stage with get_concrete_function.
When you pass arguments of different types into a tf.function, both stages are run:
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)
Note that if you repeatedly call a tf.function
with the same argument type, TensorFlow will skip the tracing stage and reuse a previously traced graph, as the generated graph would be identical.
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
You can use pretty_printed_concrete_signatures()
to see all of the available traces:
print(double.pretty_printed_concrete_signatures())
Input Parameters:
a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)
Output Type:
TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
None
Input Parameters:
a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
None
So far, you've seen that tf.function
creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic. To be more specific about the terminology:
- A
tf.Graph
is the raw, language-agnostic, portable representation of a TensorFlow computation. - Tracing is the process through which new
tf.Graph
s are generated from Python code. - An instance of
tf.Graph
is specialized to the specific input types it was traced with. Differing types require retracing. - Each traced
tf.Graph
has a correspondingConcreteFunction
. - A
tf.function
manages a cache ofConcreteFunction
s and picks the right one for your inputs. tf.function
wraps the Python function that will be traced, returning atf.types.experimental.PolymorphicFunction
object.
Rules of tracing
When called, a tf.function
first evaluates the type of each input argument using the tf.types.experimental.TraceType
of each argument. This is used to construct a tf.types.experimental.FunctionType
describing the signature of the desired ConcreteFunction
. We compare this FunctionType
to the FunctionType
s of existing ConcreteFunction
s. If a matching ConcreteFunction
is found, the call is dispatched to it. If no match is found, a new ConcreteFunction
is traced for the desired FunctionType
.
If multiple matches are found, the most specific signature is chosen. Matching is done by subtyping, much like normal function calls in C++ or Java, for instance. For example, TensorShape([1, 2])
is a subtype of TensorShape([None, None])
and so a call to the tf.function with TensorShape([1, 2])
can be dispatched to the ConcreteFunction
produced with TensorShape([None, None])
but if a ConcreteFunction
with TensorShape([1, None])
also exists then it will be prioritized since it is more specific.
The TraceType
is determined from input arguments as follows:
-
For
Tensor
, the type is parameterized by theTensor
'sdtype
andshape
; ranked shapes are a subtype of unranked shapes; fixed dimensions are a subtype of unknown dimensions -
For
Variable
, the type is similar toTensor
, but also includes a unique resource ID of the variable, necessary to correctly wire control dependencies -
For Python primitive values, the type corresponds to the value itself. For example, the
TraceType
of the value3
isLiteralTraceType<3>
, notint
. -
For Python ordered containers such as
list
andtuple
, etc., the type is parameterized by the types of their elements; for example, the type of[1, 2]
isListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>
and the type for[2, 1]
isListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>
which is different. -
For Python mappings such as
dict
, the type is also a mapping from the same keys but to the types of values instead of the actual values. For example, the type of{1: 2, 3: 4}
, isMappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>
. However, unlike ordered containers,{1: 2, 3: 4}
and{3: 4, 1: 2}
have equivalent types. -
For Python objects which implement the
__tf_tracing_type__
method, the type is whatever that method returns. -
For any other Python objects, the type is a generic
TraceType
, and the matching precedure is:- First it checks if the object is the same object used in the previous trace (using Python
id()
oris
). Note that this will still match if the object has changed, so if you use Python objects astf.function
arguments it's best to use immutable ones. - Next it checks if the object is equal to the object used in the previous trace (using Python
==
).
Note that this procedure only keeps a weakref to the object and hence only works as long as the object is in scope/not deleted
- First it checks if the object is the same object used in the previous trace (using Python
Note: TraceType
is based on the tf.function
input parameters so changes to global and
Controlling retracing
Retracing, which is when your tf.function
creates more than one trace, helps ensure that TensorFlow generates correct graphs for each set of inputs. However, tracing is an expensive operation! If your tf.function
retraces a new graph for every call, you'll find that your code executes more slowly than if you didn't use tf.function
.
To control the tracing behavior, you can use the following techniques:
Pass a fixed input_signature
to tf.function
This forces tf.function
to constrain itself to only one tf.types.experimental.FunctionType
composed of the types enumerated by the input_signature
. Calls that cannot be dispatched to this FunctionType
will throw an error.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(TypeError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(TypeError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception
<class 'TypeError'>:
Caught expected exception
<class 'TypeError'>:
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/3657259638.py", line 9, in <module>
next_collatz(tf.constant([[1, 2], [3, 4]]))
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4]], dtype=int32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/3657259638.py", line 13, in <module>
next_collatz(tf.constant([1.0, 2.0]))
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Use unknown dimensions for flexibility
Since TensorFlow matches tensors based on their shape, using a None
dimension as a wildcard will allow tf.function
s to reuse traces for variably-sized input. Variably-sized input can occur if you have sequences of different length, or images of different sizes for each batch.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
Use reduce_retracing
for automatic flexibility
When reduce_retracing
is enabled, tf.function
automatically identifies supertypes of the input types it is observing and chooses to trace more generalized graphs automatically. It is less efficient than setting the input_signature
directly but useful when many types need to be supported.
@tf.function(reduce_retracing=True)
def g(x):
print('Tracing with', x)
return x
# Traces once.
print(g(tf.constant([1, 2, 3])))
# Traces again, but more generalized this time.
print(g(tf.constant([1, 2, 3, 4, 5])))
# No more tracing!
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7])))
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])))
Tracing with Tensor("x:0", shape=(3,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)
tf.Tensor([1 2 3 4 5 6 7 8 9], shape=(9,), dtype=int32)
Pass tensors instead of python literals
Often, Python arguments are used to control hyperparameters and graph constructions - for example, num_layers=10
or training=True
or nonlinearity='relu'
. So, if the Python argument changes, it makes sense that you'd have to retrace the graph.
However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so retracing is unnecessary.
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps = 10
Executing with num_steps = 10
Tracing with num_steps = 20
Executing with num_steps = 20
Traces are reused for Tensor arguments.
Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps = 10
Executing with num_steps = 20
If you need to force retracing, create a new tf.function
. Separate tf.function
objects are guaranteed not to share traces.
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing
Use the tracing protocol
Where possible, you should prefer converting the Python type into a tf.experimental.ExtensionType
instead. Moreover, the TraceType
of an ExtensionType
is the tf.TypeSpec
associated with it. Therefore, if needed, you can simply override the default tf.TypeSpec
to take control of an ExtensionType
's Tracing Protocol
.
Otherwise, for direct control over when tf.function
should retrace in regards to a particular Python type, you can implement the Tracing Protocol
for it yourself.
@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
return fruit_a.flavor + fruit_b.flavor
class Fruit:
flavor = tf.constant([0, 0])
class Apple(Fruit):
flavor = tf.constant([1, 2])
class Mango(Fruit):
flavor = tf.constant([3, 4])
# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again
# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.
class FruitTraceType(tf.types.experimental.TraceType):
def __init__(self, fruit):
self.fruit_type = type(fruit)
self.fruit_value = fruit
def is_subtype_of(self, other):
# True if self subtypes `other` and `other`'s type matches FruitTraceType.
return (type(other) is FruitTraceType and
self.fruit_type is other.fruit_type)
def most_specific_common_supertype(self, others):
# `self` is the specific common supertype if all input types match it.
return self if all(self == other for other in others) else None
def placeholder_value(self, placeholder_context=None):
# Use the fruit itself instead of the type for correct tracing.
return self.fruit_value
def __eq__(self, other):
return type(other) is FruitTraceType and self.fruit_type == other.fruit_type
def __hash__(self):
return hash(self.fruit_type)
class FruitWithTraceType:
def __tf_tracing_type__(self, context):
return FruitTraceType(self)
class AppleWithTraceType(FruitWithTraceType):
flavor = tf.constant([1, 2])
class MangoWithTraceType(FruitWithTraceType):
flavor = tf.constant([3, 4])
# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>
Obtaining concrete functions
Every time a function is traced, a new concrete function is created. You can directly obtain a concrete function, by using get_concrete_function
.
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
Printing a ConcreteFunction
displays a summary of its input arguments (with types) and its output type.
print(double_strings)
ConcreteFunction Input Parameters:
a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)
Output Type:
TensorSpec(shape=(), dtype=tf.string, name=None)
Captures:
None
You can also directly retrieve a concrete function's signature.
print(double_strings.function_type)
(a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None)
Using a concrete trace with incompatible types will throw an error
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs
bound_arguments = function_type.bind_with_defaults(
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults
with_default_args[arg_name] = constraint.cast(
TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1179, in _call_impl
return self._call_with_structured_signature(args, kwargs)
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1259, in _call_with_structured_signature
function_type_utils.canonicalize_function_inputs(
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (<tf.Tensor: shape=(), dtype=int32, numpy=1>,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/3196284684.py", line 2, in <module>
double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_189 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_189]
You may notice that Python arguments are given special treatment in a concrete function's input signature. Prior to TensorFlow 2.3, Python arguments were simply removed from the concrete function's signature. Starting with TensorFlow 2.3, Python arguments remain in the signature, but are constrained to take the value set during tracing.
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction Input Parameters:
a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)
b (POSITIONAL_OR_KEYWORD): Literal[2]
Output Type:
TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)
Captures:
None
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception
<class 'TypeError'>:
Traceback (most recent call last):
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs
bound_arguments = function_type.bind_with_defaults(
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults
with_default_args[arg_name] = constraint.cast(
ValueError: Can not cast 3 to Literal[2]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1179, in _call_impl
return self._call_with_structured_signature(args, kwargs)
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1259, in _call_with_structured_signature
function_type_utils.canonicalize_function_inputs(
TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1182, in _call_impl
return self._call_with_flat_signature(args, kwargs)
File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1233, in _call_with_flat_signature
raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/2310937119.py", line 4, in <module>
square(tf.constant(10.0), b=3)
TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None).
Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.
Obtaining graphs
Although retrieving the actual tf.Graph
object is not something you'll normally need to do, you can obtain it easily from any concrete function.
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity
In reality, tf.Graph
s are not directly callable. We actually use an tf.types.experimental.AtomicFunction
to perform the computations described by the tf.Graph
. You can access the AtomicFunction
describing the traced tf.Graph
and call it directly instead of the ConcreteFunction
:
atomic_fn = double_strings.inference_fn
atomic_fn(tf.constant("a"))
<tf.Tensor: shape=(), dtype=string, numpy=b'aa'>
This has the advantage of having lower Python overhead for high-performance scenarios. But it should only be used for forward inference (no gradient support), and captured tensor values (if any) would need to be explicitly supplied.
Debugging
In general, debugging code is easier in eager mode than inside tf.function
. You should ensure that your code executes error-free in eager mode before decorating with tf.function
. To assist in the debugging process, you can call tf.config.run_functions_eagerly(True)
to globally disable and reenable tf.function
.
When tracking down issues that only appear within tf.function
, here are some tips:
- Plain old Python
print
calls only execute during tracing, helping you track down when your function gets (re)traced. tf.print
calls will execute every time, and can help you track down intermediate values during execution.tf.debugging.enable_check_numerics
is an easy way to track down where NaNs and Inf are created.pdb
(the Python debugger) can help you understand what's going on during tracing. (Caveat:pdb
will drop you into AutoGraph-transformed source code.)
Originally published on the