class JaxBackend

Backends delegate low-level operations to a ML or numerics library or emulate them. The methods of Backend form a comprehensive list of available operations.

To support a library, subclass Backend and register it by adding it to BACKENDS.


Human-readable string
ComputeDevice being used by default
class JaxBackend(Backend):

    def __init__(self):
        devices = []
        for device_type in ['cpu', 'gpu', 'tpu']:
                for jax_dev in jax.devices(device_type):
                    devices.append(ComputeDevice(self, device_type.upper(), jax_dev.platform.upper(), -1, -1, f"id={}", jax_dev))
            except RuntimeError as err:
                pass  # this is just Jax not finding anything. jaxlib.xla_client._get_local_backends() could help but isn't currently available on GitHub actions
        Backend.__init__(self, 'jax', devices, devices[-1])
            self.rnd_key = jax.random.PRNGKey(seed=0)
        except RuntimeError as err:
            warnings.warn(f"{err}", RuntimeWarning)
            self.rnd_key = None

    def prefers_channels_last(self) -> bool:
        return True

    def requires_fixed_shapes_when_tracing(self) -> bool:
        return True

    def nn_library(self):
        from . import stax_nets
        return stax_nets

    def _check_float64(self):
        if self.precision == 64:
            if not'jax_enable_x64'):
                jax.config.update('jax_enable_x64', True)
            assert'jax_enable_x64'), "FP64 is disabled for Jax."

    def seed(self, seed: int):
        self.rnd_key = jax.random.PRNGKey(seed)

    def as_tensor(self, x, convert_external=True):
        if self.is_tensor(x, only_native=convert_external):
            array = x
            array = jnp.array(x)
        # --- Enforce Precision ---
        if not isinstance(array, numbers.Number):
            if self.dtype(array).kind == float:
                array = self.to_float(array)
            elif self.dtype(array).kind == complex:
                array = self.to_complex(array)
        return array

    def is_module(self, obj):
        return False

    def is_tensor(self, x, only_native=False):
        if isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray):  # NumPy arrays inherit from Jax arrays
            return True
        if isinstance(x, jnp.bool_) and not isinstance(x, np.bool_):
            return True
        if self.is_sparse(x):
            return True
        # --- Above considered native ---
        if only_native:
            return False
        # --- Non-native types ---
        if isinstance(x, np.ndarray):
            return True
        if isinstance(x, np.bool_):
            return True
        if isinstance(x, (numbers.Number, bool)):
            return True
        if isinstance(x, (tuple, list)):
            return all([self.is_tensor(item, False) for item in x])
        return False

    def is_sparse(self, x) -> bool:
        return isinstance(x, (COO, BCOO, CSR, CSC))

    def get_sparse_format(self, x) -> str:
        format_names = {
            COO: 'coo',
            BCOO: 'coo',
            CSR: 'csr',
            CSC: 'csc',
        return format_names.get(type(x), 'dense')

    def is_available(self, tensor):
        if isinstance(tensor, JVPTracer):
            tensor = tensor.primal
        return not isinstance(tensor, Tracer)

    def numpy(self, tensor):
        if isinstance(tensor, JVPTracer):
            tensor = tensor.primal
        if self.is_sparse(tensor):
            assemble, parts = self.disassemble(tensor)
            return assemble(NUMPY, *[self.numpy(t) for t in parts])
        return np.array(tensor)

    def disassemble(self, x) -> Tuple[Callable, Sequence[TensorType]]:
        if self.is_sparse(x):
            if isinstance(x, COO):
                raise NotImplementedError
                # return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (np.stack([x.row, x.col], -1),
            if isinstance(x, BCOO):
                return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (x.indices,
            if isinstance(x, CSR):
                raise NotImplementedError
                # return lambda b, v, i, p: b.csr_matrix(i, p, v, x.shape), (, x.indices, x.indptr)
            elif isinstance(x, CSC):
                raise NotImplementedError
                # return lambda b, v, i, p: b.csc_matrix(p, i, v, x.shape), (, x.indices, x.indptr)
            raise NotImplementedError
            return lambda b, t: t, (x,)

    def to_dlpack(self, tensor):
        from jax import dlpack
        return dlpack.to_dlpack(tensor)

    def from_dlpack(self, capsule):
        from jax import dlpack
        return dlpack.from_dlpack(capsule)

    def copy(self, tensor, only_mutable=False):
        return jnp.array(tensor, copy=True)

    def get_device(self, tensor: TensorType) -> ComputeDevice:
        if hasattr(tensor, 'devices'):
            return self.get_device_by_ref(next(iter(tensor.devices())))
        if hasattr(tensor, 'device'):
            return self.get_device_by_ref(tensor.device())
        raise AssertionError(f"tensor {type(tensor)} has no device attribute")

    def allocate_on_device(self, tensor: TensorType, device: ComputeDevice) -> TensorType:
        return jax.device_put(tensor, device.ref)

    sqrt = staticmethod(jnp.sqrt)
    exp = staticmethod(jnp.exp)
    erf = staticmethod(scipy.special.erf)
    softplus = staticmethod(jax.nn.softplus)
    sin = staticmethod(jnp.sin)
    arcsin = staticmethod(jnp.arcsin)
    cos = staticmethod(jnp.cos)
    arccos = staticmethod(jnp.arccos)
    tan = staticmethod(jnp.tan)
    arctan = staticmethod(np.arctan)
    arctan2 = staticmethod(np.arctan2)
    sinh = staticmethod(np.sinh)
    arcsinh = staticmethod(np.arcsinh)
    cosh = staticmethod(np.cosh)
    arccosh = staticmethod(np.arccosh)
    tanh = staticmethod(np.tanh)
    arctanh = staticmethod(np.arctanh)
    log = staticmethod(jnp.log)
    log2 = staticmethod(jnp.log2)
    log10 = staticmethod(jnp.log10)
    isfinite = staticmethod(jnp.isfinite)
    isnan = staticmethod(jnp.isnan)
    isinf = staticmethod(jnp.isinf)
    abs = staticmethod(jnp.abs)
    sign = staticmethod(jnp.sign)
    round = staticmethod(jnp.round)
    ceil = staticmethod(jnp.ceil)
    floor = staticmethod(jnp.floor)
    flip = staticmethod(jnp.flip)
    stop_gradient = staticmethod(jax.lax.stop_gradient)
    transpose = staticmethod(jnp.transpose)
    equal = staticmethod(jnp.equal)
    tile = staticmethod(jnp.tile)
    stack = staticmethod(jnp.stack)
    concat = staticmethod(jnp.concatenate)
    maximum = staticmethod(jnp.maximum)
    minimum = staticmethod(jnp.minimum)
    clip = staticmethod(jnp.clip)
    argmax = staticmethod(np.argmax)
    argmin = staticmethod(np.argmin)
    shape = staticmethod(jnp.shape)
    staticshape = staticmethod(jnp.shape)
    imag = staticmethod(jnp.imag)
    real = staticmethod(jnp.real)
    conj = staticmethod(jnp.conjugate)
    einsum = staticmethod(jnp.einsum)
    cumsum = staticmethod(jnp.cumsum)

    def nonzero(self, values, length=None, fill_value=-1):
        result = jnp.nonzero(values, size=length, fill_value=fill_value)
        return jnp.stack(result, -1)

    def vectorized_call(self, f, *args, output_dtypes=None, **aux_args):
        batch_size = self.determine_size(args, 0)
        args = [self.tile_to(t, 0, batch_size) for t in args]
        def f_positional(*args):
            return f(*args, **aux_args)
        vec_f = jax.vmap(f_positional, 0, 0)
        return vec_f(*args)

    def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args):
        if all([self.is_available(arg) for arg in args]):
            args = [self.numpy(arg) for arg in args]
            output = f(*args, **aux_args)
            result = map_structure(self.as_tensor, output)
            return result
        class OutputTensor:
            shape: Tuple[int]
            dtype: np.dtype
        output_specs = map_structure(lambda t, s: OutputTensor(s, to_numpy_dtype(t)), output_dtypes, output_shapes)
        if hasattr(jax, 'pure_callback'):
            def aux_f(*args):
                return f(*args, **aux_args)
            return jax.pure_callback(aux_f, output_specs, *args)
            def aux_f(args):
                if isinstance(args, tuple):
                    return f(*args, **aux_args)
                    return f(args, **aux_args)
            from jax.experimental.host_callback import call
            return call(aux_f, args, result_shape=output_specs)

    def jit_compile(self, f: Callable) -> Callable:
        def run_jit_f(*args):
            # print(jax.make_jaxpr(f)(*args))
            ML_LOGGER.debug(f"JaxBackend: running jit-compiled '{f.__name__}' with shapes {[self.shape(arg) for arg in args]} and dtypes {[self.dtype(arg) for arg in args]}")
            return, *args, name=f"run jit-compiled '{f.__name__}'")

        run_jit_f.__name__ = f"Jax-Jit({f.__name__})"
        jit_f = jax.jit(f, device=self._default_device.ref)
        return run_jit_f

    def block_until_ready(self, values):
        if hasattr(values, 'block_until_ready'):
        if isinstance(values, (tuple, list)):
            for v in values:

    def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool):
        if get_output:
            jax_grad_f = jax.value_and_grad(f, argnums=wrt, has_aux=True)
            def unwrap_outputs(*args):
                args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
                (_, output_tuple), grads = jax_grad_f(*args)
                return (*output_tuple, *[jnp.conjugate(g) for g in grads])
            return unwrap_outputs
            def nonaux_f(*args):
                loss, output = f(*args)
                return loss
            jax_grad = jax.grad(nonaux_f, argnums=wrt, has_aux=False)
            def call_jax_grad(*args):
                args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
                grads = jax_grad(*args)
                return tuple([jnp.conjugate(g) for g in grads])
            return call_jax_grad

    def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) -> Callable:
        jax_fun = jax.custom_vjp(f)  # custom vector-Jacobian product (reverse-mode differentiation)

        def forward(*x):
            y = f(*x)
            return y, (x, y)

        def backward(x_y, dy):
            x, y = x_y
            dx = gradient(x, y, dy)
            return tuple(dx)

        jax_fun.defvjp(forward, backward)
        return jax_fun

    def divide_no_nan(self, x, y):
        return jnp.where(y == 0, 0, x / y)
        # jnp.nan_to_num(x / y, copy=True, nan=0) covers up NaNs from before

    def random_uniform(self, shape, low, high, dtype: Union[DType, None]):
        self.rnd_key, subkey = jax.random.split(self.rnd_key)

        dtype = dtype or self.float_type
        jdt = to_numpy_dtype(dtype)
        if dtype.kind == float:
            tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt)
        elif dtype.kind == complex:
            real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType(float, dtype.precision)))
            imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType(float, dtype.precision)))
            return real + 1j * imag
        elif dtype.kind == int:
            tensor = random.randint(subkey, shape, low, high, dtype=jdt)
            if tensor.dtype != jdt:
                warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning)
            raise ValueError(dtype)
        return jax.device_put(tensor, self._default_device.ref)

    def random_normal(self, shape, dtype: DType):
        self.rnd_key, subkey = jax.random.split(self.rnd_key)
        dtype = dtype or self.float_type
        return jax.device_put(random.normal(subkey, shape, dtype=to_numpy_dtype(dtype)), self._default_device.ref)

    def random_permutations(self, permutations: int, n: int):
        self.rnd_key, subkey = jax.random.split(self.rnd_key)
        result = jnp.stack([jax.random.permutation(subkey, n) for _ in range(permutations)])
        return jax.device_put(result, self._default_device.ref)

    def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)):
        if limit is None:
            start, limit = 0, start
        return jnp.arange(start, limit, delta, to_numpy_dtype(dtype))

    def pad(self, value, pad_width, mode='constant', constant_values=0):
        assert mode in ('constant', 'symmetric', 'periodic', 'reflect', 'boundary'), mode
        if mode == 'constant':
            constant_values = jnp.array(constant_values, dtype=value.dtype)
            return jnp.pad(value, pad_width, 'constant', constant_values=constant_values)
            if mode in ('periodic', 'boundary'):
                mode = {'periodic': 'wrap', 'boundary': 'edge'}[mode]
            return jnp.pad(value, pad_width, mode)

    def reshape(self, value, shape):
        return jnp.reshape(value, shape)

    def sum(self, value, axis=None, keepdims=False):
        if isinstance(value, (tuple, list)):
            assert axis == 0
            return sum(value[1:], value[0])
        return jnp.sum(value, axis=axis, keepdims=keepdims)

    def prod(self, value, axis=None):
        if not isinstance(value, jnp.ndarray):
            value = jnp.array(value)
        if value.dtype == bool:
            return jnp.all(value, axis=axis)
        return, axis=axis)

    def where(self, condition, x=None, y=None):
        if x is None or y is None:
            return jnp.argwhere(condition)
        return jnp.where(condition, x, y)

    def zeros(self, shape, dtype: DType = None):
        return jax.device_put(jnp.zeros(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)

    def zeros_like(self, tensor):
        return jax.device_put(jnp.zeros_like(tensor), self._default_device.ref)

    def ones(self, shape, dtype: DType = None):
        return jax.device_put(jnp.ones(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)

    def ones_like(self, tensor):
        return jax.device_put(jnp.ones_like(tensor), self._default_device.ref)

    def meshgrid(self, *coordinates):
        coordinates = [self.as_tensor(c) for c in coordinates]
        return [jax.device_put(c, self._default_device.ref) for c in jnp.meshgrid(*coordinates, indexing='ij')]

    def linspace(self, start, stop, number):
        return jax.device_put(jnp.linspace(start, stop, number, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)

    def linspace_without_last(self, start, stop, number):
        return jax.device_put(jnp.linspace(start, stop, number, endpoint=False, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)

    def mean(self, value, axis=None, keepdims=False):
        return jnp.mean(value, axis, keepdims=keepdims)

    def log_gamma(self, x):
        return jax.lax.lgamma(self.to_float(x))

    def gamma_inc_l(self, a, x):
        return scipy.special.gammainc(a, x)

    def gamma_inc_u(self, a, x):
        return scipy.special.gammaincc(a, x)

    def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list]):
        return jnp.tensordot(a, b, (a_axes, b_axes))

    def mul(self, a, b):
        # if scipy.sparse.issparse(a):  # TODO sparse?
        #     return a.multiply(b)
        # elif scipy.sparse.issparse(b):
        #     return b.multiply(a)
        # else:
            return Backend.mul(self, a, b)

    def mul_matrix_batched_vector(self, A, b):
        from jax.experimental.sparse import BCOO
        if isinstance(A, BCOO):
            return(A @ b.T).T
        return jnp.stack([[i]) for i in range(b.shape[0])])

    def get_diagonal(self, matrices, offset=0):
        result = jnp.diagonal(matrices, offset=offset, axis1=1, axis2=2)
        return jnp.transpose(result, [0, 2, 1])

    def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]]):
        if all(self.is_available(t) for t in values):
            return self.stop_gradient_tree(Backend.while_loop(self, loop, values, max_iter))
        if isinstance(max_iter, (tuple, list)):  # stack traced trajectory, unroll until max_iter
            values = self.stop_gradient_tree(values)
            trj = [values] if 0 in max_iter else []
            for i in range(1, max(max_iter) + 1):
                values = loop(*values)
                if i in max_iter:
                    trj.append(values)  # values are not mutable so no need to copy
            return self.stop_gradient_tree(self.stack_leaves(trj))
            if max_iter is None:
                cond = lambda vals: jnp.any(vals[0])
                body = lambda vals: loop(*vals)
                return jax.lax.while_loop(cond, body, values)
                cond = lambda vals: jnp.any(vals[1][0]) & (vals[0] < max_iter)
                body = lambda vals: (vals[0] + 1, loop(*vals[1]))
                return jax.lax.while_loop(cond, body, (self.as_tensor(0), values))[1]

    def max(self, x, axis=None, keepdims=False):
        return jnp.max(x, axis, keepdims=keepdims)

    def min(self, x, axis=None, keepdims=False):
        return jnp.min(x, axis, keepdims=keepdims)

    def conv(self, value, kernel, zero_padding=True):
        assert kernel.shape[0] in (1, value.shape[0])
        assert value.shape[1] == kernel.shape[2], f"value has {value.shape[1]} channels but kernel has {kernel.shape[2]}"
        assert value.ndim + 1 == kernel.ndim
        value, kernel = self.auto_cast(value, kernel, bool_to_int=True)
        # AutoDiff may require jax.lax.conv_general_dilated
        result = []
        for b in range(value.shape[0]):
            b_kernel = kernel[min(b, kernel.shape[0] - 1)]
            result_b = []
            for o in range(kernel.shape[1]):
                for i in range(value.shape[1]):
                    #[b, o, ...].set(scipy.signal.correlate(value[b, i, ...], b_kernel[o, i, ...], mode='same' if zero_padding else 'valid'))
                    result_b[-1] += scipy.signal.correlate(value[b, i, ...], b_kernel[o, i, ...], mode='same' if zero_padding else 'valid')
            result.append(jnp.stack(result_b, 0))
        return jnp.stack(result, 0)

    def expand_dims(self, a, axis=0, number=1):
        for _i in range(number):
            a = jnp.expand_dims(a, axis)
        return a

    def cast(self, x, dtype: DType):
        if self.is_tensor(x, only_native=True) and from_numpy_dtype(x.dtype) == dtype:
            return x
            return jnp.array(x, to_numpy_dtype(dtype))

    def unravel_index(self, flat_index, shape):
        return jnp.stack(jnp.unravel_index(flat_index, shape), -1)

    def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined'):
        if not self.is_available(shape):
            return Backend.ravel_multi_index(self, multi_index, shape, mode)
        mode = mode if isinstance(mode, int) else {'undefined': 'clip', 'periodic': 'wrap', 'clamp': 'clip'}[mode]
        idx_first = jnp.transpose(multi_index, (self.ndims(multi_index)-1,) + tuple(range(self.ndims(multi_index)-1)))
        result = jnp.ravel_multi_index(idx_first, shape, mode='wrap' if isinstance(mode, int) else mode)
        if isinstance(mode, int):
            outside = self.any((multi_index < 0) | (multi_index >= jnp.asarray(shape, dtype=multi_index.dtype)), -1)
            result = self.where(outside, mode, result)
        return result

    def gather(self, values, indices, axis: int):
        slices = [indices if i == axis else slice(None) for i in range(self.ndims(values))]
        return values[tuple(slices)]

    def batched_gather_nd(self, values, indices):
        values = self.as_tensor(values)
        indices = self.as_tensor(indices)
        assert indices.shape[-1] == self.ndims(values) - 2
        batch_size = combined_dim(values.shape[0], indices.shape[0])
        indices_list = [indices[..., i] for i in range(indices.shape[-1])]
        batch_range = self.expand_dims(np.arange(batch_size), -1, number=self.ndims(indices) - 2)
        slices = (batch_range, *indices_list)
        return values[slices]

    def batched_gather_1d(self, values, indices):
        batch_size = combined_dim(values.shape[0], indices.shape[0])
        return values[np.arange(batch_size)[:, None], indices]

    def repeat(self, x, repeats, axis: int, new_length=None):
        return jnp.repeat(x, self.as_tensor(repeats), axis, total_repeat_length=new_length)

    def std(self, x, axis=None, keepdims=False):
        return jnp.std(x, axis, keepdims=keepdims)

    def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0):
        if new_length is None:
            slices = [mask if i == axis else slice(None) for i in range(len(x.shape))]
            return x[tuple(slices)]
            indices = jnp.argwhere(mask, size=new_length, fill_value=-1)[..., 0]
            valid = indices >= 0
            valid = valid[tuple([slice(None) if i == axis else None for i in range(len(x.shape))])]
            result = self.gather(x, jnp.maximum(0, indices), axis)
            return jnp.where(valid, result, fill_value)

    def any(self, boolean_tensor, axis=None, keepdims=False):
        if isinstance(boolean_tensor, (tuple, list)):
            boolean_tensor = jnp.stack(boolean_tensor)
        return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims)

    def all(self, boolean_tensor, axis=None, keepdims=False):
        if isinstance(boolean_tensor, (tuple, list)):
            boolean_tensor = jnp.stack(boolean_tensor)
        return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims)

    def scatter(self, base_grid, indices, values, mode: str):
        out_kind = self.combine_types(self.dtype(base_grid), self.dtype(values)).kind
        base_grid, values = self.auto_cast(base_grid, values, bool_to_int=True)
        batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0])
        spatial_dims = tuple(range(base_grid.ndim - 2))
        dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,),  # channel dim of updates (batch dim removed)
                                                inserted_window_dims=spatial_dims,  # no idea what this does but spatial_dims seems to work
                                                scatter_dims_to_operand_dims=spatial_dims)  # spatial dims of base_grid (batch dim removed)
        scatter = {'add': jax.lax.scatter_add, 'update': jax.lax.scatter, 'max': jax.lax.scatter_max, 'min': jax.lax.scatter_min}[mode]
        def scatter_single(base_grid, indices, values):
            return scatter(base_grid, indices, values, dnums)
        if self.staticshape(indices)[0] == 1:
            indices = self.tile(indices, [batch_size, 1, 1])
        if self.staticshape(values)[0] == 1:
            values = self.tile(values, [batch_size, 1, 1])
        result = self.vectorized_call(scatter_single, base_grid, indices, values)
        if self.dtype(result).kind != out_kind:
            if out_kind == bool:
                result = self.cast(result, BOOL)
        return result

    def histogram1d(self, values, weights, bin_edges):
        def unbatched_hist(values, weights, bin_edges):
            hist, _ = jnp.histogram(values, bin_edges, weights=weights)
            return hist
        return jax.vmap(unbatched_hist)(values, weights, bin_edges)

    def bincount(self, x, weights: Optional[TensorType], bins: int, x_sorted=False):
        if x_sorted:
            return jax.ops.segment_sum(weights or 1, x, bins, indices_are_sorted=True)
            return jnp.bincount(x, weights=weights, minlength=bins, length=bins)

    def unique(self, x: TensorType, return_inverse: bool, return_counts: bool, axis: int) -> Tuple[TensorType, ...]:
        return jnp.unique(x, return_inverse=return_inverse, return_counts=return_counts, axis=axis)

    def quantile(self, x, quantiles):
        return jnp.quantile(x, quantiles, axis=-1)

    def argsort(self, x, axis=-1):
        return jnp.argsort(x, axis)

    def sort(self, x, axis=-1):
        return jnp.sort(x, axis)

    def searchsorted(self, sorted_sequence, search_values, side: str, dtype=DType(int, 32)):
        if self.ndims(sorted_sequence) == 1:
            return jnp.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype))
            return jax.vmap(partial(self.searchsorted, side=side, dtype=dtype))(sorted_sequence, search_values)

    def fft(self, x, axes: Union[tuple, list]):
        x = self.to_complex(x)
        if not axes:
            return x
        if len(axes) == 1:
            return jnp.fft.fft(x, axis=axes[0]).astype(x.dtype)
        elif len(axes) == 2:
            return jnp.fft.fft2(x, axes=axes).astype(x.dtype)
            return jnp.fft.fftn(x, axes=axes).astype(x.dtype)

    def ifft(self, k, axes: Union[tuple, list]):
        if not axes:
            return k
        if len(axes) == 1:
            return jnp.fft.ifft(k, axis=axes[0]).astype(k.dtype)
        elif len(axes) == 2:
            return jnp.fft.ifft2(k, axes=axes).astype(k.dtype)
            return jnp.fft.ifftn(k, axes=axes).astype(k.dtype)

    def dtype(self, array) -> DType:
        if isinstance(array, bool):
            return BOOL
        if isinstance(array, int):
            return INT32
        if isinstance(array, float):
            return FLOAT64
        if isinstance(array, complex):
            return COMPLEX128
        if not isinstance(array, jnp.ndarray):
            array = jnp.array(array)
        return from_numpy_dtype(array.dtype)

    def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tuple[TensorType, TensorType, TensorType, TensorType]:
        solution, residuals, rank, singular_values = lstsq_batched(matrix, rhs)
        return solution, residuals, rank, singular_values

    def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool):
        matrix, rhs = self.auto_cast(matrix, rhs, int_to_float=True, bool_to_int=True)
        x = jax.lax.linalg.triangular_solve(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal, left_side=True)
        return x

    def matrix_rank_dense(self, matrix, hermitian=False):
            return jnp.linalg.matrix_rank(matrix)
        except TypeError as err:
            if err.args[0] == "array should have 2 or fewer dimensions":  # this is a Jax bug on some distributions/versions
                warnings.warn("You are using a broken version of JAX. matrix_rank for dense matrices will fall back to NumPy.")
                return self.as_tensor(NUMPY.matrix_rank_dense(self.numpy(matrix), hermitian=hermitian))
                raise err

    def eigvals(self, matrix: TensorType) -> TensorType:
        return jnp.linalg.eigvals(matrix)

    def eig(self, matrix: TensorType) -> TensorType:
        return jnp.linalg.eig(matrix)

    def svd(self, matrix: TensorType, full_matrices=True) -> Tuple[TensorType, TensorType, TensorType]:
        result = jnp.linalg.svd(matrix, full_matrices=full_matrices)
        return result[0], result[1], result[2]

    def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple):
        return BCOO((values, indices), shape=shape)


  • phiml.backend._backend.Backend

