Module phiml.backend.jax

Expand source code
from ._jax_backend import JaxBackend
"""Backend for Jax operations."""

JAX = JaxBackend()

__all__ = [key for key in globals().keys() if not key.startswith('_')]

Sub-modules

phiml.backend.jax.stax_nets

Stax implementation of the unified machine learning API. Equivalent functions also exist for the other frameworks …

Classes

class JaxBackend

Backends delegate low-level operations to a ML or numerics library or emulate them. The methods of Backend form a comprehensive list of available operations.

To support a library, subclass Backend and register it by adding it to BACKENDS.

Args

name
Human-readable string
default_device
ComputeDevice being used by default
Expand source code
class JaxBackend(Backend):

    def __init__(self):
        devices = []
        for device_type in ['cpu', 'gpu', 'tpu']:
            try:
                for jax_dev in jax.devices(device_type):
                    devices.append(ComputeDevice(self, device_type.upper(), jax_dev.platform.upper(), -1, -1, f"id={jax_dev.id}", jax_dev))
            except RuntimeError as err:
                pass  # this is just Jax not finding anything. jaxlib.xla_client._get_local_backends() could help but isn't currently available on GitHub actions
        Backend.__init__(self, 'jax', devices, devices[-1])
        try:
            self.rnd_key = jax.random.PRNGKey(seed=0)
        except RuntimeError as err:
            warnings.warn(f"{err}", RuntimeWarning)
            self.rnd_key = None

    def prefers_channels_last(self) -> bool:
        return True

    def requires_fixed_shapes_when_tracing(self) -> bool:
        return True

    def nn_library(self):
        from . import stax_nets
        return stax_nets

    def _check_float64(self):
        if self.precision == 64:
            if not jax.config.read('jax_enable_x64'):
                jax.config.update('jax_enable_x64', True)
            assert jax.config.read('jax_enable_x64'), "FP64 is disabled for Jax."

    def seed(self, seed: int):
        self.rnd_key = jax.random.PRNGKey(seed)

    def as_tensor(self, x, convert_external=True):
        self._check_float64()
        if self.is_tensor(x, only_native=convert_external):
            array = x
        else:
            array = jnp.array(x)
        # --- Enforce Precision ---
        if not isinstance(array, numbers.Number):
            if self.dtype(array).kind == float:
                array = self.to_float(array)
            elif self.dtype(array).kind == complex:
                array = self.to_complex(array)
        return array

    def is_module(self, obj):
        return False

    def is_tensor(self, x, only_native=False):
        if isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray):  # NumPy arrays inherit from Jax arrays
            return True
        if isinstance(x, jnp.bool_) and not isinstance(x, np.bool_):
            return True
        if self.is_sparse(x):
            return True
        # --- Above considered native ---
        if only_native:
            return False
        # --- Non-native types ---
        if isinstance(x, np.ndarray):
            return True
        if isinstance(x, np.bool_):
            return True
        if isinstance(x, (numbers.Number, bool)):
            return True
        if isinstance(x, (tuple, list)):
            return all([self.is_tensor(item, False) for item in x])
        return False

    def is_sparse(self, x) -> bool:
        return isinstance(x, (COO, BCOO, CSR, CSC))

    def get_sparse_format(self, x) -> str:
        format_names = {
            COO: 'coo',
            BCOO: 'coo',
            CSR: 'csr',
            CSC: 'csc',
        }
        return format_names.get(type(x), 'dense')

    def is_available(self, tensor):
        return not isinstance(tensor, Tracer)

    def numpy(self, tensor):
        if self.is_sparse(tensor):
            assemble, parts = self.disassemble(tensor)
            return assemble(NUMPY, *[self.numpy(t) for t in parts])
        return np.array(tensor)

    def disassemble(self, x) -> Tuple[Callable, Sequence[TensorType]]:
        if self.is_sparse(x):
            if isinstance(x, COO):
                raise NotImplementedError
                # return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (np.stack([x.row, x.col], -1), x.data)
            if isinstance(x, BCOO):
                return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (x.indices, x.data)
            if isinstance(x, CSR):
                raise NotImplementedError
                # return lambda b, v, i, p: b.csr_matrix(i, p, v, x.shape), (x.data, x.indices, x.indptr)
            elif isinstance(x, CSC):
                raise NotImplementedError
                # return lambda b, v, i, p: b.csc_matrix(p, i, v, x.shape), (x.data, x.indices, x.indptr)
            raise NotImplementedError
        else:
            return lambda b, t: t, (x,)

    def to_dlpack(self, tensor):
        from jax import dlpack
        return dlpack.to_dlpack(tensor)

    def from_dlpack(self, capsule):
        from jax import dlpack
        return dlpack.from_dlpack(capsule)

    def copy(self, tensor, only_mutable=False):
        return jnp.array(tensor, copy=True)

    def get_device(self, tensor: TensorType) -> ComputeDevice:
        return self.get_device_by_ref(tensor.device())

    def allocate_on_device(self, tensor: TensorType, device: ComputeDevice) -> TensorType:
        return jax.device_put(tensor, device.ref)

    sqrt = staticmethod(jnp.sqrt)
    exp = staticmethod(jnp.exp)
    erf = staticmethod(scipy.special.erf)
    softplus = staticmethod(jax.nn.softplus)
    sin = staticmethod(jnp.sin)
    arcsin = staticmethod(jnp.arcsin)
    cos = staticmethod(jnp.cos)
    arccos = staticmethod(jnp.arccos)
    tan = staticmethod(jnp.tan)
    arctan = staticmethod(np.arctan)
    arctan2 = staticmethod(np.arctan2)
    sinh = staticmethod(np.sinh)
    arcsinh = staticmethod(np.arcsinh)
    cosh = staticmethod(np.cosh)
    arccosh = staticmethod(np.arccosh)
    tanh = staticmethod(np.tanh)
    arctanh = staticmethod(np.arctanh)
    log = staticmethod(jnp.log)
    log2 = staticmethod(jnp.log2)
    log10 = staticmethod(jnp.log10)
    isfinite = staticmethod(jnp.isfinite)
    isnan = staticmethod(jnp.isnan)
    isinf = staticmethod(jnp.isinf)
    abs = staticmethod(jnp.abs)
    sign = staticmethod(jnp.sign)
    round = staticmethod(jnp.round)
    ceil = staticmethod(jnp.ceil)
    floor = staticmethod(jnp.floor)
    flip = staticmethod(jnp.flip)
    stop_gradient = staticmethod(jax.lax.stop_gradient)
    transpose = staticmethod(jnp.transpose)
    equal = staticmethod(jnp.equal)
    tile = staticmethod(jnp.tile)
    stack = staticmethod(jnp.stack)
    concat = staticmethod(jnp.concatenate)
    maximum = staticmethod(jnp.maximum)
    minimum = staticmethod(jnp.minimum)
    clip = staticmethod(jnp.clip)
    argmax = staticmethod(np.argmax)
    argmin = staticmethod(np.argmin)
    shape = staticmethod(jnp.shape)
    staticshape = staticmethod(jnp.shape)
    imag = staticmethod(jnp.imag)
    real = staticmethod(jnp.real)
    conj = staticmethod(jnp.conjugate)
    einsum = staticmethod(jnp.einsum)
    cumsum = staticmethod(jnp.cumsum)

    def nonzero(self, values, length=None, fill_value=-1):
        result = jnp.nonzero(values, size=length, fill_value=fill_value)
        return jnp.stack(result, -1)

    def vectorized_call(self, f, *args, output_dtypes=None, **aux_args):
        batch_size = self.determine_size(args, 0)
        args = [self.tile_to(t, 0, batch_size) for t in args]
        def f_positional(*args):
            return f(*args, **aux_args)
        vec_f = jax.vmap(f_positional, 0, 0)
        return vec_f(*args)

    def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args):
        @dataclasses.dataclass
        class OutputTensor:
            shape: Tuple[int]
            dtype: np.dtype
        output_specs = map_structure(lambda t, s: OutputTensor(s, to_numpy_dtype(t)), output_dtypes, output_shapes)
        if hasattr(jax, 'pure_callback'):
            def aux_f(*args):
                return f(*args, **aux_args)
            return jax.pure_callback(aux_f, output_specs, *args)
        else:
            def aux_f(args):
                if isinstance(args, tuple):
                    return f(*args, **aux_args)
                else:
                    return f(args, **aux_args)
            from jax.experimental.host_callback import call
            return call(aux_f, args, result_shape=output_specs)

    def jit_compile(self, f: Callable) -> Callable:
        def run_jit_f(*args):
            # print(jax.make_jaxpr(f)(*args))
            ML_LOGGER.debug(f"JaxBackend: running jit-compiled '{f.__name__}' with shapes {[self.shape(arg) for arg in args]} and dtypes {[self.dtype(arg) for arg in args]}")
            return self.as_registered.call(jit_f, *args, name=f"run jit-compiled '{f.__name__}'")

        run_jit_f.__name__ = f"Jax-Jit({f.__name__})"
        jit_f = jax.jit(f, device=self._default_device.ref)
        return run_jit_f

    def block_until_ready(self, values):
        if hasattr(values, 'block_until_ready'):
            values.block_until_ready()
        if isinstance(values, (tuple, list)):
            for v in values:
                self.block_until_ready(v)

    def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool):
        if get_output:
            jax_grad_f = jax.value_and_grad(f, argnums=wrt, has_aux=True)
            @wraps(f)
            def unwrap_outputs(*args):
                args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
                (_, output_tuple), grads = jax_grad_f(*args)
                return (*output_tuple, *[jnp.conjugate(g) for g in grads])
            return unwrap_outputs
        else:
            @wraps(f)
            def nonaux_f(*args):
                loss, output = f(*args)
                return loss
            jax_grad = jax.grad(nonaux_f, argnums=wrt, has_aux=False)
            @wraps(f)
            def call_jax_grad(*args):
                args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
                grads = jax_grad(*args)
                return tuple([jnp.conjugate(g) for g in grads])
            return call_jax_grad

    def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) -> Callable:
        jax_fun = jax.custom_vjp(f)  # custom vector-Jacobian product (reverse-mode differentiation)

        def forward(*x):
            y = f(*x)
            return y, (x, y)

        def backward(x_y, dy):
            x, y = x_y
            dx = gradient(x, y, dy)
            return tuple(dx)

        jax_fun.defvjp(forward, backward)
        return jax_fun

    def divide_no_nan(self, x, y):
        return jnp.where(y == 0, 0, x / y)
        # jnp.nan_to_num(x / y, copy=True, nan=0) covers up NaNs from before

    def random_uniform(self, shape, low, high, dtype: Union[DType, None]):
        self._check_float64()
        self.rnd_key, subkey = jax.random.split(self.rnd_key)

        dtype = dtype or self.float_type
        jdt = to_numpy_dtype(dtype)
        if dtype.kind == float:
            tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt)
        elif dtype.kind == complex:
            real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType(float, dtype.precision)))
            imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType(float, dtype.precision)))
            return real + 1j * imag
        elif dtype.kind == int:
            tensor = random.randint(subkey, shape, low, high, dtype=jdt)
            if tensor.dtype != jdt:
                warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning)
        else:
            raise ValueError(dtype)
        return jax.device_put(tensor, self._default_device.ref)

    def random_normal(self, shape, dtype: DType):
        self._check_float64()
        self.rnd_key, subkey = jax.random.split(self.rnd_key)
        dtype = dtype or self.float_type
        return jax.device_put(random.normal(subkey, shape, dtype=to_numpy_dtype(dtype)), self._default_device.ref)

    def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)):
        if limit is None:
            start, limit = 0, start
        return jnp.arange(start, limit, delta, to_numpy_dtype(dtype))

    def pad(self, value, pad_width, mode='constant', constant_values=0):
        assert mode in ('constant', 'symmetric', 'periodic', 'reflect', 'boundary'), mode
        if mode == 'constant':
            constant_values = jnp.array(constant_values, dtype=value.dtype)
            return jnp.pad(value, pad_width, 'constant', constant_values=constant_values)
        else:
            if mode in ('periodic', 'boundary'):
                mode = {'periodic': 'wrap', 'boundary': 'edge'}[mode]
            return jnp.pad(value, pad_width, mode)

    def reshape(self, value, shape):
        return jnp.reshape(value, shape)

    def sum(self, value, axis=None, keepdims=False):
        if isinstance(value, (tuple, list)):
            assert axis == 0
            return sum(value[1:], value[0])
        return jnp.sum(value, axis=axis, keepdims=keepdims)

    def prod(self, value, axis=None):
        if not isinstance(value, jnp.ndarray):
            value = jnp.array(value)
        if value.dtype == bool:
            return jnp.all(value, axis=axis)
        return jnp.prod(value, axis=axis)

    def where(self, condition, x=None, y=None):
        if x is None or y is None:
            return jnp.argwhere(condition)
        return jnp.where(condition, x, y)

    def zeros(self, shape, dtype: DType = None):
        self._check_float64()
        return jax.device_put(jnp.zeros(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)

    def zeros_like(self, tensor):
        return jax.device_put(jnp.zeros_like(tensor), self._default_device.ref)

    def ones(self, shape, dtype: DType = None):
        self._check_float64()
        return jax.device_put(jnp.ones(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)

    def ones_like(self, tensor):
        return jax.device_put(jnp.ones_like(tensor), self._default_device.ref)

    def meshgrid(self, *coordinates):
        self._check_float64()
        coordinates = [self.as_tensor(c) for c in coordinates]
        return [jax.device_put(c, self._default_device.ref) for c in jnp.meshgrid(*coordinates, indexing='ij')]

    def linspace(self, start, stop, number):
        self._check_float64()
        return jax.device_put(jnp.linspace(start, stop, number, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)

    def linspace_without_last(self, start, stop, number):
        self._check_float64()
        return jax.device_put(jnp.linspace(start, stop, number, endpoint=False, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)

    def mean(self, value, axis=None, keepdims=False):
        return jnp.mean(value, axis, keepdims=keepdims)

    def log_gamma(self, x):
        return jax.lax.lgamma(self.to_float(x))

    def gamma_inc_l(self, a, x):
        return scipy.special.gammainc(a, x)

    def gamma_inc_u(self, a, x):
        return scipy.special.gammaincc(a, x)

    def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list]):
        return jnp.tensordot(a, b, (a_axes, b_axes))

    def mul(self, a, b):
        # if scipy.sparse.issparse(a):  # TODO sparse?
        #     return a.multiply(b)
        # elif scipy.sparse.issparse(b):
        #     return b.multiply(a)
        # else:
            return Backend.mul(self, a, b)

    def mul_matrix_batched_vector(self, A, b):
        from jax.experimental.sparse import BCOO
        if isinstance(A, BCOO):
            return(A @ b.T).T
        return jnp.stack([A.dot(b[i]) for i in range(b.shape[0])])

    def get_diagonal(self, matrices, offset=0):
        result = jnp.diagonal(matrices, offset=offset, axis1=1, axis2=2)
        return jnp.transpose(result, [0, 2, 1])

    def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]]):
        if all(self.is_available(t) for t in values):
            return self.stop_gradient_tree(Backend.while_loop(self, loop, values, max_iter))
        if isinstance(max_iter, (tuple, list)):  # stack traced trajectory, unroll until max_iter
            values = self.stop_gradient_tree(values)
            trj = [values] if 0 in max_iter else []
            for i in range(1, max(max_iter) + 1):
                values = loop(*values)
                if i in max_iter:
                    trj.append(values)  # values are not mutable so no need to copy
            return self.stop_gradient_tree(self.stack_leaves(trj))
        else:
            if max_iter is None:
                cond = lambda vals: jnp.any(vals[0])
                body = lambda vals: loop(*vals)
                return jax.lax.while_loop(cond, body, values)
            else:
                cond = lambda vals: jnp.any(vals[1][0]) & (vals[0] < max_iter)
                body = lambda vals: (vals[0] + 1, loop(*vals[1]))
                return jax.lax.while_loop(cond, body, (self.as_tensor(0), values))[1]

    def max(self, x, axis=None, keepdims=False):
        return jnp.max(x, axis, keepdims=keepdims)

    def min(self, x, axis=None, keepdims=False):
        return jnp.min(x, axis, keepdims=keepdims)

    def conv(self, value, kernel, zero_padding=True):
        assert kernel.shape[0] in (1, value.shape[0])
        assert value.shape[1] == kernel.shape[2], f"value has {value.shape[1]} channels but kernel has {kernel.shape[2]}"
        assert value.ndim + 1 == kernel.ndim
        # AutoDiff may require jax.lax.conv_general_dilated
        result = []
        for b in range(value.shape[0]):
            b_kernel = kernel[min(b, kernel.shape[0] - 1)]
            result_b = []
            for o in range(kernel.shape[1]):
                result_b.append(0)
                for i in range(value.shape[1]):
                    # result.at[b, o, ...].set(scipy.signal.correlate(value[b, i, ...], b_kernel[o, i, ...], mode='same' if zero_padding else 'valid'))
                    result_b[-1] += scipy.signal.correlate(value[b, i, ...], b_kernel[o, i, ...], mode='same' if zero_padding else 'valid')
            result.append(jnp.stack(result_b, 0))
        return jnp.stack(result, 0)

    def expand_dims(self, a, axis=0, number=1):
        for _i in range(number):
            a = jnp.expand_dims(a, axis)
        return a

    def cast(self, x, dtype: DType):
        if self.is_tensor(x, only_native=True) and from_numpy_dtype(x.dtype) == dtype:
            return x
        else:
            return jnp.array(x, to_numpy_dtype(dtype))

    def unravel_index(self, flat_index, shape):
        return jnp.stack(jnp.unravel_index(flat_index, shape), -1)

    def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined'):
        if not self.is_available(shape):
            return Backend.ravel_multi_index(self, multi_index, shape, mode)
        mode = mode if isinstance(mode, int) else {'undefined': 'clip', 'periodic': 'wrap', 'clamp': 'clip'}[mode]
        idx_first = jnp.transpose(multi_index, (self.ndims(multi_index)-1,) + tuple(range(self.ndims(multi_index)-1)))
        result = jnp.ravel_multi_index(idx_first, shape, mode='wrap' if isinstance(mode, int) else mode)
        if isinstance(mode, int):
            outside = self.any((multi_index < 0) | (multi_index >= jnp.asarray(shape, dtype=multi_index.dtype)), -1)
            result = self.where(outside, mode, result)
        return result

    def gather(self, values, indices, axis: int):
        slices = [indices if i == axis else slice(None) for i in range(self.ndims(values))]
        return values[tuple(slices)]

    def batched_gather_nd(self, values, indices):
        values = self.as_tensor(values)
        indices = self.as_tensor(indices)
        # batch_size = combined_dim(values.shape[0], indices.shape[0])
        assert indices.shape[-1] == self.ndims(values) - 2
        def unbatched_gather_nd(b_values, b_indices):
            b_indices = self.unstack(b_indices, -1)
            return b_values[b_indices]
        return self.vectorized_call(unbatched_gather_nd, values, indices)

    def repeat(self, x, repeats, axis: int, new_length=None):
        return jnp.repeat(x, self.as_tensor(repeats), axis, total_repeat_length=new_length)

    def std(self, x, axis=None, keepdims=False):
        return jnp.std(x, axis, keepdims=keepdims)

    def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0):
        if new_length is None:
            slices = [mask if i == axis else slice(None) for i in range(len(x.shape))]
            return x[tuple(slices)]
        else:
            indices = jnp.argwhere(mask, size=new_length, fill_value=-1)[..., 0]
            valid = indices >= 0
            valid = valid[tuple([slice(None) if i == axis else None for i in range(len(x.shape))])]
            result = self.gather(x, jnp.maximum(0, indices), axis)
            return jnp.where(valid, result, fill_value)

    def any(self, boolean_tensor, axis=None, keepdims=False):
        if isinstance(boolean_tensor, (tuple, list)):
            boolean_tensor = jnp.stack(boolean_tensor)
        return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims)

    def all(self, boolean_tensor, axis=None, keepdims=False):
        if isinstance(boolean_tensor, (tuple, list)):
            boolean_tensor = jnp.stack(boolean_tensor)
        return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims)

    def scatter(self, base_grid, indices, values, mode: str):
        assert mode in ('add', 'update', 'max', 'min')
        base_grid, values = self.auto_cast(base_grid, values)
        batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0])
        spatial_dims = tuple(range(base_grid.ndim - 2))
        dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,),  # channel dim of updates (batch dim removed)
                                                inserted_window_dims=spatial_dims,  # no idea what this does but spatial_dims seems to work
                                                scatter_dims_to_operand_dims=spatial_dims)  # spatial dims of base_grid (batch dim removed)
        scatter = {'add': jax.lax.scatter_add, 'update': jax.lax.scatter, 'max': jax.lax.scatter_max, 'min': jax.lax.scatter_min}[mode]
        def scatter_single(base_grid, indices, values):
            return scatter(base_grid, indices, values, dnums)
        if self.staticshape(indices)[0] == 1:
            indices = self.tile(indices, [batch_size, 1, 1])
        if self.staticshape(values)[0] == 1:
            values = self.tile(values, [batch_size, 1, 1])
        return self.vectorized_call(scatter_single, base_grid, indices, values)

    def histogram1d(self, values, weights, bin_edges):
        def unbatched_hist(values, weights, bin_edges):
            hist, _ = jnp.histogram(values, bin_edges, weights=weights)
            return hist
        return jax.vmap(unbatched_hist)(values, weights, bin_edges)

    def bincount(self, x, weights: Optional[TensorType], bins: int, x_sorted=False):
        if x_sorted:
            return jax.ops.segment_sum(weights or 1, x, bins, indices_are_sorted=True)
        else:
            return jnp.bincount(x, weights=weights, minlength=bins, length=bins)

    def unique(self, x: TensorType, return_inverse: bool, return_counts: bool, axis: int) -> Tuple[TensorType, ...]:
        return jnp.unique(x, return_inverse=return_inverse, return_counts=return_counts, axis=axis)

    def quantile(self, x, quantiles):
        return jnp.quantile(x, quantiles, axis=-1)

    def argsort(self, x, axis=-1):
        return jnp.argsort(x, axis)

    def sort(self, x, axis=-1):
        return jnp.sort(x, axis)

    def searchsorted(self, sorted_sequence, search_values, side: str, dtype=DType(int, 32)):
        if self.ndims(sorted_sequence) == 1:
            return jnp.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype))
        else:
            return jax.vmap(partial(self.searchsorted, side=side, dtype=dtype))(sorted_sequence, search_values)

    def fft(self, x, axes: Union[tuple, list]):
        x = self.to_complex(x)
        if not axes:
            return x
        if len(axes) == 1:
            return jnp.fft.fft(x, axis=axes[0]).astype(x.dtype)
        elif len(axes) == 2:
            return jnp.fft.fft2(x, axes=axes).astype(x.dtype)
        else:
            return jnp.fft.fftn(x, axes=axes).astype(x.dtype)

    def ifft(self, k, axes: Union[tuple, list]):
        if not axes:
            return k
        if len(axes) == 1:
            return jnp.fft.ifft(k, axis=axes[0]).astype(k.dtype)
        elif len(axes) == 2:
            return jnp.fft.ifft2(k, axes=axes).astype(k.dtype)
        else:
            return jnp.fft.ifftn(k, axes=axes).astype(k.dtype)

    def dtype(self, array) -> DType:
        if isinstance(array, bool):
            return DType(bool)
        if isinstance(array, int):
            return DType(int, 32)
        if isinstance(array, float):
            return DType(float, 64)
        if isinstance(array, complex):
            return DType(complex, 128)
        if not isinstance(array, jnp.ndarray):
            array = jnp.array(array)
        return from_numpy_dtype(array.dtype)

    def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tuple[TensorType, TensorType, TensorType, TensorType]:
        solution, residuals, rank, singular_values = lstsq_batched(matrix, rhs)
        return solution, residuals, rank, singular_values

    def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool):
        matrix, rhs = self.auto_cast(matrix, rhs, int_to_float=True, bool_to_int=True)
        x = jax.lax.linalg.triangular_solve(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal, left_side=True)
        return x

    def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple):
        return BCOO((values, indices), shape=shape)

