Module phiml.backend.tensorflow.nets

TensorFlow implementation of the unified machine learning API. Equivalent functions also exist for the other frameworks.

For API documentation, see phiml.nn.

Functions

def adagrad(net: keras.src.models.model.Model,
learning_rate: float = 0.001,
lr_decay=0.0,
weight_decay=0.0,
initial_accumulator_value=0.0,
eps=1e-10)
Expand source code
def adagrad(net: keras.Model, learning_rate: float = 1e-3, lr_decay=0., weight_decay=0., initial_accumulator_value=0., eps=1e-10):
    assert lr_decay == 0
    assert weight_decay == 0
    return keras.optimizers.Adagrad(learning_rate, initial_accumulator_value, eps)
def adam(net: keras.src.models.model.Model,
learning_rate: float = 0.001,
betas=(0.9, 0.999),
epsilon=1e-07)
Expand source code
def adam(net: keras.Model, learning_rate: float = 1e-3, betas=(0.9, 0.999), epsilon=1e-07):
    return keras.optimizers.Adam(learning_rate, betas[0], betas[1], epsilon)
def conv_classifier(in_features: int,
in_spatial: tuple | list,
num_classes: int,
blocks=(64, 128, 256, 256, 512, 512),
block_sizes=(2, 2, 3, 3, 3),
dense_layers=(4096, 4096, 100),
batch_norm=True,
activation='ReLU',
softmax=True,
periodic=False)
Expand source code
def conv_classifier(in_features: int,
                    in_spatial: Union[tuple, list],
                    num_classes: int,
                    blocks=(64, 128, 256, 256, 512, 512),
                    block_sizes=(2, 2, 3, 3, 3),
                    dense_layers=(4096, 4096, 100),
                    batch_norm=True,
                    activation='ReLU',
                    softmax=True,
                    periodic=False):
    assert isinstance(in_spatial, (tuple, list))
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    d = len(in_spatial)
    x = inputs = keras.Input(shape=in_spatial + (in_features,))
    for i, (next, block_size) in enumerate(zip(blocks, block_sizes)):
        for j in range(block_size):
            x = CONV[d](next, 3, padding='valid')(pad_periodic(x)) if periodic else CONV[d](next, 3, padding='same')(x)
            if batch_norm:
                x = kl.BatchNormalization()(x)
            x = activation(x)
        x = MAX_POOL[d](2)(x)
    x = kl.Flatten()(x)
    flat_size = int(np.prod(in_spatial) * blocks[-1] / (2**d) ** len(blocks))
    x = mlp(flat_size, num_classes, dense_layers, batch_norm, activation, softmax)(x)
    return keras.Model(inputs, x)
def conv_net(in_channels: int,
out_channels: int,
layers: Sequence[int],
batch_norm: bool = False,
activation: str | Callable = 'ReLU',
periodic=False,
in_spatial: int | tuple = 2) ‑> keras.src.models.model.Model
Expand source code
def conv_net(in_channels: int,
             out_channels: int,
             layers: Sequence[int],
             batch_norm: bool = False,
             activation: Union[str, Callable] = 'ReLU',
             periodic=False,
             in_spatial: Union[int, tuple] = 2) -> keras.Model:
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (None,) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    x = inputs = keras.Input(shape=in_spatial + (in_channels,))
    if len(layers) < 1:
        layers.append(out_channels)
    for i in range(len(layers)):
        x = CONV[d](layers[i], 3, padding='valid')(pad_periodic(x)) if periodic else CONV[d](layers[i], 3, padding='same')(x)
        if batch_norm:
            x = kl.BatchNormalization()(x)
        x = activation(x)
    x = CONV[d](out_channels, 1)(x)
    return keras.Model(inputs, x)
def double_conv(x,
d: int,
out_channels: int,
mid_channels: int,
batch_norm: bool,
activation: Callable,
periodic: bool,
kernel_size=3)
Expand source code
def double_conv(x, d: int, out_channels: int, mid_channels: int, batch_norm: bool, activation: Callable, periodic: bool, kernel_size=3):
    x = CONV[d](mid_channels, kernel_size, padding='valid')(pad_periodic(x)) if periodic else CONV[d](mid_channels, kernel_size, padding='same')(x)
    if batch_norm:
        x = kl.BatchNormalization()(x)
    x = activation(x)
    x = CONV[d](out_channels, kernel_size, padding='valid')(pad_periodic(x)) if periodic else CONV[d](out_channels, kernel_size, padding='same')(x)
    if batch_norm:
        x = kl.BatchNormalization()(x)
    x = activation(x)
    return x