Class variables

var arccosh
var arcsinh
var arctan
var arctan2
var arctanh
var cosh
var sinh
var tanh

Static methods

def abs(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def arccos(x, /)
def arcsin(x, /)
def argmax(a, axis=None, out=None, *, keepdims=<no value>)

Returns the indices of the maximum values along an axis.


a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims : bool, optional

If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

Added in version: 1.22.0


index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.

See Also

ndarray.argmax, argmin amax : The maximum value along a given axis. unravel_index : Convert a flat index into an index tuple. take_along_axis : Apply ``np.expand_dims(index_array, axis)`` from argmax to an array as if by calling max.


In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.


>>> a = np.arange(6).reshape(2,3) + 10
>>> a
array([[10, 11, 12],
       [13, 14, 15]])
>>> np.argmax(a)
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])

Indexes of the maximal elements of a N-dimensional array:

>>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
>>> ind
(1, 2)
>>> a[ind]
>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b)  # Only the first occurrence is returned.
>>> x = np.array([[4,2,3], [1,0,3]])
>>> index_array = np.argmax(x, axis=-1)
>>> # Same as np.amax(x, axis=-1, keepdims=True)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
>>> # Same as np.amax(x, axis=-1)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([4, 3])

Setting keepdims to True,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmax(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
def argmin(a, axis=None, out=None, *, keepdims=<no value>)

Returns the indices of the minimum values along an axis.


a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims : bool, optional

If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

Added in version: 1.22.0


index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.

See Also

ndarray.argmin, argmax amin : The minimum value along a given axis. unravel_index : Convert a flat index into an index tuple. take_along_axis : Apply ``np.expand_dims(index_array, axis)`` from argmin to an array as if by calling min.


In case of multiple occurrences of the minimum values, the indices corresponding to the first occurrence are returned.


>>> a = np.arange(6).reshape(2,3) + 10
>>> a
array([[10, 11, 12],
       [13, 14, 15]])
>>> np.argmin(a)
>>> np.argmin(a, axis=0)
array([0, 0, 0])
>>> np.argmin(a, axis=1)
array([0, 0])

Indices of the minimum elements of a N-dimensional array:

>>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape)
>>> ind
(0, 0)
>>> a[ind]
>>> b = np.arange(6) + 10
>>> b[4] = 10
>>> b
array([10, 11, 12, 13, 10, 15])
>>> np.argmin(b)  # Only the first occurrence is returned.
>>> x = np.array([[4,2,3], [1,0,3]])
>>> index_array = np.argmin(x, axis=-1)
>>> # Same as np.amin(x, axis=-1, keepdims=True)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
>>> # Same as np.amax(x, axis=-1)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([2, 0])