Ancestors

  • phiml.backend._backend.Backend

Class variables

var arccosh
var arcsinh
var arctan
var arctan2
var arctanh
var cosh
var sinh
var tanh

Static methods

def abs(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.absolute, module='numpy')
@partial(jit, inline=True)
def absolute(x: ArrayLike, /) -> Array:
  check_arraylike('absolute', x)
  dt = dtypes.dtype(x)
  return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
def arccos(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def arcsin(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def argmax(a, axis=None, out=None, *, keepdims=<no value>)

Returns the indices of the maximum values along an axis.

Parameters

a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims : bool, optional

If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

Added in version: 1.22.0

Returns

index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.

See Also

ndarray.argmax, argmin amax : The maximum value along a given axis. unravel_index : Convert a flat index into an index tuple. take_along_axis : Apply ``np.expand_dims(index_array, axis)`` from argmax to an array as if by calling max.

Notes

In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.

Examples

>>> a = np.arange(6).reshape(2,3) + 10
>>> a
array([[10, 11, 12],
       [13, 14, 15]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])

Indexes of the maximal elements of a N-dimensional array:

>>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
>>> ind
(1, 2)
>>> a[ind]
15
>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b)  # Only the first occurrence is returned.
1
>>> x = np.array([[4,2,3], [1,0,3]])
>>> index_array = np.argmax(x, axis=-1)
>>> # Same as np.amax(x, axis=-1, keepdims=True)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
array([[4],
       [3]])
>>> # Same as np.amax(x, axis=-1)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([4, 3])

Setting keepdims to True,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmax(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
Expand source code
@array_function_dispatch(_argmax_dispatcher)
def argmax(a, axis=None, out=None, *, keepdims=np._NoValue):
    """
    Returns the indices of the maximum values along an axis.

    Parameters
    ----------
    a : array_like
        Input array.
    axis : int, optional
        By default, the index is into the flattened array, otherwise
        along the specified axis.
    out : array, optional
        If provided, the result will be inserted into this array. It should
        be of the appropriate shape and dtype.
    keepdims : bool, optional
        If this is set to True, the axes which are reduced are left
        in the result as dimensions with size one. With this option,
        the result will broadcast correctly against the array.

        .. versionadded:: 1.22.0

    Returns
    -------
    index_array : ndarray of ints
        Array of indices into the array. It has the same shape as `a.shape`
        with the dimension along `axis` removed. If `keepdims` is set to True,
        then the size of `axis` will be 1 with the resulting array having same
        shape as `a.shape`.

    See Also
    --------
    ndarray.argmax, argmin
    amax : The maximum value along a given axis.
    unravel_index : Convert a flat index into an index tuple.
    take_along_axis : Apply ``np.expand_dims(index_array, axis)``
                      from argmax to an array as if by calling max.

    Notes
    -----
    In case of multiple occurrences of the maximum values, the indices
    corresponding to the first occurrence are returned.

    Examples
    --------
    >>> a = np.arange(6).reshape(2,3) + 10
    >>> a
    array([[10, 11, 12],
           [13, 14, 15]])
    >>> np.argmax(a)
    5
    >>> np.argmax(a, axis=0)
    array([1, 1, 1])
    >>> np.argmax(a, axis=1)
    array([2, 2])

    Indexes of the maximal elements of a N-dimensional array:

    >>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
    >>> ind
    (1, 2)
    >>> a[ind]
    15

    >>> b = np.arange(6)
    >>> b[1] = 5
    >>> b
    array([0, 5, 2, 3, 4, 5])
    >>> np.argmax(b)  # Only the first occurrence is returned.
    1

    >>> x = np.array([[4,2,3], [1,0,3]])
    >>> index_array = np.argmax(x, axis=-1)
    >>> # Same as np.amax(x, axis=-1, keepdims=True)
    >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
    array([[4],
           [3]])
    >>> # Same as np.amax(x, axis=-1)
    >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
    array([4, 3])

    Setting `keepdims` to `True`,

    >>> x = np.arange(24).reshape((2, 3, 4))
    >>> res = np.argmax(x, axis=1, keepdims=True)
    >>> res.shape
    (2, 1, 4)
    """
    kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
    return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)
def argmin(a, axis=None, out=None, *, keepdims=<no value>)

Returns the indices of the minimum values along an axis.

Parameters

a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims : bool, optional

If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

Added in version: 1.22.0

Returns

index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.

See Also

ndarray.argmin, argmax amin : The minimum value along a given axis. unravel_index : Convert a flat index into an index tuple. take_along_axis : Apply ``np.expand_dims(index_array, axis)`` from argmin to an array as if by calling min.

Notes

In case of multiple occurrences of the minimum values, the indices corresponding to the first occurrence are returned.

Examples

>>> a = np.arange(6).reshape(2,3) + 10
>>> a
array([[10, 11, 12],
       [13, 14, 15]])
>>> np.argmin(a)
0
>>> np.argmin(a, axis=0)
array([0, 0, 0])
>>> np.argmin(a, axis=1)
array([0, 0])

Indices of the minimum elements of a N-dimensional array:

>>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape)
>>> ind
(0, 0)
>>> a[ind]
10
>>> b = np.arange(6) + 10
>>> b[4] = 10
>>> b
array([10, 11, 12, 13, 10, 15])
>>> np.argmin(b)  # Only the first occurrence is returned.
0
>>> x = np.array([[4,2,3], [1,0,3]])
>>> index_array = np.argmin(x, axis=-1)
>>> # Same as np.amin(x, axis=-1, keepdims=True)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
array([[2],
       [0]])