def get_learning_rate(optimizer: keras.src.optimizers.optimizer.Optimizer)
Expand source code
def get_learning_rate(optimizer: keras.optimizers.Optimizer):
    """
    Gets the global learning rate for the given optimizer.

    Args:
        optimizer (optim.Optimizer): The optimizer whose learning rate needs to be updated.
    """
    return keras.backend.get_value(optimizer.lr)

Gets the global learning rate for the given optimizer.

Args

optimizer : optim.Optimizer
The optimizer whose learning rate needs to be updated.
def get_mask(inputs, reverse_mask, data_format='NHWC')
Expand source code
def get_mask(inputs, reverse_mask, data_format='NHWC'):
    """ Compute mask for slicing input feature map for Invertible Nets """
    shape = inputs.shape
    if len(shape) == 2:
        N = shape[-1]
        range_n = tf.range(0, N)
        even_ind = range_n % 2
        checker = tf.reshape(even_ind, (-1, N))
    elif len(shape) == 4:
        H = shape[2] if data_format == 'NCHW' else shape[1]
        W = shape[3] if data_format == 'NCHW' else shape[2]
        range_h = tf.range(0, H)
        range_w = tf.range(0, W)
        even_ind_h = tf.cast(range_h % 2, dtype=tf.bool)
        even_ind_w = tf.cast(range_w % 2, dtype=tf.bool)
        ind_h = tf.tile(tf.expand_dims(even_ind_h, -1), [1, W])
        ind_w = tf.tile(tf.expand_dims(even_ind_w, 0), [H, 1])
        # ind_h = even_ind_h.unsqueeze(-1).repeat(1, W)
        # ind_w = even_ind_w.unsqueeze( 0).repeat(H, 1)
        checker = tf.math.logical_xor(ind_h, ind_w)
        reshape = [-1, 1, H, W] if data_format == 'NCHW' else [-1, H, W, 1]
        checker = tf.reshape(checker, reshape)
        checker = tf.cast(checker, dtype=tf.float32)
    else:
        raise ValueError('Invalid tensor shape. Dimension of the tensor shape must be 2 (NxD) or 4 (NxCxHxW or NxHxWxC), got {}.'.format(inputs.get_shape().as_list()))
    if reverse_mask:
        checker = 1 - checker
    return checker

Compute mask for slicing input feature map for Invertible Nets

def get_parameters(model: keras.src.models.model.Model, wrap=True) ‑> dict
Expand source code
def get_parameters(model: keras.Model, wrap=True) -> dict:
    result = {}
    for var in model.trainable_weights:
        if hasattr(var, 'path'):
            name: str = var.path.split('/')
            layer = '.'.join(name[:-1]).replace('dense', 'linear').replace('_', '')
            try:
                int(layer[-1:])
            except ValueError:
                layer += '0'
            prop = name[-1].replace('kernel', 'weight')
        else:
            name: str = var.path if hasattr(var, 'path') else var.name  # path replaces name in tensorflow >= 2.16
            layer = name[:name.index('/')].replace('_', '').replace('dense', 'linear')
            try:
                int(layer[-1:])
            except ValueError:
                layer += '0'
            prop = name[name.index('/') + 1:].replace('kernel', 'weight')
            if prop.endswith(':0'):
                prop = prop[:-2]
        name = f"{layer}.{prop}"
        var = var.numpy()
        if not wrap:
            result[name] = var
        else:
            if name.endswith('.weight'):
                if var.ndim == 2:
                    uml_tensor = math.wrap(var, math.channel('input,output'))
                elif var.ndim == 3:
                    uml_tensor = math.wrap(var, math.channel('x,input,output'))
                elif var.ndim == 4:
                    uml_tensor = math.wrap(var, math.channel('x,y,input,output'))
                elif var.ndim == 5:
                    uml_tensor = math.wrap(var, math.channel('x,y,z,input,output'))
            elif name.endswith('.bias'):
                uml_tensor = math.wrap(var, math.channel('output'))
            elif var.ndim == 1:
                uml_tensor = math.wrap(var, math.channel('output'))
            else:
                raise NotImplementedError(name, var)
            result[name] = uml_tensor
    return result
