Module phiml.backend.jax.stax_nets

Stax implementation of the unified machine learning API. Equivalent functions also exist for the other frameworks.

For API documentation, see https://tum-pbs.github.io/PhiML/Network_API .

Functions

def Dense_resnet_block(in_channels: int,
mid_channels: int,
batch_norm: bool = False,
activation: str | Callable = 'ReLU')
Expand source code
def Dense_resnet_block(in_channels: int,
                       mid_channels: int,
                       batch_norm: bool = False,
                       activation: Union[str, Callable] = 'ReLU'):
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    init_fn, apply_fn = {}, {}
    init_fn['dense1'], apply_fn['dense1'] = stax.serial(stax.Dense(mid_channels), stax.BatchNorm(axis=(0,)), activation)
    init_fn['dense2'], apply_fn['dense2'] = stax.serial(stax.Dense(in_channels), stax.BatchNorm(axis=(0,)), activation)
    init_activation, apply_activation = activation

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape, params['dense1'] = init_fn['dense1'](rngs[0], input_shape)
        shape, params['dense2'] = init_fn['dense2'](rngs[1], shape)
        shape, params['activation'] = init_activation(rngs[2], shape)
        return shape, params

    def net_apply(params, inputs, **kwargs):
        x = inputs
        out = apply_fn['dense1'](params['dense1'], x)
        out = apply_fn['dense2'](params['dense2'], out)
        out = jnp.add(out, x)
        return out

    return net_init, net_apply
def adagrad(net: StaxNet,
learning_rate: float = 0.001,
lr_decay=0.0,
weight_decay=0.0,
initial_accumulator_value=0.0,
eps=1e-10)
Expand source code
def adagrad(net: StaxNet, learning_rate: float = 1e-3, lr_decay=0., weight_decay=0., initial_accumulator_value=0., eps=1e-10):
    assert lr_decay == 0
    assert weight_decay == 0
    assert initial_accumulator_value == 0
    assert eps == 1e-10
    opt = JaxOptimizer(*optim.adagrad(learning_rate))
    opt.initialize(net.parameters)
    return opt
def adam(net: StaxNet,
learning_rate: float = 0.001,
betas=(0.9, 0.999),
epsilon=1e-07)
Expand source code
def adam(net: StaxNet, learning_rate: float = 1e-3, betas=(0.9, 0.999), epsilon=1e-07):
    opt = JaxOptimizer(*optim.adam(learning_rate, betas[0], betas[1], epsilon))
    opt.initialize(net.parameters)
    return opt