>>> # Same as np.amax(x, axis=-1)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([2, 0])

Setting keepdims to True,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmin(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
Expand source code
@array_function_dispatch(_argmin_dispatcher)
def argmin(a, axis=None, out=None, *, keepdims=np._NoValue):
    """
    Returns the indices of the minimum values along an axis.

    Parameters
    ----------
    a : array_like
        Input array.
    axis : int, optional
        By default, the index is into the flattened array, otherwise
        along the specified axis.
    out : array, optional
        If provided, the result will be inserted into this array. It should
        be of the appropriate shape and dtype.
    keepdims : bool, optional
        If this is set to True, the axes which are reduced are left
        in the result as dimensions with size one. With this option,
        the result will broadcast correctly against the array.

        .. versionadded:: 1.22.0

    Returns
    -------
    index_array : ndarray of ints
        Array of indices into the array. It has the same shape as `a.shape`
        with the dimension along `axis` removed. If `keepdims` is set to True,
        then the size of `axis` will be 1 with the resulting array having same
        shape as `a.shape`.

    See Also
    --------
    ndarray.argmin, argmax
    amin : The minimum value along a given axis.
    unravel_index : Convert a flat index into an index tuple.
    take_along_axis : Apply ``np.expand_dims(index_array, axis)``
                      from argmin to an array as if by calling min.

    Notes
    -----
    In case of multiple occurrences of the minimum values, the indices
    corresponding to the first occurrence are returned.

    Examples
    --------
    >>> a = np.arange(6).reshape(2,3) + 10
    >>> a
    array([[10, 11, 12],
           [13, 14, 15]])
    >>> np.argmin(a)
    0
    >>> np.argmin(a, axis=0)
    array([0, 0, 0])
    >>> np.argmin(a, axis=1)
    array([0, 0])

    Indices of the minimum elements of a N-dimensional array:

    >>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape)
    >>> ind
    (0, 0)
    >>> a[ind]
    10

    >>> b = np.arange(6) + 10
    >>> b[4] = 10
    >>> b
    array([10, 11, 12, 13, 10, 15])
    >>> np.argmin(b)  # Only the first occurrence is returned.
    0

    >>> x = np.array([[4,2,3], [1,0,3]])
    >>> index_array = np.argmin(x, axis=-1)
    >>> # Same as np.amin(x, axis=-1, keepdims=True)
    >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
    array([[2],
           [0]])
    >>> # Same as np.amax(x, axis=-1)
    >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
    array([2, 0])

    Setting `keepdims` to `True`,

    >>> x = np.arange(24).reshape((2, 3, 4))
    >>> res = np.argmin(x, axis=1, keepdims=True)
    >>> res.shape
    (2, 1, 4)
    """
    kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
    return _wrapfunc(a, 'argmin', axis=axis, out=out, **kwds)
def ceil(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def clip(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], a_min: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, None] = None, a_max: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, None] = None, out: None = None) ‑> jax.Array
Expand source code
@util._wraps(np.clip, skip_params=['out'])
@jit
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = None,
         a_max: Optional[ArrayLike] = None, out: None = None) -> Array:
  util.check_arraylike("clip", a)
  if out is not None:
    raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
  if a_min is None and a_max is None:
    raise ValueError("At most one of a_min and a_max may be None")
  if a_min is not None:
    a = ufuncs.maximum(a_min, a)
  if a_max is not None:
    a = ufuncs.minimum(a_max, a)
  return asarray(a)