Setting keepdims to True,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmin(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
def ceil(x, /)
def clip(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], a_min: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, None] = None, a_max: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, None] = None, out: None = None) ‑> jax.Array
def concat(arrays: Union[numpy.ndarray, jax.Array, Sequence[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]]], axis: Optional[int] = 0, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array

Join a sequence of arrays along an existing axis.

LAX-backend implementation of :func:numpy.concatenate.

Original docstring below.


axis : int, optional
The axis along which the arrays will be joined. If axis is None, arrays are flattened before use. Default is 0.
dtype : str or dtype
If provided, the destination array will have this dtype. Cannot be provided together with out.


res : ndarray
The concatenated array.
def conj(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def cos(x, /)
def cumsum(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axis: Union[None, int, Sequence[int]] = None, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType] = None, out: None = None) ‑> jax.Array

Return the cumulative sum of the elements along a given axis.

LAX-backend implementation of :func:numpy.cumsum.

Original docstring below.


a : array_like
Input array.
axis : int, optional
Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
dtype : dtype, optional
Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.


cumsum_along_axis : ndarray. A new array holding the result is returned unless out is specified, in which case a reference to out is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.

def einsum(subscripts, /, *operands, out: None = None, optimize: str = 'optimal', precision: Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, preferred_element_type: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array

Evaluates the Einstein summation convention on the operands.

LAX-backend implementation of :func:numpy.einsum.

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a :class:~jax.lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two :class:~jax.lax.Precision enums indicating separate precision for each argument. A tuple precision does not necessarily map to multiple arguments of einsum(); rather, the specified precision is forwarded to each dot_general call used in the implementation.

Original docstring below.

Using the Einstein summation convention, many common multi-dimensional, linear algebraic array operations can be represented in a simple fashion. In implicit mode einsum computes these values.

In explicit mode, einsum provides further flexibility to compute other array operations that might not be considered classical Einstein summation operations, by disabling, or forcing summation over specified subscript labels.

See the notes and examples for clarification.


subscripts : str
Specifies the subscripts for summation as comma separated list of subscript labels. An implicit (classical Einstein summation) calculation is performed unless the explicit indicator '->' is included as well as subscript labels of the precise output form.
operands : list of array_like
These are the arrays for the operation.
optimize : {False, True, 'greedy', 'optimal'}, optional
Controls if intermediate optimization should occur. No optimization will occur if False and True will default to the 'greedy' algorithm. Also accepts an explicit contraction list from the np.einsum_path function. See np.einsum_path for more details. Defaults to False.


output : ndarray
The calculation based on the Einstein summation convention.
def equal(x1, x2, /)
def erf(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]) ‑> jax.Array