def invertible_net(num_blocks: int, construct_net: str | Callable, **construct_kwargs)
Expand source code
def invertible_net(num_blocks: int,
                   construct_net: Union[str, Callable],
                   **construct_kwargs):  # mlp, u_net, res_net, conv_net
    if construct_net == 'mlp':
        construct_net = '_inv_net_dense_resnet_block'
    if isinstance(construct_net, str):
        construct_net = globals()[construct_net]
    if 'in_channels' in construct_kwargs and 'out_channels' not in construct_kwargs:
        construct_kwargs['out_channels'] = construct_kwargs['in_channels']
    return InvertibleNet(num_blocks, construct_net, construct_kwargs)
def load_state(obj: keras.src.models.model.Model | keras.src.optimizers.optimizer.Optimizer,
path: str)
Expand source code
def load_state(obj: Union[keras.models.Model, keras.optimizers.Optimizer], path: str):
    if isinstance(obj, keras.models.Model):
        if not path.endswith('.h5'):
            path += '.h5'
        obj.load_weights(path)
    elif isinstance(obj, keras.optimizers.Optimizer):
        if not path.endswith('.pkl'):
            path += '.pkl'
        with open(path, 'rb') as f:
            weights = pickle.load(f)
        obj.set_weights(weights)
    else:
        raise ValueError("obj must be a Keras model or optimizer")
def mlp(in_channels: int,
out_channels: int,
layers: Sequence[int],
batch_norm=False,
activation='ReLU',
softmax=False) ‑> keras.src.models.model.Model
Expand source code
def mlp(in_channels: int,
              out_channels: int,
              layers: Sequence[int],
              batch_norm=False,
              activation='ReLU',
              softmax=False) -> keras.Model:
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    keras_layers = []
    for neuron_count in layers:
        keras_layers.append(kl.Dense(neuron_count, activation=activation))
        if batch_norm:
            keras_layers.append(kl.BatchNormalization())
    return keras.models.Sequential([kl.InputLayer(input_shape=(in_channels,)),
                                    *keras_layers,
                                    kl.Dense(out_channels, activation='linear'),
                                    *([kl.Softmax()] if softmax else [])])
def pad_periodic(x: tensorflow.python.framework.tensor.Tensor)
Expand source code
def pad_periodic(x: Tensor):
    return PeriodicPad()(x)
def res_net(in_channels: int,
out_channels: int,
layers: Sequence[int],
batch_norm: bool = False,
activation: str | Callable = 'ReLU',
periodic=False,
in_spatial: int | tuple = 2)
Expand source code
def res_net(in_channels: int,
            out_channels: int,
            layers: Sequence[int],
            batch_norm: bool = False,
            activation: Union[str, Callable] = 'ReLU',
            periodic=False,
            in_spatial: Union[int, tuple] = 2):
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (None,) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    x = inputs = keras.Input(shape=in_spatial + (in_channels,))
    if len(layers) < 1:
        layers.append(out_channels)
    out = resnet_block(in_channels, layers[0], periodic, batch_norm, activation, d)(x)
    for i in range(1, len(layers)):
        out = resnet_block(layers[i - 1], layers[i], periodic, batch_norm, activation, d)(out)
    out = CONV[d](out_channels, 1)(out)
    return keras.Model(inputs, out)