def concat(arrays: Union[numpy.ndarray, jax.Array, Sequence[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]]], axis: Optional[int] = 0, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array

Join a sequence of arrays along an existing axis.

LAX-backend implementation of :func:numpy.concatenate.

Original docstring below.

Parameters

axis : int, optional
The axis along which the arrays will be joined. If axis is None, arrays are flattened before use. Default is 0.
dtype : str or dtype
If provided, the destination array will have this dtype. Cannot be provided together with out.

Returns

res : ndarray
The concatenated array.
Expand source code
@util._wraps(np.concatenate)
def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
                axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array:
  if isinstance(arrays, (np.ndarray, Array)):
    return _concatenate_array(arrays, axis, dtype=dtype)
  util.check_arraylike("concatenate", *arrays)
  if not len(arrays):
    raise ValueError("Need at least one array to concatenate.")
  if ndim(arrays[0]) == 0:
    raise ValueError("Zero-dimensional arrays cannot be concatenated.")
  if axis is None:
    return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
  axis = _canonicalize_axis(axis, ndim(arrays[0]))
  if dtype is None:
    arrays_out = util.promote_dtypes(*arrays)
  else:
    arrays_out = [asarray(arr, dtype=dtype) for arr in arrays]
  # lax.concatenate can be slow to compile for wide concatenations, so form a
  # tree of concatenations as a workaround especially for op-by-op mode.
  # (https://github.com/google/jax/issues/653).
  k = 16
  while len(arrays_out) > 1:
    arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
                  for i in range(0, len(arrays_out), k)]
  return arrays_out[0]
def conj(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.conjugate, module='numpy')
@partial(jit, inline=True)
def conjugate(x: ArrayLike, /) -> Array:
  check_arraylike("conjugate", x)
  return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)
def cos(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def cumsum(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axis: Union[None, int, Sequence[int]] = None, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType] = None, out: None = None) ‑> jax.Array

Return the cumulative sum of the elements along a given axis.

LAX-backend implementation of :func:numpy.cumsum.

Original docstring below.

Parameters

a : array_like
Input array.
axis : int, optional
Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
dtype : dtype, optional
Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.

Returns

cumsum_along_axis : ndarray. A new array holding the result is returned unless out is specified, in which case a reference to out is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.

Expand source code
@_wraps(np_reduction, skip_params=['out'])
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
                         dtype: DTypeLike = None, out: None = None) -> Array:
  return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
def einsum(subscripts, /, *operands, out: None = None, optimize: str = 'optimal', precision: Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, preferred_element_type: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array

Evaluates the Einstein summation convention on the operands.

LAX-backend implementation of :func:numpy.einsum.

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a :class:~jax.lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two :class:~jax.lax.Precision enums indicating separate precision for each argument. A tuple precision does not necessarily map to multiple arguments of einsum(); rather, the specified precision is forwarded to each dot_general call used in the implementation.

Original docstring below.

Using the Einstein summation convention, many common multi-dimensional, linear algebraic array operations can be represented in a simple fashion. In implicit mode einsum computes these values.

In explicit mode, einsum provides further flexibility to compute other array operations that might not be considered classical Einstein summation operations, by disabling, or forcing summation over specified subscript labels.

See the notes and examples for clarification.

Parameters

subscripts : str
Specifies the subscripts for summation as comma separated list of subscript labels. An implicit (classical Einstein summation) calculation is performed unless the explicit indicator '->' is included as well as subscript labels of the precise output form.
operands : list of array_like
These are the arrays for the operation.
optimize : {False, True, 'greedy', 'optimal'}, optional
Controls if intermediate optimization should occur. No optimization will occur if False and True will default to the 'greedy' algorithm. Also accepts an explicit contraction list from the np.einsum_path function. See np.einsum_path for more details. Defaults to False.

Returns

output : ndarray
The calculation based on the Einstein summation convention.
Expand source code
@util._wraps(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out'])
def einsum(
    subscripts, /,
    *operands,
    out: None = None,
    optimize: str = "optimal",
    precision: PrecisionLike = None,
    preferred_element_type: Optional[DTypeLike] = None,
    _use_xeinsum: bool = False,
    _dot_general: Callable[..., Array] = lax.dot_general,
) -> Array:
  operands = (subscripts, *operands)
  if out is not None:
    raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")

  spec = operands[0] if isinstance(operands[0], str) else None

  if (_use_xeinsum or spec is not None and '{' in spec):
    return jax.named_call(lax.xeinsum, name=spec)(*operands)

  optimize = 'optimal' if optimize is True else optimize
  # using einsum_call=True here is an internal api for opt_einsum

  # Allow handling of shape polymorphism
  non_constant_dim_types = {
      type(d) for op in operands if not isinstance(op, str)
      for d in np.shape(op) if not core.is_constant_dim(d)
  }
  if not non_constant_dim_types:
    contract_path = opt_einsum.contract_path
  else:
    ty = next(iter(non_constant_dim_types))
    contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
  operands, contractions = contract_path(
        *operands, einsum_call=True, use_blas=True, optimize=optimize)

  contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)

  _einsum_computation = jax.named_call(
      _einsum, name=spec) if spec is not None else _einsum
  return _einsum_computation(operands, contractions, precision,  # type: ignore[operator]
                             preferred_element_type, _dot_general)
def equal(x1, x2, /)
Expand source code
fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2))
def erf(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]) ‑> jax.Array

Returns the error function of complex argument.

LAX-backend implementation of :func:scipy.special.erf.

Note that the JAX version does not support complex inputs.

Original docstring below.

It is defined as 2/sqrt(pi)*integral(exp(-t**2), t=0..z).

Parameters

x : ndarray
Input array.

Returns

res : scalar or ndarray
The values of the error function at the given points x.

References

.. [1] https://en.wikipedia.org/wiki/Error_function .. [2] Milton Abramowitz and Irene A. Stegun, eds. Handbook of Mathematical Functions with Formulas, Graphs, and Mathematical Tables. New York: Dover, 1972. http://www.math.sfu.ca/~cbm/aands/page_297.htm .. [3] Steven G. Johnson, Faddeeva W function implementation. http://ab-initio.mit.edu/Faddeeva

Expand source code
@_wraps(osp_special.erf, module='scipy.special', skip_params=["out"],
        lax_description="Note that the JAX version does not support complex inputs.")
def erf(x: ArrayLike) -> Array:
  x, = promote_args_inexact("erf", x)
  return lax.erf(x)
def exp(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def flip(m: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axis: Union[int, Tuple[int, ...], None] = None) ‑> jax.Array

Reverse the order of elements in an array along the given axis.

LAX-backend implementation of :func:numpy.flip.

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

The shape of the array is preserved, but the elements are reordered.

Added in version: 1.12.0

Parameters

m : array_like
Input array.
axis : None or int or tuple of ints, optional

Axis or axes along which to flip over. The default, axis=None, will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis.

If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple.

!!! versionchanged "Changed in version: 1.15.0" None and tuples of axes are supported

Returns

out : array_like
A view of m with the entries of axis reversed. Since a view is returned, this operation is done in constant time.
Expand source code
@util._wraps(np.flip, lax_description=_ARRAY_VIEW_DOC)
def flip(m: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
  util.check_arraylike("flip", m)
  return _flip(asarray(m), reductions._ensure_optional_axes(axis))
def floor(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def imag(val: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.imag)
@partial(jit, inline=True)
def imag(val: ArrayLike, /) -> Array:
  check_arraylike("imag", val)
  return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
def isfinite(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.isfinite, module='numpy')
@jit
def isfinite(x: ArrayLike, /) -> Array:
  check_arraylike("isfinite", x)
  dtype = dtypes.dtype(x)
  if dtypes.issubdtype(dtype, np.floating):
    return lax.is_finite(x)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
  else:
    return lax.full_like(x, True, dtype=np.bool_)
def isinf(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.isinf, module='numpy')
@jit
def isinf(x: ArrayLike, /) -> Array:
  check_arraylike("isinf", x)
  dtype = dtypes.dtype(x)
  if dtypes.issubdtype(dtype, np.floating):
    return lax.eq(lax.abs(x), _constant_like(x, np.inf))
  elif dtypes.issubdtype(dtype, np.complexfloating):
    re = lax.real(x)
    im = lax.imag(x)
    return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
                          lax.eq(lax.abs(im), _constant_like(im, np.inf)))
  else:
    return lax.full_like(x, False, dtype=np.bool_)
def isnan(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.isnan, module='numpy')
@jit
def isnan(x: ArrayLike, /) -> Array:
  check_arraylike("isnan", x)
  return lax.ne(x, x)
def log(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def log10(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.log10, module='numpy')
@partial(jit, inline=True)
def log10(x: ArrayLike, /) -> Array:
  x, = promote_args_inexact("log10", x)
  return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
def log2(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.log2, module='numpy')
@partial(jit, inline=True)
def log2(x: ArrayLike, /) -> Array:
  x, = promote_args_inexact("log2", x)
  return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
def maximum(x1, x2, /)
Expand source code
fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2))
def minimum(x1, x2, /)
Expand source code
fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2))
def real(val: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.real)
@partial(jit, inline=True)
def real(val: ArrayLike, /) -> Array:
  check_arraylike("real", val)
  return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
def round(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], decimals: int = 0, out: None = None) ‑> jax.Array
Expand source code
@util._wraps(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
  util.check_arraylike("round", a)
  decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
  if out is not None:
    raise NotImplementedError("The 'out' argument to jnp.round is not supported.")
  dtype = _dtype(a)
  if issubdtype(dtype, integer):
    if decimals < 0:
      raise NotImplementedError(
        "integer np.round not implemented for decimals < 0")
    return asarray(a)  # no-op on integer types

  def _round_float(x: ArrayLike) -> Array:
    if decimals == 0:
      return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)

    # TODO(phawkins): the strategy of rescaling the value isn't necessarily a
    # good one since we may be left with an incorrectly rounded value at the
    # end due to precision problems. As a workaround for float16, convert to
    # float32,
    x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x
    factor = _lax_const(x, 10 ** decimals)
    out = lax.div(lax.round(lax.mul(x, factor),
                            lax.RoundingMethod.TO_NEAREST_EVEN), factor)
    return lax.convert_element_type(out, dtype) if dtype == np.float16 else out

  if issubdtype(dtype, complexfloating):
    return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a)))
  else:
    return _round_float(a)
def shape(a)

Return the shape of an array.

Parameters

a : array_like
Input array.

Returns

shape : tuple of ints
The elements of the shape tuple give the lengths of the corresponding array dimensions.

See Also

len
len(a) is equivalent to np.shape(a)[0] for N-D arrays with N>=1.
ndarray.shape
Equivalent array method.

Examples

