Content Overview

Tensor API dispatch

Extension types can be "tensor-like", in the sense that they specialize or extend the interface defined by the tf.Tensor type. Examples of tensor-like extension types include RaggedTensorSparseTensor, and MaskedTensorDispatch decorators can be used to override the default behavior of TensorFlow operations when applied to tensor-like extension types. TensorFlow currently defines three dispatch decorators:

Dispatch for a single API

The tf.experimental.dispatch_for_api decorator overrides the default behavior of a specified TensorFlow operation when it is called with the specified signature. For example, you can use this decorator to specify how tf.stack should process MaskedTensor values:

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

This overrides the default implementation for tf.stack whenever it is called with a list of MaskedTensor values (since the values argument is annotated with typing.List[MaskedTensor]):

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])

To allow tf.stack to handle lists of mixed MaskedTensor and Tensor values, you can refine the type annotation for the values parameter and update the body of the function appropriately:

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])

For a list of APIs that can be overridden, see the API documentation for tf.experimental.dispatch_for_api.

Dispatch for all unary elementwise APIs

The tf.experimental.dispatch_for_unary_elementwise_apis decorator overrides the default behavior of all unary elementwise ops (such as tf.math.cos) whenever the value for the first argument (typically named x) matches the type annotation x_type. The decorated function should take two arguments:

The following example updates all unary elementwise operations to handle the MaskedTensor type:

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

This function will now be used whenever a unary elementwise operation is called on a MaskedTensor.

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))

print(tf.ones_like(x, dtype=tf.float32))

Dispatch for binary all elementwise APIs

Similarly, tf.experimental.dispatch_for_binary_elementwise_apis can be used to update all binary elementwise operations to handle the MaskedTensor type:

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)

x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)

For a list of the elementwise APIs that are overridden, go to the API documentation for tf.experimental.dispatch_for_unary_elementwise_apis and tf.experimental.dispatch_for_binary_elementwise_apis.

Batchable ExtensionTypes

An ExtensionType is batchable if a single instance can be used to represent a batch of values. Typically, this is accomplished by adding batch dimensions to all nested Tensors. The following TensorFlow APIs require that any extension type inputs be batchable:

By default, BatchableExtensionType creates batched values by batching any nested Tensors, CompositeTensors, and ExtensionTypes. If this is not appropriate for your class, then you will need to use tf.experimental.ExtensionTypeBatchEncoder to override this default behavior. For example, it would not be appropriate to create a batch of tf.SparseTensor values by simply stacking individual sparse tensors' valuesindices, and dense_shape fields -- in most cases, you can't stack these tensors, since they have incompatible shapes; and even if you could, the result would not be a valid SparseTensor.

Note: BatchableExtensionTypes do not automatically define dispatchers for tf.stacktf.concattf.slice, etc. If your class needs to be supported by these APIs, then use the dispatch decorators described above.

BatchableExtensionType example: Network

As an example, consider a simple Network class used for load balancing, which tracks how much work is left to do at each node, and how much bandwidth is available to move work between nodes:

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

To make this type batchable, change the base type to BatchableExtensionType, and adjust the shape of each field to include optional batch dimensions. The following example also adds a shape field to keep track of the batch shape. This shape field is not required by tf.data.Dataset or tf.map_fn, but it is required by tf.keras.

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")

You can then use tf.data.Dataset to iterate through a batch of networks:

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")

And you can also use map_fn to apply a function to each batch element:

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)

TensorFlow APIs that support ExtensionTypes

@tf.function

tf.function is a decorator that precomputes TensorFlow graphs for Python functions, which can substantially improve the performance of your TensorFlow code. Extension type values can be used transparently with @tf.function-decorated functions.

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)

If you wish to explicitly specify the input_signature for tf.function, then you can do so using the extension type's TypeSpec.

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)

Concrete functions

Concrete functions encapsulate individual traced graphs that are built by tf.function. Extension types can be used transparently with concrete functions.

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)

Control flow operations

Extension types are supported by TensorFlow's control-flow operations:

# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))

# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])

Autograph control flow

Extension types are also supported by control flow statements in tf.function (using autograph). In the following example, the if statement and for statements are automatically converted to tf.cond and tf.while_loop operations, which support extension types.

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))

Keras

tf.keras is TensorFlow's high-level API for building and training deep learning models. Extension types may be passed as inputs to a Keras model, passed between Keras layers, and returned by Keras models. Keras currently puts two requirements on extension types:

The following two subsections give examples showing how extension types can be used with Keras.

Keras example: Network

For the first example, consider the Network class defined in the "Batchable ExtensionTypes" section above, which can be used for load balancing work between nodes. Its definition is repeated here:

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

single_network = Network(  # A single network with 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

You can define a new Keras layer that processes Networks.

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above in the "Batchable `ExtensionType`s" section.
    return balance_work_greedy(inputs)

You can then use these layers to create a simple model. To feed an ExtensionType into a model, you can use a tf.keras.layer.Input layer with type_spec set to the extension type's TypeSpec. If the Keras model will be used to process batches, then the type_spec must include the batch dimension.

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

Finally, you can apply the model to a single network and to a batch of networks.

model(single_network)

model(batch_of_networks)

Keras example: MaskedTensor

In this example, MaskedTensor is extended to support Kerasshape is defined as a property that is calculated from the values field. Keras requires that you add this property to both the extension type and its TypeSpecMaskedTensor also defines a __name__ variable, which will be required for SavedModel serialization (below).

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

Next, the dispatch decorators are used to override the default behavior of several TensorFlow APIs. Since these APIs are used by standard Keras layers (such as the Dense layer), overriding these will allow us to use those layers with MaskedTensor. For the purposes of this example, matmul for masked tensors is defined to treat the masked values as zeros (that is, to not include them in the product).

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

You can then construct a Keras model that accepts MaskedTensor inputs, using standard Keras layers:

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')

a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))

SavedModel

SavedModel is a serialized TensorFlow program, including both weights and computation. It can be built from a Keras model or from a custom model. In either case, extension types can be used transparently with the functions and methods defined by a SavedModel.

SavedModel can save models, layers, and functions that process extension types, as long as the extension types have a __name__ field. This name is used to register the extension type, so it can be located when the model is loaded.

Example: saving a Keras model

Keras models that use extension types may be saved using SavedModel.

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)

Example: saving a custom model

SavedModel can also be used to save custom tf.Module subclasses with functions that process extension types.

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))

Loading a SavedModel when the ExtensionType is unavailable

If you load a SavedModel that uses an ExtensionType, but that ExtensionType is not available (that is, it has not been imported), then you will get a warning and TensorFlow will fall back to using an "anonymous extension type" object. This object will have the same fields as the original type, but will lack any further customization you have added for the type, such as custom methods or properties.

Using ExtensionTypes with TensorFlow Serving

Currently, TensorFlow Serving (and other consumers of the SavedModel "signatures" dictionary) require that all inputs and outputs be raw tensors. If you wish to use TensorFlow Serving with a model that uses extension types, then you can add wrapper methods that compose or decompose extension type values from tensors. For example:

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)

Datasets

tf.data is an API that enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is tf.data.Dataset, which represents a sequence of elements, in which each element consists of one or more components.

Building Datasets with extension types

Datasets can be built from extension type values using Dataset.from_tensorsDataset.from_tensor_slices, or Dataset.from_generator:

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()

mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)

def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)

Batching and unbatching Datasets with extension types

Datasets with extension types can be batchand and unbatched using Dataset.batch and Dataset.unbatch.

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)

unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)

Originally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.