Content Overview
- Tensor API dispatch
- Dispatch for a single API
- Dispatch for all unary elementwise APIs
- Dispatch for binary all elementwise APIs
- Batchable ExtensionTypes
- BatchableExtensionType example: Network
- TensorFlow APIs that support ExtensionTypes
- @tf.function
- Control flow operations
- Autograph control flow
- Keras
- SavedModel
- Datasets
Tensor API dispatch
Extension types can be "tensor-like", in the sense that they specialize or extend the interface defined by the tf.Tensor
type. Examples of tensor-like extension types include RaggedTensor
, SparseTensor
, and MaskedTensor
. Dispatch decorators can be used to override the default behavior of TensorFlow operations when applied to tensor-like extension types. TensorFlow currently defines three dispatch decorators:
@tf.experimental.dispatch_for_api(tf_api)
@tf.experimental.dispatch_for_unary_elementwise_apis(x_type)
@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)
Dispatch for a single API
The tf.experimental.dispatch_for_api
decorator overrides the default behavior of a specified TensorFlow operation when it is called with the specified signature. For example, you can use this decorator to specify how tf.stack
should process MaskedTensor
values:
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
This overrides the default implementation for tf.stack
whenever it is called with a list of MaskedTensor
values (since the values
argument is annotated with typing.List[MaskedTensor]
):
x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
To allow tf.stack
to handle lists of mixed MaskedTensor
and Tensor
values, you can refine the type annotation for the values
parameter and update the body of the function appropriately:
tf.experimental.unregister_dispatch_for(masked_stack)
def convert_to_masked_tensor(x):
if isinstance(x, MaskedTensor):
return x
else:
return MaskedTensor(x, tf.ones_like(x, tf.bool))
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
values = [convert_to_masked_tensor(v) for v in values]
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
For a list of APIs that can be overridden, see the API documentation for tf.experimental.dispatch_for_api
.
Dispatch for all unary elementwise APIs
The tf.experimental.dispatch_for_unary_elementwise_apis
decorator overrides the default behavior of all unary elementwise ops (such as tf.math.cos
) whenever the value for the first argument (typically named x
) matches the type annotation x_type
. The decorated function should take two arguments:
api_func
: A function that takes a single parameter and performs the elementwise operation (for example,tf.abs
).x
: The first argument to the elementwise operation.
The following example updates all unary elementwise operations to handle the MaskedTensor
type:
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def masked_tensor_unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
This function will now be used whenever a unary elementwise operation is called on a MaskedTensor
.
x = MaskedTensor([1, -2, -3], [True, False, True])
print(tf.abs(x))
print(tf.ones_like(x, dtype=tf.float32))
Dispatch for binary all elementwise APIs
Similarly, tf.experimental.dispatch_for_binary_elementwise_apis
can be used to update all binary elementwise operations to handle the MaskedTensor
type:
@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
For a list of the elementwise APIs that are overridden, go to the API documentation for tf.experimental.dispatch_for_unary_elementwise_apis
and tf.experimental.dispatch_for_binary_elementwise_apis
.
Batchable ExtensionType
s
An ExtensionType
is batchable if a single instance can be used to represent a batch of values. Typically, this is accomplished by adding batch dimensions to all nested Tensor
s. The following TensorFlow APIs require that any extension type inputs be batchable:
tf.data.Dataset
(batch
,unbatch
,from_tensor_slices
)tf.keras
(fit
,evaluate
,predict
)tf.map_fn
By default, BatchableExtensionType
creates batched values by batching any nested Tensor
s, CompositeTensor
s, and ExtensionType
s. If this is not appropriate for your class, then you will need to use tf.experimental.ExtensionTypeBatchEncoder
to override this default behavior. For example, it would not be appropriate to create a batch of tf.SparseTensor
values by simply stacking individual sparse tensors' values
, indices
, and dense_shape
fields -- in most cases, you can't stack these tensors, since they have incompatible shapes; and even if you could, the result would not be a valid SparseTensor
.
Note: BatchableExtensionType
s do not automatically define dispatchers for tf.stack
, tf.concat
, tf.slice
, etc. If your class needs to be supported by these APIs, then use the dispatch decorators described above.
BatchableExtensionType
example: Network
As an example, consider a simple Network
class used for load balancing, which tracks how much work is left to do at each node, and how much bandwidth is available to move work between nodes:
class Network(tf.experimental.ExtensionType): # This version is not batchable.
work: tf.Tensor # work[n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
To make this type batchable, change the base type to BatchableExtensionType
, and adjust the shape of each field to include optional batch dimensions. The following example also adds a shape
field to keep track of the batch shape. This shape
field is not required by tf.data.Dataset
or tf.map_fn
, but it is required by tf.keras
.
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
def network_repr(network):
work = network.work
bandwidth = network.bandwidth
if hasattr(work, 'numpy'):
work = ' '.join(str(work.numpy()).split())
if hasattr(bandwidth, 'numpy'):
bandwidth = ' '.join(str(bandwidth.numpy()).split())
return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
work=tf.stack([net1.work, net2.work]),
bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
You can then use tf.data.Dataset
to iterate through a batch of networks:
dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
print(f"Batch element {i}: {network}")
And you can also use map_fn
to apply a function to each batch element:
def balance_work_greedy(network):
delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
delta /= 4
delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
new_work = network.work + tf.reduce_sum(delta, -1)
return Network(new_work, network.bandwidth)
tf.map_fn(balance_work_greedy, batch_of_networks)
TensorFlow APIs that support ExtensionType
s
@tf.function
tf.function
is a decorator that precomputes TensorFlow graphs for Python functions, which can substantially improve the performance of your TensorFlow code. Extension type values can be used transparently with @tf.function
-decorated functions.
class Pastry(tf.experimental.ExtensionType):
sweetness: tf.Tensor # 2d embedding that encodes sweetness
chewiness: tf.Tensor # 2d embedding that encodes chewiness
@tf.function
def combine_pastry_features(x: Pastry):
return (x.sweetness + x.chewiness) / 2
cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
If you wish to explicitly specify the input_signature
for tf.function
, then you can do so using the extension type's TypeSpec
.
pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))
@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
return Pastry(x.sweetness + delta, x.chewiness)
increase_sweetness(cookie)
Concrete functions
Concrete functions encapsulate individual traced graphs that are built by tf.function
. Extension types can be used transparently with concrete functions.
cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
Control flow operations
Extension types are supported by TensorFlow's control-flow operations:
tf.cond
tf.case
tf.while_loop
tf.identity
# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
Autograph control flow
Extension types are also supported by control flow statements in tf.function
(using autograph). In the following example, the if
statement and for
statements are automatically converted to tf.cond
and tf.while_loop
operations, which support extension types.
@tf.function
def fn(x, b):
if b:
x = MaskedTensor(x, tf.less(x, 0))
else:
x = MaskedTensor(x, tf.greater(x, 0))
for i in tf.range(5 if b else 7):
x = x.with_values(x.values + 1 / 2)
return x
print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
Keras
tf.keras is TensorFlow's high-level API for building and training deep learning models. Extension types may be passed as inputs to a Keras model, passed between Keras layers, and returned by Keras models. Keras currently puts two requirements on extension types:
- They must be batchable (go to "Batchable
ExtensionType
s" above). - They must have a field or property named
shape
.shape[0]
is assumed to be the batch dimension.
The following two subsections give examples showing how extension types can be used with Keras.
Keras example: Network
For the first example, consider the Network
class defined in the "Batchable ExtensionType
s" section above, which can be used for load balancing work between nodes. Its definition is repeated here:
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
single_network = Network( # A single network with 4 nodes.
work=[8.0, 5, 12, 2],
bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])
batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.
work=[[8.0, 5], [3, 2]],
bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])
You can define a new Keras layer that processes Network
s.
class BalanceNetworkLayer(tf.keras.layers.Layer):
"""Layer that balances work between nodes in a network.
Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
"""
def call(self, inputs):
# This function is defined above in the "Batchable `ExtensionType`s" section.
return balance_work_greedy(inputs)
You can then use these layers to create a simple model. To feed an ExtensionType
into a model, you can use a tf.keras.layer.Input
layer with type_spec
set to the extension type's TypeSpec
. If the Keras model will be used to process batches, then the type_spec
must include the batch dimension.
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
BalanceNetworkLayer(),
])
Finally, you can apply the model to a single network and to a batch of networks.
model(single_network)
model(batch_of_networks)
Keras example: MaskedTensor
In this example, MaskedTensor
is extended to support Keras
. shape
is defined as a property that is calculated from the values
field. Keras requires that you add this property to both the extension type and its TypeSpec
. MaskedTensor
also defines a __name__
variable, which will be required for SavedModel
serialization (below).
class MaskedTensor(tf.experimental.BatchableExtensionType):
# __name__ is required for serialization in SavedModel; see below for details.
__name__ = 'extension_type_colab.MaskedTensor'
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_default(self, default):
return tf.where(self.mask, self.values, default)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
tf.TensorSpec(shape, self.mask.dtype))
Next, the dispatch decorators are used to override the default behavior of several TensorFlow APIs. Since these APIs are used by standard Keras layers (such as the Dense
layer), overriding these will allow us to use those layers with MaskedTensor
. For the purposes of this example, matmul
for masked tensors is defined to treat the masked values as zeros (that is, to not include them in the product).
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
return MaskedTensor(op(x.values), x.mask)
@tf.experimental.dispatch_for_binary_elementwise_apis(
Union[MaskedTensor, tf.Tensor],
Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_masked_tensor(x)
y = convert_to_masked_tensor(y)
return MaskedTensor(op(x.values, y.values), x.mask & y.mask)
@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None):
if isinstance(a, MaskedTensor):
a = a.with_default(0)
if isinstance(b, MaskedTensor):
b = b.with_default(0)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
You can then construct a Keras model that accepts MaskedTensor
inputs, using standard Keras layers:
input_spec = MaskedTensor.Spec([None, 2], tf.float32)
masked_tensor_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
[[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
SavedModel
A SavedModel is a serialized TensorFlow program, including both weights and computation. It can be built from a Keras model or from a custom model. In either case, extension types can be used transparently with the functions and methods defined by a SavedModel.
SavedModel can save models, layers, and functions that process extension types, as long as the extension types have a __name__
field. This name is used to register the extension type, so it can be located when the model is loaded.
Example: saving a Keras model
Keras models that use extension types may be saved using SavedModel
.
masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
Example: saving a custom model
SavedModel can also be used to save custom tf.Module
subclasses with functions that process extension types.
class CustomModule(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def grow(self, x: MaskedTensor):
"""Increase values in `x` by multiplying them by `self.v`."""
return MaskedTensor(x.values * self.v, x.mask)
module = CustomModule(100.0)
module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
Loading a SavedModel when the ExtensionType
is unavailable
If you load a SavedModel
that uses an ExtensionType
, but that ExtensionType
is not available (that is, it has not been imported), then you will get a warning and TensorFlow will fall back to using an "anonymous extension type" object. This object will have the same fields as the original type, but will lack any further customization you have added for the type, such as custom methods or properties.
Using ExtensionType
s with TensorFlow Serving
Currently, TensorFlow Serving (and other consumers of the SavedModel "signatures" dictionary) require that all inputs and outputs be raw tensors. If you wish to use TensorFlow Serving with a model that uses extension types, then you can add wrapper methods that compose or decompose extension type values from tensors. For example:
class CustomModuleWrapper(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def var_weighted_mean(self, x: MaskedTensor):
"""Mean value of unmasked values in x, weighted by self.v."""
x = MaskedTensor(x.values * self.v, x.mask)
return (tf.reduce_sum(x.with_default(0)) /
tf.reduce_sum(tf.cast(x.mask, x.dtype)))
@tf.function()
def var_weighted_mean_wrapper(self, x_values, x_mask):
"""Raw tensor wrapper for var_weighted_mean."""
return self.var_weighted_mean(MaskedTensor(x_values, x_mask))
module = CustomModuleWrapper([3., 2., 8., 5.])
module.var_weighted_mean_wrapper.get_concrete_function(
tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
Dataset
s
tf.data
is an API that enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is tf.data.Dataset
, which represents a sequence of elements, in which each element consists of one or more components.
Building Dataset
s with extension types
Datasets can be built from extension type values using Dataset.from_tensors
, Dataset.from_tensor_slices
, or Dataset.from_generator
:
ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
print(value)
def value_gen():
for i in range(2, 7):
yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])
ds = tf.data.Dataset.from_generator(
value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
print(value)
Batching and unbatching Dataset
s with extension types
Datasets with extension types can be batchand and unbatched using Dataset.batch
and Dataset.unbatch
.
batched_ds = ds.batch(2)
for value in batched_ds:
print(value)
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
print(value)
Originally published on the