>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 3]])
(1, 2)
>>> np.shape([0])
(1,)
>>> np.shape(0)
()
>>> a = np.array([(1, 2), (3, 4), (5, 6)],
...              dtype=[('x', 'i4'), ('y', 'i4')])
>>> np.shape(a)
(3,)
>>> a.shape
(3,)
Expand source code
@array_function_dispatch(_shape_dispatcher)
def shape(a):
    """
    Return the shape of an array.

    Parameters
    ----------
    a : array_like
        Input array.

    Returns
    -------
    shape : tuple of ints
        The elements of the shape tuple give the lengths of the
        corresponding array dimensions.

    See Also
    --------
    len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with
          ``N>=1``.
    ndarray.shape : Equivalent array method.

    Examples
    --------
    >>> np.shape(np.eye(3))
    (3, 3)
    >>> np.shape([[1, 3]])
    (1, 2)
    >>> np.shape([0])
    (1,)
    >>> np.shape(0)
    ()

    >>> a = np.array([(1, 2), (3, 4), (5, 6)],
    ...              dtype=[('x', 'i4'), ('y', 'i4')])
    >>> np.shape(a)
    (3,)
    >>> a.shape
    (3,)

    """
    try:
        result = a.shape
    except AttributeError:
        result = asarray(a).shape
    return result
def sign(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
Expand source code
@_wraps(np.sign, module='numpy')
@jit
def sign(x: ArrayLike, /) -> Array:
  check_arraylike('sign', x)
  dtype = dtypes.dtype(x)
  if dtypes.issubdtype(dtype, np.complexfloating):
    re = lax.real(x)
    return lax.complex(
      lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
  return lax.sign(x)
def sin(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def softplus(x: Any) ‑> Any

Softplus activation function.

Computes the element-wise function

[ \mathrm{softplus}(x) = \log(1 + e^x) ]

Args

x : input array

Expand source code
@jax.jit
def softplus(x: Array) -> Array:
  r"""Softplus activation function.

  Computes the element-wise function

  .. math::
    \mathrm{softplus}(x) = \log(1 + e^x)

  Args:
    x : input array
  """
  return jnp.logaddexp(x, 0)
def sqrt(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def stack(arrays: Union[numpy.ndarray, jax.Array, Sequence[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]]], axis: int = 0, out: None = None, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array

Join a sequence of arrays along a new axis.

LAX-backend implementation of :func:numpy.stack.

Original docstring below.

The axis parameter specifies the index of the new axis in the dimensions of the result. For example, if axis=0 it will be the first dimension and if axis=-1 it will be the last dimension.

Added in version: 1.10.0

Parameters

arrays : sequence of array_like
Each array must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.
dtype : str or dtype
If provided, the destination array will have this dtype. Cannot be provided together with out.

Returns

stacked : ndarray
The stacked array has one more dimension than the input arrays.
Expand source code
@util._wraps(np.stack, skip_params=['out'])
def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
          axis: int = 0, out: None = None, dtype: Optional[DTypeLike] = None) -> Array:
  if not len(arrays):
    raise ValueError("Need at least one array to stack.")
  if out is not None:
    raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
  if isinstance(arrays, (np.ndarray, Array)):
    axis = _canonicalize_axis(axis, arrays.ndim)
    return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
  else:
    util.check_arraylike("stack", *arrays)
    shape0 = shape(arrays[0])
    axis = _canonicalize_axis(axis, len(shape0) + 1)
    new_arrays = []
    for a in arrays:
      if shape(a) != shape0:
        raise ValueError("All input arrays must have the same shape.")
      new_arrays.append(expand_dims(a, axis))
    return concatenate(new_arrays, axis=axis, dtype=dtype)
def staticshape(a)

Return the shape of an array.

Parameters

a : array_like
Input array.

Returns

shape : tuple of ints
The elements of the shape tuple give the lengths of the corresponding array dimensions.

See Also

len
len(a) is equivalent to np.shape(a)[0] for N-D arrays with N>=1.
ndarray.shape
Equivalent array method.

Examples

>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 3]])
(1, 2)
>>> np.shape([0])
(1,)
>>> np.shape(0)
()
>>> a = np.array([(1, 2), (3, 4), (5, 6)],
...              dtype=[('x', 'i4'), ('y', 'i4')])
>>> np.shape(a)
(3,)
>>> a.shape
(3,)
Expand source code
@array_function_dispatch(_shape_dispatcher)
def shape(a):
    """
    Return the shape of an array.

    Parameters
    ----------
    a : array_like
        Input array.

    Returns
    -------
    shape : tuple of ints
        The elements of the shape tuple give the lengths of the
        corresponding array dimensions.

    See Also
    --------
    len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with
          ``N>=1``.
    ndarray.shape : Equivalent array method.

    Examples
    --------
    >>> np.shape(np.eye(3))
    (3, 3)
    >>> np.shape([[1, 3]])
    (1, 2)
    >>> np.shape([0])
    (1,)
    >>> np.shape(0)
    ()

    >>> a = np.array([(1, 2), (3, 4), (5, 6)],
    ...              dtype=[('x', 'i4'), ('y', 'i4')])
    >>> np.shape(a)
    (3,)
    >>> a.shape
    (3,)

    """
    try:
        result = a.shape
    except AttributeError:
        result = asarray(a).shape
    return result
def stop_gradient(x: ~T) ‑> ~T

Stops gradient computation.

Operationally stop_gradient is the identity function, that is, it returns argument x unchanged. However, stop_gradient prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, stop_gradient stops gradients for all of them.

For example:

>>> jax.grad(lambda x: x**2)(3.)
Array(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
Array(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
Array(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
Array(0., dtype=float32, weak_type=True)
Expand source code
def stop_gradient(x: T) -> T:
  """Stops gradient computation.

  Operationally ``stop_gradient`` is the identity function, that is, it returns
  argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
  gradients during forward or reverse-mode automatic differentiation. If there
  are multiple nested gradient computations, ``stop_gradient`` stops gradients
  for all of them.

  For example:

  >>> jax.grad(lambda x: x**2)(3.)
  Array(6., dtype=float32, weak_type=True)
  >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
  Array(0., dtype=float32, weak_type=True)
  >>> jax.grad(jax.grad(lambda x: x**2))(3.)
  Array(2., dtype=float32, weak_type=True)
  >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
  Array(0., dtype=float32, weak_type=True)
  """
  def stop(x):
    # only bind primitive on inexact dtypes, to avoid some staging
    if core.has_opaque_dtype(x):
      return x
    elif (dtypes.issubdtype(_dtype(x), np.floating) or
        dtypes.issubdtype(_dtype(x), np.complexfloating)):
      return ad_util.stop_gradient_p.bind(x)
    else:
      return x
  return tree_map(stop, x)
def tan(x, /)
Expand source code
fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x))
def tile(A: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], reps: Union[int, Any, Sequence[Union[int, Any]]]) ‑> jax.Array

Construct an array by repeating A the number of times given by reps.

LAX-backend implementation of :func:numpy.tile.

Original docstring below.

If reps has length d, the result will have dimension of max(d, A.ndim).

If A.ndim < d, A is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not the desired behavior, promote A to d-dimensions manually before calling this function.

If A.ndim > d, reps is promoted to A.ndim by pre-pending 1's to it. Thus for an A of shape (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2).

Note : Although tile may be used for broadcasting, it is strongly recommended to use numpy's broadcasting operations and functions.

Parameters

A : array_like
The input array.
reps : array_like
The number of repetitions of A along each axis.

Returns

c : ndarray
The tiled output array.
Expand source code
@util._wraps(np.tile)
def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array:
  util.check_arraylike("tile", A)
  try:
    iter(reps)  # type: ignore[arg-type]
  except TypeError:
    reps_tup: Tuple[DimSize, ...] = (reps,)
  else:
    reps_tup = tuple(reps)  # type: ignore[assignment,arg-type]
  reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
                   for rep in reps_tup)
  A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A)
  reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup
  result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
                        [k for pair in zip(reps_tup, A_shape) for k in pair])
  return reshape(result, tuple(np.multiply(A_shape, reps_tup)))
def transpose(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axes: Optional[Sequence[int]] = None) ‑> jax.Array

Returns an array with axes transposed.

LAX-backend implementation of :func:numpy.transpose.

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

For a 1-D array, this returns an unchanged view of the original array, as a transposed vector is simply the same vector. To convert a 1-D array into a 2-D column vector, an additional dimension must be added, e.g., np.atleast2d(a).T achieves this, as does a[:, np.newaxis]. For a 2-D array, this is the standard matrix transpose. For an n-D array, if axes are given, their order indicates how the axes are permuted (see Examples). If axes are not provided, then transpose(a).shape == a.shape[::-1].

Parameters

a : array_like
Input array.
axes : tuple or list of ints, optional
If specified, it must be a tuple or list which contains a permutation of [0,1,…,N-1] where N is the number of axes of a. The i'th axis of the returned array will correspond to the axis numbered axes[i] of the input. If not specified, defaults to range(a.ndim)[::-1], which reverses the order of the axes.

Returns

p : ndarray
a with its axes permuted. A view is returned whenever possible.
Expand source code
@util._wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = None) -> Array:
  util.check_arraylike("transpose", a)
  axes_ = list(range(ndim(a))[::-1]) if axes is None else axes
  axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_]
  return lax.transpose(a, axes_)

Methods

def all(self, boolean_tensor, axis=None, keepdims=False)
Expand source code
def all(self, boolean_tensor, axis=None, keepdims=False):
    if isinstance(boolean_tensor, (tuple, list)):
        boolean_tensor = jnp.stack(boolean_tensor)
    return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims)
def allocate_on_device(self, tensor: ~TensorType, device: phiml.backend._backend.ComputeDevice) ‑> ~TensorType

Moves tensor to device. May copy the tensor if it is already on the device.

Args

tensor
Existing tensor native to this backend.
device
Target device, associated with this backend.
Expand source code
def allocate_on_device(self, tensor: TensorType, device: ComputeDevice) -> TensorType:
    return jax.device_put(tensor, device.ref)
def any(self, boolean_tensor, axis=None, keepdims=False)
Expand source code
def any(self, boolean_tensor, axis=None, keepdims=False):
    if isinstance(boolean_tensor, (tuple, list)):
        boolean_tensor = jnp.stack(boolean_tensor)
    return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims)
def argsort(self, x, axis=-1)
Expand source code
def argsort(self, x, axis=-1):
    return jnp.argsort(x, axis)
def as_tensor(self, x, convert_external=True)

Converts a tensor-like object to the native tensor representation of this backend. If x is a native tensor of this backend, it is returned without modification. If x is a Python number (numbers.Number instance), convert_numbers decides whether to convert it unless the backend cannot handle Python numbers.

Note: There may be objects that are considered tensors by this backend but are not native and thus, will be converted by this method.

Args

x
tensor-like, e.g. list, tuple, Python number, tensor
convert_external
if False and x is a Python number that is understood by this backend, this method returns the number as-is. This can help prevent type clashes like int32 vs int64. (Default value = True)

Returns

tensor representation of x

Expand source code
def as_tensor(self, x, convert_external=True):
    self._check_float64()
    if self.is_tensor(x, only_native=convert_external):
        array = x
    else:
        array = jnp.array(x)
    # --- Enforce Precision ---
    if not isinstance(array, numbers.Number):
        if self.dtype(array).kind == float:
            array = self.to_float(array)
        elif self.dtype(array).kind == complex:
            array = self.to_complex(array)
    return array
def batched_gather_nd(self, values, indices)

