Module phiml.backend.jax.stax_nets
Stax implementation of the unified machine learning API. Equivalent functions also exist for the other frameworks.
For API documentation, see https://tum-pbs.github.io/PhiML/Network_API .
Functions
def Dense_resnet_block(in_channels: int, mid_channels: int, batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU')
def adagrad(net: StaxNet, learning_rate: float = 0.001, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10)
def adam(net: StaxNet, learning_rate: float = 0.001, betas=(0.9, 0.999), epsilon=1e-07)
def conv_classifier(in_features: int, in_spatial: Union[tuple, list], num_classes: int, blocks=(64, 128, 256, 256, 512, 512), block_sizes=(2, 2, 3, 3, 3), dense_layers=(4096, 4096, 100), batch_norm=True, activation='ReLU', softmax=True, periodic=False)
-
Based on VGG16.
def conv_net(in_channels: int, out_channels: int, layers: Sequence[int], batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU', periodic=False, in_spatial: Union[int, tuple] = 2) ‑> StaxNet
-
Built in Conv-Nets are also provided. Contrary to the classical convolutional neural networks, the feature map spatial size remains the same throughout the layers. Each layer of the network is essentially a convolutional block comprising of two conv layers. A filter size of 3 is used in the convolutional layers.
Arguments
in_channels : input channels of the feature map, dtype : int out_channels : output channels of the feature map, dtype : int layers : list or tuple of output channels for each intermediate layer between the input and final output channels, dtype : list or tuple activation : activation function used within the layers, dtype : string batch_norm : use of batchnorm after each conv layer, dtype : bool in_spatial : spatial dimensions of the input feature map, dtype : int
Returns
Conv-net model as specified by input arguments
def conv_net_unit(in_channels: int, out_channels: int, layers: Sequence[int], periodic: bool = False, batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU', in_spatial: Union[int, tuple] = 2, **kwargs)
-
Conv-net unit for Invertible Nets
def coupling_layer(in_channels: int, activation: Union[str, Callable] = 'ReLU', batch_norm: bool = False, in_spatial: Union[int, tuple] = 2, net: str = 'u_net', reverse_mask: bool = False, **kwargs)
def create_double_conv(d: int, out_channels: int, mid_channels: int, batch_norm: bool, activation: Callable, periodic: bool, kernel_size=3)
def create_upsample()
def fno(in_channels: int, out_channels: int, mid_channels: int, modes: Sequence[int], activation: Union[str, type] = 'ReLU', batch_norm: bool = False, in_spatial: int = 2)
def get_mask(inputs, reverse_mask, data_format='NHWC')
-
Compute mask for slicing input feature map for Invertible Nets
def get_parameters(model: StaxNet, wrap=True) ‑> dict
def invertible_net(num_blocks: int, construct_net: Union[str, Callable], **construct_kwargs)
def load_state(obj: Union[StaxNet, JaxOptimizer], path: str)
def mlp(in_channels: int, out_channels: int, layers: Sequence[int], batch_norm=False, activation='ReLU', softmax=False) ‑> StaxNet
def res_net(in_channels: int, out_channels: int, layers: Sequence[int], batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU', periodic=False, in_spatial: Union[int, tuple] = 2) ‑> StaxNet
def res_net_unit(in_channels: int, out_channels: int, layers: Sequence[int], batch_norm: bool = False, activation: Union[str, Callable] = 'ReLU', periodic=False, in_spatial: Union[int, tuple] = 2, **kwargs)
-
Res-net unit for Invertible Nets
def resnet_block(in_channels: int, out_channels: int, periodic: bool, batch_norm: bool, activation: Union[str, Callable] = 'ReLU', d: Union[int, tuple] = 2, kernel_size=3)
def rmsprop(net: StaxNet, learning_rate: float = 0.001, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False)
def save_state(obj: Union[StaxNet, JaxOptimizer], path: str)
def sgd(net: StaxNet, learning_rate: float = 0.001, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False)
def u_net(in_channels: int, out_channels: int, levels: int = 4, filters: Union[int, Sequence[+T_co]] = 16, batch_norm: bool = True, activation='ReLU', in_spatial: Union[int, tuple] = 2, periodic=False, use_res_blocks: bool = False, down_kernel_size=3, up_kernel_size=3) ‑> StaxNet
def u_net_unit(in_channels: int, out_channels: int, levels: int = 4, filters: Union[int, Sequence[+T_co]] = 16, batch_norm: bool = True, activation='ReLU', periodic=False, in_spatial: Union[int, tuple] = 2, use_res_blocks: bool = False, **kwargs)
-
U-net unit for Invertible Nets
def update_weights(net: StaxNet, optimizer: JaxOptimizer, loss_function: Callable, *loss_args, **loss_kwargs)
Classes
class JaxOptimizer (initialize: Callable, update: Callable, get_params: Callable)
-
Expand source code
class JaxOptimizer: def __init__(self, initialize: Callable, update: Callable, get_params: Callable): self._initialize, self._update, self._get_params = initialize, update, get_params # Stax functions self._state = None # List[Tuple[T,T,T]]: (parameter, m, v) self._step_i = 0 self._update_function_cache = {} def initialize(self, net: tuple): self._state = self._initialize(net) def update_step(self, grads: tuple): self._state = self._update(self._step_i, grads, self._state) self._step_i += 1 def get_network_parameters(self): return self._get_params(self._state) def update(self, net: StaxNet, loss_function, wrt, loss_args, loss_kwargs): if loss_function not in self._update_function_cache: @functools.wraps(loss_function) def update(packed_current_state, *loss_args, **loss_kwargs): @functools.wraps(loss_function) def loss_depending_on_net(params_tracer: tuple, *args, **kwargs): net._tracers = params_tracer loss_function_non_jit = loss_function.f if isinstance(loss_function, JitFunction) else loss_function result = loss_function_non_jit(*args, **kwargs) net._tracers = None return result gradient_function = math.gradient(loss_depending_on_net) current_state = OptimizerState(packed_current_state, self._state.tree_def, self._state.subtree_defs) current_params = self._get_params(current_state) value, grads = gradient_function(current_params, *loss_args, **loss_kwargs) next_state = self._update(self._step_i, grads[0], self._state) return next_state.packed_state, value if isinstance(loss_function, JitFunction): update = math.jit_compile(update) self._update_function_cache[loss_function] = update xs, _ = tree_flatten(net.parameters) packed_state = [(x, *pt[1:]) if isinstance(pt, (tuple, list)) else (x,) for x, pt in zip(xs, self._state.packed_state)] next_packed_state, loss_output = self._update_function_cache[loss_function](packed_state, *loss_args, **loss_kwargs) self._state = OptimizerState(next_packed_state, self._state.tree_def, self._state.subtree_defs) net.parameters = self.get_network_parameters() return loss_output
Methods
def get_network_parameters(self)
def initialize(self, net: tuple)
def update(self, net: StaxNet, loss_function, wrt, loss_args, loss_kwargs)
def update_step(self, grads: tuple)
class StaxNet (initialize: Callable, apply: Callable, input_shape: tuple)
-
Expand source code
class StaxNet: def __init__(self, initialize: Callable, apply: Callable, input_shape: tuple): self._initialize = initialize self._apply = apply self._input_shape = input_shape self.parameters = None self._tracers = None def initialize(self): rnd_key = JAX.rnd_key JAX.rnd_key, init_key = random.split(rnd_key) out_shape, params64 = self._initialize(init_key, input_shape=self._input_shape) if math.get_precision() < 64: self.parameters = _recursive_to_float32(params64) def __call__(self, *args, **kwargs): if self._tracers is not None: return self._apply(self._tracers, *args) else: return self._apply(self.parameters, *args)
Methods
def initialize(self)