Returns the error function of complex argument.

LAX-backend implementation of :func:scipy.special.erf.

Note that the JAX version does not support complex inputs.

Original docstring below.

It is defined as 2/sqrt(pi)*integral(exp(-t**2), t=0..z).


x : ndarray
Input array.


res : scalar or ndarray
The values of the error function at the given points x.


.. [1] .. [2] Milton Abramowitz and Irene A. Stegun, eds. Handbook of Mathematical Functions with Formulas, Graphs, and Mathematical Tables. New York: Dover, 1972. .. [3] Steven G. Johnson, Faddeeva W function implementation.

def exp(x, /)
def flip(m: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axis: Union[int, Tuple[int, ...], None] = None) ‑> jax.Array

Reverse the order of elements in an array along the given axis.

LAX-backend implementation of :func:numpy.flip.

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

The shape of the array is preserved, but the elements are reordered.

Added in version: 1.12.0


m : array_like
Input array.
axis : None or int or tuple of ints, optional

Axis or axes along which to flip over. The default, axis=None, will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis.

If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple.

!!! versionchanged "Changed in version: 1.15.0" None and tuples of axes are supported


out : array_like
A view of m with the entries of axis reversed. Since a view is returned, this operation is done in constant time.
def floor(x, /)
def imag(val: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def isfinite(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def isinf(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def isnan(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def log(x, /)
def log10(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def log2(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def maximum(x1, x2, /)
def minimum(x1, x2, /)
def real(val: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def round(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], decimals: int = 0, out: None = None) ‑> jax.Array
def shape(a)

Return the shape of an array.


a : array_like
Input array.


shape : tuple of ints
The elements of the shape tuple give the lengths of the corresponding array dimensions.

See Also

len(a) is equivalent to np.shape(a)[0] for N-D arrays with N>=1.
Equivalent array method.


>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 3]])
(1, 2)
>>> np.shape([0])
>>> np.shape(0)
>>> a = np.array([(1, 2), (3, 4), (5, 6)],
...              dtype=[('x', 'i4'), ('y', 'i4')])
>>> np.shape(a)
>>> a.shape
def sign(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def sin(x, /)
def softplus(x: Any) ‑> Any

Softplus activation function.

Computes the element-wise function

[ \mathrm{softplus}(x) = \log(1 + e^x) ]


x : input array

def sqrt(x, /)
def stack(arrays: Union[numpy.ndarray, jax.Array, Sequence[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]]], axis: int = 0, out: None = None, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array

Join a sequence of arrays along a new axis.

LAX-backend implementation of :func:numpy.stack.

Original docstring below.

The axis parameter specifies the index of the new axis in the dimensions of the result. For example, if axis=0 it will be the first dimension and if axis=-1 it will be the last dimension.

Added in version: 1.10.0


arrays : sequence of array_like
Each array must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.
dtype : str or dtype
If provided, the destination array will have this dtype. Cannot be provided together with out.


stacked : ndarray
The stacked array has one more dimension than the input arrays.
def staticshape(a)

Return the shape of an array.


a : array_like
Input array.


shape : tuple of ints
The elements of the shape tuple give the lengths of the corresponding array dimensions.

See Also

len(a) is equivalent to np.shape(a)[0] for N-D arrays with N>=1.
Equivalent array method.


>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 3]])
(1, 2)
>>> np.shape([0])
>>> np.shape(0)
>>> a = np.array([(1, 2), (3, 4), (5, 6)],
...              dtype=[('x', 'i4'), ('y', 'i4')])
>>> np.shape(a)
>>> a.shape
def stop_gradient(x: ~T) ‑> ~T

Stops gradient computation.

Operationally stop_gradient is the identity function, that is, it returns argument x unchanged. However, stop_gradient prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, stop_gradient stops gradients for all of them.

For example:

>>> jax.grad(lambda x: x**2)(3.)
Array(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
Array(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
Array(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
Array(0., dtype=float32, weak_type=True)
def tan(x, /)
def tile(A: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], reps: Union[int, Any, Sequence[Union[int, Any]]]) ‑> jax.Array

Construct an array by repeating A the number of times given by reps.

LAX-backend implementation of :func:numpy.tile.

Original docstring below.

If reps has length d, the result will have dimension of max(d, A.ndim).

If A.ndim < d, A is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not the desired behavior, promote A to d-dimensions manually before calling this function.

If A.ndim > d, reps is promoted to A.ndim by pre-pending 1's to it. Thus for an A of shape (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2).

Note : Although tile may be used for broadcasting, it is strongly recommended to use numpy's broadcasting operations and functions.


A : array_like
The input array.
reps : array_like
The number of repetitions of A along each axis.


c : ndarray
The tiled output array.
def transpose(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axes: Optional[Sequence[int]] = None) ‑> jax.Array

Returns an array with axes transposed.

LAX-backend implementation of :func:numpy.transpose.

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

For a 1-D array, this returns an unchanged view of the original array, as a transposed vector is simply the same vector. To convert a 1-D array into a 2-D column vector, an additional dimension must be added, e.g., np.atleast2d(a).T achieves this, as does a[:, np.newaxis]. For a 2-D array, this is the standard matrix transpose. For an n-D array, if axes are given, their order indicates how the axes are permuted (see Examples). If axes are not provided, then transpose(a).shape == a.shape[::-1].


a : array_like
Input array.
axes : tuple or list of ints, optional
If specified, it must be a tuple or list which contains a permutation of [0,1,…,N-1] where N is the number of axes of a. The i'th axis of the returned array will correspond to the axis numbered axes[i] of the input. If not specified, defaults to range(a.ndim)[::-1], which reverses the order of the axes.


p : ndarray
a with its axes permuted. A view is returned whenever possible.


def all(self, boolean_tensor, axis=None, keepdims=False)
def allocate_on_device(self, tensor: ~TensorType, device: phiml.backend._backend.ComputeDevice) ‑> ~TensorType

Moves tensor to device. May copy the tensor if it is already on the device.


Existing tensor native to this backend.
Target device, associated with this backend.
def any(self, boolean_tensor, axis=None, keepdims=False)
def argsort(self, x, axis=-1)
def as_tensor(self, x, convert_external=True)

Converts a tensor-like object to the native tensor representation of this backend. If x is a native tensor of this backend, it is returned without modification. If x is a Python number (numbers.Number instance), convert_numbers decides whether to convert it unless the backend cannot handle Python numbers.

Note: There may be objects that are considered tensors by this backend but are not native and thus, will be converted by this method.


tensor-like, e.g. list, tuple, Python number, tensor
if False and x is a Python number that is understood by this backend, this method returns the number as-is. This can help prevent type clashes like int32 vs int64. (Default value = True)


tensor representation of x

def batched_gather_1d(self, values, indices)


(batch, spatial)
(batch, indices)


(batch, indices)

def batched_gather_nd(self, values, indices)

Gathers values from the tensor values at locations indices. The first dimension of values and indices is the batch dimension which must be either equal for both or one for either.


tensor of shape (batch, spatial…, channel)
int tensor of shape (batch, any…, multi_index) where the size of multi_index is values.rank - 2.


Gathered values as tensor of shape (batch, any…, channel)

def bincount(self, x, weights: Optional[~TensorType], bins: int, x_sorted=False)


Bin indices, 1D int tensor.
Weights corresponding to x, 1D tensor. All weights are 1 if weights=None.
Number of bins.
Whether x is sorted from lowest to highest bin.



def block_until_ready(self, values)
def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0)


tensor with any number of dimensions
1D mask tensor
Axis index >= 0
Maximum size of the output along axis. This must be set when jit-compiling with Jax.
If new_length is larger than the filtered result, the remaining values will be set to fill_value.
def cast(self, x, dtype: phiml.backend._dtype.DType)
def conv(self, value, kernel, zero_padding=True)

Convolve value with kernel. Depending on the tensor rank, the convolution is either 1D (rank=3), 2D (rank=4) or 3D (rank=5). Higher dimensions may not be supported.


tensor of shape (batch_size, in_channel, spatial…)
tensor of shape (batch_size or 1, out_channel, in_channel, spatial…)
If True, pads the edges of value with zeros so that the result has the same shape as value.


Convolution result as tensor of shape (batch_size, out_channel, spatial…)

def copy(self, tensor, only_mutable=False)
def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) ‑> Callable

Creates a function based on f that uses a custom gradient for backprop.


Forward function.
Function for backprop. Will be called as gradient(*d_out) to compute the gradient of f.


Function with similar signature and return values as f. However, the returned function does not support keyword arguments.

def disassemble(self, x) ‑> Tuple[Callable, Sequence[~TensorType]]

Disassemble a (sparse) tensor into its individual constituents, such as values and indices.




Function assemble(backend, *constituents) that reassembles x from the constituents.
Tensors contained in x.
def divide_no_nan(self, x, y)

Computes x/y but returns 0 if y=0.

def dtype(self, array) ‑> phiml.backend._dtype.DType
def eig(self, matrix: ~TensorType) ‑> ~TensorType


(batch…, n, n)


(batch…, n,)
(batch…, n, n)
def eigvals(self, matrix: ~TensorType) ‑> ~TensorType


(batch…, n, n)


eigenvalues as (batch…, n,)

def expand_dims(self, a, axis=0, number=1)
def fft(self, x, axes: Union[tuple, list])

Computes the n-dimensional FFT along all but the first and last dimensions.


tensor of dimension 3 or higher
Along which axes to perform the FFT


Complex tensor k

def from_dlpack(self, capsule)
def gamma_inc_l(self, a, x)

Regularized lower incomplete gamma function.

def gamma_inc_u(self, a, x)

Regularized upper incomplete gamma function.

def gather(self, values, indices, axis: int)

Gathers values from the tensor values at locations indices.


1D tensor
Axis along which to gather slices


tensor, with size along axis being the length of indices

def get_device(self, tensor: ~TensorType) ‑> phiml.backend._backend.ComputeDevice

Returns the device tensor is located on.

def get_diagonal(self, matrices, offset=0)


(batch, rows, cols, channels)
0=diagonal, positive=above diagonal, negative=below diagonal


(batch, max(rows,cols), channels)
def get_sparse_format(self, x) ‑> str

Returns lower-case format string, such as 'coo', 'csr', 'csc'

def histogram1d(self, values, weights, bin_edges)


(batch, values)
(batch, edges)
(batch, values)


(batch, edges) with dtype matching weights

def ifft(self, k, axes: Union[tuple, list])

Computes the n-dimensional inverse FFT along all but the first and last dimensions.


tensor of dimension 3 or higher
Along which axes to perform the inverse FFT


Complex tensor x

def is_available(self, tensor)

Tests if the value of the tensor is known and can be read at this point. If true, numpy(tensor) must return a valid NumPy representation of the value.

Tensors are typically available when the backend operates in eager mode.


backend-compatible tensor



def is_module(self, obj)

Tests if obj is of a type that is specific to this backend, e.g. a neural network. If True, this backend will be chosen for operations involving obj.

See Also: Backend.is_tensor().


Object to test.
def is_sparse(self, x) ‑> bool


Tensor native to this Backend.
def is_tensor(self, x, only_native=False)

An object is considered a native tensor by a backend if no internal conversion is required by backend methods. An object is considered a tensor (nativer or otherwise) by a backend if it is not a struct (e.g. tuple, list) and all methods of the backend accept it as a tensor argument.

If True, this backend will be chosen for operations involving x.

See Also: Backend.is_module().


object to check
If True, only accepts true native tensor representations, not Python numbers or others that are also supported as tensors (Default value = False)


whether x is considered a tensor by this backend
def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool)


