Content Overview
- Setup
- Extension types
- Supported APIs
- Requirements
- Field types
- Mutability
- Functionality added by ExtensionType
- Constructor
- Printable representation
- Equality operators
- Validation method
- Enforced immutability
- Nested TypeSpec
- Customizing ExtensionTypes
- Overriding the default printable representation
- Defining methods
- Defining classmethods and staticmethods
- Defining properties
- Overriding the default constructor
- Overriding the default equality operator (eq)
- Using forward references
- Defining subclasses
- Defining private fields
- Customizing the ExtensionType’s TypeSpec
Setup
!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
Extension types
User-defined types can make projects more readable, modular, maintainable. However, most TensorFlow APIs have very limited support for user-defined Python types. This includes both high-level APIs (such as Keras, tf.function, tf.SavedModel
) and lower-level APIs (such as tf.while_loop
and tf.concat
). TensorFlow extension types can be used to create user-defined object-oriented types that work seamlessly with TensorFlow's APIs. To create an extension type, simply define a Python class with tf.experimental.ExtensionType
as its base, and use type annotations to specify the type for each field.
class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.
class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64
The tf.experimental.ExtensionType
base class works similarly to typing.NamedTuple
and @dataclasses.dataclass
from the standard Python library. In particular, it automatically adds a constructor and special methods (such as __repr__
and __eq__
) based on the field type annotations.
Typically, extension types tend to fall into one of two categories:
- Data structures, which group together a collection of related values, and can provide useful operations based on those values. Data structures may be fairly general (such as the
TensorGraph
example above); or they may be highly customized to a specific model. - Tensor-like types, which specialize or extend the concept of "Tensor." Types in this category have a
rank
, ashape
, and usually adtype
; and it makes sense to use them with Tensor operations (such astf.stack
,tf.add
, ortf.matmul
).MaskedTensor
andCSRSparseMatrix
are examples of tensor-like types.
Supported APIs
Extension types are supported by the following TensorFlow APIs:
- Keras: Extension types can be used as inputs and outputs for Keras
Models
andLayers
. tf.data.Dataset
: Extension types can be included inDatasets
, and returned by datasetIterators
.- TensorFlow Hub: Extension types can be used as inputs and outputs for
tf.hub
modules. - SavedModel: Extension types can be used as inputs and outputs for
SavedModel
functions. tf.function
: Extension types can be used as arguments and return values for functions wrapped with the@tf.function
decorator.- While loops: Extension types can be used as loop variables in
tf.while_loop
, and can be used as arguments and return values for the while-loop's body. - Conditionals: Extension types can be conditionally selected using
tf.cond
andtf.case
. tf.py_function
: Extension types can be used as arguments and return values for thefunc
argument totf.py_function
.- Tensor ops: Extension types can be extended to support most TensorFlow ops that accept Tensor inputs (such as
tf.matmul
,tf.gather
, andtf.reduce_sum
). Go to the "Dispatch" section below for more information. - Distribution strategy: Extension types can be used as per-replica values.
For more details, see the section on "TensorFlow APIs that support ExtensionTypes" below.
Requirements
Field types
All fields—instance variables—must be declared, and a type annotation must be provided for each field. The following type annotations are supported:
Type |
Example |
---|---|
Python integers |
|
Python floats |
|
Python strings |
|
Python booleans |
|
Python |
|
|
|
Tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Mutability
Extension types are required to be immutable. This ensures that they can be properly tracked by TensorFlow's graph-tracing mechanisms. If you find yourself wanting to mutate an extension type value, consider instead defining methods that transform values. For example, rather than defining a set_mask
method to mutate a MaskedTensor
, you could define a replace_mask
method that returns a new MaskedTensor
:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def replace_mask(self, new_mask):
self.values.shape.assert_is_compatible_with(new_mask.shape)
return MaskedTensor(self.values, new_mask)
Functionality added by ExtensionType
The ExtensionType
base class provides the following functionality:
- A constructor (
__init__
). - A printable representation method (
__repr__
). - Equality and inequality operators (
__eq__
). - A validation method (
__validate__
). - Enforced immutability.
- A nested
TypeSpec
. - Tensor API dispatch support.
Go to the "Customizing ExtensionType
s" section below for more information on customizing this functionality.
Constructor
The constructor added by ExtensionType
takes each field as a named argument (in the order they were listed in the class definition). This constructor will type-check each parameter, and convert them where necessary. In particular, Tensor
fields are converted using tf.convert_to_tensor
; Tuple
fields are converted to tuple
s; and Mapping
fields are converted to immutable dicts.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)
The constructor raises an TypeError
if a field value can not be converted to its declared type:
try:
MaskedTensor([1, 2, 3], None)
except TypeError as e:
print(f"Got expected TypeError: {e}")
The default value for a field can be specified by setting its value at the class level:
class Pencil(tf.experimental.ExtensionType):
color: str = "black"
has_erasor: bool = True
length: tf.Tensor = 1.0
Pencil()
Pencil(length=0.5, color="blue")
Printable representation
ExtensionType
adds a default printable representation method (__repr__
) that includes the class name and the value for each field:
print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
Equality operators
ExtensionType
adds default equality operators (__eq__
and __ne__
) that consider two values equal if they have the same type and all their fields are equal. Tensor fields are considered equal if they have the same shape and are elementwise equal for all elements.
a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
Note: if any field contains a Tensor
, then __eq__
may return a scalar boolean Tensor
(rather than a Python boolean value).
Validation method
ExtensionType
adds a __validate__
method, which can be overridden to perform validation checks on fields. It is run after the constructor is called, and after fields have been type-checked and converted to their declared types, so it can assume that all fields have their declared types.
The following example updates MaskedTensor
to validate the shape
s and dtype
s of its fields:
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor
def __validate__(self):
self.values.shape.assert_is_compatible_with(self.mask.shape)
assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
MaskedTensor([1, 2, 3], [0, 1, 0]) # Wrong `dtype` for mask.
except AssertionError as e:
print(f"Got expected AssertionError: {e}")
try:
MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.
except ValueError as e:
print(f"Got expected ValueError: {e}")
Enforced immutability
ExtensionType
overrides the __setattr__
and __delattr__
methods to prevent mutation, ensuring that extension type values are immutable.
mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
mt.mask = [True, True, True]
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
try:
mt.mask[0] = False
except TypeError as e:
print(f"Got expected TypeError: {e}")
try:
del mt.mask
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Nested TypeSpec
Each ExtensionType
class has a corresponding TypeSpec
class, which is created automatically and stored as <extension_type_name>.Spec
.
This class captures all the information from a value except for the values of any nested tensors. In particular, the TypeSpec
for a value is created by replacing any nested Tensor, ExtensionType, or CompositeTensor with its TypeSpec
.
class Player(tf.experimental.ExtensionType):
name: tf.Tensor
attributes: Mapping[str, tf.Tensor]
anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name) # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes) # Records keys and TensorSpecs for values.
TypeSpec
values can be constructed explicitly, or they can be built from an ExtensionType
value using tf.type_spec_from_value
:
spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)
TypeSpec
s are used by TensorFlow to divide values into a static component and a dynamic component:
- The static component (which is fixed at graph-construction time) is encoded with a
tf.TypeSpec
. - The dynamic component (which can vary each time the graph is run) is encoded as a list of
tf.Tensor
s.
For example, tf.function
retraces its wrapped function whenever an argument has a previously unseen TypeSpec
:
@tf.function
def anonymize_player(player):
print("<<TRACING>>")
return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
For more information, see the tf.function Guide.
Customizing ExtensionType
s
In addition to simply declaring fields and their types, extension types may:
- Override the default printable representation (
__repr__
). - Define methods.
- Define
classmethod
s andstaticmethod
s. - Define properties.
- Override the default constructor (
__init__
). - Override the default equality operator (
__eq__
). - Define operators (such as
__add__
and__lt__
). - Declare default values for fields.
- Define subclasses.
Overriding the default printable representation
You can override this default string conversion operator for extension types. The following example updates the MaskedTensor
class to generate a more readable string representation when values are printed in Eager mode.
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for invalid values.
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def masked_tensor_str(values, mask):
if isinstance(values, tf.Tensor):
if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
else:
return f'MaskedTensor(values={values}, mask={mask})'
if len(values.shape) == 1:
items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
else:
items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
return '[%s]' % ', '.join(items)
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
print(mt)
Defining methods
Extension types may define methods, just like any normal Python class. For example, the MaskedTensor
type could define a with_default
method that returns a copy of self
with masked values replaced by a given default
value. Methods may optionally be annotated with the @tf.function
decorator.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def with_default(self, default):
return tf.where(self.mask, self.values, default)
MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
Defining classmethod
s and staticmethod
s
Extension types may define methods using the @classmethod
and @staticmethod
decorators. For example, the MaskedTensor
type could define a factory method that masks any element with a given value:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
@staticmethod
def from_tensor_and_value_to_mask(values, value_to_mask):
return MaskedTensor(values, values != value_to_mask)
x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
Defining properties
Extension types may define properties using the @property
decorator, just like any normal Python class. For example, the MaskedTensor
type could define a dtype
property that's a shorthand for the dtype
of the values:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@property
def dtype(self):
return self.values.dtype
MaskedTensor([1, 2, 3], [True, False, True]).dtype
Overriding the default constructor
You can override the default constructor for extension types. Custom constructors must set a value for every declared field; and after the custom constructor returns, all fields will be type-checked, and values will be converted as described above.
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
def __init__(self, name, price, discount=0):
self.name = name
self.price = price * (1 - discount)
print(Toy("ball", 5.0, discount=0.2)) # On sale -- 20% off!
Alternatively, you might consider leaving the default constructor as-is, but adding one or more factory methods. For example:
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
@staticmethod
def new_toy_with_discount(name, price, discount):
return Toy(name, price * (1 - discount))
print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Overriding the default equality operator (__eq__
)
You can override the default __eq__
operator for extension types. The following example updates MaskedTensor
to ignore masked elements when comparing for equality.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def __eq__(self, other):
result = tf.math.equal(self.values, other.values)
result = result | ~(self.mask & other.mask)
return tf.reduce_all(result)
x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
Note: You generally don't need to override __ne__
, since its default implementation simply calls __eq__
and negates the result.
Using forward references
If the type for a field has not been defined yet, you may use a string containing the name of the type instead. In the following example, the string "Node"
is used to annotate the children
field because the Node
type hasn't been (fully) defined yet.
class Node(tf.experimental.ExtensionType):
value: tf.Tensor
children: Tuple["Node", ...] = ()
Node(3, [Node(5), Node(2)])
Defining subclasses
Extension types may be subclassed using the standard Python syntax. Extension type subclasses may add new fields, methods, and properties; and may override the constructor, the printable representation, and the equality operator. The following example defines a basic TensorGraph
class that uses three Tensor
fields to encode a set of edges between nodes. It then defines a subclass that adds a Tensor
field to record a "feature value" for each node. The subclass also defines a method to propagate the feature values along the edges.
class TensorGraph(tf.experimental.ExtensionType):
num_nodes: tf.Tensor
edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.
edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.
class TensorGraphWithNodeFeature(TensorGraph):
node_features: tf.Tensor # node_features[n] = feature value for node n.
def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
updates = tf.gather(self.node_features, self.edge_src) * weight
new_node_features = tf.tensor_scatter_nd_add(
self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
return TensorGraphWithNodeFeature(
self.num_nodes, self.edge_src, self.edge_dst, new_node_features)
g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1
num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])
print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Defining private fields
An extension type's fields may be marked private by prefixing them with an underscore (following standard Python conventions). This does not impact the way that TensorFlow treats the fields in any way; but simply serves as a signal to any users of the extension type that those fields are private.
Customizing the ExtensionType
's TypeSpec
Each ExtensionType
class has a corresponding TypeSpec
class, which is created automatically and stored as <extension_type_name>.Spec
. For more information, see the section "Nested TypeSpec" above.
To customize the TypeSpec
, simply define your own nested class named Spec
, and ExtensionType
will use that as the basis for the automatically constructed TypeSpec
. You can customize the Spec
class by:
- Overriding the default printable representation.
- Overriding the default constructor.
- Defining methods,
classmethod
s,staticmethod
s, and properties.
The following example customizes the MaskedTensor.Spec
class to make it easier to use:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def with_values(self, new_values):
return MaskedTensor(new_values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
def __repr__(self):
return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
Note: The custom Spec
class may not use any instance variables that were not declared in the original ExtensionType
.
Originally published on the