def conv_classifier(in_features: int,
in_spatial: tuple | list,
num_classes: int,
blocks=(64, 128, 256, 256, 512, 512),
block_sizes=(2, 2, 3, 3, 3),
dense_layers=(4096, 4096, 100),
batch_norm=True,
activation='ReLU',
softmax=True,
periodic=False)
Expand source code
def conv_classifier(in_features: int,
                    in_spatial: Union[tuple, list],
                    num_classes: int,
                    blocks=(64, 128, 256, 256, 512, 512),
                    block_sizes=(2, 2, 3, 3, 3),
                    dense_layers=(4096, 4096, 100),
                    batch_norm=True,
                    activation='ReLU',
                    softmax=True,
                    periodic=False):
    """
    Based on VGG16.
    """
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (1,) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    stax_dense_layers = []
    init_fn, apply_fn = {}, {}
    net_list = []
    for i, (prev, next, block_size) in enumerate(zip((in_features,) + tuple(blocks[:-1]), blocks, block_sizes)):
        for j in range(block_size):
            net_list.append(f'conv{i+1}_{j}')
            init_fn[net_list[-1]], apply_fn[net_list[-1]] = stax.serial(CONV[d](next, (3,) * d, padding='valid'),
                                                                        stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity,
                                                                        activation)
        net_list.append(f'max_pool{i+1}')
        init_fn[net_list[-1]], apply_fn[net_list[-1]] = stax.MaxPool((2,) * d, padding='valid', strides=(2,) * d)
    init_fn['flatten'], apply_fn['flatten'] = stax.Flatten
    for i, neuron_count in enumerate(dense_layers):
        stax_dense_layers.append(stax.Dense(neuron_count))
        stax_dense_layers.append(activation)
        if batch_norm:
            stax_dense_layers.append(stax.BatchNorm(axis=(0,)))
    stax_dense_layers.append(stax.Dense(num_classes))
    if softmax:
        stax_dense_layers.append(stax.elementwise(stax.softmax, axis=-1))
    dense_init, dense_apply = stax.serial(*stax_dense_layers)

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape = input_shape
        N = len(net_list)
        for i, layer in enumerate(net_list):
            shape, params[layer] = init_fn[layer](rngs[i], shape)
        shape, params['flatten'] = init_fn['flatten'](rngs[N], shape)
        flat_size = int(np.prod(in_spatial) * blocks[-1] / (2**d) ** len(blocks))
        shape, params['dense'] = dense_init(rngs[N + 1], (1,) + (flat_size,))
        return shape, params

    def net_apply(params, inputs, **kwargs):
        x = inputs
        pad_tuple = [[0, 0]] + [[1, 1]] * d + [[0, 0]]
        for i in range(len(net_list)):
            if net_list[i].startswith('conv'):
                x = jnp.pad(x, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
            x = apply_fn[f'{net_list[i]}'](params[f'{net_list[i]}'], x)
        x = apply_fn['flatten'](params['flatten'], x)
        out = dense_apply(params['dense'], x, **kwargs)
        return out

    net = StaxNet(net_init, net_apply, (1,) + in_spatial + (in_features,))
    net.initialize()
    return net

Based on VGG16.

def conv_net(in_channels: int,
out_channels: int,
layers: Sequence[int],
batch_norm: bool = False,
activation: str | Callable = 'ReLU',
periodic=False,
in_spatial: int | tuple = 2) ‑> StaxNet
Expand source code
def conv_net(in_channels: int,
             out_channels: int,
             layers: Sequence[int],
             batch_norm: bool = False,
             activation: Union[str, Callable] = 'ReLU',
             periodic=False,
             in_spatial: Union[int, tuple] = 2) -> StaxNet:
    """
    Built in Conv-Nets are also provided. Contrary to the classical convolutional neural networks, the feature map spatial size remains the same throughout the layers. Each layer of the network is essentially a convolutional block comprising of two conv layers. A filter size of 3 is used in the convolutional layers.
    Arguments:

        in_channels : input channels of the feature map, dtype : int
        out_channels : output channels of the feature map, dtype : int
        layers : list or tuple of output channels for each intermediate layer between the input and final output channels, dtype : list or tuple
        activation : activation function used within the layers, dtype : string
        batch_norm : use of batchnorm after each conv layer, dtype : bool
        in_spatial : spatial dimensions of the input feature map, dtype : int

    Returns:

        Conv-net model as specified by input arguments
    """
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (1,) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    init_fn, apply_fn = {}, {}
    if len(layers) < 1:
        layers.append(out_channels)
    init_fn['conv_in'], apply_fn['conv_in'] = stax.serial(
        CONV[d](layers[0], (3,) * d, padding='valid'),
        stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity,
        activation)
    for i in range(1, len(layers)):
        init_fn[f'conv{i}'], apply_fn[f'conv{i}'] = stax.serial(
            CONV[d](layers[i], (3,) * d, padding='valid'),
            stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity,
            activation)
    init_fn['conv_out'], apply_fn['conv_out'] = CONV[d](out_channels, (1,) * d)

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape, params['conv_in'] = init_fn['conv_in'](rngs[0], input_shape)
        for i in range(1, len(layers)):
            shape, params[f'conv{i + 1}'] = init_fn[f'conv{i + 1}'](rngs[i], shape)
        shape, params['conv_out'] = init_fn['conv_out'](rngs[len(layers)], shape)
        return shape, params

    def net_apply(params, inputs):
        x = inputs
        pad_tuple = [(0, 0)]
        for i in range(d):
            pad_tuple.append((1, 1))
        pad_tuple.append((0, 0))
        out = jnp.pad(x, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
        out = apply_fn['conv_in'](params['conv_in'], out)
        for i in range(1, len(layers)):
            out = jnp.pad(out, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
            out = apply_fn[f'conv{i + 1}'](params[f'conv{i + 1}'], out)
        out = apply_fn['conv_out'](params['conv_out'], out)
        return out

    net = StaxNet(net_init, net_apply, (1,) + in_spatial + (in_channels,))
    net.initialize()
    return net

Built in Conv-Nets are also provided. Contrary to the classical convolutional neural networks, the feature map spatial size remains the same throughout the layers. Each layer of the network is essentially a convolutional block comprising of two conv layers. A filter size of 3 is used in the convolutional layers.

Arguments

in_channels : input channels of the feature map, dtype : int out_channels : output channels of the feature map, dtype : int layers : list or tuple of output channels for each intermediate layer between the input and final output channels, dtype : list or tuple activation : activation function used within the layers, dtype : string batch_norm : use of batchnorm after each conv layer, dtype : bool in_spatial : spatial dimensions of the input feature map, dtype : int

Returns

Conv-net model as specified by input arguments

def conv_net_unit(in_channels: int,
out_channels: int,
layers: Sequence[int],
periodic: bool = False,
batch_norm: bool = False,
activation: str | Callable = 'ReLU',
in_spatial: int | tuple = 2,
**kwargs)
Expand source code
def conv_net_unit(in_channels: int,
                  out_channels: int,
                  layers: Sequence[int],
                  periodic: bool = False,
                  batch_norm: bool = False,
                  activation: Union[str, Callable] = 'ReLU',
                  in_spatial: Union[int, tuple] = 2, **kwargs):
    """ Conv-net unit for Invertible Nets"""
    if isinstance(in_spatial, int):
        d = in_spatial
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    if isinstance(activation, str):
        activation = ACTIVATIONS[activation]
    init_fn, apply_fn = {}, {}
    if len(layers) < 1:
        layers.append(out_channels)
    init_fn['conv_in'], apply_fn['conv_in'] = stax.serial(
        CONV[d](layers[0], (3,) * d, padding='valid'),
        stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity,
        activation)
    for i in range(1, len(layers)):
        init_fn[f'conv{i}'], apply_fn[f'conv{i}'] = stax.serial(
            CONV[d](layers[i], (3,) * d, padding='valid'),
            stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity,
            activation)
    init_fn['conv_out'], apply_fn['conv_out'] = CONV[d](out_channels, (1,) * d)

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape, params['conv_in'] = init_fn['conv_in'](rngs[0], input_shape)
        for i in range(1, len(layers)):
            shape, params[f'conv{i + 1}'] = init_fn[f'conv{i + 1}'](rngs[i], shape)
        shape, params['conv_out'] = init_fn['conv_out'](rngs[len(layers)], shape)
        return shape, params

    def net_apply(params, inputs):
        x = inputs
        pad_tuple = [(0, 0)]
        for i in range(d):
            pad_tuple.append((1, 1))
        pad_tuple.append((0, 0))
        out = jnp.pad(x, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
        out = apply_fn['conv_in'](params['conv_in'], out)
        for i in range(1, len(layers)):
            out = jnp.pad(out, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
            out = apply_fn[f'conv{i + 1}'](params[f'conv{i + 1}'], out)
        out = apply_fn['conv_out'](params['conv_out'], out)
        return out

    return net_init, net_apply

Conv-net unit for Invertible Nets

def coupling_layer(in_channels: int,
activation: str | Callable = 'ReLU',
batch_norm: bool = False,
in_spatial: int | tuple = 2,
net: str = 'u_net',
reverse_mask: bool = False,
**kwargs)
Expand source code
def coupling_layer(in_channels: int,
                   activation: Union[str, Callable] = 'ReLU',
                   batch_norm: bool = False,
                   in_spatial: Union[int, tuple] = 2,
                   net: str = 'u_net',
                   reverse_mask: bool = False,
                   **kwargs):
    if isinstance(in_spatial, int):
        d = in_spatial
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    init_fn, apply_fn = {}, {}
    if d == 0:
        init_fn['s1'], apply_fn['s1'] = stax.serial(
            Dense_resnet_block(in_channels, in_channels, batch_norm, activation),
            stax.Tanh)
        init_fn['t1'], apply_fn['t1'] = Dense_resnet_block(in_channels, in_channels, batch_norm, activation)

        init_fn['s2'], apply_fn['s2'] = stax.serial(
            Dense_resnet_block(in_channels, in_channels, batch_norm, activation),
            stax.Tanh)
        init_fn['t2'], apply_fn['t2'] = Dense_resnet_block(in_channels, in_channels, batch_norm, activation)
    else:
        init_fn['s1'], apply_fn['s1'] = NET[net](in_channels=in_channels, out_channels=in_channels, layers=[], batch_norm=batch_norm, activation=activation, in_spatial=in_spatial, **kwargs)
        init_fn['t1'], apply_fn['t1'] = NET[net](in_channels=in_channels, out_channels=in_channels, layers=[], batch_norm=batch_norm, activation=activation, in_spatial=in_spatial, **kwargs)
        init_fn['s2'], apply_fn['s2'] = NET[net](in_channels=in_channels, out_channels=in_channels, layers=[], batch_norm=batch_norm, activation=activation, in_spatial=in_spatial, **kwargs)
        init_fn['t2'], apply_fn['t2'] = NET[net](in_channels=in_channels, out_channels=in_channels, layers=[], batch_norm=batch_norm, activation=activation, in_spatial=in_spatial, **kwargs)

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape, params['s1'] = init_fn['s1'](rngs[0], input_shape)
        shape, params['t1'] = init_fn['t1'](rngs[1], input_shape)
        shape, params['s2'] = init_fn['s2'](rngs[2], input_shape)
        shape, params['t2'] = init_fn['t2'](rngs[3], input_shape)
        return shape, params

    def net_apply(params, inputs, invert=False):
        x = inputs
        mask = get_mask(x, reverse_mask, 'NCHW')
        if invert:
            v1 = x * mask
            v2 = x * (1 - mask)
            s1 = apply_fn['s1'](params['s1'], v1)
            t1 = apply_fn['t1'](params['t1'], v1)
            u2 = (1 - mask) * (v2 - t1) * jnp.exp(-jnp.tanh(s1))
            s2 = apply_fn['s2'](params['s2'], u2)
            t2 = apply_fn['t2'](params['t2'], u2)
            u1 = mask * (v1 - t2) * jnp.exp(-jnp.tanh(s2))
            return u1 + u2
        else:
            u1 = x * mask
            u2 = x * (1 - mask)
            s2 = apply_fn['s2'](params['s2'], u2)
            t2 = apply_fn['t2'](params['t2'], u2)
            v1 = mask * (u1 * jnp.exp(jnp.tanh(s2)) + t2)
            s1 = apply_fn['s1'](params['s1'], v1)
            t1 = apply_fn['t1'](params['t1'], v1)
            v2 = (1 - mask) * (u2 * jnp.exp(jnp.tanh(s1)) + t1)
            return v1 + v2
    return net_init, net_apply
def create_double_conv(d: int,
out_channels: int,
mid_channels: int,
batch_norm: bool,
activation: Callable,
periodic: bool,
kernel_size=3)
Expand source code
def create_double_conv(d: int, out_channels: int, mid_channels: int, batch_norm: bool, activation: Callable, periodic: bool, kernel_size=3):
    init_fn, apply_fn = {}, {}
    init_fn['conv1'], apply_fn['conv1'] = stax.serial(CONV[d](mid_channels, (kernel_size,) * d, padding='valid'), stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity, activation)
    init_fn['conv2'], apply_fn['conv2'] = stax.serial(CONV[d](mid_channels, (kernel_size,) * d, padding='valid'), stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity, activation)

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape, params['conv1'] = init_fn['conv1'](rngs[0], input_shape)
        shape, params['conv2'] = init_fn['conv2'](rngs[1], shape)

        return shape, params

    def net_apply(params, inputs):
        x = inputs
        pad_tuple = [[0, 0]] + [[1, 1] for i in range(d)] + [[0, 0]]
        out = jnp.pad(x, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
        out = apply_fn['conv1'](params['conv1'], out)
        out = jnp.pad(out, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
        out = apply_fn['conv2'](params['conv2'], out)
        return out

    return net_init, net_apply
def create_upsample()
Expand source code
def create_upsample():
    def upsample_apply(params, inputs, **kwargs):
        dims = [math.batch('batch'), *[math.spatial(f'{i}') for i in range(len(inputs.shape) - 2)], math.channel('vector')]
        x = math.wrap(inputs, dims)
        x = math.upsample2x(x)
        return x.native([d.name for d in dims])
    return NotImplemented, upsample_apply
def fno(in_channels: int,
out_channels: int,
mid_channels: int,
modes: Sequence[int],
activation: str | type = 'ReLU',
batch_norm: bool = False,
in_spatial: int = 2)
Expand source code
def fno(in_channels: int,
        out_channels: int,
        mid_channels: int,
        modes: Sequence[int],
        activation: Union[str, type] = 'ReLU',
        batch_norm: bool = False,
        in_spatial: int = 2):
    raise NotImplementedError("fno is not implemented for Jax")
def get_mask(inputs, reverse_mask, data_format='NHWC')
Expand source code
def get_mask(inputs, reverse_mask, data_format='NHWC'):
    """ Compute mask for slicing input feature map for Invertible Nets """
    shape = inputs.shape
    if len(shape) == 2:
        N = shape[-1]
        range_n = jnp.arange(0, N)
        even_ind = range_n % 2
        checker = jnp.reshape(even_ind, (-1, N))
    elif len(shape) == 4:
        H = shape[2] if data_format == 'NCHW' else shape[1]
        W = shape[3] if data_format == 'NCHW' else shape[2]
        range_h = jnp.arange(0, H) % 2
        range_w = jnp.arange(0, W) % 2
        even_ind_h = range_h.astype(bool)
        even_ind_w = range_w.astype(bool)
        ind_h = jnp.tile(jnp.expand_dims(even_ind_h, -1), [1, W])
        ind_w = jnp.tile(jnp.expand_dims(even_ind_w, 0), [H, 1])
        # ind_h = even_ind_h.unsqueeze(-1).repeat(1, W)
        # ind_w = even_ind_w.unsqueeze( 0).repeat(H, 1)
        checker = jnp.logical_xor(ind_h, ind_w)
        reshape = [-1, 1, H, W] if data_format == 'NCHW' else [-1, H, W, 1]
        checker = jnp.reshape(checker, reshape)
        checker = checker.astype(jnp.float32)
    else:
        raise ValueError('Invalid tensor shape. Dimension of the tensor shape must be 2 (NxD) or 4 (NxCxHxW or NxHxWxC), got {}.'.format(inputs.get_shape().as_list()))
    if reverse_mask:
        checker = 1 - checker
    return checker

Compute mask for slicing input feature map for Invertible Nets

def get_parameters(model: StaxNet,
wrap=True) ‑> dict
Expand source code
def get_parameters(model: StaxNet, wrap=True) -> dict:
    result = {}
    _recursive_add_parameters(model.parameters, wrap, (), result)
    return result
def invertible_net(num_blocks: int, construct_net: str | Callable, **construct_kwargs)
Expand source code
def invertible_net(num_blocks: int,
                   construct_net: Union[str, Callable],
                   **construct_kwargs):  # mlp, u_net, res_net, conv_net
    raise NotImplementedError("invertible_net is not implemented for Jax")
    # if construct_net == 'mlp':
    #     construct_net = '_inv_net_dense_resnet_block'
    # if isinstance(construct_net, str):
    #     construct_net = globals()[construct_net]
    # if 'in_channels' in construct_kwargs and 'out_channels' not in construct_kwargs:
    #     construct_kwargs['out_channels'] = construct_kwargs['in_channels']
    # init_fn, apply_fn = {}, {}
    # for i in range(num_blocks):
    #     init_fn[f'CouplingLayer{i + 1}'], apply_fn[f'CouplingLayer{i + 1}'] = coupling_layer(in_channels, activation, batch_norm, d, net, (i % 2 == 0), **kwargs)
    #
    # def net_init(rng, input_shape):
    #     params = {}
    #     rngs = random.split(rng, 2)
    #     for i in range(num_blocks):
    #         shape, params[f'CouplingLayer{i + 1}'] = init_fn[f'CouplingLayer{i + 1}'](rngs[i], input_shape)
    #     return shape, params
    #
    # def net_apply(params, inputs, invert=False):
    #     out = inputs
    #     if invert:
    #         for i in range(num_blocks, 0, -1):
    #             out = apply_fn[f'CouplingLayer{i}'](params[f'CouplingLayer{i}'], out, invert)
    #     else:
    #         for i in range(1, num_blocks + 1):
    #             out = apply_fn[f'CouplingLayer{i}'](params[f'CouplingLayer{i}'], out)
    #     return out
    #
    # if d == 0:
    #     net = StaxNet(net_init, net_apply, (1,) + (in_channels,))
    # else:
    #     net = StaxNet(net_init, net_apply, (1,) + (1,) * d + (in_channels,))
    # net.initialize()
    # return net
def load_state(obj: StaxNet | JaxOptimizer,
path: str)
Expand source code
def load_state(obj: Union[StaxNet, JaxOptimizer], path: str):
    if not path.endswith('.npz'):
        path += '.npz'
    if isinstance(obj, StaxNet):
        data = numpy.load(path, allow_pickle=True)
        x = data['x'].tolist()
        x_flat, tree = tree_flatten(x)
        x_flat = [jnp.array(f) for f in x_flat]
        obj.parameters = tree_unflatten(tree, x_flat)
    else:
        xs, tree_def = tree_flatten(obj.get_network_parameters())
        state = np.load(path, allow_pickle=True)['state'].tolist()
        packed_state = [(x, *pt) for x, pt in zip(xs, state)]
        obj._state = OptimizerState(packed_state, obj._state.tree_def, obj._state.subtree_defs)
def mlp(in_channels: int,
out_channels: int,
layers: Sequence[int],
batch_norm=False,
activation='ReLU',
softmax=False) ‑> StaxNet
Expand source code
def mlp(in_channels: int,
              out_channels: int,
              layers: Sequence[int],
              batch_norm=False,
              activation='ReLU',
              softmax=False) -> StaxNet:
    activation = {'ReLU': stax.Relu, 'Sigmoid': stax.Sigmoid, 'tanh': stax.Tanh}[activation]
    stax_layers = []
    for neuron_count in layers:
        stax_layers.append(stax.Dense(neuron_count))
        stax_layers.append(activation)
        if batch_norm:
            stax_layers.append(stax.BatchNorm(axis=(0,)))
    stax_layers.append(stax.Dense(out_channels))
    if softmax:
        stax_layers.append(stax.elementwise(stax.softmax, axis=-1))
    net_init, net_apply = stax.serial(*stax_layers)
    net = StaxNet(net_init, net_apply, (-1, in_channels))
    net.initialize()
    return net
def res_net(in_channels: int,
out_channels: int,
layers: Sequence[int],
batch_norm: bool = False,
activation: str | Callable = 'ReLU',
periodic=False,
in_spatial: int | tuple = 2) ‑> StaxNet
Expand source code
def res_net(in_channels: int,
            out_channels: int,
            layers: Sequence[int],
            batch_norm: bool = False,
            activation: Union[str, Callable] = 'ReLU',
            periodic=False,
            in_spatial: Union[int, tuple] = 2) -> StaxNet:
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (1,) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    stax_layers = []
    if len(layers) > 0:
        stax_layers.append(resnet_block(in_channels, layers[0], periodic, batch_norm, activation, d))
        for i in range(1, len(layers)):
            stax_layers.append(resnet_block(layers[i - 1], layers[i], periodic, batch_norm, activation, d))
        stax_layers.append(resnet_block(layers[len(layers) - 1], out_channels, periodic, batch_norm, activation, d))
    else:
        stax_layers.append(resnet_block(in_channels, out_channels, periodic, batch_norm, activation, d))
    net_init, net_apply = stax.serial(*stax_layers)
    net = StaxNet(net_init, net_apply, (1,) + in_spatial + (in_channels,))
    net.initialize()
    return net
def res_net_unit(in_channels: int,
out_channels: int,
layers: Sequence[int],
batch_norm: bool = False,
activation: str | Callable = 'ReLU',
periodic=False,
in_spatial: int | tuple = 2,
**kwargs)
Expand source code
def res_net_unit(in_channels: int,
                 out_channels: int,
                 layers: Sequence[int],
                 batch_norm: bool = False,
                 activation: Union[str, Callable] = 'ReLU',
                 periodic=False,
                 in_spatial: Union[int, tuple] = 2, **kwargs):
    """ Res-net unit for Invertible Nets"""
    if isinstance(in_spatial, int):
        d = in_spatial
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    stax_layers = []
    if len(layers) < 1:
        layers.append(out_channels)
    stax_layers.append(resnet_block(in_channels, layers[0], periodic, batch_norm, activation, d))
    for i in range(1, len(layers)):
        stax_layers.append(resnet_block(layers[i - 1], layers[i], periodic, batch_norm, activation, d))
    stax_layers.append(CONV[d](out_channels, (1,) * d))
    return stax.serial(*stax_layers)

Res-net unit for Invertible Nets

def resnet_block(in_channels: int,
out_channels: int,
periodic: bool,
batch_norm: bool,
activation: str | Callable = 'ReLU',
d: int | tuple = 2,
kernel_size=3)
Expand source code
def resnet_block(in_channels: int,
                 out_channels: int,
                 periodic: bool,
                 batch_norm: bool,
                 activation: Union[str, Callable] = 'ReLU',
                 d: Union[int, tuple] = 2,
                 kernel_size=3):
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    init_fn, apply_fn = {}, {}
    init_fn['conv1'], apply_fn['conv1'] = stax.serial(
        CONV[d](out_channels, (kernel_size,) * d, padding='valid'),
        stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity,
        activation)
    init_fn['conv2'], apply_fn['conv2'] = stax.serial(
        CONV[d](out_channels, (kernel_size,) * d, padding='valid'),
        stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity,
        activation)
    init_activation, apply_activation = activation
    if in_channels != out_channels:
        init_fn['sample_conv'], apply_fn['sample_conv'] = stax.serial(
            CONV[d](out_channels, (1,) * d, padding='VALID'),
            stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity)
    else:
        init_fn['sample_conv'], apply_fn['sample_conv'] = stax.Identity

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        # Preparing a list of shapes and dictionary of parameters to return
        shape, params['conv1'] = init_fn['conv1'](rngs[0], input_shape)
        shape, params['conv2'] = init_fn['conv2'](rngs[1], shape)
        shape, params['sample_conv'] = init_fn['sample_conv'](rngs[2], input_shape)
        shape, params['activation'] = init_activation(rngs[3], shape)
        return shape, params

    def net_apply(params, inputs, **kwargs):
        x = inputs
        pad_tuple = [[0, 0]] + [[1, 1] for i in range(d)] + [[0, 0]]
        out = jnp.pad(x, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
        out = apply_fn['conv1'](params['conv1'], out)
        out = jnp.pad(out, pad_width=pad_tuple, mode='wrap' if periodic else 'constant')
        out = apply_fn['conv2'](params['conv2'], out)
        skip_x = apply_fn['sample_conv'](params['sample_conv'], x, **kwargs)
        out = jnp.add(out, skip_x)
        # out = apply_activation(params['activation'], out)
        return out

    return net_init, net_apply
def rmsprop(net: StaxNet,
learning_rate: float = 0.001,
alpha=0.99,
eps=1e-08,
weight_decay=0.0,
momentum=0.0,
centered=False)
Expand source code
def rmsprop(net: StaxNet, learning_rate: float = 1e-3, alpha=0.99, eps=1e-08, weight_decay=0., momentum=0., centered=False):
    assert weight_decay == 0
    assert not centered
    if momentum == 0:
        opt = JaxOptimizer(*optim.rmsprop(learning_rate, alpha, eps))
    else:
        opt = JaxOptimizer(*optim.rmsprop_momentum(learning_rate, alpha, eps, momentum))
    opt.initialize(net.parameters)
    return opt
def save_state(obj: StaxNet | JaxOptimizer,
path: str)
Expand source code
def save_state(obj: Union[StaxNet, JaxOptimizer], path: str):
    if not path.endswith('.npz'):
        path += '.npz'
    if isinstance(obj, StaxNet):
        numpy.savez(path, x=np.array(obj.parameters, dtype=object), allow_pickle=True)
    elif isinstance(obj, JaxOptimizer):
        data = [[np.asarray(t) for t in pt[1:]] if isinstance(pt, (tuple, list)) else [] for pt in obj._state.packed_state]
        np.savez(path, state=np.array(data, dtype=object), allow_pickle=True)
    else:
        raise ValueError(f"obj must be a network or optimizer but got {type(obj)}")
    return path
def sgd(net: StaxNet,
learning_rate: float = 0.001,
momentum=0.0,
dampening=0.0,
weight_decay=0.0,
nesterov=False)
Expand source code
def sgd(net: StaxNet, learning_rate: float = 1e-3, momentum=0., dampening=0., weight_decay=0., nesterov=False):
    assert dampening == 0
    assert weight_decay == 0
    assert not nesterov
    if momentum == 0:
        opt = JaxOptimizer(*optim.sgd(learning_rate))
    else:
        opt = JaxOptimizer(*optim.momentum(learning_rate, momentum))
    opt.initialize(net.parameters)
    return opt
def u_net(in_channels: int,
out_channels: int,
levels: int = 4,
filters: int | Sequence = 16,
batch_norm: bool = True,
activation='ReLU',
in_spatial: int | tuple = 2,
periodic=False,
use_res_blocks: bool = False,
down_kernel_size=3,
up_kernel_size=3) ‑> StaxNet
Expand source code
def u_net(in_channels: int,
          out_channels: int,
          levels: int = 4,
          filters: Union[int, Sequence] = 16,
          batch_norm: bool = True,
          activation='ReLU',
          in_spatial: Union[tuple, int] = 2,
          periodic=False,
          use_res_blocks: bool = False,
          down_kernel_size=3,
          up_kernel_size=3) -> StaxNet:
    if isinstance(filters, (tuple, list)):
        assert len(filters) == levels, f"List of filters has length {len(filters)} but u-net has {levels} levels."
    else:
        filters = (filters,) * levels
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (1,) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    # Create layers
    if use_res_blocks:
        inc_init, inc_apply = resnet_block(in_channels, filters[0], periodic, batch_norm, activation, d, down_kernel_size)
    else:
        inc_init, inc_apply = create_double_conv(d, filters[0], filters[0], batch_norm, activation, periodic, down_kernel_size)
    init_functions, apply_functions = {}, {}
    for i in range(1, levels):
        if use_res_blocks:
            init_functions[f'down{i}'], apply_functions[f'down{i}'] = resnet_block(filters[i - 1], filters[i], periodic, batch_norm, activation, d, down_kernel_size)
            init_functions[f'up{i}'], apply_functions[f'up{i}'] = resnet_block(filters[i] + filters[i - 1], filters[i - 1], periodic, batch_norm, activation, d, down_kernel_size)
        else:
            init_functions[f'down{i}'], apply_functions[f'down{i}'] = create_double_conv(d, filters[i], filters[i], batch_norm, activation, periodic, up_kernel_size)
            init_functions[f'up{i}'], apply_functions[f'up{i}'] = create_double_conv(d, filters[i - 1], filters[i - 1], batch_norm, activation, periodic, up_kernel_size)
    outc_init, outc_apply = CONV[d](out_channels, (1,) * d, padding='same')
    max_pool_init, max_pool_apply = stax.MaxPool((2,) * d, padding='same', strides=(2,) * d)
    _, up_apply = create_upsample()

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape = input_shape
        # Layers
        shape, params['inc'] = inc_init(rngs[0], shape)
        shapes = [shape]
        for i in range(1, levels):
            shape, _ = max_pool_init(None, shape)
            shape, params[f'down{i}'] = init_functions[f'down{i}'](rngs[i], shape)
            shapes.insert(0, shape)
        for i in range(1, levels):
            shape = shapes[i][:-1] + (shapes[i][-1] + shape[-1],)
            shape, params[f'up{i}'] = init_functions[f'up{i}'](rngs[levels + i], shape)
        shape, params['outc'] = outc_init(rngs[-1], shape)
        return shape, params

    # no @jax.jit needed here since the user can jit this in the loss_function
    def net_apply(params, inputs, **kwargs):
        x = inputs
        x = inc_apply(params['inc'], x, **kwargs)
        xs = [x]
        for i in range(1, levels):
            x = max_pool_apply(None, x, **kwargs)
            x = apply_functions[f'down{i}'](params[f'down{i}'], x, **kwargs)
            xs.insert(0, x)
        for i in range(1, levels):
            x = up_apply(None, x, **kwargs)
            x = jnp.concatenate([x, xs[i]], axis=-1)
            x = apply_functions[f'up{i}'](params[f'up{i}'], x, **kwargs)
        x = outc_apply(params['outc'], x, **kwargs)
        return x

    net = StaxNet(net_init, net_apply, (1,) + in_spatial + (in_channels,))
    net.initialize()
    return net
def u_net_unit(in_channels: int,
out_channels: int,
levels: int = 4,
filters: int | Sequence = 16,
batch_norm: bool = True,
activation='ReLU',
periodic=False,
in_spatial: int | tuple = 2,
use_res_blocks: bool = False,
**kwargs)
Expand source code
def u_net_unit(in_channels: int,
               out_channels: int,
               levels: int = 4,
               filters: Union[int, Sequence] = 16,
               batch_norm: bool = True,
               activation='ReLU',
               periodic=False,
               in_spatial: Union[tuple, int] = 2,
               use_res_blocks: bool = False, **kwargs):
    """ U-net unit for Invertible Nets"""
    if isinstance(filters, (tuple, list)):
        assert len(filters) == levels, f"List of filters has length {len(filters)} but u-net has {levels} levels."
    else:
        filters = (filters,) * levels
    activation = ACTIVATIONS[activation] if isinstance(activation, str) else activation
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (1,) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    # Create layers
    if use_res_blocks:
        inc_init, inc_apply = resnet_block(in_channels, filters[0], periodic, batch_norm, activation, d)
    else:
        inc_init, inc_apply = create_double_conv(d, filters[0], filters[0], batch_norm, activation, periodic)
    init_functions, apply_functions = {}, {}
    for i in range(1, levels):
        if use_res_blocks:
            init_functions[f'down{i}'], apply_functions[f'down{i}'] = resnet_block(filters[i - 1], filters[i], periodic, batch_norm, activation, d)
            init_functions[f'up{i}'], apply_functions[f'up{i}'] = resnet_block(filters[i] + filters[i - 1], filters[i - 1], periodic, batch_norm, activation, d)
        else:
            init_functions[f'down{i}'], apply_functions[f'down{i}'] = create_double_conv(d, filters[i], filters[i], batch_norm, activation, periodic)
            init_functions[f'up{i}'], apply_functions[f'up{i}'] = create_double_conv(d, filters[i - 1], filters[i - 1], batch_norm, activation, periodic)
    outc_init, outc_apply = CONV[d](out_channels, (1,) * d, padding='same')
    max_pool_init, max_pool_apply = stax.MaxPool((2,) * d, padding='same', strides=(2,) * d)
    _, up_apply = create_upsample()

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape = input_shape
        # Layers
        shape, params['inc'] = inc_init(rngs[0], shape)
        shapes = [shape]
        for i in range(1, levels):
            shape, _ = max_pool_init(None, shape)
            shape, params[f'down{i}'] = init_functions[f'down{i}'](rngs[i], shape)
            shapes.insert(0, shape)
        for i in range(1, levels):
            shape = shapes[i][:-1] + (shapes[i][-1] + shape[-1],)
            shape, params[f'up{i}'] = init_functions[f'up{i}'](rngs[levels + i], shape)
        shape, params['outc'] = outc_init(rngs[-1], shape)
        return shape, params

    # no @jax.jit needed here since the user can jit this in the loss_function
    def net_apply(params, inputs, **kwargs):
        x = inputs
        x = inc_apply(params['inc'], x, **kwargs)
        xs = [x]
        for i in range(1, levels):
            x = max_pool_apply(None, x, **kwargs)
            x = apply_functions[f'down{i}'](params[f'down{i}'], x, **kwargs)
            xs.insert(0, x)
        for i in range(1, levels):
            x = up_apply(None, x, **kwargs)
            x = jnp.concatenate([x, xs[i]], axis=-1)
            x = apply_functions[f'up{i}'](params[f'up{i}'], x, **kwargs)
        x = outc_apply(params['outc'], x, **kwargs)
        return x

    return net_init, net_apply

U-net unit for Invertible Nets

def update_weights(net: StaxNet,
optimizer: JaxOptimizer,
loss_function: Callable,
*loss_args,
**loss_kwargs)
Expand source code
def update_weights(net: StaxNet, optimizer: JaxOptimizer, loss_function: Callable, *loss_args, **loss_kwargs):
    return optimizer.update(net, loss_function, net.parameters, loss_args, loss_kwargs)

Classes

class JaxOptimizer (initialize: Callable, update: Callable, get_params: Callable)
Expand source code
class JaxOptimizer:

    def __init__(self, initialize: Callable, update: Callable, get_params: Callable):
        self._initialize, self._update, self._get_params = initialize, update, get_params  # Stax functions
        self._state = None  # List[Tuple[T,T,T]]: (parameter, m, v)
        self._step_i = 0
        self._update_function_cache = {}

    def initialize(self, net: tuple):
        self._state = self._initialize(net)

    def update_step(self, grads: tuple):
        self._state = self._update(self._step_i, grads, self._state)
        self._step_i += 1

    def get_network_parameters(self):
        return self._get_params(self._state)

    def update(self, net: StaxNet, loss_function, wrt, loss_args, loss_kwargs):
        if loss_function not in self._update_function_cache:
            @functools.wraps(loss_function)
            def update(packed_current_state, *loss_args, **loss_kwargs):
                @functools.wraps(loss_function)
                def loss_depending_on_net(params_tracer: tuple, *args, **kwargs):
                    net._tracers = params_tracer
                    loss_function_non_jit = loss_function.f if isinstance(loss_function, JitFunction) else loss_function
                    result = loss_function_non_jit(*args, **kwargs)
                    net._tracers = None
                    return result
                gradient_function = math.gradient(loss_depending_on_net)
                current_state = OptimizerState(packed_current_state, self._state.tree_def, self._state.subtree_defs)
                current_params = self._get_params(current_state)
                value, grads = gradient_function(current_params, *loss_args, **loss_kwargs)
                next_state = self._update(self._step_i, grads[0], self._state)
                return next_state.packed_state, value
            if isinstance(loss_function, JitFunction):
                update = math.jit_compile(update)
            self._update_function_cache[loss_function] = update
        xs, _ = tree_flatten(net.parameters)
        packed_state = [(x, *pt[1:]) if isinstance(pt, (tuple, list)) else (x,) for x, pt in zip(xs, self._state.packed_state)]
        next_packed_state, loss_output = self._update_function_cache[loss_function](packed_state, *loss_args, **loss_kwargs)
        self._state = OptimizerState(next_packed_state, self._state.tree_def, self._state.subtree_defs)
        net.parameters = self.get_network_parameters()
        return loss_output

Methods

def get_network_parameters(self)
Expand source code
def get_network_parameters(self):
    return self._get_params(self._state)
def initialize(self, net: tuple)
Expand source code
def initialize(self, net: tuple):
    self._state = self._initialize(net)
def update(self,
net: StaxNet,
loss_function,
wrt,
loss_args,
loss_kwargs)
Expand source code
def update(self, net: StaxNet, loss_function, wrt, loss_args, loss_kwargs):
    if loss_function not in self._update_function_cache:
        @functools.wraps(loss_function)
        def update(packed_current_state, *loss_args, **loss_kwargs):
            @functools.wraps(loss_function)
            def loss_depending_on_net(params_tracer: tuple, *args, **kwargs):
                net._tracers = params_tracer
                loss_function_non_jit = loss_function.f if isinstance(loss_function, JitFunction) else loss_function
                result = loss_function_non_jit(*args, **kwargs)
                net._tracers = None
                return result
            gradient_function = math.gradient(loss_depending_on_net)
            current_state = OptimizerState(packed_current_state, self._state.tree_def, self._state.subtree_defs)
            current_params = self._get_params(current_state)
            value, grads = gradient_function(current_params, *loss_args, **loss_kwargs)
            next_state = self._update(self._step_i, grads[0], self._state)
            return next_state.packed_state, value
        if isinstance(loss_function, JitFunction):
            update = math.jit_compile(update)
        self._update_function_cache[loss_function] = update
    xs, _ = tree_flatten(net.parameters)
    packed_state = [(x, *pt[1:]) if isinstance(pt, (tuple, list)) else (x,) for x, pt in zip(xs, self._state.packed_state)]
    next_packed_state, loss_output = self._update_function_cache[loss_function](packed_state, *loss_args, **loss_kwargs)
    self._state = OptimizerState(next_packed_state, self._state.tree_def, self._state.subtree_defs)
    net.parameters = self.get_network_parameters()
    return loss_output
def update_step(self, grads: tuple)
Expand source code
def update_step(self, grads: tuple):
    self._state = self._update(self._step_i, grads, self._state)
    self._step_i += 1
class StaxNet (initialize: Callable, apply: Callable, input_shape: tuple)
Expand source code
class StaxNet:

    def __init__(self, initialize: Callable, apply: Callable, input_shape: tuple):
        self._initialize = initialize
        self._apply = apply
        self._input_shape = input_shape
        self.parameters = None
        self._tracers = None

    def initialize(self):
        rnd_key = JAX.rnd_key
        JAX.rnd_key, init_key = random.split(rnd_key)
        out_shape, params64 = self._initialize(init_key, input_shape=self._input_shape)
        if math.get_precision() < 64:
            self.parameters = _recursive_to_float32(params64)

    def __call__(self, *args, **kwargs):
        if self._tracers is not None:
            return self._apply(self._tracers, *args)
        else:
            return self._apply(self.parameters, *args)

Methods

def initialize(self)
Expand source code
def initialize(self):
    rnd_key = JAX.rnd_key
    JAX.rnd_key, init_key = random.split(rnd_key)
    out_shape, params64 = self._initialize(init_key, input_shape=self._input_shape)
    if math.get_precision() < 64:
        self.parameters = _recursive_to_float32(params64)