Function to differentiate. Returns a tuple containing (reduced_loss, output)
Argument indices for which to compute the gradient.
Whether the derivative function should return the output of f in addition to the gradient.
Whether f is guaranteed to return a scalar output.


A function g with the same arguments as f. If get_output=True, g returns a tuplecontaining the outputs of f followed by the gradients. The gradients retain the dimensions of reduced_loss in order as outer (first) dimensions.

def jit_compile(self, f: Callable) ‑> Callable
def linspace(self, start, stop, number)
def linspace_without_last(self, start, stop, number)
def log_gamma(self, x)
def matrix_rank_dense(self, matrix, hermitian=False)


Dense matrix of shape (batch, rows, cols)
Whether all matrices are guaranteed to be hermitian.
def matrix_solve_least_squares(self, matrix: ~TensorType, rhs: ~TensorType) ‑> Tuple[~TensorType, ~TensorType, ~TensorType, ~TensorType]


Shape (batch, vec, constraints)
Shape (batch, vec, batch_per_matrix)


Solution vector of Shape (batch, constraints, batch_per_matrix)
Optional, can be None
Optional, can be None
Optional, can be None
def max(self, x, axis=None, keepdims=False)
def mean(self, value, axis=None, keepdims=False)
def meshgrid(self, *coordinates)
def min(self, x, axis=None, keepdims=False)
def mul(self, a, b)
def mul_matrix_batched_vector(self, A, b)
def nn_library(self)
def nonzero(self, values, length=None, fill_value=-1)