def resnet_block(in_channels: int,
out_channels: int,
periodic: bool,
batch_norm: bool = False,
activation: str | Callable = 'ReLU',
in_spatial: int | tuple = 2,
kernel_size=3)
Expand source code
def resnet_block(in_channels: int,
                 out_channels: int,
                 periodic: bool,
                 batch_norm: bool = False,
                 activation: Union[str, Callable] = 'ReLU',
                 in_spatial: Union[int, tuple] = 2,
                 kernel_size=3):
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    if isinstance(in_spatial, int):
        d = in_spatial
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    x = x_1 = inputs = keras.Input(shape=(None,) * d + (in_channels,))
    x = CONV[d](out_channels, kernel_size, padding='valid')(pad_periodic(x)) if periodic else CONV[d](out_channels, kernel_size, padding='same')(x)
    if batch_norm:
        x = kl.BatchNormalization()(x)
    x = activation(x)
    x = CONV[d](out_channels, kernel_size, padding='valid')(pad_periodic(x)) if periodic else CONV[d](out_channels, kernel_size, padding='same')(x)
    if batch_norm:
        x = kl.BatchNormalization()(x)
    x = activation(x)
    if in_channels != out_channels:
        x_1 = CONV[d](out_channels, 1)(x_1)
        if batch_norm:
            x_1 = kl.BatchNormalization()(x_1)
    x = kl.Add()([x, x_1])
    return keras.Model(inputs, x)
def rmsprop(net: keras.src.models.model.Model,
learning_rate: float = 0.001,
alpha=0.99,
eps=1e-08,
weight_decay=0.0,
momentum=0.0,
centered=False)
Expand source code
def rmsprop(net: keras.Model, learning_rate: float = 1e-3, alpha=0.99, eps=1e-08, weight_decay=0., momentum=0., centered=False):
    assert weight_decay == 0
    return keras.optimizers.RMSprop(learning_rate, alpha, momentum, eps, centered)
def save_state(obj: keras.src.models.model.Model | keras.src.optimizers.optimizer.Optimizer,
path: str)
Expand source code
def save_state(obj: Union[keras.models.Model, keras.optimizers.Optimizer], path: str):
    if isinstance(obj, keras.models.Model):
        if not path.endswith('.h5'):
            path += '.h5'
        obj.save_weights(path)
        return path
    elif isinstance(obj, keras.optimizers.Optimizer):
        if not path.endswith('.pkl'):
            path += '.pkl'
        weights = obj.get_weights()
        with open(path, 'wb') as f:
            pickle.dump(weights, f)
        return path
    else:
        raise ValueError("obj must be a Keras model or optimizer")
def set_learning_rate(optimizer: keras.src.optimizers.optimizer.Optimizer, learning_rate: float)
Expand source code
def set_learning_rate(optimizer: keras.optimizers.Optimizer, learning_rate: float):
    """
    Sets the global learning rate for the given optimizer.

    Args:
        optimizer (optim.Optimizer): The optimizer whose learning rate needs to be updated.
        learning_rate (float): The new learning rate to set.
    """
    keras.backend.set_value(optimizer.lr, learning_rate)

Sets the global learning rate for the given optimizer.

Args

optimizer : optim.Optimizer
The optimizer whose learning rate needs to be updated.
learning_rate : float
The new learning rate to set.
def sgd(net: keras.src.models.model.Model,
learning_rate: float = 0.001,
momentum=0.0,
dampening=0.0,
weight_decay=0.0,
nesterov=False)
Expand source code
def sgd(net: keras.Model, learning_rate: float = 1e-3, momentum=0., dampening=0., weight_decay=0., nesterov=False):
    assert dampening == 0
    assert weight_decay == 0
    return keras.optimizers.SGD(learning_rate, momentum, nesterov)
