Content Overview

Setup

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

Extension types

User-defined types can make projects more readable, modular, maintainable. However, most TensorFlow APIs have very limited support for user-defined Python types. This includes both high-level APIs (such as Kerastf.functiontf.SavedModel) and lower-level APIs (such as tf.while_loop and tf.concat). TensorFlow extension types can be used to create user-defined object-oriented types that work seamlessly with TensorFlow's APIs. To create an extension type, simply define a Python class with tf.experimental.ExtensionType as its base, and use type annotations to specify the type for each field.

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

The tf.experimental.ExtensionType base class works similarly to typing.NamedTuple and @dataclasses.dataclass from the standard Python library. In particular, it automatically adds a constructor and special methods (such as __repr__ and __eq__) based on the field type annotations.

Typically, extension types tend to fall into one of two categories:

Supported APIs

Extension types are supported by the following TensorFlow APIs:

For more details, see the section on "TensorFlow APIs that support ExtensionTypes" below.

Requirements

Field types

All fields—instance variables—must be declared, and a type annotation must be provided for each field. The following type annotations are supported:

Type

Example

Python integers

i: int

Python floats

f: float

Python strings

s: str

Python booleans

b: bool

Python None

n: None

Tensor shapes

shape: tf.TensorShape

Tensor dtypes

dtype: tf.DType

Tensors

t: tf.Tensor

Extension types

mt: MyMaskedTensor

Ragged tensors

rt: tf.RaggedTensor

Sparse tensors

st: tf.SparseTensor

Indexed slices

s: tf.IndexedSlices

Optional tensors

o: tf.experimental.Optional

Type unions

int_or_float: typing.Union[int, float]

Tuples

params: typing.Tuple[int, float, tf.Tensor, int]

Var-length tuples

lengths: typing.Tuple[int, ...]

Mappings

tags: typing.Mapping[str, tf.Tensor]

Optional values

weight: typing.Optional[tf.Tensor]

Mutability

Extension types are required to be immutable. This ensures that they can be properly tracked by TensorFlow's graph-tracing mechanisms. If you find yourself wanting to mutate an extension type value, consider instead defining methods that transform values. For example, rather than defining a set_mask method to mutate a MaskedTensor, you could define a replace_mask method that returns a new MaskedTensor:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

Functionality added by ExtensionType

The ExtensionType base class provides the following functionality:

Go to the "Customizing ExtensionTypes" section below for more information on customizing this functionality.

Constructor

The constructor added by ExtensionType takes each field as a named argument (in the order they were listed in the class definition). This constructor will type-check each parameter, and convert them where necessary. In particular, Tensor fields are converted using tf.convert_to_tensorTuple fields are converted to tuples; and Mapping fields are converted to immutable dicts.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)

The constructor raises an TypeError if a field value can not be converted to its declared type:

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")

The default value for a field can be specified by setting its value at the class level:

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()

Pencil(length=0.5, color="blue")

Printable representation

ExtensionType adds a default printable representation method (__repr__) that includes the class name and the value for each field:

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))

Equality operators

ExtensionType adds default equality operators (__eq__ and __ne__) that consider two values equal if they have the same type and all their fields are equal. Tensor fields are considered equal if they have the same shape and are elementwise equal for all elements.

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")

Note: if any field contains a Tensor, then __eq__ may return a scalar boolean Tensor (rather than a Python boolean value).

Validation method

ExtensionType adds a __validate__ method, which can be overridden to perform validation checks on fields. It is run after the constructor is called, and after fields have been type-checked and converted to their declared types, so it can assume that all fields have their declared types.

The following example updates MaskedTensor to validate the shapes and dtypes of its fields:

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'

try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # Wrong `dtype` for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")

try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")

Enforced immutability

ExtensionType overrides the __setattr__ and __delattr__ methods to prevent mutation, ensuring that extension type values are immutable.

mt = MaskedTensor([1, 2, 3], [True, False, True])

try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")

try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")

try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")

Nested TypeSpec

Each ExtensionType class has a corresponding TypeSpec class, which is created automatically and stored as <extension_type_name>.Spec.

This class captures all the information from a value except for the values of any nested tensors. In particular, the TypeSpec for a value is created by replacing any nested Tensor, ExtensionType, or CompositeTensor with its TypeSpec.

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.

TypeSpec values can be constructed explicitly, or they can be built from an ExtensionType value using tf.type_spec_from_value:

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TypeSpecs are used by TensorFlow to divide values into a static component and a dynamic component:

For example, tf.function retraces its wrapped function whenever an argument has a previously unseen TypeSpec:

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)

# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))

# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))

# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))

For more information, see the tf.function Guide.

Customizing ExtensionTypes

In addition to simply declaring fields and their types, extension types may:

Overriding the default printable representation

You can override this default string conversion operator for extension types. The following example updates the MaskedTensor class to generate a more readable string representation when values are printed in Eager mode.

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)

Defining methods

Extension types may define methods, just like any normal Python class. For example, the MaskedTensor type could define a with_default method that returns a copy of self with masked values replaced by a given default value. Methods may optionally be annotated with the @tf.function decorator.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)

Defining classmethods and staticmethods

Extension types may define methods using the @classmethod and @staticmethod decorators. For example, the MaskedTensor type could define a factory method that masks any element with a given value:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values != value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)

Defining properties

Extension types may define properties using the @property decorator, just like any normal Python class. For example, the MaskedTensor type could define a dtype property that's a shorthand for the dtype of the values:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype

Overriding the default constructor

You can override the default constructor for extension types. Custom constructors must set a value for every declared field; and after the custom constructor returns, all fields will be type-checked, and values will be converted as described above.

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!

Alternatively, you might consider leaving the default constructor as-is, but adding one or more factory methods. For example:

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))

Overriding the default equality operator (__eq__)

You can override the default __eq__ operator for extension types. The following example updates MaskedTensor to ignore masked elements when comparing for equality.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)

Note: You generally don't need to override __ne__, since its default implementation simply calls __eq__ and negates the result.

Using forward references

If the type for a field has not been defined yet, you may use a string containing the name of the type instead. In the following example, the string "Node" is used to annotate the children field because the Node type hasn't been (fully) defined yet.

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])

Defining subclasses

Extension types may be subclassed using the standard Python syntax. Extension type subclasses may add new fields, methods, and properties; and may override the constructor, the printable representation, and the equality operator. The following example defines a basic TensorGraph class that uses three Tensor fields to encode a set of edges between nodes. It then defines a subclass that adds a Tensor field to record a "feature value" for each node. The subclass also defines a method to propagate the feature values along the edges.

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)

Defining private fields

An extension type's fields may be marked private by prefixing them with an underscore (following standard Python conventions). This does not impact the way that TensorFlow treats the fields in any way; but simply serves as a signal to any users of the extension type that those fields are private.

Customizing the ExtensionType's TypeSpec

Each ExtensionType class has a corresponding TypeSpec class, which is created automatically and stored as <extension_type_name>.Spec. For more information, see the section "Nested TypeSpec" above.

To customize the TypeSpec, simply define your own nested class named Spec, and ExtensionType will use that as the basis for the automatically constructed TypeSpec. You can customize the Spec class by:

The following example customizes the MaskedTensor.Spec class to make it easier to use:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

Note: The custom Spec class may not use any instance variables that were not declared in the original ExtensionType.


Originally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.