Tensor with only spatial dimensions
(Optional) Length of the resulting array. If specified, the result array will be padded with fill_value or trimmed.


non-zero multi-indices as tensor of shape (nnz/length, vector)

def numpy(self, tensor)

Returns a NumPy representation of the given tensor. If tensor is already a NumPy array, it is returned without modification.

This method raises an error if the value of the tensor is not known at this point, e.g. because it represents a node in a graph. Use is_available(tensor) to check if the value can be represented as a NumPy array.


backend-compatible tensor or sparse tensor


NumPy representation of the values stored in the tensor

def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args)

This call can be used in jit-compiled code but is not differentiable.


Function operating on numpy arrays.
Single shape tuple or tuple of shapes declaring the shapes of the tensors returned by f.
Single DType or tuple of DTypes declaring the dtypes of the tensors returned by f.
Tensor arguments to be converted to NumPy arrays and then passed to f.
Keyword arguments to be passed to f without conversion.


Returned arrays of f converted to tensors.

def ones(self, shape, dtype: phiml.backend._dtype.DType = None)
def ones_like(self, tensor)
def pad(self, value, pad_width, mode='constant', constant_values=0)

Pad a tensor with values as specified by mode and constant_values.

If the mode is not supported, returns NotImplemented.


2D tensor specifying the number of values padded to the edges of each axis in the form [[axis 0 lower, axis 0 upper], …] including batch and component axes.
constant', 'boundary', 'periodic', 'symmetric', 'reflect'
Scalar value used for out-of-bounds points if mode='constant'. Must be a Python primitive type or scalar tensor.
str: (Default value = 'constant')


