Module phiml.backend.jax.stax_nets

Stax implementation of the unified machine learning API. Equivalent functions also exist for the other frameworks.

For API documentation, see https://tum-pbs.github.io/PhiML/Network_API .

Functions

def Dense_resnet_block(in_channels: int, mid_channels: int, batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU')
def adagrad(net: StaxNet, learning_rate: float = 0.001, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10)
def adam(net: StaxNet, learning_rate: float = 0.001, betas=(0.9, 0.999), epsilon=1e-07)
def conv_classifier(in_features: int, in_spatial: Union[tuple, list], num_classes: int, blocks=(64, 128, 256, 256, 512, 512), block_sizes=(2, 2, 3, 3, 3), dense_layers=(4096, 4096, 100), batch_norm=True, activation='ReLU', softmax=True, periodic=False)

Based on VGG16.

def conv_net(in_channels: int, out_channels: int, layers: Sequence[int], batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU', periodic=False, in_spatial: Union[int, tuple] = 2) ‑> StaxNet

Built in Conv-Nets are also provided. Contrary to the classical convolutional neural networks, the feature map spatial size remains the same throughout the layers. Each layer of the network is essentially a convolutional block comprising of two conv layers. A filter size of 3 is used in the convolutional layers.

Arguments

in_channels : input channels of the feature map, dtype : int out_channels : output channels of the feature map, dtype : int layers : list or tuple of output channels for each intermediate layer between the input and final output channels, dtype : list or tuple activation : activation function used within the layers, dtype : string batch_norm : use of batchnorm after each conv layer, dtype : bool in_spatial : spatial dimensions of the input feature map, dtype : int

Returns

Conv-net model as specified by input arguments

def conv_net_unit(in_channels: int, out_channels: int, layers: Sequence[int], periodic: bool = False, batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU', in_spatial: Union[int, tuple] = 2, **kwargs)

Conv-net unit for Invertible Nets

def coupling_layer(in_channels: int, activation: Union[str, Callable] = 'ReLU', batch_norm: bool = False, in_spatial: Union[int, tuple] = 2, net: str = 'u_net', reverse_mask: bool = False, **kwargs)
def create_double_conv(d: int, out_channels: int, mid_channels: int, batch_norm: bool, activation: Callable, periodic: bool, kernel_size=3)
def create_upsample()
def fno(in_channels: int, out_channels: int, mid_channels: int, modes: Sequence[int], activation: Union[str, type] = 'ReLU', batch_norm: bool = False, in_spatial: int = 2)
def get_mask(inputs, reverse_mask, data_format='NHWC')

Compute mask for slicing input feature map for Invertible Nets

def get_parameters(model: StaxNet, wrap=True) ‑> dict
def invertible_net(num_blocks: int, construct_net: Union[str, Callable], **construct_kwargs)
def load_state(obj: Union[StaxNetJaxOptimizer], path: str)
def mlp(in_channels: int, out_channels: int, layers: Sequence[int], batch_norm=False, activation='ReLU', softmax=False) ‑> StaxNet
def res_net(in_channels: int, out_channels: int, layers: Sequence[int], batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU', periodic=False, in_spatial: Union[int, tuple] = 2) ‑> StaxNet
def res_net_unit(in_channels: int, out_channels: int, layers: Sequence[int], batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU', periodic=False, in_spatial: Union[int, tuple] = 2, **kwargs)

Res-net unit for Invertible Nets

def resnet_block(in_channels: int, out_channels: int, periodic: bool, batch_norm: bool, activation: Union[str, Callable] = 'ReLU', d: Union[int, tuple] = 2, kernel_size=3)
def rmsprop(net: StaxNet, learning_rate: float = 0.001, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False)
def save_state(obj: Union[StaxNetJaxOptimizer], path: str)
def sgd(net: StaxNet, learning_rate: float = 0.001, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False)
def u_net(in_channels: int, out_channels: int, levels: int = 4, filters: Union[int, Sequence[+T_co]] = 16, batch_norm: bool = True, activation='ReLU', in_spatial: Union[int, tuple] = 2, periodic=False, use_res_blocks: bool = False, down_kernel_size=3, up_kernel_size=3) ‑> StaxNet
def u_net_unit(in_channels: int, out_channels: int, levels: int = 4, filters: Union[int, Sequence[+T_co]] = 16, batch_norm: bool = True, activation='ReLU', periodic=False, in_spatial: Union[int, tuple] = 2, use_res_blocks: bool = False, **kwargs)

U-net unit for Invertible Nets

def update_weights(net: StaxNet, optimizer: JaxOptimizer, loss_function: Callable, *loss_args, **loss_kwargs)

Classes

class JaxOptimizer (initialize: Callable, update: Callable, get_params: Callable)
Expand source code
class JaxOptimizer:

    def __init__(self, initialize: Callable, update: Callable, get_params: Callable):
        self._initialize, self._update, self._get_params = initialize, update, get_params  # Stax functions
        self._state = None  # List[Tuple[T,T,T]]: (parameter, m, v)
        self._step_i = 0
        self._update_function_cache = {}

    def initialize(self, net: tuple):
        self._state = self._initialize(net)

    def update_step(self, grads: tuple):
        self._state = self._update(self._step_i, grads, self._state)
        self._step_i += 1

    def get_network_parameters(self):
        return self._get_params(self._state)

    def update(self, net: StaxNet, loss_function, wrt, loss_args, loss_kwargs):
        if loss_function not in self._update_function_cache:
            @functools.wraps(loss_function)
            def update(packed_current_state, *loss_args, **loss_kwargs):
                @functools.wraps(loss_function)
                def loss_depending_on_net(params_tracer: tuple, *args, **kwargs):
                    net._tracers = params_tracer
                    loss_function_non_jit = loss_function.f if isinstance(loss_function, JitFunction) else loss_function
                    result = loss_function_non_jit(*args, **kwargs)
                    net._tracers = None
                    return result
                gradient_function = math.gradient(loss_depending_on_net)
                current_state = OptimizerState(packed_current_state, self._state.tree_def, self._state.subtree_defs)
                current_params = self._get_params(current_state)
                value, grads = gradient_function(current_params, *loss_args, **loss_kwargs)
                next_state = self._update(self._step_i, grads[0], self._state)
                return next_state.packed_state, value
            if isinstance(loss_function, JitFunction):
                update = math.jit_compile(update)
            self._update_function_cache[loss_function] = update
        xs, _ = tree_flatten(net.parameters)
        packed_state = [(x, *pt[1:]) if isinstance(pt, (tuple, list)) else (x,) for x, pt in zip(xs, self._state.packed_state)]
        next_packed_state, loss_output = self._update_function_cache[loss_function](packed_state, *loss_args, **loss_kwargs)
        self._state = OptimizerState(next_packed_state, self._state.tree_def, self._state.subtree_defs)
        net.parameters = self.get_network_parameters()
        return loss_output

Methods

def get_network_parameters(self)
def initialize(self, net: tuple)
def update(self, net: StaxNet, loss_function, wrt, loss_args, loss_kwargs)
def update_step(self, grads: tuple)
class StaxNet (initialize: Callable, apply: Callable, input_shape: tuple)
Expand source code
class StaxNet:

    def __init__(self, initialize: Callable, apply: Callable, input_shape: tuple):
        self._initialize = initialize
        self._apply = apply
        self._input_shape = input_shape
        self.parameters = None
        self._tracers = None

    def initialize(self):
        rnd_key = JAX.rnd_key
        JAX.rnd_key, init_key = random.split(rnd_key)
        out_shape, params64 = self._initialize(init_key, input_shape=self._input_shape)
        if math.get_precision() < 64:
            self.parameters = _recursive_to_float32(params64)

    def __call__(self, *args, **kwargs):
        if self._tracers is not None:
            return self._apply(self._tracers, *args)
        else:
            return self._apply(self.parameters, *args)

Methods

def initialize(self)