Gathers values from the tensor values at locations indices. The first dimension of values and indices is the batch dimension which must be either equal for both or one for either.

Args

values
tensor of shape (batch, spatial…, channel)
indices
int tensor of shape (batch, any…, multi_index) where the size of multi_index is values.rank - 2.

Returns

Gathered values as tensor of shape (batch, any…, channel)

Expand source code
def batched_gather_nd(self, values, indices):
    values = self.as_tensor(values)
    indices = self.as_tensor(indices)
    # batch_size = combined_dim(values.shape[0], indices.shape[0])
    assert indices.shape[-1] == self.ndims(values) - 2
    def unbatched_gather_nd(b_values, b_indices):
        b_indices = self.unstack(b_indices, -1)
        return b_values[b_indices]
    return self.vectorized_call(unbatched_gather_nd, values, indices)
def bincount(self, x, weights: Optional[~TensorType], bins: int, x_sorted=False)

Args

x
Bin indices, 1D int tensor.
weights
Weights corresponding to x, 1D tensor. All weights are 1 if weights=None.
bins
Number of bins.
x_sorted
Whether x is sorted from lowest to highest bin.

Returns

bin_counts

Expand source code
def bincount(self, x, weights: Optional[TensorType], bins: int, x_sorted=False):
    if x_sorted:
        return jax.ops.segment_sum(weights or 1, x, bins, indices_are_sorted=True)
    else:
        return jnp.bincount(x, weights=weights, minlength=bins, length=bins)
def block_until_ready(self, values)
Expand source code
def block_until_ready(self, values):
    if hasattr(values, 'block_until_ready'):
        values.block_until_ready()
    if isinstance(values, (tuple, list)):
        for v in values:
            self.block_until_ready(v)
def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0)

Args

x
tensor with any number of dimensions
mask
1D mask tensor
axis
Axis index >= 0
new_length
Maximum size of the output along axis. This must be set when jit-compiling with Jax.
fill_value
If new_length is larger than the filtered result, the remaining values will be set to fill_value.
Expand source code
def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0):
    if new_length is None:
        slices = [mask if i == axis else slice(None) for i in range(len(x.shape))]
        return x[tuple(slices)]
    else:
        indices = jnp.argwhere(mask, size=new_length, fill_value=-1)[..., 0]
        valid = indices >= 0
        valid = valid[tuple([slice(None) if i == axis else None for i in range(len(x.shape))])]
        result = self.gather(x, jnp.maximum(0, indices), axis)
        return jnp.where(valid, result, fill_value)
def cast(self, x, dtype: phiml.backend._dtype.DType)
Expand source code
def cast(self, x, dtype: DType):
    if self.is_tensor(x, only_native=True) and from_numpy_dtype(x.dtype) == dtype:
        return x
    else:
        return jnp.array(x, to_numpy_dtype(dtype))
def conv(self, value, kernel, zero_padding=True)

Convolve value with kernel. Depending on the tensor rank, the convolution is either 1D (rank=3), 2D (rank=4) or 3D (rank=5). Higher dimensions may not be supported.

Args

value
tensor of shape (batch_size, in_channel, spatial…)
kernel
tensor of shape (batch_size or 1, out_channel, in_channel, spatial…)
zero_padding
If True, pads the edges of value with zeros so that the result has the same shape as value.

Returns

Convolution result as tensor of shape (batch_size, out_channel, spatial…)

Expand source code
def conv(self, value, kernel, zero_padding=True):
    assert kernel.shape[0] in (1, value.shape[0])
    assert value.shape[1] == kernel.shape[2], f"value has {value.shape[1]} channels but kernel has {kernel.shape[2]}"
    assert value.ndim + 1 == kernel.ndim
    # AutoDiff may require jax.lax.conv_general_dilated
    result = []
    for b in range(value.shape[0]):
        b_kernel = kernel[min(b, kernel.shape[0] - 1)]
        result_b = []
        for o in range(kernel.shape[1]):
            result_b.append(0)
            for i in range(value.shape[1]):
                # result.at[b, o, ...].set(scipy.signal.correlate(value[b, i, ...], b_kernel[o, i, ...], mode='same' if zero_padding else 'valid'))
                result_b[-1] += scipy.signal.correlate(value[b, i, ...], b_kernel[o, i, ...], mode='same' if zero_padding else 'valid')
        result.append(jnp.stack(result_b, 0))
    return jnp.stack(result, 0)
def copy(self, tensor, only_mutable=False)
Expand source code
def copy(self, tensor, only_mutable=False):
    return jnp.array(tensor, copy=True)
def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) ‑> Callable

Creates a function based on f that uses a custom gradient for backprop.

Args

f
Forward function.
gradient
Function for backprop. Will be called as gradient(*d_out) to compute the gradient of f.

Returns

Function with similar signature and return values as f. However, the returned function does not support keyword arguments.

Expand source code
def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) -> Callable:
    jax_fun = jax.custom_vjp(f)  # custom vector-Jacobian product (reverse-mode differentiation)

    def forward(*x):
        y = f(*x)
        return y, (x, y)

    def backward(x_y, dy):
        x, y = x_y
        dx = gradient(x, y, dy)
        return tuple(dx)

    jax_fun.defvjp(forward, backward)
    return jax_fun
def disassemble(self, x) ‑> Tuple[Callable, Sequence[~TensorType]]

Disassemble a (sparse) tensor into its individual constituents, such as values and indices.

Args

x
Tensor

Returns

assemble
Function assemble(backend, *constituents) that reassembles x from the constituents.
constituents
Tensors contained in x.
Expand source code
def disassemble(self, x) -> Tuple[Callable, Sequence[TensorType]]:
    if self.is_sparse(x):
        if isinstance(x, COO):
            raise NotImplementedError
            # return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (np.stack([x.row, x.col], -1), x.data)
        if isinstance(x, BCOO):
            return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (x.indices, x.data)
        if isinstance(x, CSR):
            raise NotImplementedError
            # return lambda b, v, i, p: b.csr_matrix(i, p, v, x.shape), (x.data, x.indices, x.indptr)
        elif isinstance(x, CSC):
            raise NotImplementedError
            # return lambda b, v, i, p: b.csc_matrix(p, i, v, x.shape), (x.data, x.indices, x.indptr)
        raise NotImplementedError
    else:
        return lambda b, t: t, (x,)
def divide_no_nan(self, x, y)

Computes x/y but returns 0 if y=0.

Expand source code
def divide_no_nan(self, x, y):
    return jnp.where(y == 0, 0, x / y)
    # jnp.nan_to_num(x / y, copy=True, nan=0) covers up NaNs from before
def dtype(self, array) ‑> phiml.backend._dtype.DType
Expand source code
def dtype(self, array) -> DType:
    if isinstance(array, bool):
        return DType(bool)
    if isinstance(array, int):
        return DType(int, 32)
    if isinstance(array, float):
        return DType(float, 64)
    if isinstance(array, complex):
        return DType(complex, 128)
    if not isinstance(array, jnp.ndarray):
        array = jnp.array(array)
    return from_numpy_dtype(array.dtype)
def expand_dims(self, a, axis=0, number=1)
Expand source code
def expand_dims(self, a, axis=0, number=1):
    for _i in range(number):
        a = jnp.expand_dims(a, axis)
    return a
def fft(self, x, axes: Union[tuple, list])

Computes the n-dimensional FFT along all but the first and last dimensions.

Args

x
tensor of dimension 3 or higher
axes
Along which axes to perform the FFT

Returns

Complex tensor k

Expand source code
def fft(self, x, axes: Union[tuple, list]):
    x = self.to_complex(x)
    if not axes:
        return x
    if len(axes) == 1:
        return jnp.fft.fft(x, axis=axes[0]).astype(x.dtype)
    elif len(axes) == 2:
        return jnp.fft.fft2(x, axes=axes).astype(x.dtype)
    else:
        return jnp.fft.fftn(x, axes=axes).astype(x.dtype)
def from_dlpack(self, capsule)
Expand source code
def from_dlpack(self, capsule):
    from jax import dlpack
    return dlpack.from_dlpack(capsule)
def gamma_inc_l(self, a, x)

Regularized lower incomplete gamma function.

Expand source code
def gamma_inc_l(self, a, x):
    return scipy.special.gammainc(a, x)
def gamma_inc_u(self, a, x)

Regularized upper incomplete gamma function.

Expand source code
def gamma_inc_u(self, a, x):
    return scipy.special.gammaincc(a, x)
def gather(self, values, indices, axis: int)

Gathers values from the tensor values at locations indices.

Args

values
tensor
indices
1D tensor
axis
Axis along which to gather slices

Returns

tensor, with size along axis being the length of indices

Expand source code
def gather(self, values, indices, axis: int):
    slices = [indices if i == axis else slice(None) for i in range(self.ndims(values))]
    return values[tuple(slices)]
def get_device(self, tensor: ~TensorType) ‑> phiml.backend._backend.ComputeDevice

Returns the device tensor is located on.

Expand source code
def get_device(self, tensor: TensorType) -> ComputeDevice:
    return self.get_device_by_ref(tensor.device())
def get_diagonal(self, matrices, offset=0)

Args

matrices
(batch, rows, cols, channels)
offset
0=diagonal, positive=above diagonal, negative=below diagonal

Returns

diagonal
(batch, max(rows,cols), channels)
Expand source code
def get_diagonal(self, matrices, offset=0):
    result = jnp.diagonal(matrices, offset=offset, axis1=1, axis2=2)
    return jnp.transpose(result, [0, 2, 1])
def get_sparse_format(self, x) ‑> str

Returns lower-case format string, such as 'coo', 'csr', 'csc'

Expand source code
def get_sparse_format(self, x) -> str:
    format_names = {
        COO: 'coo',
        BCOO: 'coo',
        CSR: 'csr',
        CSC: 'csc',
    }
    return format_names.get(type(x), 'dense')
def histogram1d(self, values, weights, bin_edges)

Args

values
(batch, values)
bin_edges
(batch, edges)
weights
(batch, values)

Returns

(batch, edges) with dtype matching weights

Expand source code
def histogram1d(self, values, weights, bin_edges):
    def unbatched_hist(values, weights, bin_edges):
        hist, _ = jnp.histogram(values, bin_edges, weights=weights)
        return hist
    return jax.vmap(unbatched_hist)(values, weights, bin_edges)
def ifft(self, k, axes: Union[tuple, list])

Computes the n-dimensional inverse FFT along all but the first and last dimensions.

Args

k
tensor of dimension 3 or higher
axes
Along which axes to perform the inverse FFT

Returns

Complex tensor x

Expand source code
def ifft(self, k, axes: Union[tuple, list]):
    if not axes:
        return k
    if len(axes) == 1:
        return jnp.fft.ifft(k, axis=axes[0]).astype(k.dtype)
    elif len(axes) == 2:
        return jnp.fft.ifft2(k, axes=axes).astype(k.dtype)
    else:
        return jnp.fft.ifftn(k, axes=axes).astype(k.dtype)
def is_available(self, tensor)

Tests if the value of the tensor is known and can be read at this point. If true, numpy(tensor) must return a valid NumPy representation of the value.