padded tensor or NotImplemented

def prefers_channels_last(self) ‑> bool
def prod(self, value, axis=None)
def quantile(self, x, quantiles)

Reduces the last / inner axis of x.


List or 1D tensor of quantiles to compute.


Tensor with shape (quantiles, *x.shape[:-1])

def random_normal(self, shape, dtype: phiml.backend._dtype.DType)

Float tensor of selected precision containing random values sampled from a normal distribution with mean 0 and std 1.

def random_permutations(self, permutations: int, n: int)

Generate permutations stacked arrays of shuffled integers between 0 and n.

def random_uniform(self, shape, low, high, dtype: Optional[phiml.backend._dtype.DType])

Float tensor of selected precision containing random values in the range [0, 1)

def range(self, start, limit=None, delta=1, dtype: phiml.backend._dtype.DType = int32)
def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined')


(batch…, index_dim)
1D tensor or tuple/list
'undefined', 'periodic', 'clamp' or an int to use for all invalid indices.


Integer tensor of shape (batch…) of same dtype as multi_index.

def repeat(self, x, repeats, axis: int, new_length=None)

Repeats the elements along axis repeats times.


How often to repeat each element. 1D tensor of length x.shape[axis]
Which axis to repeat elements along
Set the length of axis after repeating. This is required for jit compilation with Jax.