def u_net(in_channels: int,
out_channels: int,
levels: int = 4,
filters: int | Sequence = 16,
batch_norm: bool = True,
activation: str | Callable = 'ReLU',
in_spatial: int | tuple = 2,
periodic=False,
use_res_blocks: bool = False,
down_kernel_size=3,
up_kernel_size=3) ‑> keras.src.models.model.Model
Expand source code
def u_net(in_channels: int,
          out_channels: int,
          levels: int = 4,
          filters: Union[int, Sequence] = 16,
          batch_norm: bool = True,
          activation: Union[str, Callable] = 'ReLU',
          in_spatial: Union[tuple, int] = 2,
          periodic=False,
          use_res_blocks: bool = False,
          down_kernel_size=3,
          up_kernel_size=3) -> keras.Model:
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (None,) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    if isinstance(filters, (tuple, list)):
        assert len(filters) == levels, f"List of filters has length {len(filters)} but u-net has {levels} levels."
    else:
        filters = (filters,) * levels
    # --- Construct the U-Net ---
    x = inputs = keras.Input(shape=in_spatial + (in_channels,))
    x = resnet_block(x.shape[-1], filters[0], periodic, batch_norm, activation, d, down_kernel_size)(x) if use_res_blocks else double_conv(x, d, filters[0], filters[0], batch_norm, activation, periodic, down_kernel_size)
    xs = [x]
    for i in range(1, levels):
        x = MAX_POOL[d](2, padding="same")(x)
        x = resnet_block(x.shape[-1], filters[i], periodic, batch_norm, activation, d, down_kernel_size)(x) if use_res_blocks else double_conv(x, d, filters[i], filters[i], batch_norm, activation, periodic, down_kernel_size)
        xs.insert(0, x)
    for i in range(1, levels):
        x = UPSAMPLE[d](2)(x)
        x = kl.Concatenate()([x, xs[i]])
        x = resnet_block(x.shape[-1], filters[i - 1], periodic, batch_norm, activation, d, up_kernel_size)(x) if use_res_blocks else double_conv(x, d, filters[i - 1], filters[i - 1], batch_norm, activation, periodic, up_kernel_size)
    x = CONV[d](out_channels, 1)(x)
    return keras.Model(inputs, x)
def update_weights(net: keras.src.models.model.Model,
optimizer: keras.src.optimizers.optimizer.Optimizer,
loss_function: Callable,
*loss_args,
**loss_kwargs)
Expand source code
def update_weights(net: keras.Model, optimizer: keras.optimizers.Optimizer, loss_function: Callable, *loss_args, **loss_kwargs):
    with tf.GradientTape() as tape:
        output = loss_function(*loss_args, **loss_kwargs)
        loss = output[0] if isinstance(output, tuple) else output
        gradients = tape.gradient(loss.sum, net.trainable_variables)
    optimizer.apply_gradients(zip(gradients, net.trainable_variables))
    return output

Classes

class CouplingLayer (construct_net: Callable, construction_kwargs: dict, reverse_mask)
Expand source code
class CouplingLayer(keras.Model):

    def __init__(self, construct_net: Callable, construction_kwargs: dict, reverse_mask):
        super().__init__()
        self.reverse_mask = reverse_mask
        self.s1 = construct_net(**construction_kwargs)
        self.t1 = construct_net(**construction_kwargs)
        self.s2 = construct_net(**construction_kwargs)
        self.t2 = construct_net(**construction_kwargs)

    def call(self, x, invert=False):
        mask = tf.cast(get_mask(x, self.reverse_mask, 'NCHW'), x.dtype)
        if invert:
            v1 = x * mask
            v2 = x * (1 - mask)
            u2 = (1 - mask) * (v2 - self.t1(v1)) * tf.math.exp(tf.tanh(-self.s1(v1)))
            u1 = mask * (v1 - self.t2(u2)) * tf.math.exp(tf.tanh(-self.s2(u2)))
            return u1 + u2
        else:
            u1 = x * mask
            u2 = x * (1 - mask)
            v1 = mask * (u1 * tf.math.exp(tf.tanh(self.s2(u2))) + self.t2(u2))
            v2 = (1 - mask) * (u2 * tf.math.exp(tf.tanh(self.s1(v1))) + self.t1(v1))
            return v1 + v2

A model grouping layers into an object with training/inference features.

There are three ways to instantiate a Model:

With the "Functional API"

You start from Input, you chain layer calls to specify the model's forward pass, and finally, you create your model from inputs and outputs:

inputs = keras.Input(shape=(37,))
x = keras.layers.Dense(32, activation="relu")(inputs)
outputs = keras.layers.Dense(5, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

Note: Only dicts, lists, and tuples of input tensors are supported. Nested inputs are not supported (e.g. lists of list or dicts of dict).

A new Functional API model can also be created by using the intermediate tensors. This enables you to quickly extract sub-components of the model.

Example:

inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=128, height=128)(inputs)
conv = keras.layers.Conv2D(filters=32, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)

full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)