Tensors are typically available when the backend operates in eager mode.

Args

tensor
backend-compatible tensor

Returns

bool

Expand source code
def is_available(self, tensor):
    return not isinstance(tensor, Tracer)
def is_module(self, obj)

Tests if obj is of a type that is specific to this backend, e.g. a neural network. If True, this backend will be chosen for operations involving obj.

See Also: Backend.is_tensor().

Args

obj
Object to test.
Expand source code
def is_module(self, obj):
    return False
def is_sparse(self, x) ‑> bool

Args

x
Tensor native to this Backend.
Expand source code
def is_sparse(self, x) -> bool:
    return isinstance(x, (COO, BCOO, CSR, CSC))
def is_tensor(self, x, only_native=False)

An object is considered a native tensor by a backend if no internal conversion is required by backend methods. An object is considered a tensor (nativer or otherwise) by a backend if it is not a struct (e.g. tuple, list) and all methods of the backend accept it as a tensor argument.

If True, this backend will be chosen for operations involving x.

See Also: Backend.is_module().

Args

x
object to check
only_native
If True, only accepts true native tensor representations, not Python numbers or others that are also supported as tensors (Default value = False)

Returns

bool
whether x is considered a tensor by this backend
Expand source code
def is_tensor(self, x, only_native=False):
    if isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray):  # NumPy arrays inherit from Jax arrays
        return True
    if isinstance(x, jnp.bool_) and not isinstance(x, np.bool_):
        return True
    if self.is_sparse(x):
        return True
    # --- Above considered native ---
    if only_native:
        return False
    # --- Non-native types ---
    if isinstance(x, np.ndarray):
        return True
    if isinstance(x, np.bool_):
        return True
    if isinstance(x, (numbers.Number, bool)):
        return True
    if isinstance(x, (tuple, list)):
        return all([self.is_tensor(item, False) for item in x])
    return False
def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool)

Args

f
Function to differentiate. Returns a tuple containing (reduced_loss, output)
wrt
Argument indices for which to compute the gradient.
get_output
Whether the derivative function should return the output of f in addition to the gradient.
is_f_scalar
Whether f is guaranteed to return a scalar output.

Returns

A function g with the same arguments as f. If get_output=True, g returns a tuplecontaining the outputs of f followed by the gradients. The gradients retain the dimensions of reduced_loss in order as outer (first) dimensions.

Expand source code
def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool):
    if get_output:
        jax_grad_f = jax.value_and_grad(f, argnums=wrt, has_aux=True)
        @wraps(f)
        def unwrap_outputs(*args):
            args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
            (_, output_tuple), grads = jax_grad_f(*args)
            return (*output_tuple, *[jnp.conjugate(g) for g in grads])
        return unwrap_outputs
    else:
        @wraps(f)
        def nonaux_f(*args):
            loss, output = f(*args)
            return loss
        jax_grad = jax.grad(nonaux_f, argnums=wrt, has_aux=False)
        @wraps(f)
        def call_jax_grad(*args):
            args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
            grads = jax_grad(*args)
            return tuple([jnp.conjugate(g) for g in grads])
        return call_jax_grad
def jit_compile(self, f: Callable) ‑> Callable
Expand source code
def jit_compile(self, f: Callable) -> Callable:
    def run_jit_f(*args):
        # print(jax.make_jaxpr(f)(*args))
        ML_LOGGER.debug(f"JaxBackend: running jit-compiled '{f.__name__}' with shapes {[self.shape(arg) for arg in args]} and dtypes {[self.dtype(arg) for arg in args]}")
        return self.as_registered.call(jit_f, *args, name=f"run jit-compiled '{f.__name__}'")

    run_jit_f.__name__ = f"Jax-Jit({f.__name__})"
    jit_f = jax.jit(f, device=self._default_device.ref)
    return run_jit_f
def linspace(self, start, stop, number)
Expand source code
def linspace(self, start, stop, number):
    self._check_float64()
    return jax.device_put(jnp.linspace(start, stop, number, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)
def linspace_without_last(self, start, stop, number)
Expand source code
def linspace_without_last(self, start, stop, number):
    self._check_float64()
    return jax.device_put(jnp.linspace(start, stop, number, endpoint=False, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)
def log_gamma(self, x)
Expand source code
def log_gamma(self, x):
    return jax.lax.lgamma(self.to_float(x))
def matrix_solve_least_squares(self, matrix: ~TensorType, rhs: ~TensorType) ‑> Tuple[~TensorType, ~TensorType, ~TensorType, ~TensorType]

Args

matrix
Shape (batch, vec, constraints)
rhs
Shape (batch, vec, batch_per_matrix)

Returns

solution
Solution vector of Shape (batch, constraints, batch_per_matrix)
residuals
Optional, can be None
rank
Optional, can be None
singular_values
Optional, can be None
Expand source code
def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tuple[TensorType, TensorType, TensorType, TensorType]:
    solution, residuals, rank, singular_values = lstsq_batched(matrix, rhs)
    return solution, residuals, rank, singular_values
def max(self, x, axis=None, keepdims=False)
Expand source code
def max(self, x, axis=None, keepdims=False):
    return jnp.max(x, axis, keepdims=keepdims)
def mean(self, value, axis=None, keepdims=False)
Expand source code
def mean(self, value, axis=None, keepdims=False):
    return jnp.mean(value, axis, keepdims=keepdims)
def meshgrid(self, *coordinates)
Expand source code
def meshgrid(self, *coordinates):
    self._check_float64()
    coordinates = [self.as_tensor(c) for c in coordinates]
    return [jax.device_put(c, self._default_device.ref) for c in jnp.meshgrid(*coordinates, indexing='ij')]
def min(self, x, axis=None, keepdims=False)
Expand source code
def min(self, x, axis=None, keepdims=False):
    return jnp.min(x, axis, keepdims=keepdims)
def mul(self, a, b)
Expand source code
def mul(self, a, b):
    # if scipy.sparse.issparse(a):  # TODO sparse?
    #     return a.multiply(b)
    # elif scipy.sparse.issparse(b):
    #     return b.multiply(a)
    # else:
        return Backend.mul(self, a, b)
def mul_matrix_batched_vector(self, A, b)
Expand source code
def mul_matrix_batched_vector(self, A, b):
    from jax.experimental.sparse import BCOO
    if isinstance(A, BCOO):
        return(A @ b.T).T
    return jnp.stack([A.dot(b[i]) for i in range(b.shape[0])])
def nn_library(self)
Expand source code
def nn_library(self):
    from . import stax_nets
    return stax_nets
def nonzero(self, values, length=None, fill_value=-1)

Args

values
Tensor with only spatial dimensions
length
(Optional) Length of the resulting array. If specified, the result array will be padded with fill_value or trimmed.

Returns

non-zero multi-indices as tensor of shape (nnz/length, vector)

Expand source code
def nonzero(self, values, length=None, fill_value=-1):
    result = jnp.nonzero(values, size=length, fill_value=fill_value)
    return jnp.stack(result, -1)
def numpy(self, tensor)

Returns a NumPy representation of the given tensor. If tensor is already a NumPy array, it is returned without modification.

This method raises an error if the value of the tensor is not known at this point, e.g. because it represents a node in a graph. Use is_available(tensor) to check if the value can be represented as a NumPy array.

Args

tensor
backend-compatible tensor or sparse tensor

Returns

NumPy representation of the values stored in the tensor

Expand source code
def numpy(self, tensor):
    if self.is_sparse(tensor):
        assemble, parts = self.disassemble(tensor)
        return assemble(NUMPY, *[self.numpy(t) for t in parts])
    return np.array(tensor)
def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args)

This call can be used in jit-compiled code but is not differentiable.

Args

f
Function operating on numpy arrays.
output_shapes
Single shape tuple or tuple of shapes declaring the shapes of the tensors returned by f.
output_dtypes
Single DType or tuple of DTypes declaring the dtypes of the tensors returned by f.
*args
Tensor arguments to be converted to NumPy arrays and then passed to f.
**aux_args
Keyword arguments to be passed to f without conversion.

Returns

Returned arrays of f converted to tensors.

Expand source code
def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args):
    @dataclasses.dataclass
    class OutputTensor:
        shape: Tuple[int]
        dtype: np.dtype
    output_specs = map_structure(lambda t, s: OutputTensor(s, to_numpy_dtype(t)), output_dtypes, output_shapes)
    if hasattr(jax, 'pure_callback'):
        def aux_f(*args):
            return f(*args, **aux_args)
        return jax.pure_callback(aux_f, output_specs, *args)
    else:
        def aux_f(args):
            if isinstance(args, tuple):
                return f(*args, **aux_args)
            else:
                return f(args, **aux_args)
        from jax.experimental.host_callback import call
        return call(aux_f, args, result_shape=output_specs)
def ones(self, shape, dtype: phiml.backend._dtype.DType = None)
Expand source code
def ones(self, shape, dtype: DType = None):
    self._check_float64()
    return jax.device_put(jnp.ones(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)
def ones_like(self, tensor)
Expand source code
def ones_like(self, tensor):
    return jax.device_put(jnp.ones_like(tensor), self._default_device.ref)
def pad(self, value, pad_width, mode='constant', constant_values=0)

Pad a tensor with values as specified by mode and constant_values.

If the mode is not supported, returns NotImplemented.

Args

value
tensor
pad_width
2D tensor specifying the number of values padded to the edges of each axis in the form [[axis 0 lower, axis 0 upper], …] including batch and component axes.
mode
constant', 'boundary', 'periodic', 'symmetric', 'reflect'
constant_values
Scalar value used for out-of-bounds points if mode='constant'. Must be a Python primitive type or scalar tensor.
mode
str: (Default value = 'constant')

Returns

padded tensor or NotImplemented

Expand source code
def pad(self, value, pad_width, mode='constant', constant_values=0):
    assert mode in ('constant', 'symmetric', 'periodic', 'reflect', 'boundary'), mode
    if mode == 'constant':
        constant_values = jnp.array(constant_values, dtype=value.dtype)
        return jnp.pad(value, pad_width, 'constant', constant_values=constant_values)
    else:
        if mode in ('periodic', 'boundary'):
            mode = {'periodic': 'wrap', 'boundary': 'edge'}[mode]
        return jnp.pad(value, pad_width, mode)
def prefers_channels_last(self) ‑> bool
Expand source code
def prefers_channels_last(self) -> bool:
    return True
def prod(self, value, axis=None)
Expand source code
def prod(self, value, axis=None):
    if not isinstance(value, jnp.ndarray):
        value = jnp.array(value)
    if value.dtype == bool:
        return jnp.all(value, axis=axis)
    return jnp.prod(value, axis=axis)
def quantile(self, x, quantiles)

Reduces the last / inner axis of x.

Args

x
Tensor
quantiles
List or 1D tensor of quantiles to compute.

Returns

Tensor with shape (quantiles, *x.shape[:-1])

Expand source code
def quantile(self, x, quantiles):
    return jnp.quantile(x, quantiles, axis=-1)
def random_normal(self, shape, dtype: phiml.backend._dtype.DType)

Float tensor of selected precision containing random values sampled from a normal distribution with mean 0 and std 1.

Expand source code
def random_normal(self, shape, dtype: DType):
    self._check_float64()
    self.rnd_key, subkey = jax.random.split(self.rnd_key)
    dtype = dtype or self.float_type
    return jax.device_put(random.normal(subkey, shape, dtype=to_numpy_dtype(dtype)), self._default_device.ref)
def random_uniform(self, shape, low, high, dtype: Optional[phiml.backend._dtype.DType])

Float tensor of selected precision containing random values in the range [0, 1)

Expand source code
def random_uniform(self, shape, low, high, dtype: Union[DType, None]):
    self._check_float64()
    self.rnd_key, subkey = jax.random.split(self.rnd_key)

    dtype = dtype or self.float_type
    jdt = to_numpy_dtype(dtype)
    if dtype.kind == float:
        tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt)
    elif dtype.kind == complex:
        real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType(float, dtype.precision)))
        imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType(float, dtype.precision)))
        return real + 1j * imag
    elif dtype.kind == int:
        tensor = random.randint(subkey, shape, low, high, dtype=jdt)
        if tensor.dtype != jdt:
            warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning)
    else:
        raise ValueError(dtype)
    return jax.device_put(tensor, self._default_device.ref)