repeated Tensor

def requires_fixed_shapes_when_tracing(self) ‑> bool
def reshape(self, value, shape)
def scatter(self, base_grid, indices, values, mode: str)

Batched n-dimensional scatter.


Tensor into which scatter values are inserted at indices. Tensor of shape (batch_size, spatial…, channels)
Tensor of shape (batch_size or 1, update_count, index_vector)
Values to scatter at indices. Tensor of shape (batch_size or 1, update_count or 1, channels or 1)
One of ('update', 'add', 'max', 'min')


Copy of base_grid with values at indices updated by values.

def searchsorted(self, sorted_sequence, search_values, side: str, dtype=int32)
def seed(self, seed: int)
def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool)


(batch_size, rows, cols)
(batch_size, cols)

lower: unit_diagonal:


(batch_size, cols)

def sort(self, x, axis=-1)
def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple)

Create a sparse matrix in coordinate list (COO) format.

Optional feature.

See Also: Backend.csr_matrix(), Backend.csc_matrix().


2D tensor of shape (nnz, dims).
1D values tensor matching indices
Shape of the sparse matrix


Native representation of the sparse matrix

def std(self, x, axis=None, keepdims=False)
def sum(self, value, axis=None, keepdims=False)
def svd(self, matrix: ~TensorType, full_matrices=True) ‑> Tuple[~TensorType, ~TensorType, ~TensorType]


(batch…, m, n)


(batch…, n,)
(batch…, n, n)
def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list])

Multiply-sum-reduce a_axes of a with b_axes of b.

def to_dlpack(self, tensor)
def unique(self, x: ~TensorType, return_inverse: bool, return_counts: bool, axis: int) ‑> Tuple[~TensorType, ...]


n-dimensional int array. Will compare axis-slices of x for multidimensional x.
Whether to return the inverse
Whether to return the counts.
Axis along which slices of x should be compared.


Sorted unique slices of x
(optional) index of the unique slice for each slice of x
Number of occurrences of each unique slices
def unravel_index(self, flat_index, shape)
def vectorized_call(self, f, *args, output_dtypes=None, **aux_args)


Function with only positional tensor argument, returning one or multiple tensors.
Batched inputs for f. The first dimension of all args is vectorized. All tensors in args must have the same size or 1 in their first dimension.
Single DType or tuple of DTypes declaring the dtypes of the tensors returned by f.
Non-vectorized keyword arguments to be passed to f.
def where(self, condition, x=None, y=None)
def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]])

If max_iter is None, runs

while any(values[0]):
    values = loop(*values)
return values

This operation does not support backpropagation.


Loop function, must return a tuple with entries equal to values in shape and data type.
Initial values of loop variables.
Maximum number of iterations to run, single int or sequence of integers.


Loop variables upon loop completion if max_iter is a single integer. If max_iter is a sequence, stacks the variables after each entry in max_iter, adding an outer dimension of size <= len(max_iter). If the condition is fulfilled before the maximum max_iter is reached, the loop may be broken or not, depending on the implementation. If the loop is broken, the values returned by the last loop are expected to be constant and filled.

def zeros(self, shape, dtype: phiml.backend._dtype.DType = None)
def zeros_like(self, tensor)