Note that the backbone and activations models are not created with keras.Input objects, but with the tensors that originate from keras.Input objects. Under the hood, the layers and weights will be shared across these models, so that user can train the full_model, and use backbone or activations to do feature extraction. The inputs and outputs of the model can be nested structures of tensors as well, and the created models are standard Functional API models that support all the existing APIs.

By subclassing the Model class

In that case, you should define your layers in __init__() and you should implement the model's forward pass in call().

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(5, activation="softmax")

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

model = MyModel()

If you subclass Model, you can optionally have a training argument (boolean) in call(), which you can use to specify a different behavior in training and inference:

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(5, activation="softmax")
        self.dropout = keras.layers.Dropout(0.5)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.dropout(x, training=training)
        return self.dense2(x)

model = MyModel()

Once the model is created, you can config the model with losses and metrics with model.compile(), train the model with model.fit(), or use the model to do prediction with model.predict().

With the Sequential class

In addition, keras.Sequential is a special case of model where the model is purely a stack of single-input, single-output layers.

model = keras.Sequential([
    keras.Input(shape=(None, None, 3)),
    keras.layers.Conv2D(filters=32, kernel_size=3),
])

Ancestors

  • keras.src.models.model.Model
  • keras.src.backend.tensorflow.trainer.TensorFlowTrainer
  • keras.src.trainers.trainer.Trainer
  • keras.src.layers.layer.Layer
  • keras.src.backend.tensorflow.layer.TFLayer
  • keras.src.backend.tensorflow.trackable.KerasAutoTrackable
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.src.ops.operation.Operation
  • keras.src.saving.keras_saveable.KerasSaveable

Methods

def call(self, x, invert=False)
Expand source code
def call(self, x, invert=False):
    mask = tf.cast(get_mask(x, self.reverse_mask, 'NCHW'), x.dtype)
    if invert:
        v1 = x * mask
        v2 = x * (1 - mask)
        u2 = (1 - mask) * (v2 - self.t1(v1)) * tf.math.exp(tf.tanh(-self.s1(v1)))
        u1 = mask * (v1 - self.t2(u2)) * tf.math.exp(tf.tanh(-self.s2(u2)))
        return u1 + u2
    else:
        u1 = x * mask
        u2 = x * (1 - mask)
        v1 = mask * (u1 * tf.math.exp(tf.tanh(self.s2(u2))) + self.t2(u2))
        v2 = (1 - mask) * (u2 * tf.math.exp(tf.tanh(self.s1(v1))) + self.t1(v1))
        return v1 + v2
class InvertibleNet (num_blocks: int, construct_net, construction_kwargs: dict)
Expand source code
class InvertibleNet(keras.Model):

    def __init__(self, num_blocks: int, construct_net, construction_kwargs: dict):
        super(InvertibleNet, self).__init__()
        self.num_blocks = num_blocks
        self.layer_dict = {}
        for i in range(num_blocks):
            self.layer_dict[f'coupling_block{i + 1}'] = CouplingLayer(construct_net, construction_kwargs, (i % 2 == 0))

    def call(self, x, backward=False):
        if backward:
            for i in range(self.num_blocks, 0, -1):
                x = self.layer_dict[f'coupling_block{i}'](x, backward)
        else:
            for i in range(1, self.num_blocks + 1):
                x = self.layer_dict[f'coupling_block{i}'](x)
        return x

A model grouping layers into an object with training/inference features.

There are three ways to instantiate a Model:

With the "Functional API"

You start from Input, you chain layer calls to specify the model's forward pass, and finally, you create your model from inputs and outputs:

inputs = keras.Input(shape=(37,))
x = keras.layers.Dense(32, activation="relu")(inputs)
outputs = keras.layers.Dense(5, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

Note: Only dicts, lists, and tuples of input tensors are supported. Nested inputs are not supported (e.g. lists of list or dicts of dict).

A new Functional API model can also be created by using the intermediate tensors. This enables you to quickly extract sub-components of the model.

Example:

inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=128, height=128)(inputs)
conv = keras.layers.Conv2D(filters=32, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)

full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)

Note that the backbone and activations models are not created with keras.Input objects, but with the tensors that originate from keras.Input objects. Under the hood, the layers and weights will be shared across these models, so that user can train the full_model, and use backbone or activations to do feature extraction. The inputs and outputs of the model can be nested structures of tensors as well, and the created models are standard Functional API models that support all the existing APIs.