def range(self, start, limit=None, delta=1, dtype: phiml.backend._dtype.DType = int32)
Expand source code
def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)):
    if limit is None:
        start, limit = 0, start
    return jnp.arange(start, limit, delta, to_numpy_dtype(dtype))
def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined')

Args

multi_index
(batch…, index_dim)
shape
1D tensor or tuple/list
mode
'undefined', 'periodic', 'clamp' or an int to use for all invalid indices.

Returns

Integer tensor of shape (batch…) of same dtype as multi_index.

Expand source code
def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined'):
    if not self.is_available(shape):
        return Backend.ravel_multi_index(self, multi_index, shape, mode)
    mode = mode if isinstance(mode, int) else {'undefined': 'clip', 'periodic': 'wrap', 'clamp': 'clip'}[mode]
    idx_first = jnp.transpose(multi_index, (self.ndims(multi_index)-1,) + tuple(range(self.ndims(multi_index)-1)))
    result = jnp.ravel_multi_index(idx_first, shape, mode='wrap' if isinstance(mode, int) else mode)
    if isinstance(mode, int):
        outside = self.any((multi_index < 0) | (multi_index >= jnp.asarray(shape, dtype=multi_index.dtype)), -1)
        result = self.where(outside, mode, result)
    return result
def repeat(self, x, repeats, axis: int, new_length=None)

Repeats the elements along axis repeats times.

Args

x
Tensor
repeats
How often to repeat each element. 1D tensor of length x.shape[axis]
axis
Which axis to repeat elements along
new_length
Set the length of axis after repeating. This is required for jit compilation with Jax.

Returns

repeated Tensor

Expand source code
def repeat(self, x, repeats, axis: int, new_length=None):
    return jnp.repeat(x, self.as_tensor(repeats), axis, total_repeat_length=new_length)
def requires_fixed_shapes_when_tracing(self) ‑> bool
Expand source code
def requires_fixed_shapes_when_tracing(self) -> bool:
    return True
def reshape(self, value, shape)
Expand source code
def reshape(self, value, shape):
    return jnp.reshape(value, shape)
def scatter(self, base_grid, indices, values, mode: str)

Batched n-dimensional scatter.

Args

base_grid
Tensor into which scatter values are inserted at indices. Tensor of shape (batch_size, spatial…, channels)
indices
Tensor of shape (batch_size or 1, update_count, index_vector)
values
Values to scatter at indices. Tensor of shape (batch_size or 1, update_count or 1, channels or 1)
mode
One of ('update', 'add', 'max', 'min')

Returns

Copy of base_grid with values at indices updated by values.

Expand source code
def scatter(self, base_grid, indices, values, mode: str):
    assert mode in ('add', 'update', 'max', 'min')
    base_grid, values = self.auto_cast(base_grid, values)
    batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0])
    spatial_dims = tuple(range(base_grid.ndim - 2))
    dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,),  # channel dim of updates (batch dim removed)
                                            inserted_window_dims=spatial_dims,  # no idea what this does but spatial_dims seems to work
                                            scatter_dims_to_operand_dims=spatial_dims)  # spatial dims of base_grid (batch dim removed)
    scatter = {'add': jax.lax.scatter_add, 'update': jax.lax.scatter, 'max': jax.lax.scatter_max, 'min': jax.lax.scatter_min}[mode]
    def scatter_single(base_grid, indices, values):
        return scatter(base_grid, indices, values, dnums)
    if self.staticshape(indices)[0] == 1:
        indices = self.tile(indices, [batch_size, 1, 1])
    if self.staticshape(values)[0] == 1:
        values = self.tile(values, [batch_size, 1, 1])
    return self.vectorized_call(scatter_single, base_grid, indices, values)
def searchsorted(self, sorted_sequence, search_values, side: str, dtype=int32)
Expand source code
def searchsorted(self, sorted_sequence, search_values, side: str, dtype=DType(int, 32)):
    if self.ndims(sorted_sequence) == 1:
        return jnp.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype))
    else:
        return jax.vmap(partial(self.searchsorted, side=side, dtype=dtype))(sorted_sequence, search_values)
def seed(self, seed: int)
Expand source code
def seed(self, seed: int):
    self.rnd_key = jax.random.PRNGKey(seed)
def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool)

Args

matrix
(batch_size, rows, cols)
rhs
(batch_size, cols)

lower: unit_diagonal:

Returns

(batch_size, cols)

Expand source code
def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool):
    matrix, rhs = self.auto_cast(matrix, rhs, int_to_float=True, bool_to_int=True)
    x = jax.lax.linalg.triangular_solve(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal, left_side=True)
    return x
def sort(self, x, axis=-1)
Expand source code
def sort(self, x, axis=-1):
    return jnp.sort(x, axis)
def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple)

Create a sparse matrix in coordinate list (COO) format.

Optional feature.

See Also: Backend.csr_matrix(), Backend.csc_matrix().

Args

indices
2D tensor of shape (nnz, dims).
values
1D values tensor matching indices
shape
Shape of the sparse matrix

Returns

Native representation of the sparse matrix

Expand source code
def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple):
    return BCOO((values, indices), shape=shape)
def std(self, x, axis=None, keepdims=False)
Expand source code
def std(self, x, axis=None, keepdims=False):
    return jnp.std(x, axis, keepdims=keepdims)
def sum(self, value, axis=None, keepdims=False)
Expand source code
def sum(self, value, axis=None, keepdims=False):
    if isinstance(value, (tuple, list)):
        assert axis == 0
        return sum(value[1:], value[0])
    return jnp.sum(value, axis=axis, keepdims=keepdims)
def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list])

Multiply-sum-reduce a_axes of a with b_axes of b.

Expand source code
def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list]):
    return jnp.tensordot(a, b, (a_axes, b_axes))
def to_dlpack(self, tensor)
Expand source code
def to_dlpack(self, tensor):
    from jax import dlpack
    return dlpack.to_dlpack(tensor)
def unique(self, x: ~TensorType, return_inverse: bool, return_counts: bool, axis: int) ‑> Tuple[~TensorType, ...]

Args

x
n-dimensional int array. Will compare axis-slices of x for multidimensional x.
return_inverse
Whether to return the inverse
return_counts
Whether to return the counts.
axis
Axis along which slices of x should be compared.

Returns

unique_slices
Sorted unique slices of x
unique_inverse
(optional) index of the unique slice for each slice of x
unique_counts
Number of occurrences of each unique slices
Expand source code
def unique(self, x: TensorType, return_inverse: bool, return_counts: bool, axis: int) -> Tuple[TensorType, ...]:
    return jnp.unique(x, return_inverse=return_inverse, return_counts=return_counts, axis=axis)
def unravel_index(self, flat_index, shape)
Expand source code
def unravel_index(self, flat_index, shape):
    return jnp.stack(jnp.unravel_index(flat_index, shape), -1)
def vectorized_call(self, f, *args, output_dtypes=None, **aux_args)

Args

f
Function with only positional tensor argument, returning one or multiple tensors.
*args
Batched inputs for f. The first dimension of all args is vectorized. All tensors in args must have the same size or 1 in their first dimension.
output_dtypes
Single DType or tuple of DTypes declaring the dtypes of the tensors returned by f.
**aux_args
Non-vectorized keyword arguments to be passed to f.
Expand source code
def vectorized_call(self, f, *args, output_dtypes=None, **aux_args):
    batch_size = self.determine_size(args, 0)
    args = [self.tile_to(t, 0, batch_size) for t in args]
    def f_positional(*args):
        return f(*args, **aux_args)
    vec_f = jax.vmap(f_positional, 0, 0)
    return vec_f(*args)
def where(self, condition, x=None, y=None)
Expand source code
def where(self, condition, x=None, y=None):
    if x is None or y is None:
        return jnp.argwhere(condition)
    return jnp.where(condition, x, y)
def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]])

If max_iter is None, runs

while any(values[0]):
    values = loop(*values)
return values

This operation does not support backpropagation.

Args

loop
Loop function, must return a tuple with entries equal to values in shape and data type.
values
Initial values of loop variables.
max_iter
Maximum number of iterations to run, single int or sequence of integers.

Returns

Loop variables upon loop completion if max_iter is a single integer. If max_iter is a sequence, stacks the variables after each entry in max_iter, adding an outer dimension of size <= len(max_iter). If the condition is fulfilled before the maximum max_iter is reached, the loop may be broken or not, depending on the implementation. If the loop is broken, the values returned by the last loop are expected to be constant and filled.

Expand source code
def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]]):
    if all(self.is_available(t) for t in values):
        return self.stop_gradient_tree(Backend.while_loop(self, loop, values, max_iter))
    if isinstance(max_iter, (tuple, list)):  # stack traced trajectory, unroll until max_iter
        values = self.stop_gradient_tree(values)
        trj = [values] if 0 in max_iter else []
        for i in range(1, max(max_iter) + 1):
            values = loop(*values)
            if i in max_iter:
                trj.append(values)  # values are not mutable so no need to copy
        return self.stop_gradient_tree(self.stack_leaves(trj))
    else:
        if max_iter is None:
            cond = lambda vals: jnp.any(vals[0])
            body = lambda vals: loop(*vals)
            return jax.lax.while_loop(cond, body, values)
        else:
            cond = lambda vals: jnp.any(vals[1][0]) & (vals[0] < max_iter)
            body = lambda vals: (vals[0] + 1, loop(*vals[1]))
            return jax.lax.while_loop(cond, body, (self.as_tensor(0), values))[1]
def zeros(self, shape, dtype: phiml.backend._dtype.DType = None)
Expand source code
def zeros(self, shape, dtype: DType = None):
    self._check_float64()
    return jax.device_put(jnp.zeros(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)
def zeros_like(self, tensor)
Expand source code
def zeros_like(self, tensor):
    return jax.device_put(jnp.zeros_like(tensor), self._default_device.ref)