By subclassing the Model class

In that case, you should define your layers in __init__() and you should implement the model's forward pass in call().

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(5, activation="softmax")

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

model = MyModel()

If you subclass Model, you can optionally have a training argument (boolean) in call(), which you can use to specify a different behavior in training and inference:

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(5, activation="softmax")
        self.dropout = keras.layers.Dropout(0.5)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.dropout(x, training=training)
        return self.dense2(x)

model = MyModel()

Once the model is created, you can config the model with losses and metrics with model.compile(), train the model with model.fit(), or use the model to do prediction with model.predict().

With the Sequential class

In addition, keras.Sequential is a special case of model where the model is purely a stack of single-input, single-output layers.

model = keras.Sequential([
    keras.Input(shape=(None, None, 3)),
    keras.layers.Conv2D(filters=32, kernel_size=3),
])

Ancestors

  • keras.src.models.model.Model
  • keras.src.backend.tensorflow.trainer.TensorFlowTrainer
  • keras.src.trainers.trainer.Trainer
  • keras.src.layers.layer.Layer
  • keras.src.backend.tensorflow.layer.TFLayer
  • keras.src.backend.tensorflow.trackable.KerasAutoTrackable
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.src.ops.operation.Operation
  • keras.src.saving.keras_saveable.KerasSaveable

Methods

def call(self, x, backward=False)
Expand source code
def call(self, x, backward=False):
    if backward:
        for i in range(self.num_blocks, 0, -1):
            x = self.layer_dict[f'coupling_block{i}'](x, backward)
    else:
        for i in range(1, self.num_blocks + 1):
            x = self.layer_dict[f'coupling_block{i}'](x)
    return x
class PeriodicPad (*,
activity_regularizer=None,
trainable=True,
dtype=None,
autocast=True,
name=None,
**kwargs)
Expand source code
class PeriodicPad(kl.Layer):
    def call(self, x):
        d = len(x.shape) - 2
        if d >= 1:
            x = tf.concat([tf.expand_dims(x[:, -1, ...], axis=1), x, tf.expand_dims(x[:, 0, ...], axis=1)], axis=1)
        if d >= 2:
            x = tf.concat([tf.expand_dims(x[:, :, -1, ...], axis=2), x, tf.expand_dims(x[:, :, 0, ...], axis=2)], axis=2)
        if d >= 3:
            x = tf.concat([tf.expand_dims(x[:, :, :, -1, ...], axis=3), x, tf.expand_dims(x[:, :, :, 0, ...], axis=3)], axis=3)
        return x

This is the class from which all layers inherit.

A layer is a callable object that takes as input one or more tensors and that outputs one or more tensors. It involves computation, defined in the call() method, and a state (weight variables). State can be created:

  • in __init__(), for instance via self.add_weight();
  • in the optional build() method, which is invoked by the first __call__() to the layer, and supplies the shape(s) of the input(s), which may not have been known at initialization time.

Layers are recursively composable: If you assign a Layer instance as an attribute of another Layer, the outer layer will start tracking the weights created by the inner layer. Nested layers should be instantiated in the __init__() method or build() method.

Users will just instantiate a layer and then treat it as a callable.

Args

trainable
Boolean, whether the layer's variables should be trainable.
name
String name of the layer.
dtype
The dtype of the layer's computations and weights. Can also be a keras.DTypePolicy, which allows the computation and weight dtype to differ. Defaults to None. None means to use keras.config.dtype_policy(), which is a float32 policy unless set to different value (via keras.config.set_dtype_policy()).

Attributes

name
The name of the layer (string).
dtype
Dtype of the layer's weights. Alias of layer.variable_dtype.
variable_dtype
Dtype of the layer's weights.
compute_dtype
The dtype of the layer's computations. Layers automatically cast inputs to this dtype, which causes the computations and output to also be in this dtype. When mixed precision is used with a keras.DTypePolicy, this will be different than variable_dtype.
trainable_weights
List of variables to be included in backprop.
non_trainable_weights
List of variables that should not be included in backprop.
weights
The concatenation of the lists trainable_weights and non_trainable_weights (in this order).
trainable
Whether the layer should be trained (boolean), i.e. whether its potentially-trainable weights should be returned as part of layer.trainable_weights.
input_spec
Optional (list of) InputSpec object(s) specifying the constraints on inputs that can be accepted by the layer.

We recommend that descendants of Layer implement the following methods:

  • __init__(): Defines custom layer attributes, and creates layer weights that do not depend on input shapes, using add_weight(), or other state.
  • build(self, input_shape): This method can be used to create weights that depend on the shape(s) of the input(s), using add_weight(), or other state. __call__() will automatically build the layer (if it has not been built yet) by calling build().
  • call(self, *args, **kwargs): Called in __call__ after making sure build() has been called. call() performs the logic of applying the layer to the input arguments. Two reserved keyword arguments you can optionally use in call() are: 1. training (boolean, whether the call is in inference mode or training mode). 2. mask (boolean tensor encoding masked timesteps in the input, used e.g. in RNN layers). A typical signature for this method is call(self, inputs), and user could optionally add training and mask if the layer need them.
  • get_config(self): Returns a dictionary containing the configuration used to initialize this layer. If the keys differ from the arguments in __init__(), then override from_config(self) as well. This method is used when saving the layer or a model that contains this layer.

Examples:

Here's a basic example: a layer with two variables, w and b, that returns y = w . x + b. It shows how to implement build() and call(). Variables set as attributes of a layer are tracked as weights of the layers (in layer.weights).

class SimpleDense(Layer):
    def __init__(self, units=32):
        super().__init__()
        self.units = units

    # Create the state of the layer (weights)
    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="glorot_uniform",
            trainable=True,
            name="kernel",
        )
        self.bias = self.add_weight(
            shape=(self.units,),
            initializer="zeros",
            trainable=True,
            name="bias",
        )

    # Defines the computation
    def call(self, inputs):
        return ops.matmul(inputs, self.kernel) + self.bias

# Instantiates the layer.
linear_layer = SimpleDense(4)

# This will also call `build(input_shape)` and create the weights.
y = linear_layer(ops.ones((2, 2)))
assert len(linear_layer.weights) == 2

# These weights are trainable, so they're listed in `trainable_weights`:
assert len(linear_layer.trainable_weights) == 2

Besides trainable weights, updated via backpropagation during training, layers can also have non-trainable weights. These weights are meant to be updated manually during call(). Here's a example layer that computes the running sum of its inputs:

class ComputeSum(Layer):

  def __init__(self, input_dim):
      super(ComputeSum, self).__init__()
      # Create a non-trainable weight.
      self.total = self.add_weight(
        shape=(),
        initializer="zeros",
        trainable=False,
        name="total",
      )

  def call(self, inputs):
      self.total.assign(self.total + ops.sum(inputs))
      return self.total

my_sum = ComputeSum(2)
x = ops.ones((2, 2))
y = my_sum(x)

assert my_sum.weights == [my_sum.total]
assert my_sum.non_trainable_weights == [my_sum.total]
assert my_sum.trainable_weights == []

Ancestors

  • keras.src.layers.layer.Layer
  • keras.src.backend.tensorflow.layer.TFLayer
  • keras.src.backend.tensorflow.trackable.KerasAutoTrackable
  • tensorflow.python.trackable.autotrackable.AutoTrackable
  • tensorflow.python.trackable.base.Trackable
  • keras.src.ops.operation.Operation
  • keras.src.saving.keras_saveable.KerasSaveable

Methods

def call(self, x)
Expand source code
def call(self, x):
    d = len(x.shape) - 2
    if d >= 1:
        x = tf.concat([tf.expand_dims(x[:, -1, ...], axis=1), x, tf.expand_dims(x[:, 0, ...], axis=1)], axis=1)
    if d >= 2:
        x = tf.concat([tf.expand_dims(x[:, :, -1, ...], axis=2), x, tf.expand_dims(x[:, :, 0, ...], axis=2)], axis=2)
    if d >= 3:
        x = tf.concat([tf.expand_dims(x[:, :, :, -1, ...], axis=3), x, tf.expand_dims(x[:, :, :, 0, ...], axis=3)], axis=3)
    return x