Module phiml.backend.jax

Sub-modules

phiml.backend.jax.stax_nets

Stax implementation of the unified machine learning API. Equivalent functions also exist for the other frameworks …

Classes

class JaxBackend
Expand source code
class JaxBackend(Backend):

    def __init__(self):
        devices = []
        for device_type in ['cpu', 'gpu', 'tpu']:
            try:
                for jax_dev in jax.devices(device_type):
                    devices.append(ComputeDevice(self, device_type.upper(), jax_dev.platform.upper(), -1, -1, f"id={jax_dev.id}", jax_dev))
            except RuntimeError as err:
                pass  # this is just Jax not finding anything. jaxlib.xla_client._get_local_backends() could help but isn't currently available on GitHub actions
        Backend.__init__(self, 'jax', devices, devices[-1])
        try:
            self.rnd_key = jax.random.PRNGKey(seed=0)
        except RuntimeError as err:
            warnings.warn(f"{err}", RuntimeWarning)
            self.rnd_key = None

    def prefers_channels_last(self) -> bool:
        return True

    def requires_fixed_shapes_when_tracing(self) -> bool:
        return True

    def nn_library(self):
        from . import stax_nets
        return stax_nets

    def _check_float64(self):
        if self.precision == 64:
            if not jax.config.read('jax_enable_x64'):
                jax.config.update('jax_enable_x64', True)
            assert jax.config.read('jax_enable_x64'), "FP64 is disabled for Jax."

    def seed(self, seed: int):
        self.rnd_key = jax.random.PRNGKey(seed)

    def as_tensor(self, x, convert_external=True):
        self._check_float64()
        if self.is_tensor(x, only_native=convert_external):
            array = x
        else:
            array = jnp.array(x)
        # --- Enforce Precision ---
        if not isinstance(array, numbers.Number):
            if self.dtype(array).kind == float:
                array = self.to_float(array)
            elif self.dtype(array).kind == complex:
                array = self.to_complex(array)
        return array

    def is_module(self, obj):
        return False

    def is_tensor(self, x, only_native=False):
        if isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray):  # NumPy arrays inherit from Jax arrays
            return True
        if isinstance(x, jnp.bool_) and not isinstance(x, np.bool_):
            return True
        if self.is_sparse(x):
            return True
        # --- Above considered native ---
        if only_native:
            return False
        # --- Non-native types ---
        if isinstance(x, np.ndarray):
            return True
        if isinstance(x, np.bool_):
            return True
        if isinstance(x, (numbers.Number, bool)):
            return True
        if isinstance(x, (tuple, list)):
            return all([self.is_tensor(item, False) for item in x])
        return False

    def sizeof(self, tensor) -> int:
        return tensor.nbytes

    def is_sparse(self, x) -> bool:
        return isinstance(x, (COO, BCOO, CSR, CSC))

    def get_sparse_format(self, x) -> str:
        format_names = {
            COO: 'coo',
            BCOO: 'coo',
            CSR: 'csr',
            CSC: 'csc',
        }
        return format_names.get(type(x), 'dense')

    def is_available(self, tensor):
        if isinstance(tensor, JVPTracer):
            tensor = tensor.primal
        return not isinstance(tensor, Tracer)

    def numpy(self, tensor):
        if isinstance(tensor, JVPTracer):
            tensor = tensor.primal
        if self.is_sparse(tensor):
            assemble, parts = self.disassemble(tensor)
            return assemble(NUMPY, *[self.numpy(t) for t in parts])
        return np.array(tensor)

    def disassemble(self, x) -> Tuple[Callable, Sequence[TensorType]]:
        if self.is_sparse(x):
            if isinstance(x, COO):
                raise NotImplementedError
                # return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (np.stack([x.row, x.col], -1), x.data)
            if isinstance(x, BCOO):
                return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (x.indices, x.data)
            if isinstance(x, CSR):
                raise NotImplementedError
                # return lambda b, v, i, p: b.csr_matrix(i, p, v, x.shape), (x.data, x.indices, x.indptr)
            elif isinstance(x, CSC):
                raise NotImplementedError
                # return lambda b, v, i, p: b.csc_matrix(p, i, v, x.shape), (x.data, x.indices, x.indptr)
            raise NotImplementedError
        else:
            return lambda b, t: t, (x,)

    def to_dlpack(self, tensor):
        if version.parse(jax.__version__) < version.parse("0.7.0"):
            from jax import dlpack
            return dlpack.to_dlpack(tensor)
        else:
            return tensor.__dlpack__()

    def from_dlpack(self, external_array):
        from jax import dlpack
        return dlpack.from_dlpack(external_array)

    if version.parse(jax.__version__) >= version.parse("0.7.2"):
        def from_external_array_dlpack(self, external_array):
            from jax import dlpack
            return dlpack.from_dlpack(external_array)

    def copy(self, tensor, only_mutable=False):
        return jnp.array(tensor, copy=True)

    def get_device(self, tensor: TensorType) -> ComputeDevice:
        if hasattr(tensor, 'devices'):
            return self.get_device_by_ref(next(iter(tensor.devices())))
        if hasattr(tensor, 'device'):
            return self.get_device_by_ref(tensor.device())
        raise AssertionError(f"tensor {type(tensor)} has no device attribute")

    def allocate_on_device(self, tensor: TensorType, device: ComputeDevice) -> TensorType:
        return jax.device_put(tensor, device.ref)

    sqrt = staticmethod(jnp.sqrt)
    exp = staticmethod(jnp.exp)
    erf = staticmethod(scipy.special.erf)
    softplus = staticmethod(jax.nn.softplus)
    sin = staticmethod(jnp.sin)
    arcsin = staticmethod(jnp.arcsin)
    cos = staticmethod(jnp.cos)
    arccos = staticmethod(jnp.arccos)
    tan = staticmethod(jnp.tan)
    arctan = staticmethod(np.arctan)
    arctan2 = staticmethod(np.arctan2)
    sinh = staticmethod(np.sinh)
    arcsinh = staticmethod(np.arcsinh)
    cosh = staticmethod(np.cosh)
    arccosh = staticmethod(np.arccosh)
    tanh = staticmethod(np.tanh)
    arctanh = staticmethod(np.arctanh)
    log = staticmethod(jnp.log)
    log2 = staticmethod(jnp.log2)
    log10 = staticmethod(jnp.log10)
    isfinite = staticmethod(jnp.isfinite)
    isnan = staticmethod(jnp.isnan)
    isinf = staticmethod(jnp.isinf)
    abs = staticmethod(jnp.abs)
    sign = staticmethod(jnp.sign)
    round = staticmethod(jnp.round)
    ceil = staticmethod(jnp.ceil)
    floor = staticmethod(jnp.floor)
    flip = staticmethod(jnp.flip)
    stop_gradient = staticmethod(jax.lax.stop_gradient)
    transpose = staticmethod(jnp.transpose)
    equal = staticmethod(jnp.equal)
    tile = staticmethod(jnp.tile)
    stack = staticmethod(jnp.stack)
    concat = staticmethod(jnp.concatenate)
    maximum = staticmethod(jnp.maximum)
    minimum = staticmethod(jnp.minimum)
    clip = staticmethod(jnp.clip)
    argmax = staticmethod(np.argmax)
    argmin = staticmethod(np.argmin)
    shape = staticmethod(jnp.shape)
    staticshape = staticmethod(jnp.shape)
    imag = staticmethod(jnp.imag)
    real = staticmethod(jnp.real)
    conj = staticmethod(jnp.conjugate)
    einsum = staticmethod(jnp.einsum)
    cumsum = staticmethod(jnp.cumsum)

    def nonzero(self, values, length=None, fill_value=-1):
        result = jnp.nonzero(values, size=length, fill_value=fill_value)
        return jnp.stack(result, -1)

    def vectorized_call(self, f, *args, output_dtypes=None, **aux_args):
        batch_size = self.determine_size(args, 0)
        args = [self.tile_to(t, 0, batch_size) for t in args]
        def f_positional(*args):
            return f(*args, **aux_args)
        vec_f = jax.vmap(f_positional, 0, 0)
        return vec_f(*args)

    def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args):
        if all([self.is_available(arg) for arg in args]):
            args = [self.numpy(arg) for arg in args]
            output = f(*args, **aux_args)
            result = map_structure(self.as_tensor, output)
            return result
        @dataclasses.dataclass
        class OutputTensor:
            shape: Tuple[int]
            dtype: np.dtype
        output_specs = map_structure(lambda t, s: OutputTensor(s, to_numpy_dtype(t)), output_dtypes, output_shapes)
        if hasattr(jax, 'pure_callback'):
            def aux_f(*args):
                return f(*args, **aux_args)
            return jax.pure_callback(aux_f, output_specs, *args)
        else:
            def aux_f(args):
                if isinstance(args, tuple):
                    return f(*args, **aux_args)
                else:
                    return f(args, **aux_args)
            from jax.experimental.host_callback import call
            return call(aux_f, args, result_shape=output_specs)

    def jit_compile(self, f: Callable) -> Callable:
        def run_jit_f(*args):
            # print(jax.make_jaxpr(f)(*args))
            ML_LOGGER.debug(f"JaxBackend: running jit-compiled '{f.__name__}' with shapes {[self.shape(arg) for arg in args]} and dtypes {[self.dtype(arg) for arg in args]}")
            return self.as_registered.call(jit_f, *args, name=f"run jit-compiled '{f.__name__}'")

        run_jit_f.__name__ = f"Jax-Jit({f.__name__})"
        jit_f = jax.jit(f, device=self._default_device.ref)
        return run_jit_f

    def block_until_ready(self, values):
        if hasattr(values, 'block_until_ready'):
            values.block_until_ready()
        if isinstance(values, (tuple, list)):
            for v in values:
                self.block_until_ready(v)

    def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool):
        if get_output:
            jax_grad_f = jax.value_and_grad(f, argnums=wrt, has_aux=True)
            @wraps(f)
            def unwrap_outputs(*args):
                args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
                (_, output_tuple), grads = jax_grad_f(*args)
                return (*output_tuple, *[jnp.conjugate(g) for g in grads])
            return unwrap_outputs
        else:
            @wraps(f)
            def nonaux_f(*args):
                loss, output = f(*args)
                return loss
            jax_grad = jax.grad(nonaux_f, argnums=wrt, has_aux=False)
            @wraps(f)
            def call_jax_grad(*args):
                args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
                grads = jax_grad(*args)
                return tuple([jnp.conjugate(g) for g in grads])
            return call_jax_grad

    def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) -> Callable:
        jax_fun = jax.custom_vjp(f)  # custom vector-Jacobian product (reverse-mode differentiation)

        def forward(*x):
            y = f(*x)
            return y, (x, y)

        def backward(x_y, dy):
            x, y = x_y
            dx = gradient(x, y, dy)
            return tuple(dx)

        jax_fun.defvjp(forward, backward)
        return jax_fun

    def divide_no_nan(self, x, y):
        return jnp.where(y == 0, 0, x / y)
        # jnp.nan_to_num(x / y, copy=True, nan=0) covers up NaNs from before

    def random_uniform(self, shape, low, high, dtype: Union[DType, None]):
        self._check_float64()
        self.rnd_key, subkey = jax.random.split(self.rnd_key)

        dtype = dtype or self.float_type
        jdt = to_numpy_dtype(dtype)
        if dtype.kind == float:
            tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt)
        elif dtype.kind == complex:
            real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType.by_precision(float, dtype.precision)))
            imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType.by_precision(float, dtype.precision)))
            return real + 1j * imag
        elif dtype.kind == int:
            tensor = random.randint(subkey, shape, low, high, dtype=jdt)
            if tensor.dtype != jdt:
                warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning)
        else:
            raise ValueError(dtype)
        return jax.device_put(tensor, self._default_device.ref)

    def random_normal(self, shape, dtype: DType):
        self._check_float64()
        self.rnd_key, subkey = jax.random.split(self.rnd_key)
        dtype = dtype or self.float_type
        return jax.device_put(random.normal(subkey, shape, dtype=to_numpy_dtype(dtype)), self._default_device.ref)

    def random_permutations(self, permutations: int, n: int):
        self.rnd_key, subkey = jax.random.split(self.rnd_key)
        result = jnp.stack([jax.random.permutation(subkey, n) for _ in range(permutations)])
        return jax.device_put(result, self._default_device.ref)

    def range(self, start, limit=None, delta=1, dtype: DType = INT32):
        if limit is None:
            start, limit = 0, start
        return jnp.arange(start, limit, delta, to_numpy_dtype(dtype))

    def pad(self, value, pad_width, mode='constant', constant_values=0):
        assert mode in ('constant', 'symmetric', 'periodic', 'reflect', 'boundary'), mode
        if mode == 'constant':
            constant_values = jnp.array(constant_values, dtype=value.dtype)
            return jnp.pad(value, pad_width, 'constant', constant_values=constant_values)
        else:
            if mode in ('periodic', 'boundary'):
                mode = {'periodic': 'wrap', 'boundary': 'edge'}[mode]
            return jnp.pad(value, pad_width, mode)

    def reshape(self, value, shape):
        return jnp.reshape(value, shape)

    def sum(self, value, axis=None, keepdims=False):
        if isinstance(value, (tuple, list)):
            assert axis == 0
            return sum(value[1:], value[0])
        return jnp.sum(value, axis=axis, keepdims=keepdims)

    def prod(self, value, axis=None):
        if not isinstance(value, jnp.ndarray):
            value = jnp.array(value)
        if value.dtype == bool:
            return jnp.all(value, axis=axis)
        return jnp.prod(value, axis=axis)

    def where(self, condition, x=None, y=None):
        if x is None or y is None:
            return jnp.argwhere(condition)
        return jnp.where(condition, x, y)

    def zeros(self, shape, dtype: DType = None):
        self._check_float64()
        return jax.device_put(jnp.zeros(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)

    def zeros_like(self, tensor):
        return jax.device_put(jnp.zeros_like(tensor), self._default_device.ref)

    def ones(self, shape, dtype: DType = None):
        self._check_float64()
        return jax.device_put(jnp.ones(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)

    def ones_like(self, tensor):
        return jax.device_put(jnp.ones_like(tensor), self._default_device.ref)

    def meshgrid(self, *coordinates):
        self._check_float64()
        coordinates = [self.as_tensor(c) for c in coordinates]
        return [jax.device_put(c, self._default_device.ref) for c in jnp.meshgrid(*coordinates, indexing='ij')]

    def linspace(self, start, stop, number):
        self._check_float64()
        return jax.device_put(jnp.linspace(start, stop, number, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)

    def linspace_without_last(self, start, stop, number):
        self._check_float64()
        return jax.device_put(jnp.linspace(start, stop, number, endpoint=False, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)

    def mean(self, value, axis=None, keepdims=False):
        return jnp.mean(value, axis, keepdims=keepdims)

    def log_gamma(self, x):
        return jax.lax.lgamma(self.to_float(x))

    def gamma_inc_l(self, a, x):
        return scipy.special.gammainc(a, x)

    def gamma_inc_u(self, a, x):
        return scipy.special.gammaincc(a, x)

    def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list]):
        return jnp.tensordot(a, b, (a_axes, b_axes))

    def mul(self, a, b):
        # if scipy.sparse.issparse(a):  # TODO sparse?
        #     return a.multiply(b)
        # elif scipy.sparse.issparse(b):
        #     return b.multiply(a)
        # else:
            return Backend.mul(self, a, b)

    def mul_matrix_batched_vector(self, A, b):
        from jax.experimental.sparse import BCOO
        if isinstance(A, BCOO):
            return(A @ b.T).T
        return jnp.stack([A.dot(b[i]) for i in range(b.shape[0])])

    def get_diagonal(self, matrices, offset=0):
        result = jnp.diagonal(matrices, offset=offset, axis1=1, axis2=2)
        return jnp.transpose(result, [0, 2, 1])

    def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]]):
        if all(self.is_available(t) for t in values):
            return self.stop_gradient_tree(Backend.while_loop(self, loop, values, max_iter))
        if isinstance(max_iter, (tuple, list)):  # stack traced trajectory, unroll until max_iter
            values = self.stop_gradient_tree(values)
            trj = [values] if 0 in max_iter else []
            for i in range(1, max(max_iter) + 1):
                values = loop(*values)
                if i in max_iter:
                    trj.append(values)  # values are not mutable so no need to copy
            return self.stop_gradient_tree(self.stack_leaves(trj))
        else:
            if max_iter is None:
                cond = lambda vals: jnp.any(vals[0])
                body = lambda vals: loop(*vals)
                return jax.lax.while_loop(cond, body, values)
            else:
                cond = lambda vals: jnp.any(vals[1][0]) & (vals[0] < max_iter)
                body = lambda vals: (vals[0] + 1, loop(*vals[1]))
                return jax.lax.while_loop(cond, body, (self.as_tensor(0), values))[1]

    def max(self, x, axis=None, keepdims=False):
        return jnp.max(x, axis, keepdims=keepdims)

    def min(self, x, axis=None, keepdims=False):
        return jnp.min(x, axis, keepdims=keepdims)

    def conv(self, value, kernel, strides: Sequence[int], out_sizes: Sequence[int], transpose: bool):
        assert not transpose, "transpose conv not yet supported for Jax"
        assert kernel.shape[0] in (1, value.shape[0])
        assert value.shape[1] == kernel.shape[2], f"value has {value.shape[1]} channels but kernel has {kernel.shape[2]}"
        assert value.ndim + 1 == kernel.ndim
        value, kernel = self.auto_cast(value, kernel, bool_to_int=True)
        ndim = len(value.shape) - 2  # Number of spatial dimensions
        assert len(strides) == ndim, f"Expected {ndim} stride values, got {len(strides)}"
        # --- Determine padding ---
        # default_size = [int(np.ceil((vs - ks + 1) / st)) for vs, ks, st in zip(value.shape[2:], kernel.shape[3:], strides)]  # size if no padding is used
        lr_padding = [max(0, st * (os - 1) - vs + ks) for st, os, vs, ks in zip(strides, out_sizes, value.shape[2:], kernel.shape[3:])]
        padding = [((p+1) // 2, p // 2) for p in lr_padding]
        # --- Run the (transposed) convolution ---
        sp = ''.join(['WHD'[i] for i in range(len(strides))])
        dim_num = jax.lax.conv_dimension_numbers(value.shape, kernel.shape[1:], ('NC'+sp, 'OI'+sp, 'NC'+sp))
        if kernel.shape[0] == 1:
            return jax.lax.conv_general_dilated(value, kernel[0], strides, padding, None, None, dim_num)
        else:
            result = []
            for b in range(kernel.shape[0]):
                result.append(jax.lax.conv_general_dilated(value[b:b + 1], kernel[b], strides, padding, None, None, dim_num))
            return jnp.concatenate(result, 0)

    def expand_dims(self, a, axis=0, number=1):
        for _i in range(number):
            a = jnp.expand_dims(a, axis)
        return a

    def cast(self, x, dtype: DType):
        if self.is_tensor(x, only_native=True) and from_numpy_dtype(x.dtype) == dtype:
            return x
        else:
            return jnp.array(x, to_numpy_dtype(dtype))

    def unravel_index(self, flat_index, shape):
        return jnp.stack(jnp.unravel_index(flat_index, shape), -1)

    def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined'):
        if not self.is_available(shape):
            return Backend.ravel_multi_index(self, multi_index, shape, mode)
        mode = mode if isinstance(mode, int) else {'undefined': 'clip', 'periodic': 'wrap', 'clamp': 'clip'}[mode]
        idx_first = jnp.transpose(multi_index, (self.ndims(multi_index)-1,) + tuple(range(self.ndims(multi_index)-1)))
        result = jnp.ravel_multi_index(idx_first, shape, mode='wrap' if isinstance(mode, int) else mode)
        if isinstance(mode, int):
            outside = self.any((multi_index < 0) | (multi_index >= jnp.asarray(shape, dtype=multi_index.dtype)), -1)
            result = self.where(outside, mode, result)
        return result

    def gather(self, values, indices, axis: int):
        slices = [indices if i == axis else slice(None) for i in range(self.ndims(values))]
        return values[tuple(slices)]

    def batched_gather_nd(self, values, indices):
        values = self.as_tensor(values)
        indices = self.as_tensor(indices)
        assert indices.shape[-1] == self.ndims(values) - 2
        batch_size = combined_dim(values.shape[0], indices.shape[0])
        indices_list = [indices[..., i] for i in range(indices.shape[-1])]
        batch_range = self.expand_dims(np.arange(batch_size), -1, number=self.ndims(indices) - 2)
        slices = (batch_range, *indices_list)
        return values[slices]

    def batched_gather_1d(self, values, indices):
        batch_size = combined_dim(values.shape[0], indices.shape[0])
        return values[np.arange(batch_size)[:, None], indices]

    def repeat(self, x, repeats, axis: int, new_length=None):
        return jnp.repeat(x, self.as_tensor(repeats), axis, total_repeat_length=new_length)

    def std(self, x, axis=None, keepdims=False):
        return jnp.std(x, axis, keepdims=keepdims)

    def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0):
        if new_length is None:
            slices = [mask if i == axis else slice(None) for i in range(len(x.shape))]
            return x[tuple(slices)]
        else:
            indices = jnp.argwhere(mask, size=new_length, fill_value=-1)[..., 0]
            valid = indices >= 0
            valid = valid[tuple([slice(None) if i == axis else None for i in range(len(x.shape))])]
            result = self.gather(x, jnp.maximum(0, indices), axis)
            return jnp.where(valid, result, fill_value)

    def any(self, boolean_tensor, axis=None, keepdims=False):
        if isinstance(boolean_tensor, (tuple, list)):
            boolean_tensor = jnp.stack(boolean_tensor)
        return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims)

    def all(self, boolean_tensor, axis=None, keepdims=False):
        if isinstance(boolean_tensor, (tuple, list)):
            boolean_tensor = jnp.stack(boolean_tensor)
        return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims)

    def scatter(self, base_grid, indices, values, mode: str):
        out_kind = self.combine_types(self.dtype(base_grid), self.dtype(values)).kind
        base_grid, values = self.auto_cast(base_grid, values, bool_to_int=True)
        batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0])
        spatial_dims = tuple(range(base_grid.ndim - 2))
        dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,),  # channel dim of updates (batch dim removed)
                                                inserted_window_dims=spatial_dims,  # no idea what this does but spatial_dims seems to work
                                                scatter_dims_to_operand_dims=spatial_dims)  # spatial dims of base_grid (batch dim removed)
        scatter = {'add': jax.lax.scatter_add, 'update': jax.lax.scatter, 'max': jax.lax.scatter_max, 'min': jax.lax.scatter_min}[mode]
        def scatter_single(base_grid, indices, values):
            return scatter(base_grid, indices, values, dnums)
        if self.staticshape(indices)[0] == 1:
            indices = self.tile(indices, [batch_size, 1, 1])
        if self.staticshape(values)[0] == 1:
            values = self.tile(values, [batch_size, 1, 1])
        result = self.vectorized_call(scatter_single, base_grid, indices, values)
        if self.dtype(result).kind != out_kind:
            if out_kind == bool:
                result = self.cast(result, BOOL)
        return result

    def histogram1d(self, values, weights, bin_edges):
        def unbatched_hist(values, weights, bin_edges):
            hist, _ = jnp.histogram(values, bin_edges, weights=weights)
            return hist
        return jax.vmap(unbatched_hist)(values, weights, bin_edges)

    def bincount(self, x, weights: Optional[TensorType], bins: int, x_sorted=False):
        if x_sorted:
            return jax.ops.segment_sum(weights or 1, x, bins, indices_are_sorted=True)
        else:
            return jnp.bincount(x, weights=weights, minlength=bins, length=bins)

    def unique(self, x: TensorType, return_inverse: bool, return_counts: bool, axis: int) -> Tuple[TensorType, ...]:
        return jnp.unique(x, return_inverse=return_inverse, return_counts=return_counts, axis=axis)

    def quantile(self, x, quantiles):
        return jnp.quantile(x, quantiles, axis=-1)

    def argsort(self, x, axis=-1):
        return jnp.argsort(x, axis)

    def sort(self, x, axis=-1):
        return jnp.sort(x, axis)

    def searchsorted(self, sorted_sequence, search_values, side: str, dtype=INT32):
        if self.ndims(sorted_sequence) == 1:
            return jnp.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype))
        else:
            return jax.vmap(partial(self.searchsorted, side=side, dtype=dtype))(sorted_sequence, search_values)

    def fft(self, x, axes: Union[tuple, list]):
        x = self.to_complex(x)
        if not axes:
            return x
        if len(axes) == 1:
            return jnp.fft.fft(x, axis=axes[0]).astype(x.dtype)
        elif len(axes) == 2:
            return jnp.fft.fft2(x, axes=axes).astype(x.dtype)
        else:
            return jnp.fft.fftn(x, axes=axes).astype(x.dtype)

    def ifft(self, k, axes: Union[tuple, list]):
        if not axes:
            return k
        if len(axes) == 1:
            return jnp.fft.ifft(k, axis=axes[0]).astype(k.dtype)
        elif len(axes) == 2:
            return jnp.fft.ifft2(k, axes=axes).astype(k.dtype)
        else:
            return jnp.fft.ifftn(k, axes=axes).astype(k.dtype)

    def dtype(self, array) -> DType:
        if isinstance(array, bool):
            return BOOL
        if isinstance(array, int):
            return INT32
        if isinstance(array, float):
            return FLOAT64
        if isinstance(array, complex):
            return COMPLEX128
        if not isinstance(array, jnp.ndarray):
            array = jnp.array(array)
        return from_numpy_dtype(array.dtype)

    def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tuple[TensorType, TensorType, TensorType, TensorType]:
        solution, residuals, rank, singular_values = lstsq_batched(matrix, rhs)
        return solution, residuals, rank, singular_values

    def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool):
        matrix, rhs = self.auto_cast(matrix, rhs, int_to_float=True, bool_to_int=True)
        x = jax.lax.linalg.triangular_solve(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal, left_side=True)
        return x

    def matrix_rank_dense(self, matrix, hermitian=False):
        try:
            return jnp.linalg.matrix_rank(matrix)
        except TypeError as err:
            if err.args[0] == "array should have 2 or fewer dimensions":  # this is a Jax bug on some distributions/versions
                warnings.warn("You are using a broken version of JAX. matrix_rank for dense matrices will fall back to NumPy.")
                return self.as_tensor(NUMPY.matrix_rank_dense(self.numpy(matrix), hermitian=hermitian))
            else:
                raise err

    def eigvals(self, matrix: TensorType) -> TensorType:
        return jnp.linalg.eigvals(matrix)

    def eig(self, matrix: TensorType) -> TensorType:
        return jnp.linalg.eig(matrix)

    def svd(self, matrix: TensorType, full_matrices=True) -> Tuple[TensorType, TensorType, TensorType]:
        result = jnp.linalg.svd(matrix, full_matrices=full_matrices)
        return result[0], result[1], result[2]

    def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple):
        return BCOO((values, indices), shape=shape)

Backends delegate low-level operations to a ML or numerics library or emulate them. The methods of Backend form a comprehensive list of available operations.

To support a library, subclass Backend and register it by adding it to BACKENDS.

Args

name
Human-readable string
default_device
ComputeDevice being used by default

Ancestors

  • phiml.backend._backend.Backend

Class variables

var arccosh
var arcsinh
var arctan
var arctan2
var arctanh
var cosh
var maximum
var minimum
var sinh
var tanh

Static methods

def abs(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def abs(x: ArrayLike, /) -> Array:
  """Alias of :func:`jax.numpy.absolute`."""
  return absolute(x)

Alias of :func:jax.numpy.absolute.

def arccos(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def arccos(x: ArrayLike, /) -> Array:
  """Compute element-wise inverse of trigonometric cosine of input.

  JAX implementation of :obj:`numpy.arccos`.

  Args:
    x: input array or scalar.

  Returns:
    An array containing the inverse trigonometric cosine of each element of ``x``
    in radians in the range ``[0, pi]``, promoting to inexact dtype.

  Note:
    - ``jnp.arccos`` returns ``nan`` when ``x`` is real-valued and not in the closed
      interval ``[-1, 1]``.
    - ``jnp.arccos`` follows the branch cut convention of :obj:`numpy.arccos` for
      complex inputs.

  See also:
    - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of
      input.
    - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of
      trigonometric sine of each element of input.
    - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of
      trigonometric tangent of each element of input.

  Examples:
    >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2])
    >>> with jnp.printoptions(precision=3, suppress=True):
    ...   jnp.arccos(x)
    Array([  nan, 3.142, 2.094, 1.571, 1.047, 0.   ,   nan], dtype=float32)

    For complex inputs:

    >>> with jnp.printoptions(precision=3, suppress=True):
    ...   jnp.arccos(4-1j)
    Array(0.252+2.097j, dtype=complex64, weak_type=True)
  """
  out = lax.acos(*promote_args_inexact('arccos', x))
  jnp_error._set_error_if_nan(out)
  return out

Compute element-wise inverse of trigonometric cosine of input.

JAX implementation of :obj:numpy.arccos.

Args

x
input array or scalar.

Returns

An array containing the inverse trigonometric cosine of each element of x in radians in the range [0, pi], promoting to inexact dtype.

Note

  • jnp.arccos returns nan when x is real-valued and not in the closed interval [-1, 1].
  • jnp.arccos follows the branch cut convention of :obj:numpy.arccos for complex inputs.

See also: - :func:jax.numpy.cos: Computes a trigonometric cosine of each element of input. - :func:jax.numpy.arcsin and :func:jax.numpy.asin: Computes the inverse of trigonometric sine of each element of input. - :func:jax.numpy.arctan and :func:jax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.

Examples

>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arccos(x)
Array([  nan, 3.142, 2.094, 1.571, 1.047, 0.   ,   nan], dtype=float32)

For complex inputs:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arccos(4-1j)
Array(0.252+2.097j, dtype=complex64, weak_type=True)
def arcsin(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def arcsin(x: ArrayLike, /) -> Array:
  r"""Compute element-wise inverse of trigonometric sine of input.

  JAX implementation of :obj:`numpy.arcsin`.

  Args:
    x: input array or scalar.

  Returns:
    An array containing the inverse trigonometric sine of each element of ``x``
    in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype.

  Note:
    - ``jnp.arcsin`` returns ``nan`` when ``x`` is real-valued and not in the closed
      interval ``[-1, 1]``.
    - ``jnp.arcsin`` follows the branch cut convention of :obj:`numpy.arcsin` for
      complex inputs.

  See also:
    - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input.
    - :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of
      trigonometric cosine of each element of input.
    - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of
      trigonometric tangent of each element of input.

  Examples:
    >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2])
    >>> with jnp.printoptions(precision=3, suppress=True):
    ...   jnp.arcsin(x)
    Array([   nan, -1.571, -0.524,  0.   ,  0.524,  1.571,    nan], dtype=float32)

    For complex-valued inputs:

    >>> with jnp.printoptions(precision=3, suppress=True):
    ...   jnp.arcsin(3+4j)
    Array(0.634+2.306j, dtype=complex64, weak_type=True)
  """
  out = lax.asin(*promote_args_inexact('arcsin', x))
  jnp_error._set_error_if_nan(out)
  return out

Compute element-wise inverse of trigonometric sine of input.

JAX implementation of :obj:numpy.arcsin.

Args

x
input array or scalar.

Returns

An array containing the inverse trigonometric sine of each element of x in radians in the range [-pi/2, pi/2], promoting to inexact dtype.

Note

  • jnp.arcsin returns nan when x is real-valued and not in the closed interval [-1, 1].
  • jnp.arcsin follows the branch cut convention of :obj:numpy.arcsin for complex inputs.

See also: - :func:jax.numpy.sin: Computes a trigonometric sine of each element of input. - :func:jax.numpy.arccos and :func:jax.numpy.acos: Computes the inverse of trigonometric cosine of each element of input. - :func:jax.numpy.arctan and :func:jax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.

Examples

>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arcsin(x)
Array([   nan, -1.571, -0.524,  0.   ,  0.524,  1.571,    nan], dtype=float32)

For complex-valued inputs:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arcsin(3+4j)
Array(0.634+2.306j, dtype=complex64, weak_type=True)
def argmax(a, axis=None, out=None, *, keepdims=<no value>)
Expand source code
@array_function_dispatch(_argmax_dispatcher)
def argmax(a, axis=None, out=None, *, keepdims=np._NoValue):
    """
    Returns the indices of the maximum values along an axis.

    Parameters
    ----------
    a : array_like
        Input array.
    axis : int, optional
        By default, the index is into the flattened array, otherwise
        along the specified axis.
    out : array, optional
        If provided, the result will be inserted into this array. It should
        be of the appropriate shape and dtype.
    keepdims : bool, optional
        If this is set to True, the axes which are reduced are left
        in the result as dimensions with size one. With this option,
        the result will broadcast correctly against the array.

        .. versionadded:: 1.22.0

    Returns
    -------
    index_array : ndarray of ints
        Array of indices into the array. It has the same shape as ``a.shape``
        with the dimension along `axis` removed. If `keepdims` is set to True,
        then the size of `axis` will be 1 with the resulting array having same
        shape as ``a.shape``.

    See Also
    --------
    ndarray.argmax, argmin
    amax : The maximum value along a given axis.
    unravel_index : Convert a flat index into an index tuple.
    take_along_axis : Apply ``np.expand_dims(index_array, axis)``
                      from argmax to an array as if by calling max.

    Notes
    -----
    In case of multiple occurrences of the maximum values, the indices
    corresponding to the first occurrence are returned.

    Examples
    --------
    >>> import numpy as np
    >>> a = np.arange(6).reshape(2,3) + 10
    >>> a
    array([[10, 11, 12],
           [13, 14, 15]])
    >>> np.argmax(a)
    5
    >>> np.argmax(a, axis=0)
    array([1, 1, 1])
    >>> np.argmax(a, axis=1)
    array([2, 2])

    Indexes of the maximal elements of a N-dimensional array:

    >>> a.flat[np.argmax(a)]
    15
    >>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
    >>> ind
    (1, 2)
    >>> a[ind]
    15

    >>> b = np.arange(6)
    >>> b[1] = 5
    >>> b
    array([0, 5, 2, 3, 4, 5])
    >>> np.argmax(b)  # Only the first occurrence is returned.
    1

    >>> x = np.array([[4,2,3], [1,0,3]])
    >>> index_array = np.argmax(x, axis=-1)
    >>> # Same as np.amax(x, axis=-1, keepdims=True)
    >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
    array([[4],
           [3]])
    >>> # Same as np.amax(x, axis=-1)
    >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1),
    ...     axis=-1).squeeze(axis=-1)
    array([4, 3])

    Setting `keepdims` to `True`,

    >>> x = np.arange(24).reshape((2, 3, 4))
    >>> res = np.argmax(x, axis=1, keepdims=True)
    >>> res.shape
    (2, 1, 4)
    """
    kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
    return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)

Returns the indices of the maximum values along an axis.

Parameters

a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims : bool, optional

If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

Added in version: 1.22.0

Returns

index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.

See Also

ndarray.argmax, argmin amax : The maximum value along a given axis. unravel_index : Convert a flat index into an index tuple. take_along_axis : Apply ``np.expand_dims(index_array, axis)`` from argmax to an array as if by calling max.

Notes

In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.

Examples

>>> import numpy as np
>>> a = np.arange(6).reshape(2,3) + 10
>>> a
array([[10, 11, 12],
       [13, 14, 15]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])

Indexes of the maximal elements of a N-dimensional array:

>>> a.flat[np.argmax(a)]
15
>>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
>>> ind
(1, 2)
>>> a[ind]
15
>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b)  # Only the first occurrence is returned.
1
>>> x = np.array([[4,2,3], [1,0,3]])
>>> index_array = np.argmax(x, axis=-1)
>>> # Same as np.amax(x, axis=-1, keepdims=True)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
array([[4],
       [3]])
>>> # Same as np.amax(x, axis=-1)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1),
...     axis=-1).squeeze(axis=-1)
array([4, 3])

Setting keepdims to True,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmax(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
def argmin(a, axis=None, out=None, *, keepdims=<no value>)
Expand source code
@array_function_dispatch(_argmin_dispatcher)
def argmin(a, axis=None, out=None, *, keepdims=np._NoValue):
    """
    Returns the indices of the minimum values along an axis.

    Parameters
    ----------
    a : array_like
        Input array.
    axis : int, optional
        By default, the index is into the flattened array, otherwise
        along the specified axis.
    out : array, optional
        If provided, the result will be inserted into this array. It should
        be of the appropriate shape and dtype.
    keepdims : bool, optional
        If this is set to True, the axes which are reduced are left
        in the result as dimensions with size one. With this option,
        the result will broadcast correctly against the array.

        .. versionadded:: 1.22.0

    Returns
    -------
    index_array : ndarray of ints
        Array of indices into the array. It has the same shape as `a.shape`
        with the dimension along `axis` removed. If `keepdims` is set to True,
        then the size of `axis` will be 1 with the resulting array having same
        shape as `a.shape`.

    See Also
    --------
    ndarray.argmin, argmax
    amin : The minimum value along a given axis.
    unravel_index : Convert a flat index into an index tuple.
    take_along_axis : Apply ``np.expand_dims(index_array, axis)``
                      from argmin to an array as if by calling min.

    Notes
    -----
    In case of multiple occurrences of the minimum values, the indices
    corresponding to the first occurrence are returned.

    Examples
    --------
    >>> import numpy as np
    >>> a = np.arange(6).reshape(2,3) + 10
    >>> a
    array([[10, 11, 12],
           [13, 14, 15]])
    >>> np.argmin(a)
    0
    >>> np.argmin(a, axis=0)
    array([0, 0, 0])
    >>> np.argmin(a, axis=1)
    array([0, 0])

    Indices of the minimum elements of a N-dimensional array:

    >>> a.flat[np.argmin(a)]
    10
    >>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape)
    >>> ind
    (0, 0)
    >>> a[ind]
    10

    >>> b = np.arange(6) + 10
    >>> b[4] = 10
    >>> b
    array([10, 11, 12, 13, 10, 15])
    >>> np.argmin(b)  # Only the first occurrence is returned.
    0

    >>> x = np.array([[4,2,3], [1,0,3]])
    >>> index_array = np.argmin(x, axis=-1)
    >>> # Same as np.amin(x, axis=-1, keepdims=True)
    >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
    array([[2],
           [0]])
    >>> # Same as np.amax(x, axis=-1)
    >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1),
    ...     axis=-1).squeeze(axis=-1)
    array([2, 0])

    Setting `keepdims` to `True`,

    >>> x = np.arange(24).reshape((2, 3, 4))
    >>> res = np.argmin(x, axis=1, keepdims=True)
    >>> res.shape
    (2, 1, 4)
    """
    kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
    return _wrapfunc(a, 'argmin', axis=axis, out=out, **kwds)

Returns the indices of the minimum values along an axis.

Parameters

a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims : bool, optional

If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

Added in version: 1.22.0

Returns

index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.

See Also

ndarray.argmin, argmax amin : The minimum value along a given axis. unravel_index : Convert a flat index into an index tuple. take_along_axis : Apply ``np.expand_dims(index_array, axis)`` from argmin to an array as if by calling min.

Notes

In case of multiple occurrences of the minimum values, the indices corresponding to the first occurrence are returned.

Examples

>>> import numpy as np
>>> a = np.arange(6).reshape(2,3) + 10
>>> a
array([[10, 11, 12],
       [13, 14, 15]])
>>> np.argmin(a)
0
>>> np.argmin(a, axis=0)
array([0, 0, 0])
>>> np.argmin(a, axis=1)
array([0, 0])

Indices of the minimum elements of a N-dimensional array:

>>> a.flat[np.argmin(a)]
10
>>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape)
>>> ind
(0, 0)
>>> a[ind]
10
>>> b = np.arange(6) + 10
>>> b[4] = 10
>>> b
array([10, 11, 12, 13, 10, 15])
>>> np.argmin(b)  # Only the first occurrence is returned.
0
>>> x = np.array([[4,2,3], [1,0,3]])
>>> index_array = np.argmin(x, axis=-1)
>>> # Same as np.amin(x, axis=-1, keepdims=True)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
array([[2],
       [0]])
>>> # Same as np.amax(x, axis=-1)
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1),
...     axis=-1).squeeze(axis=-1)
array([2, 0])

Setting keepdims to True,

>>> x = np.arange(24).reshape((2, 3, 4))
>>> res = np.argmin(x, axis=1, keepdims=True)
>>> res.shape
(2, 1, 4)
def ceil(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def ceil(x: ArrayLike, /) -> Array:
  """Round input to the nearest integer upwards.

  JAX implementation of :obj:`numpy.ceil`.

  Args:
    x: input array or scalar. Must not have complex dtype.

  Returns:
    An array with same shape and dtype as ``x`` containing the values rounded to
    the nearest integer that is greater than or equal to the value itself.

  See also:
    - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero.
    - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards
      zero.
    - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer.

  Examples:
    >>> key = jax.random.key(1)
    >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...     print(x)
    [[-0.61  0.34 -0.54]
     [-0.62  3.97  0.59]
     [ 4.84  3.42 -1.14]]
    >>> jnp.ceil(x)
    Array([[-0.,  1., -0.],
           [-0.,  4.,  1.],
           [ 5.,  4., -1.]], dtype=float32)
  """
  x = ensure_arraylike('ceil', x)
  if dtypes.isdtype(x.dtype, ('integral', 'bool')):
    return lax.asarray(x)
  return lax.ceil(*promote_args_inexact('ceil', x))

Round input to the nearest integer upwards.

JAX implementation of :obj:numpy.ceil.

Args

x
input array or scalar. Must not have complex dtype.

Returns

An array with same shape and dtype as x containing the values rounded to the nearest integer that is greater than or equal to the value itself. See also: - :func:jax.numpy.fix: Rounds the input to the nearest integer towards zero. - :func:jax.numpy.trunc: Rounds the input to the nearest integer towards zero. - :func:jax.numpy.floor: Rounds the input down to the nearest integer.

Examples

>>> key = jax.random.key(1)
>>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
>>> with jnp.printoptions(precision=2, suppress=True):
...     print(x)
[[-0.61  0.34 -0.54]
 [-0.62  3.97  0.59]
 [ 4.84  3.42 -1.14]]
>>> jnp.ceil(x)
Array([[-0.,  1., -0.],
       [-0.,  4.,  1.],
       [ 5.,  4., -1.]], dtype=float32)
def clip(arr: ArrayLike | None = None,
/,
min: ArrayLike | None = None,
max: ArrayLike | None = None) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@api.jit
def clip(
  arr: ArrayLike | None = None,
  /,
  min: ArrayLike | None = None,
  max: ArrayLike | None = None,
) -> Array:
  """Clip array values to a specified range.

  JAX implementation of :func:`numpy.clip`.

  Args:
    arr: N-dimensional array to be clipped.
    min: optional minimum value of the clipped range; if ``None`` (default) then
      result will not be clipped to any minimum value. If specified, it should be
      broadcast-compatible with ``arr`` and ``max``.
    max: optional maximum value of the clipped range; if ``None`` (default) then
      result will not be clipped to any maximum value. If specified, it should be
      broadcast-compatible with ``arr`` and ``min``.

  Returns:
    An array containing values from ``arr``, with values smaller than ``min`` set
    to ``min``, and values larger than ``max`` set to ``max``.
    Wherever ``min`` is larger than ``max``, the value of ``max`` is returned.

  See also:
    - :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays.
    - :func:`jax.numpy.maximum`: Compute the element-wise maximum value of two arrays.

  Examples:
    >>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
    >>> jnp.clip(arr, 2, 5)
    Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)
  """
  if arr is None:
    raise ValueError("No input was provided to the clip function.")

  util.check_arraylike("clip", arr)
  if any(iscomplexobj(t) for t in (arr, min, max)):
    raise ValueError(
      "Clip received a complex value either through the input or the min/max "
      "keywords. Complex values have no ordering and cannot be clipped. "
      "Please convert to a real value or array by taking the real or "
      "imaginary components via jax.numpy.real/imag respectively.")
  if min is not None:
    arr = ufuncs.maximum(min, arr)
  if max is not None:
    arr = ufuncs.minimum(max, arr)
  return asarray(arr)

Clip array values to a specified range.

JAX implementation of :func:numpy.clip.

Args

arr
N-dimensional array to be clipped.
min
optional minimum value of the clipped range; if None (default) then result will not be clipped to any minimum value. If specified, it should be broadcast-compatible with arr and max.
max
optional maximum value of the clipped range; if None (default) then result will not be clipped to any maximum value. If specified, it should be broadcast-compatible with arr and min.

Returns

An array containing values from arr, with values smaller than min set to min, and values larger than max set to max. Wherever min is larger than max, the value of max is returned. See also: - :func:jax.numpy.minimum: Compute the element-wise minimum value of two arrays. - :func:jax.numpy.maximum: Compute the element-wise maximum value of two arrays.

Examples

>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
>>> jnp.clip(arr, 2, 5)
Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)
def concat(arrays: np.ndarray | Array | Sequence[ArrayLike],
axis: int | None = 0,
dtype: DTypeLike | None = None) ‑> jax.jaxlib._jax.Array
Expand source code
@export
def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
                axis: int | None = 0, dtype: DTypeLike | None = None) -> Array:
  """Join arrays along an existing axis.

  JAX implementation of :func:`numpy.concatenate`.

  Args:
    arrays: a sequence of arrays to concatenate; each must have the same shape
      except along the specified axis. If a single array is given it will be
      treated equivalently to `arrays = unstack(arrays)`, but the implementation
      will avoid explicit unstacking.
    axis: specify the axis along which to concatenate. If None, the arrays are
      flattened before concatenation.
    dtype: optional dtype of the resulting array. If not specified, the dtype
      will be determined via type promotion rules described in :ref:`type-promotion`.

  Returns:
    the concatenated result.

  See also:
    - :func:`jax.lax.concatenate`: XLA concatenation API.
    - :func:`jax.numpy.concat`: Array API version of this function.
    - :func:`jax.numpy.stack`: concatenate arrays along a new axis.

  Examples:
    One-dimensional concatenation:

    >>> x = jnp.arange(3)
    >>> y = jnp.zeros(3, dtype=int)
    >>> jnp.concatenate([x, y])
    Array([0, 1, 2, 0, 0, 0], dtype=int32)

    Two-dimensional concatenation:

    >>> x = jnp.ones((2, 3))
    >>> y = jnp.zeros((2, 1))
    >>> jnp.concatenate([x, y], axis=1)
    Array([[1., 1., 1., 0.],
           [1., 1., 1., 0.]], dtype=float32)
  """
  if isinstance(arrays, (np.ndarray, Array)):
    return _concatenate_array(arrays, axis, dtype=dtype)
  arrays = util.ensure_arraylike_tuple("concatenate", arrays)
  if not len(arrays):
    raise ValueError("Need at least one array to concatenate.")
  if axis is None:
    return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
  if np.ndim(arrays[0]) == 0:
    raise ValueError("Zero-dimensional arrays cannot be concatenated.")
  axis = _canonicalize_axis(axis, np.ndim(arrays[0]))
  if dtype is None:
    arrays_out = util.promote_dtypes(*arrays)
  else:
    arrays_out = [asarray(arr, dtype=dtype) for arr in arrays]
  # lax.concatenate can be slow to compile for wide concatenations, so form a
  # tree of concatenations as a workaround especially for op-by-op mode.
  # (https://github.com/jax-ml/jax/issues/653).
  k = 16
  while len(arrays_out) > 1:
    arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
                  for i in range(0, len(arrays_out), k)]
  return arrays_out[0]

Join arrays along an existing axis.

JAX implementation of :func:numpy.concatenate.

Args

arrays
a sequence of arrays to concatenate; each must have the same shape except along the specified axis. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.
axis
specify the axis along which to concatenate. If None, the arrays are flattened before concatenation.
dtype
optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in :ref:type-promotion.

Returns

the concatenated result. See also: - :func:jax.lax.concatenate: XLA concatenation API. - :func:jax.numpy.concat: Array API version of this function. - :func:jax.numpy.stack: concatenate arrays along a new axis.

Examples

One-dimensional concatenation:

>>> x = jnp.arange(3)
>>> y = jnp.zeros(3, dtype=int)
>>> jnp.concatenate([x, y])
Array([0, 1, 2, 0, 0, 0], dtype=int32)

Two-dimensional concatenation:

>>> x = jnp.ones((2, 3))
>>> y = jnp.zeros((2, 1))
>>> jnp.concatenate([x, y], axis=1)
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.]], dtype=float32)
def conj(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def conjugate(x: ArrayLike, /) -> Array:
  """Return element-wise complex-conjugate of the input.

  JAX implementation of :obj:`numpy.conjugate`.

  Args:
    x: inpuat array or scalar.

  Returns:
    An array containing the complex-conjugate of ``x``.

  See also:
    - :func:`jax.numpy.real`: Returns the element-wise real part of the complex
      argument.
    - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the
      complex argument.

  Examples:
    >>> jnp.conjugate(3)
    Array(3, dtype=int32, weak_type=True)
    >>> x = jnp.array([2-1j, 3+5j, 7])
    >>> jnp.conjugate(x)
    Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64)
  """
  x = ensure_arraylike("conjugate", x)
  return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)

Return element-wise complex-conjugate of the input.

JAX implementation of :obj:numpy.conjugate.

Args

x
inpuat array or scalar.

Returns

An array containing the complex-conjugate of x. See also: - :func:jax.numpy.real: Returns the element-wise real part of the complex argument. - :func:jax.numpy.imag: Returns the element-wise imaginary part of the complex argument.

Examples

>>> jnp.conjugate(3)
Array(3, dtype=int32, weak_type=True)
>>> x = jnp.array([2-1j, 3+5j, 7])
>>> jnp.conjugate(x)
Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64)
def cos(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def cos(x: ArrayLike, /) -> Array:
  """Compute a trigonometric cosine of each element of input.

  JAX implementation of :obj:`numpy.cos`.

  Args:
    x: scalar or array. Angle in radians.

  Returns:
    An array containing the cosine of each element in ``x``, promotes to inexact
    dtype.

  See also:
    - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input.
    - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of
      input.
    - :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of
      trigonometric cosine of each element of input.

  Examples:
    >>> pi = jnp.pi
    >>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6])
    >>> with jnp.printoptions(precision=3, suppress=True):
    ...   print(jnp.cos(x))
    [ 0.707 -0.    -0.707 -0.866]
  """
  out = lax.cos(*promote_args_inexact('cos', x))
  jnp_error._set_error_if_nan(out)
  return out

Compute a trigonometric cosine of each element of input.

JAX implementation of :obj:numpy.cos.

Args

x
scalar or array. Angle in radians.

Returns

An array containing the cosine of each element in x, promotes to inexact dtype. See also: - :func:jax.numpy.sin: Computes a trigonometric sine of each element of input. - :func:jax.numpy.tan: Computes a trigonometric tangent of each element of input. - :func:jax.numpy.arccos and :func:jax.numpy.acos: Computes the inverse of trigonometric cosine of each element of input.

Examples

>>> pi = jnp.pi
>>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6])
>>> with jnp.printoptions(precision=3, suppress=True):
...   print(jnp.cos(x))
[ 0.707 -0.    -0.707 -0.866]
def cumsum(a: ArrayLike,
axis: int | None = None,
dtype: DTypeLike | None = None,
out: None = None) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@api.jit(static_argnames=('axis', 'dtype'))
def cumsum(a: ArrayLike, axis: int | None = None,
           dtype: DTypeLike | None = None, out: None = None) -> Array:
  """Cumulative sum of elements along an axis.

  JAX implementation of :func:`numpy.cumsum`.

  Args:
    a: N-dimensional array to be accumulated.
    axis: integer axis along which to accumulate. If None (default), then
      array will be flattened and accumulated along the flattened axis.
    dtype: optionally specify the dtype of the output. If not specified,
      then the output dtype will match the input dtype.
    out: unused by JAX

  Returns:
    An array containing the accumulated sum along the given axis.

  See also:
    - :func:`jax.numpy.cumulative_sum`: cumulative sum via the array API standard.
    - :meth:`jax.numpy.add.accumulate`: cumulative sum via ufunc methods.
    - :func:`jax.numpy.nancumsum`: cumulative sum ignoring NaN values.
    - :func:`jax.numpy.sum`: sum along axis

  Examples:
    >>> x = jnp.array([[1, 2, 3],
    ...                [4, 5, 6]])
    >>> jnp.cumsum(x)  # flattened cumulative sum
    Array([ 1,  3,  6, 10, 15, 21], dtype=int32)
    >>> jnp.cumsum(x, axis=1)  # cumulative sum along axis 1
    Array([[ 1,  3,  6],
           [ 4,  9, 15]], dtype=int32)
  """
  return _cumulative_reduction("cumsum", control_flow.cumsum, a, axis, dtype, out)

Cumulative sum of elements along an axis.

JAX implementation of :func:numpy.cumsum.

Args

a
N-dimensional array to be accumulated.
axis
integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.
dtype
optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.
out
unused by JAX

Returns

An array containing the accumulated sum along the given axis. See also: - :func:jax.numpy.cumulative_sum: cumulative sum via the array API standard. - :meth:jax.numpy.add.accumulate: cumulative sum via ufunc methods. - :func:jax.numpy.nancumsum: cumulative sum ignoring NaN values. - :func:jax.numpy.sum: sum along axis

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.cumsum(x)  # flattened cumulative sum
Array([ 1,  3,  6, 10, 15, 21], dtype=int32)
>>> jnp.cumsum(x, axis=1)  # cumulative sum along axis 1
Array([[ 1,  3,  6],
       [ 4,  9, 15]], dtype=int32)
def einsum(subscripts,
/,
*operands,
out: None = None,
optimize: str | bool | list[tuple[int, ...]] = 'auto',
precision: str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision] | jax._src.lax.lax.DotAlgorithm | jax._src.lax.lax.DotAlgorithmPreset | None = None,
preferred_element_type: str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType | None = None,
out_sharding=None) ‑> jax.jaxlib._jax.Array
Expand source code
@export
def einsum(
    subscripts, /,
    *operands,
    out: None = None,
    optimize: str | bool | list[tuple[int, ...]] = "auto",
    precision: lax.PrecisionLike = None,
    preferred_element_type: DTypeLike | None = None,
    _dot_general: Callable[..., Array] = lax.dot_general,
    out_sharding=None,
) -> Array:
  """Einstein summation

  JAX implementation of :func:`numpy.einsum`.

  ``einsum`` is a powerful and generic API for computing various reductions,
  inner products, outer products, axis reorderings, and combinations thereof
  across one or more input arrays. It has a somewhat complicated overloaded API;
  the arguments below reflect the most common calling convention. The Examples
  section below demonstrates some of the alternative calling conventions.

  Args:
    subscripts: string containing axes names separated by commas.
    *operands: sequence of one or more arrays corresponding to the subscripts.
    optimize: specify how to optimize the order of computation. In JAX this defaults
      to ``"auto"`` which produces optimized expressions via the opt_einsum_
      package. Other options are ``True`` (same as ``"optimal"``), ``False``
      (unoptimized), or any string supported by ``opt_einsum``, which
      includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also
      be a pre-computed path (see :func:`~jax.numpy.einsum_path`).
    precision: either ``None`` (default), which means the default precision for
      the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``).
    preferred_element_type: either ``None`` (default), which means the default
      accumulation type for the input types, or a datatype, indicating to
      accumulate results to and return a result with that datatype.
    out: unsupported by JAX
    _dot_general: optionally override the ``dot_general`` callable used by ``einsum``.
      This parameter is experimental, and may be removed without warning at any time.

  Returns:
    array containing the result of the einstein summation.

  See also:
    :func:`jax.numpy.einsum_path`

  Examples:
    The mechanics of ``einsum`` are perhaps best demonstrated by example. Here we
    show how to use ``einsum`` to compute a number of quantities from one or more
    arrays. For more discussion and examples of ``einsum``, see the documentation
    of :func:`numpy.einsum`.

    >>> M = jnp.arange(16).reshape(4, 4)
    >>> x = jnp.arange(4)
    >>> y = jnp.array([5, 4, 3, 2])

    **Vector product**

    >>> jnp.einsum('i,i', x, y)
    Array(16, dtype=int32)
    >>> jnp.vecdot(x, y)
    Array(16, dtype=int32)

    Here are some alternative ``einsum`` calling conventions to compute the same
    result:

    >>> jnp.einsum('i,i->', x, y)  # explicit form
    Array(16, dtype=int32)
    >>> jnp.einsum(x, (0,), y, (0,))  # implicit form via indices
    Array(16, dtype=int32)
    >>> jnp.einsum(x, (0,), y, (0,), ())  # explicit form via indices
    Array(16, dtype=int32)

    **Matrix product**

    >>> jnp.einsum('ij,j->i', M, x)  # explicit form
    Array([14, 38, 62, 86], dtype=int32)
    >>> jnp.matmul(M, x)
    Array([14, 38, 62, 86], dtype=int32)

    Here are some alternative ``einsum`` calling conventions to compute the same
    result:

    >>> jnp.einsum('ij,j', M, x) # implicit form
    Array([14, 38, 62, 86], dtype=int32)
    >>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices
    Array([14, 38, 62, 86], dtype=int32)
    >>> jnp.einsum(M, (0, 1), x, (1,))  # implicit form via indices
    Array([14, 38, 62, 86], dtype=int32)

    **Outer product**

    >>> jnp.einsum("i,j->ij", x, y)
    Array([[ 0,  0,  0,  0],
           [ 5,  4,  3,  2],
           [10,  8,  6,  4],
           [15, 12,  9,  6]], dtype=int32)
    >>> jnp.outer(x, y)
    Array([[ 0,  0,  0,  0],
           [ 5,  4,  3,  2],
           [10,  8,  6,  4],
           [15, 12,  9,  6]], dtype=int32)

    Some other ways of computing outer products:

    >>> jnp.einsum("i,j", x, y)  # implicit form
    Array([[ 0,  0,  0,  0],
           [ 5,  4,  3,  2],
           [10,  8,  6,  4],
           [15, 12,  9,  6]], dtype=int32)
    >>> jnp.einsum(x, (0,), y, (1,), (0, 1))  # explicit form via indices
    Array([[ 0,  0,  0,  0],
           [ 5,  4,  3,  2],
           [10,  8,  6,  4],
           [15, 12,  9,  6]], dtype=int32)
    >>> jnp.einsum(x, (0,), y, (1,))  # implicit form via indices
    Array([[ 0,  0,  0,  0],
           [ 5,  4,  3,  2],
           [10,  8,  6,  4],
           [15, 12,  9,  6]], dtype=int32)

    **1D array sum**

    >>> jnp.einsum("i->", x)  # requires explicit form
    Array(6, dtype=int32)
    >>> jnp.einsum(x, (0,), ())  # explicit form via indices
    Array(6, dtype=int32)
    >>> jnp.sum(x)
    Array(6, dtype=int32)

    **Sum along an axis**

    >>> jnp.einsum("...j->...", M)  # requires explicit form
    Array([ 6, 22, 38, 54], dtype=int32)
    >>> jnp.einsum(M, (..., 0), (...,))  # explicit form via indices
    Array([ 6, 22, 38, 54], dtype=int32)
    >>> M.sum(-1)
    Array([ 6, 22, 38, 54], dtype=int32)

    **Matrix transpose**

    >>> y = jnp.array([[1, 2, 3],
    ...                [4, 5, 6]])
    >>> jnp.einsum("ij->ji", y)  # explicit form
    Array([[1, 4],
           [2, 5],
           [3, 6]], dtype=int32)
    >>> jnp.einsum("ji", y)  # implicit form
    Array([[1, 4],
           [2, 5],
           [3, 6]], dtype=int32)
    >>> jnp.einsum(y, (1, 0))  # implicit form via indices
    Array([[1, 4],
           [2, 5],
           [3, 6]], dtype=int32)
    >>> jnp.einsum(y, (0, 1), (1, 0))  # explicit form via indices
    Array([[1, 4],
           [2, 5],
           [3, 6]], dtype=int32)
    >>> jnp.transpose(y)
    Array([[1, 4],
           [2, 5],
           [3, 6]], dtype=int32)

    **Matrix diagonal**

    >>> jnp.einsum("ii->i", M)
    Array([ 0,  5, 10, 15], dtype=int32)
    >>> jnp.diagonal(M)
    Array([ 0,  5, 10, 15], dtype=int32)

    **Matrix trace**

    >>> jnp.einsum("ii", M)
    Array(30, dtype=int32)
    >>> jnp.trace(M)
    Array(30, dtype=int32)

    **Tensor products**

    >>> x = jnp.arange(30).reshape(2, 3, 5)
    >>> y = jnp.arange(60).reshape(3, 4, 5)
    >>> jnp.einsum('ijk,jlk->il', x, y)  # explicit form
    Array([[ 3340,  3865,  4390,  4915],
           [ 8290,  9940, 11590, 13240]], dtype=int32)
    >>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)])
    Array([[ 3340,  3865,  4390,  4915],
           [ 8290,  9940, 11590, 13240]], dtype=int32)
    >>> jnp.einsum('ijk,jlk', x, y)  # implicit form
    Array([[ 3340,  3865,  4390,  4915],
           [ 8290,  9940, 11590, 13240]], dtype=int32)
    >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3))  # explicit form via indices
    Array([[ 3340,  3865,  4390,  4915],
           [ 8290,  9940, 11590, 13240]], dtype=int32)
    >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2))  # implicit form via indices
    Array([[ 3340,  3865,  4390,  4915],
           [ 8290,  9940, 11590, 13240]], dtype=int32)

    **Chained dot products**

    >>> w = jnp.arange(5, 9).reshape(2, 2)
    >>> x = jnp.arange(6).reshape(2, 3)
    >>> y = jnp.arange(-2, 4).reshape(3, 2)
    >>> z = jnp.array([[2, 4, 6], [3, 5, 7]])
    >>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z)
    Array([[ 481,  831, 1181],
           [ 651, 1125, 1599]], dtype=int32)
    >>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4))  # implicit, via indices
    Array([[ 481,  831, 1181],
           [ 651, 1125, 1599]], dtype=int32)
    >>> w @ x @ y @ z  # direct chain of matmuls
    Array([[ 481,  831, 1181],
           [ 651, 1125, 1599]], dtype=int32)
    >>> jnp.linalg.multi_dot([w, x, y, z])
    Array([[ 481,  831, 1181],
           [ 651, 1125, 1599]], dtype=int32)

  .. _opt_einsum: https://github.com/dgasmith/opt_einsum
  """
  operands = (subscripts, *operands)
  if out is not None:
    raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
  spec = operands[0] if isinstance(operands[0], str) else None
  path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize

  # Extract __jax_array__ before passing to contract_path()
  operands = tuple(op.__jax_array__() if hasattr(op, "__jax_array__") else op
                   for op in operands)

  # Allow handling of shape polymorphism
  non_constant_dim_types = {
      type(d) for op in operands if not isinstance(op, str)
      for d in np.shape(op) if not core.is_constant_dim(d)
  }
  if not non_constant_dim_types:
    contract_path: Any = opt_einsum.contract_path
  else:
    ty = next(iter(non_constant_dim_types))
    contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
  # using einsum_call=True here is an internal api for opt_einsum... sorry
  operands, contractions = contract_path(
        *operands, einsum_call=True, use_blas=True, optimize=path_type)

  contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
  num_contractions = len(contractions)

  out_sharding = canonicalize_sharding(out_sharding, 'einsum')
  if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
    raise NotImplementedError(
        "`out_sharding` argument of `einsum` only supports NamedSharding"
        " instances.")

  jit_einsum = api.jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
  if spec is not None:
    jit_einsum = api.named_call(jit_einsum, name=spec)
  operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))

  if num_contractions > 1 and out_sharding is not None:
    # TODO(yashkatariya): If the out_sharding is unreduced, figure out a way to
    # run the dot_general unreduced_rule on these einsums because right now we
    # drop into Auto mode skipping the checks happening in the rule.
    return auto_axes(
        jit_einsum,
        axes=out_sharding.mesh.explicit_axes,
        out_sharding=out_sharding,
    )(operand_arrays, contractions=contractions, precision=precision,
      preferred_element_type=preferred_element_type, _dot_general=_dot_general,
      out_sharding=None)
  else:
    return jit_einsum(operand_arrays, contractions, precision,
                      preferred_element_type, _dot_general, out_sharding)

Einstein summation

JAX implementation of :func:numpy.einsum.

einsum is a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. It has a somewhat complicated overloaded API; the arguments below reflect the most common calling convention. The Examples section below demonstrates some of the alternative calling conventions.

Args

subscripts
string containing axes names separated by commas.
*operands
sequence of one or more arrays corresponding to the subscripts.
optimize
specify how to optimize the order of computation. In JAX this defaults to "auto" which produces optimized expressions via the opt_einsum_ package. Other options are True (same as "optimal"), False (unoptimized), or any string supported by opt_einsum, which includes "optimal", "greedy", "eager", and others. It may also be a pre-computed path (see :func:~jax.numpy.einsum_path).
precision
either None (default), which means the default precision for the backend, a :class:~jax.lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).
preferred_element_type
either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
out
unsupported by JAX
_dot_general
optionally override the dot_general callable used by einsum. This parameter is experimental, and may be removed without warning at any time.

Returns

array containing the result of the einstein summation. See also: :func:jax.numpy.einsum_path

Examples

The mechanics of einsum are perhaps best demonstrated by example. Here we show how to use einsum to compute a number of quantities from one or more arrays. For more discussion and examples of einsum, see the documentation of :func:numpy.einsum.

>>> M = jnp.arange(16).reshape(4, 4)
>>> x = jnp.arange(4)
>>> y = jnp.array([5, 4, 3, 2])

Vector product

>>> jnp.einsum('i,i', x, y)
Array(16, dtype=int32)
>>> jnp.vecdot(x, y)
Array(16, dtype=int32)

Here are some alternative einsum calling conventions to compute the same result:

>>> jnp.einsum('i,i->', x, y)  # explicit form
Array(16, dtype=int32)
>>> jnp.einsum(x, (0,), y, (0,))  # implicit form via indices
Array(16, dtype=int32)
>>> jnp.einsum(x, (0,), y, (0,), ())  # explicit form via indices
Array(16, dtype=int32)

Matrix product

>>> jnp.einsum('ij,j->i', M, x)  # explicit form
Array([14, 38, 62, 86], dtype=int32)
>>> jnp.matmul(M, x)
Array([14, 38, 62, 86], dtype=int32)

Here are some alternative einsum calling conventions to compute the same result:

>>> jnp.einsum('ij,j', M, x) # implicit form
Array([14, 38, 62, 86], dtype=int32)
>>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices
Array([14, 38, 62, 86], dtype=int32)
>>> jnp.einsum(M, (0, 1), x, (1,))  # implicit form via indices
Array([14, 38, 62, 86], dtype=int32)

Outer product

>>> jnp.einsum("i,j->ij", x, y)
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)
>>> jnp.outer(x, y)
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)

Some other ways of computing outer products:

>>> jnp.einsum("i,j", x, y)  # implicit form
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)
>>> jnp.einsum(x, (0,), y, (1,), (0, 1))  # explicit form via indices
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)
>>> jnp.einsum(x, (0,), y, (1,))  # implicit form via indices
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)

1D array sum

>>> jnp.einsum("i->", x)  # requires explicit form
Array(6, dtype=int32)
>>> jnp.einsum(x, (0,), ())  # explicit form via indices
Array(6, dtype=int32)
>>> jnp.sum(x)
Array(6, dtype=int32)

Sum along an axis

>>> jnp.einsum("...j->...", M)  # requires explicit form
Array([ 6, 22, 38, 54], dtype=int32)
>>> jnp.einsum(M, (..., 0), (...,))  # explicit form via indices
Array([ 6, 22, 38, 54], dtype=int32)
>>> M.sum(-1)
Array([ 6, 22, 38, 54], dtype=int32)

Matrix transpose

>>> y = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.einsum("ij->ji", y)  # explicit form
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> jnp.einsum("ji", y)  # implicit form
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> jnp.einsum(y, (1, 0))  # implicit form via indices
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> jnp.einsum(y, (0, 1), (1, 0))  # explicit form via indices
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> jnp.transpose(y)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

Matrix diagonal

>>> jnp.einsum("ii->i", M)
Array([ 0,  5, 10, 15], dtype=int32)
>>> jnp.diagonal(M)
Array([ 0,  5, 10, 15], dtype=int32)

Matrix trace

>>> jnp.einsum("ii", M)
Array(30, dtype=int32)
>>> jnp.trace(M)
Array(30, dtype=int32)

Tensor products

>>> x = jnp.arange(30).reshape(2, 3, 5)
>>> y = jnp.arange(60).reshape(3, 4, 5)
>>> jnp.einsum('ijk,jlk->il', x, y)  # explicit form
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)
>>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)])
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)
>>> jnp.einsum('ijk,jlk', x, y)  # implicit form
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)
>>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3))  # explicit form via indices
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)
>>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2))  # implicit form via indices
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)

Chained dot products

>>> w = jnp.arange(5, 9).reshape(2, 2)
>>> x = jnp.arange(6).reshape(2, 3)
>>> y = jnp.arange(-2, 4).reshape(3, 2)
>>> z = jnp.array([[2, 4, 6], [3, 5, 7]])
>>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z)
Array([[ 481,  831, 1181],
       [ 651, 1125, 1599]], dtype=int32)
>>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4))  # implicit, via indices
Array([[ 481,  831, 1181],
       [ 651, 1125, 1599]], dtype=int32)
>>> w @ x @ y @ z  # direct chain of matmuls
Array([[ 481,  831, 1181],
       [ 651, 1125, 1599]], dtype=int32)
>>> jnp.linalg.multi_dot([w, x, y, z])
Array([[ 481,  831, 1181],
       [ 651, 1125, 1599]], dtype=int32)

.. _opt_einsum: https://github.com/dgasmith/opt_einsum

def equal(x: ArrayLike, y: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def equal(x: ArrayLike, y: ArrayLike, /) -> Array:
  """Returns element-wise truth value of ``x == y``.

  JAX implementation of :obj:`numpy.equal`. This function provides the implementation
  of the ``==`` operator for JAX arrays.

  Args:
    x: input array or scalar.
    y: input array or scalar. ``x`` and ``y`` should either have same shape or be
      broadcast compatible.

  Returns:
    A boolean array containing ``True`` where the elements of ``x == y`` and
    ``False`` otherwise.

  See also:
    - :func:`jax.numpy.not_equal`: Returns element-wise truth value of ``x != y``.
    - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of
      ``x >= y``.
    - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``.
    - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``.
    - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``.

  Examples:
    >>> jnp.equal(0., -0.)
    Array(True, dtype=bool, weak_type=True)
    >>> jnp.equal(1, 1.)
    Array(True, dtype=bool, weak_type=True)
    >>> jnp.equal(5, jnp.array(5))
    Array(True, dtype=bool, weak_type=True)
    >>> jnp.equal(2, -2)
    Array(False, dtype=bool, weak_type=True)
    >>> x = jnp.array([[1, 2, 3],
    ...                [4, 5, 6],
    ...                [7, 8, 9]])
    >>> y = jnp.array([1, 5, 9])
    >>> jnp.equal(x, y)
    Array([[ True, False, False],
           [False,  True, False],
           [False, False,  True]], dtype=bool)
    >>> x == y
    Array([[ True, False, False],
           [False,  True, False],
           [False, False,  True]], dtype=bool)
  """
  return lax.eq(*promote_args("equal", x, y))

Returns element-wise truth value of x == y.

JAX implementation of :obj:numpy.equal. This function provides the implementation of the == operator for JAX arrays.

Args

x
input array or scalar.
y
input array or scalar. x and y should either have same shape or be broadcast compatible.

Returns

A boolean array containing True where the elements of x == y and False otherwise. See also: - :func:jax.numpy.not_equal: Returns element-wise truth value of x != y. - :func:jax.numpy.greater_equal: Returns element-wise truth value of x >= y. - :func:jax.numpy.less_equal: Returns element-wise truth value of x <= y. - :func:jax.numpy.greater: Returns element-wise truth value of x > y. - :func:jax.numpy.less: Returns element-wise truth value of x < y.

Examples

>>> jnp.equal(0., -0.)
Array(True, dtype=bool, weak_type=True)
>>> jnp.equal(1, 1.)
Array(True, dtype=bool, weak_type=True)
>>> jnp.equal(5, jnp.array(5))
Array(True, dtype=bool, weak_type=True)
>>> jnp.equal(2, -2)
Array(False, dtype=bool, weak_type=True)
>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> y = jnp.array([1, 5, 9])
>>> jnp.equal(x, y)
Array([[ True, False, False],
       [False,  True, False],
       [False, False,  True]], dtype=bool)
>>> x == y
Array([[ True, False, False],
       [False,  True, False],
       [False, False,  True]], dtype=bool)
def erf(x: ArrayLike) ‑> jax.jaxlib._jax.Array
Expand source code
def erf(x: ArrayLike) -> Array:
  r"""The error function

  JAX implementation of :obj:`scipy.special.erf`.

  .. math::

     \mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t

  Args:
    x: arraylike, real-valued.

  Returns:
    array containing values of the error function.

  Notes:
     The JAX version only supports real-valued inputs.

  See also:
    - :func:`jax.scipy.special.erfc`
    - :func:`jax.scipy.special.erfinv`
  """
  x, = promote_args_inexact("erf", x)
  return lax.erf(x)

The error function

JAX implementation of :obj:scipy.special.erf.

[ \mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t ]

Args

x
arraylike, real-valued.

Returns

array containing values of the error function.

Notes

The JAX version only supports real-valued inputs.

See also: - :func:jax.scipy.special.erfc - :func:jax.scipy.special.erfinv

def exp(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def exp(x: ArrayLike, /) -> Array:
  """Calculate element-wise exponential of the input.

  JAX implementation of :obj:`numpy.exp`.

  Args:
    x: input array or scalar

  Returns:
    An array containing the exponential of each element in ``x``, promotes to
    inexact dtype.

  See also:
    - :func:`jax.numpy.log`: Calculates element-wise logarithm of the input.
    - :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the
      input.
    - :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of
      the input.

  Examples:
    ``jnp.exp`` follows the properties of exponential such as :math:`e^{(a+b)}
    = e^a * e^b`.

    >>> x1 = jnp.array([2, 4, 3, 1])
    >>> x2 = jnp.array([1, 3, 2, 3])
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jnp.exp(x1+x2))
    [  20.09 1096.63  148.41   54.6 ]
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jnp.exp(x1)*jnp.exp(x2))
    [  20.09 1096.63  148.41   54.6 ]

    This property holds for complex input also:

    >>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j))
    Array(True, dtype=bool)
  """
  return lax.exp(*promote_args_inexact('exp', x))

Calculate element-wise exponential of the input.

JAX implementation of :obj:numpy.exp.

Args

x
input array or scalar

Returns

An array containing the exponential of each element in x, promotes to inexact dtype. See also: - :func:jax.numpy.log: Calculates element-wise logarithm of the input. - :func:jax.numpy.expm1: Calculates :math:e^x-1 of each element of the input. - :func:jax.numpy.exp2: Calculates base-2 exponential of each element of the input.

Examples

jnp.exp follows the properties of exponential such as :math:e^{(a+b)} = e^a * e^b.

>>> x1 = jnp.array([2, 4, 3, 1])
>>> x2 = jnp.array([1, 3, 2, 3])
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.exp(x1+x2))
[  20.09 1096.63  148.41   54.6 ]
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.exp(x1)*jnp.exp(x2))
[  20.09 1096.63  148.41   54.6 ]

This property holds for complex input also:

>>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j))
Array(True, dtype=bool)
def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) ‑> jax.jaxlib._jax.Array
Expand source code
@export
def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
  """Reverse the order of elements of an array along the given axis.

  JAX implementation of :func:`numpy.flip`.

  Args:
    m: Array.
    axis: integer or sequence of integers. Specifies along which axis or axes
      should the array elements be reversed. Default is ``None``, which flips
      along all axes.

  Returns:
    An array with the elements in reverse order along ``axis``.

  See Also:
    - :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right)
    - :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down)

  Examples:
    >>> x1 = jnp.array([[1, 2],
    ...                 [3, 4]])
    >>> jnp.flip(x1)
    Array([[4, 3],
           [2, 1]], dtype=int32)

    If ``axis`` is specified with an integer, then ``jax.numpy.flip`` reverses
    the array along that particular axis only.

    >>> jnp.flip(x1, axis=1)
    Array([[2, 1],
           [4, 3]], dtype=int32)

    >>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
    >>> x2
    Array([[[1, 2],
            [3, 4]],
    <BLANKLINE>
           [[5, 6],
            [7, 8]]], dtype=int32)
    >>> jnp.flip(x2)
    Array([[[8, 7],
            [6, 5]],
    <BLANKLINE>
           [[4, 3],
            [2, 1]]], dtype=int32)

    When ``axis`` is specified with a sequence of integers, then
    ``jax.numpy.flip`` reverses the array along the specified axes.

    >>> jnp.flip(x2, axis=[1, 2])
    Array([[[4, 3],
            [2, 1]],
    <BLANKLINE>
           [[8, 7],
            [6, 5]]], dtype=int32)
  """
  arr = util.ensure_arraylike("flip", m)
  return _flip(arr, reductions._ensure_optional_axes(axis))

Reverse the order of elements of an array along the given axis.

JAX implementation of :func:numpy.flip.

Args

m
Array.
axis
integer or sequence of integers. Specifies along which axis or axes should the array elements be reversed. Default is None, which flips along all axes.

Returns

An array with the elements in reverse order along axis. See Also: - :func:jax.numpy.fliplr: reverse the order along axis 1 (left/right) - :func:jax.numpy.flipud: reverse the order along axis 0 (up/down)

Examples

>>> x1 = jnp.array([[1, 2],
...                 [3, 4]])
>>> jnp.flip(x1)
Array([[4, 3],
       [2, 1]], dtype=int32)

If axis is specified with an integer, then jax.numpy.flip reverses the array along that particular axis only.

>>> jnp.flip(x1, axis=1)
Array([[2, 1],
       [4, 3]], dtype=int32)
>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x2
Array([[[1, 2],
        [3, 4]],
<BLANKLINE>
       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.flip(x2)
Array([[[8, 7],
        [6, 5]],
<BLANKLINE>
       [[4, 3],
        [2, 1]]], dtype=int32)

When axis is specified with a sequence of integers, then jax.numpy.flip reverses the array along the specified axes.

>>> jnp.flip(x2, axis=[1, 2])
Array([[[4, 3],
        [2, 1]],
<BLANKLINE>
       [[8, 7],
        [6, 5]]], dtype=int32)
def floor(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def floor(x: ArrayLike, /) -> Array:
  """Round input to the nearest integer downwards.

  JAX implementation of :obj:`numpy.floor`.

  Args:
    x: input array or scalar. Must not have complex dtype.

  Returns:
    An array with same shape and dtype as ``x`` containing the values rounded to
    the nearest integer that is less than or equal to the value itself.

  See also:
    - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero.
    - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards
      zero.
    - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer.

  Examples:
    >>> key = jax.random.key(42)
    >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...     print(x)
    [[-0.11  1.8   1.16]
     [ 0.61 -0.49  0.86]
     [-4.25  2.75  1.99]]
    >>> jnp.floor(x)
    Array([[-1.,  1.,  1.],
           [ 0., -1.,  0.],
           [-5.,  2.,  1.]], dtype=float32)
  """
  x = ensure_arraylike('floor', x)
  if dtypes.isdtype(x.dtype, ('integral', 'bool')):
    return x
  return lax.floor(*promote_args_inexact('floor', x))

Round input to the nearest integer downwards.

JAX implementation of :obj:numpy.floor.

Args

x
input array or scalar. Must not have complex dtype.

Returns

An array with same shape and dtype as x containing the values rounded to the nearest integer that is less than or equal to the value itself. See also: - :func:jax.numpy.fix: Rounds the input to the nearest integer towards zero. - :func:jax.numpy.trunc: Rounds the input to the nearest integer towards zero. - :func:jax.numpy.ceil: Rounds the input up to the nearest integer.

Examples

>>> key = jax.random.key(42)
>>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
>>> with jnp.printoptions(precision=2, suppress=True):
...     print(x)
[[-0.11  1.8   1.16]
 [ 0.61 -0.49  0.86]
 [-4.25  2.75  1.99]]
>>> jnp.floor(x)
Array([[-1.,  1.,  1.],
       [ 0., -1.,  0.],
       [-5.,  2.,  1.]], dtype=float32)
def imag(val: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def imag(val: ArrayLike, /) -> Array:
  """Return element-wise imaginary of part of the complex argument.

  JAX implementation of :obj:`numpy.imag`.

  Args:
    val: input array or scalar.

  Returns:
    An array containing the imaginary part of the elements of ``val``.

  See also:
    - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise
      complex-conjugate of the input.
    - :func:`jax.numpy.real`: Returns the element-wise real part of the complex
      argument.

  Examples:
    >>> jnp.imag(4)
    Array(0, dtype=int32, weak_type=True)
    >>> jnp.imag(5j)
    Array(5., dtype=float32, weak_type=True)
    >>> x = jnp.array([2+3j, 5-1j, -3])
    >>> jnp.imag(x)
    Array([ 3., -1.,  0.], dtype=float32)
  """
  val = ensure_arraylike("imag", val)
  return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)

Return element-wise imaginary of part of the complex argument.

JAX implementation of :obj:numpy.imag.

Args

val
input array or scalar.

Returns

An array containing the imaginary part of the elements of val. See also: - :func:jax.numpy.conjugate and :func:jax.numpy.conj: Returns the element-wise complex-conjugate of the input. - :func:jax.numpy.real: Returns the element-wise real part of the complex argument.

Examples

>>> jnp.imag(4)
Array(0, dtype=int32, weak_type=True)
>>> jnp.imag(5j)
Array(5., dtype=float32, weak_type=True)
>>> x = jnp.array([2+3j, 5-1j, -3])
>>> jnp.imag(x)
Array([ 3., -1.,  0.], dtype=float32)
def isfinite(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def isfinite(x: ArrayLike, /) -> Array:
  """Return a boolean array indicating whether each element of input is finite.

  JAX implementation of :obj:`numpy.isfinite`.

  Args:
    x: input array or scalar.

  Returns:
    A boolean array of same shape as ``x`` containing ``True`` where ``x`` is
    not ``inf``, ``-inf``, or ``NaN``, and ``False`` otherwise.

  See also:
    - :func:`jax.numpy.isinf`: Returns a boolean array indicating whether each
      element of input is either positive or negative infinity.
    - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each
      element of input is positive infinity.
    - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each
      element of input is negative infinity.
    - :func:`jax.numpy.isnan`: Returns a boolean array indicating whether each
      element of input is not a number (``NaN``).

  Examples:
    >>> x = jnp.array([-1, 3, jnp.inf, jnp.nan])
    >>> jnp.isfinite(x)
    Array([ True,  True, False, False], dtype=bool)
    >>> jnp.isfinite(3-4j)
    Array(True, dtype=bool, weak_type=True)
  """
  x = ensure_arraylike("isfinite", x)
  dtype = x.dtype
  if dtypes.issubdtype(dtype, np.floating):
    return lax.is_finite(x)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
  else:
    return lax.full_like(x, True, dtype=np.bool_)

Return a boolean array indicating whether each element of input is finite.

JAX implementation of :obj:numpy.isfinite.

Args

x
input array or scalar.

Returns

A boolean array of same shape as x containing True where x is not inf, -inf, or NaN, and False otherwise. See also: - :func:jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity. - :func:jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity. - :func:jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity. - :func:jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).

Examples

>>> x = jnp.array([-1, 3, jnp.inf, jnp.nan])
>>> jnp.isfinite(x)
Array([ True,  True, False, False], dtype=bool)
>>> jnp.isfinite(3-4j)
Array(True, dtype=bool, weak_type=True)
def isinf(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit
def isinf(x: ArrayLike, /) -> Array:
  """Return a boolean array indicating whether each element of input is infinite.

  JAX implementation of :obj:`numpy.isinf`.

  Args:
    x: input array or scalar.

  Returns:
    A boolean array of same shape as ``x`` containing ``True`` where ``x`` is
    ``inf`` or ``-inf``, and ``False`` otherwise.

  See also:
    - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each
      element of input is positive infinity.
    - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each
      element of input is negative infinity.
    - :func:`jax.numpy.isfinite`: Returns a boolean array indicating whether each
      element of input is finite.
    - :func:`jax.numpy.isnan`: Returns a boolean array indicating whether each
      element of input is not a number (``NaN``).

  Examples:
    >>> jnp.isinf(jnp.inf)
    Array(True, dtype=bool)
    >>> x = jnp.array([2+3j, -jnp.inf, 6, jnp.inf, jnp.nan])
    >>> jnp.isinf(x)
    Array([False,  True, False,  True, False], dtype=bool)
  """
  x = ensure_arraylike("isinf", x)
  dtype = x.dtype
  if dtypes.issubdtype(dtype, np.floating):
    return lax.eq(lax.abs(x), _constant_like(x, np.inf))
  elif dtypes.issubdtype(dtype, np.complexfloating):
    re = lax.real(x)
    im = lax.imag(x)
    return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
                          lax.eq(lax.abs(im), _constant_like(im, np.inf)))
  else:
    return lax.full_like(x, False, dtype=np.bool_)

Return a boolean array indicating whether each element of input is infinite.

JAX implementation of :obj:numpy.isinf.

Args

x
input array or scalar.

Returns

A boolean array of same shape as x containing True where x is inf or -inf, and False otherwise. See also: - :func:jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity. - :func:jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity. - :func:jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite. - :func:jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).

Examples

>>> jnp.isinf(jnp.inf)
Array(True, dtype=bool)
>>> x = jnp.array([2+3j, -jnp.inf, 6, jnp.inf, jnp.nan])
>>> jnp.isinf(x)
Array([False,  True, False,  True, False], dtype=bool)
def isnan(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def isnan(x: ArrayLike, /) -> Array:
  """Returns a boolean array indicating whether each element of input is ``NaN``.

  JAX implementation of :obj:`numpy.isnan`.

  Args:
    x: input array or scalar.

  Returns:
    A boolean array of same shape as ``x`` containing ``True`` where ``x`` is
    not a number (i.e. ``NaN``) and ``False`` otherwise.

  See also:
    - :func:`jax.numpy.isfinite`: Returns a boolean array indicating whether each
      element of input is finite.
    - :func:`jax.numpy.isinf`: Returns a boolean array indicating whether each
      element of input is either positive or negative infinity.
    - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each
      element of input is positive infinity.
    - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each
      element of input is negative infinity.

  Examples:
    >>> jnp.isnan(6)
    Array(False, dtype=bool, weak_type=True)
    >>> x = jnp.array([2, 1+4j, jnp.inf, jnp.nan])
    >>> jnp.isnan(x)
    Array([False, False, False,  True], dtype=bool)
  """
  x = ensure_arraylike("isnan", x)
  return lax.ne(x, x)

Returns a boolean array indicating whether each element of input is NaN.

JAX implementation of :obj:numpy.isnan.

Args

x
input array or scalar.

Returns

A boolean array of same shape as x containing True where x is not a number (i.e. NaN) and False otherwise. See also: - :func:jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite. - :func:jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity. - :func:jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity. - :func:jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.

Examples

>>> jnp.isnan(6)
Array(False, dtype=bool, weak_type=True)
>>> x = jnp.array([2, 1+4j, jnp.inf, jnp.nan])
>>> jnp.isnan(x)
Array([False, False, False,  True], dtype=bool)
def log(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def log(x: ArrayLike, /) -> Array:
  """Calculate element-wise natural logarithm of the input.

  JAX implementation of :obj:`numpy.log`.

  Args:
    x: input array or scalar.

  Returns:
    An array containing the logarithm of each element in ``x``, promotes to inexact
    dtype.

  See also:
    - :func:`jax.numpy.exp`: Calculates element-wise exponential of the input.
    - :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input.
    - :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input.

  Examples:
    ``jnp.log`` and ``jnp.exp`` are inverse functions of each other. Applying
    ``jnp.log`` on the result of ``jnp.exp(x)`` yields the original input ``x``.

    >>> x = jnp.array([2, 3, 4, 5])
    >>> jnp.log(jnp.exp(x))
    Array([2., 3., 4., 5.], dtype=float32)

    Using ``jnp.log`` we can demonstrate well-known properties of logarithms, such
    as :math:`log(a*b) = log(a)+log(b)`.

    >>> x1 = jnp.array([2, 1, 3, 1])
    >>> x2 = jnp.array([1, 3, 2, 4])
    >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2))
    Array(True, dtype=bool)
  """
  out = lax.log(*promote_args_inexact('log', x))
  jnp_error._set_error_if_nan(out)
  return out

Calculate element-wise natural logarithm of the input.

JAX implementation of :obj:numpy.log.

Args

x
input array or scalar.

Returns

An array containing the logarithm of each element in x, promotes to inexact dtype. See also: - :func:jax.numpy.exp: Calculates element-wise exponential of the input. - :func:jax.numpy.log2: Calculates base-2 logarithm of each element of input. - :func:jax.numpy.log1p: Calculates element-wise logarithm of one plus input.

Examples

jnp.log and jnp.exp are inverse functions of each other. Applying jnp.log on the result of jnp.exp(x) yields the original input x.

>>> x = jnp.array([2, 3, 4, 5])
>>> jnp.log(jnp.exp(x))
Array([2., 3., 4., 5.], dtype=float32)

Using jnp.log we can demonstrate well-known properties of logarithms, such as :math:log(a*b) = log(a)+log(b).

>>> x1 = jnp.array([2, 1, 3, 1])
>>> x2 = jnp.array([1, 3, 2, 4])
>>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2))
Array(True, dtype=bool)
def log10(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def log10(x: ArrayLike, /) -> Array:
  """Calculates the base-10 logarithm of x element-wise

  JAX implementation of :obj:`numpy.log10`.

  Args:
    x: Input array

  Returns:
    An array containing the base-10 logarithm of each element in ``x``, promotes
    to inexact dtype.

  Examples:
    >>> x1 = jnp.array([0.01, 0.1, 1, 10, 100, 1000])
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jnp.log10(x1))
    [-2. -1.  0.  1.  2.  3.]
  """
  x, = promote_args_inexact("log10", x)
  one_over_log10 = np.array(0.4342944819032518,  # exact value of 1 / log(10)
                            dtype=dtypes.finfo(x.dtype).dtype)
  if dtypes.issubdtype(x.dtype, np.complexfloating):
    r = lax.log(x)
    re = lax.real(r)
    im = lax.imag(r)
    return lax.complex(lax.mul(re, one_over_log10), lax.mul(im, one_over_log10))
  out = lax.mul(lax.log(x), one_over_log10)
  jnp_error._set_error_if_nan(out)
  return out

Calculates the base-10 logarithm of x element-wise

JAX implementation of :obj:numpy.log10.

Args

x
Input array

Returns

An array containing the base-10 logarithm of each element in x, promotes to inexact dtype.

Examples

>>> x1 = jnp.array([0.01, 0.1, 1, 10, 100, 1000])
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.log10(x1))
[-2. -1.  0.  1.  2.  3.]
def log2(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def log2(x: ArrayLike, /) -> Array:
  """Calculates the base-2 logarithm of ``x`` element-wise.

  JAX implementation of :obj:`numpy.log2`.

  Args:
    x: Input array

  Returns:
    An array containing the base-2 logarithm of each element in ``x``, promotes
    to inexact dtype.

  Examples:
    >>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8])
    >>> jnp.log2(x1)
    Array([-2., -1.,  0.,  1.,  2.,  3.], dtype=float32)
  """
  x, = promote_args_inexact("log2", x)
  if dtypes.issubdtype(x.dtype, np.complexfloating):
    r = lax.log(x)
    re = lax.real(r)
    im = lax.imag(r)
    ln2 = lax.log(_constant_like(re, 2))
    return lax.complex(lax.div(re, ln2), lax.div(im, ln2))
  out = lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
  jnp_error._set_error_if_nan(out)
  return out

Calculates the base-2 logarithm of x element-wise.

JAX implementation of :obj:numpy.log2.

Args

x
Input array

Returns

An array containing the base-2 logarithm of each element in x, promotes to inexact dtype.

Examples

>>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8])
>>> jnp.log2(x1)
Array([-2., -1.,  0.,  1.,  2.,  3.], dtype=float32)
def real(val: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def real(val: ArrayLike, /) -> Array:
  """Return element-wise real part of the complex argument.

  JAX implementation of :obj:`numpy.real`.

  Args:
    val: input array or scalar.

  Returns:
    An array containing the real part of the elements of ``val``.

  See also:
    - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise
      complex-conjugate of the input.
    - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the
      complex argument.

  Examples:
    >>> jnp.real(5)
    Array(5, dtype=int32, weak_type=True)
    >>> jnp.real(2j)
    Array(0., dtype=float32, weak_type=True)
    >>> x = jnp.array([3-2j, 4+7j, -2j])
    >>> jnp.real(x)
    Array([ 3.,  4., -0.], dtype=float32)
  """
  val = ensure_arraylike("real", val)
  return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)

Return element-wise real part of the complex argument.

JAX implementation of :obj:numpy.real.

Args

val
input array or scalar.

Returns

An array containing the real part of the elements of val. See also: - :func:jax.numpy.conjugate and :func:jax.numpy.conj: Returns the element-wise complex-conjugate of the input. - :func:jax.numpy.imag: Returns the element-wise imaginary part of the complex argument.

Examples

>>> jnp.real(5)
Array(5, dtype=int32, weak_type=True)
>>> jnp.real(2j)
Array(0., dtype=float32, weak_type=True)
>>> x = jnp.array([3-2j, 4+7j, -2j])
>>> jnp.real(x)
Array([ 3.,  4., -0.], dtype=float32)
def round(a: ArrayLike, decimals: int = 0, out: None = None) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@api.jit(static_argnames=('decimals',))
def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
  """Round input evenly to the given number of decimals.

  JAX implementation of :func:`numpy.round`.

  Args:
    a: input array or scalar.
    decimals: int, default=0. Number of decimal points to which the input needs
      to be rounded. It must be specified statically. Not implemented for
      ``decimals < 0``.
    out: Unused by JAX.

  Returns:
    An array containing the rounded values to the specified ``decimals`` with
    same shape and dtype as ``a``.

  Note:
    ``jnp.round`` rounds to the nearest even integer for the values exactly halfway
    between rounded decimal values.

  See also:
    - :func:`jax.numpy.floor`: Rounds the input to the nearest integer downwards.
    - :func:`jax.numpy.ceil`: Rounds the input to the nearest integer upwards.
    - :func:`jax.numpy.fix` and :func:numpy.trunc`: Rounds the input to the
      nearest integer towards zero.

  Examples:
    >>> x = jnp.array([1.532, 3.267, 6.149])
    >>> jnp.round(x)
    Array([2., 3., 6.], dtype=float32)
    >>> jnp.round(x, decimals=2)
    Array([1.53, 3.27, 6.15], dtype=float32)

    For values exactly halfway between rounded values:

    >>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5])
    >>> jnp.round(x1)
    Array([10., 22., 12., 32.], dtype=float32)
  """
  a = util.ensure_arraylike("round", a)
  decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
  if out is not None:
    raise NotImplementedError("The 'out' argument to jnp.round is not supported.")
  dtype = a.dtype
  if issubdtype(dtype, np.integer):
    if decimals < 0:
      raise NotImplementedError(
        "integer np.round not implemented for decimals < 0")
    return a  # no-op on integer types

  def _round_float(x: ArrayLike) -> Array:
    if decimals == 0:
      return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)

    # TODO(phawkins): the strategy of rescaling the value isn't necessarily a
    # good one since we may be left with an incorrectly rounded value at the
    # end due to precision problems. As a workaround for float16, convert to
    # float32,
    x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x
    factor = lax._const(x, 10 ** decimals)
    out = lax.div(lax.round(lax.mul(x, factor),
                            lax.RoundingMethod.TO_NEAREST_EVEN), factor)
    return lax.convert_element_type(out, dtype) if dtype == np.float16 else out

  if decimals > np.log10(dtypes.finfo(dtype).max):
    # Rounding beyond the input precision is a no-op.
    return lax.asarray(a)
  if issubdtype(dtype, np.complexfloating):
    return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a)))
  else:
    return _round_float(a)

Round input evenly to the given number of decimals.

JAX implementation of :func:numpy.round.

Args

a
input array or scalar.
decimals
int, default=0. Number of decimal points to which the input needs to be rounded. It must be specified statically. Not implemented for decimals < 0.
out
Unused by JAX.

Returns

An array containing the rounded values to the specified decimals with same shape and dtype as a.

Note

jnp.round rounds to the nearest even integer for the values exactly halfway between rounded decimal values.

See also: - :func:jax.numpy.floor: Rounds the input to the nearest integer downwards. - :func:jax.numpy.ceil: Rounds the input to the nearest integer upwards. - :func:jax.numpy.fix and :func:numpy.trunc`: Rounds the input to the nearest integer towards zero.

Examples

>>> x = jnp.array([1.532, 3.267, 6.149])
>>> jnp.round(x)
Array([2., 3., 6.], dtype=float32)
>>> jnp.round(x, decimals=2)
Array([1.53, 3.27, 6.15], dtype=float32)

For values exactly halfway between rounded values:

>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5])
>>> jnp.round(x1)
Array([10., 22., 12., 32.], dtype=float32)
def shape(a: ArrayLike | SupportsShape) ‑> tuple[int, ...]
Expand source code
@export
def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]:
  """Return the shape an array.

  JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function
  raises a :class:`TypeError` if the input is a collection such as a list or
  tuple.

  Args:
    a: array-like object, or any object with a ``shape`` attribute.

  Returns:
    An tuple of integers representing the shape of ``a``.

  Examples:
    Shape for arrays:

    >>> x = jnp.arange(10)
    >>> jnp.shape(x)
    (10,)
    >>> y = jnp.ones((2, 3))
    >>> jnp.shape(y)
    (2, 3)

    This also works for scalars:

    >>> jnp.shape(3.14)
    ()

    For arrays, this can also be accessed via the :attr:`jax.Array.shape` property:

    >>> x.shape
    (10,)
  """
  if hasattr(a, "shape"):
    return a.shape
  # Deprecation warning added 2025-2-20.
  check_arraylike("shape", a, emit_warning=True)
  if hasattr(a, "__jax_array__"):
    a = a.__jax_array__()
  # NumPy dispatches to a.shape if available.
  return np.shape(a)

Return the shape an array.

JAX implementation of :func:numpy.shape. Unlike np.shape, this function raises a :class:TypeError if the input is a collection such as a list or tuple.

Args

a
array-like object, or any object with a shape attribute.

Returns

An tuple of integers representing the shape of a.

Examples

Shape for arrays:

>>> x = jnp.arange(10)
>>> jnp.shape(x)
(10,)
>>> y = jnp.ones((2, 3))
>>> jnp.shape(y)
(2, 3)

This also works for scalars:

>>> jnp.shape(3.14)
()

For arrays, this can also be accessed via the :attr:jax.Array.shape property:

>>> x.shape
(10,)
def sign(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def sign(x: ArrayLike, /) -> Array:
  r"""Return an element-wise indication of sign of the input.

  JAX implementation of :obj:`numpy.sign`.

  The sign of ``x`` for real-valued input is:

  .. math::
    \mathrm{sign}(x) = \begin{cases}
      1, & x > 0\\
      0, & x = 0\\
      -1, & x < 0
    \end{cases}

  For complex valued input, ``jnp.sign`` returns a unit vector representing the
  phase. For generalized case, the sign of ``x`` is given by:

  .. math::
    \mathrm{sign}(x) = \begin{cases}
      \frac{x}{abs(x)}, & x \ne 0\\
      0, & x = 0
    \end{cases}

  Args:
    x: input array or scalar.

  Returns:
    An array with same shape and dtype as ``x`` containing the sign indication.

  See also:
    - :func:`jax.numpy.positive`: Returns element-wise positive values of the input.
    - :func:`jax.numpy.negative`: Returns element-wise negative values of the input.

  Examples:
    For Real-valued inputs:

    >>> x = jnp.array([0., -3., 7.])
    >>> jnp.sign(x)
    Array([ 0., -1.,  1.], dtype=float32)

    For complex-inputs:

    >>> x1 = jnp.array([1, 3+4j, 5j])
    >>> jnp.sign(x1)
    Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)
  """
  return lax.sign(*promote_args('sign', x))

Return an element-wise indication of sign of the input.

JAX implementation of :obj:numpy.sign.

The sign of x for real-valued input is:

[ \mathrm{sign}(x) = \begin{cases} 1, & x > 0\ 0, & x = 0\ -1, & x < 0 \end{cases} ]

For complex valued input, jnp.sign returns a unit vector representing the phase. For generalized case, the sign of x is given by:

[ \mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\ 0, & x = 0 \end{cases} ]

Args

x
input array or scalar.

Returns

An array with same shape and dtype as x containing the sign indication. See also: - :func:jax.numpy.positive: Returns element-wise positive values of the input. - :func:jax.numpy.negative: Returns element-wise negative values of the input.

Examples

For Real-valued inputs:

>>> x = jnp.array([0., -3., 7.])
>>> jnp.sign(x)
Array([ 0., -1.,  1.], dtype=float32)

For complex-inputs:

>>> x1 = jnp.array([1, 3+4j, 5j])
>>> jnp.sign(x1)
Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)
def sin(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def sin(x: ArrayLike, /) -> Array:
  """Compute a trigonometric sine of each element of input.

  JAX implementation of :obj:`numpy.sin`.

  Args:
    x: array or scalar. Angle in radians.

  Returns:
    An array containing the sine of each element in ``x``, promotes to inexact
    dtype.

  See also:
    - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of
      input.
    - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of
      input.
    - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of
      trigonometric sine of each element of input.

  Examples:
    >>> pi = jnp.pi
    >>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi])
    >>> with jnp.printoptions(precision=3, suppress=True):
    ...   print(jnp.sin(x))
    [ 0.707  1.     0.707 -0.   ]
  """
  out = lax.sin(*promote_args_inexact('sin', x))
  jnp_error._set_error_if_nan(out)
  return out

Compute a trigonometric sine of each element of input.

JAX implementation of :obj:numpy.sin.

Args

x
array or scalar. Angle in radians.

Returns

An array containing the sine of each element in x, promotes to inexact dtype. See also: - :func:jax.numpy.cos: Computes a trigonometric cosine of each element of input. - :func:jax.numpy.tan: Computes a trigonometric tangent of each element of input. - :func:jax.numpy.arcsin and :func:jax.numpy.asin: Computes the inverse of trigonometric sine of each element of input.

Examples

>>> pi = jnp.pi
>>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi])
>>> with jnp.printoptions(precision=3, suppress=True):
...   print(jnp.sin(x))
[ 0.707  1.     0.707 -0.   ]
def softplus(x: ArrayLike) ‑> jax.jaxlib._jax.Array
Expand source code
@api.jit
def softplus(x: ArrayLike) -> Array:
  r"""Softplus activation function.

  Computes the element-wise function

  .. math::
    \mathrm{softplus}(x) = \log(1 + e^x)

  Args:
    x : input array
  """
  return jnp.logaddexp(x, 0)

Softplus activation function.

Computes the element-wise function

[ \mathrm{softplus}(x) = \log(1 + e^x) ]

Args

x : input array

def sqrt(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def sqrt(x: ArrayLike, /) -> Array:
  """Calculates element-wise non-negative square root of the input array.

  JAX implementation of :obj:`numpy.sqrt`.

  Args:
    x: input array or scalar.

  Returns:
    An array containing the non-negative square root of the elements of ``x``.

  Note:
    - For real-valued negative inputs, ``jnp.sqrt`` produces a ``nan`` output.
    - For complex-valued negative inputs, ``jnp.sqrt`` produces a ``complex`` output.

  See also:
    - :func:`jax.numpy.square`: Calculates the element-wise square of the input.
    - :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential
      of ``x2``.

  Examples:
    >>> x = jnp.array([-8-6j, 1j, 4])
    >>> with jnp.printoptions(precision=3, suppress=True):
    ...   jnp.sqrt(x)
    Array([1.   -3.j   , 0.707+0.707j, 2.   +0.j   ], dtype=complex64)
    >>> jnp.sqrt(-1)
    Array(nan, dtype=float32, weak_type=True)
  """
  out = lax.sqrt(*promote_args_inexact('sqrt', x))
  jnp_error._set_error_if_nan(out)
  return out

Calculates element-wise non-negative square root of the input array.

JAX implementation of :obj:numpy.sqrt.

Args

x
input array or scalar.

Returns

An array containing the non-negative square root of the elements of x.

Note

  • For real-valued negative inputs, jnp.sqrt produces a nan output.
  • For complex-valued negative inputs, jnp.sqrt produces a complex output.

See also: - :func:jax.numpy.square: Calculates the element-wise square of the input. - :func:jax.numpy.power: Calculates the element-wise base x1 exponential of x2.

Examples

>>> x = jnp.array([-8-6j, 1j, 4])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sqrt(x)
Array([1.   -3.j   , 0.707+0.707j, 2.   +0.j   ], dtype=complex64)
>>> jnp.sqrt(-1)
Array(nan, dtype=float32, weak_type=True)
def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
axis: int = 0,
out: None = None,
dtype: DTypeLike | None = None) ‑> jax.jaxlib._jax.Array
Expand source code
@export
def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
          axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array:
  """Join arrays along a new axis.

  JAX implementation of :func:`numpy.stack`.

  Args:
    arrays: a sequence of arrays to stack; each must have the same shape. If a
      single array is given it will be treated equivalently to
      `arrays = unstack(arrays)`, but the implementation will avoid explicit
      unstacking.
    axis: specify the axis along which to stack.
    out: unused by JAX
    dtype: optional dtype of the resulting array. If not specified, the dtype
      will be determined via type promotion rules described in :ref:`type-promotion`.

  Returns:
    the stacked result.

  See also:
    - :func:`jax.numpy.unstack`: inverse of ``stack``.
    - :func:`jax.numpy.concatenate`: concatenation along existing axes.
    - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0.
    - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1.
    - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2.
    - :func:`jax.numpy.column_stack`: stack columns.

  Examples:
    >>> x = jnp.array([1, 2, 3])
    >>> y = jnp.array([4, 5, 6])
    >>> jnp.stack([x, y])
    Array([[1, 2, 3],
           [4, 5, 6]], dtype=int32)
    >>> jnp.stack([x, y], axis=1)
    Array([[1, 4],
           [2, 5],
           [3, 6]], dtype=int32)

    :func:`~jax.numpy.unstack` performs the inverse operation:

    >>> arr = jnp.stack([x, y], axis=1)
    >>> x, y = jnp.unstack(arr, axis=1)
    >>> x
    Array([1, 2, 3], dtype=int32)
    >>> y
    Array([4, 5, 6], dtype=int32)
  """
  if not len(arrays):
    raise ValueError("Need at least one array to stack.")
  if out is not None:
    raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
  if isinstance(arrays, (np.ndarray, Array)):
    axis = _canonicalize_axis(axis, arrays.ndim)
    return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype)
  else:
    arrays = util.ensure_arraylike_tuple("stack", arrays)
    shape0 = np.shape(arrays[0])
    axis = _canonicalize_axis(axis, len(shape0) + 1)
    new_arrays = []
    for a in arrays:
      if np.shape(a) != shape0:
        raise ValueError("All input arrays must have the same shape.")
      new_arrays.append(expand_dims(a, axis))
    return concatenate(new_arrays, axis=axis, dtype=dtype)

Join arrays along a new axis.

JAX implementation of :func:numpy.stack.

Args

arrays
a sequence of arrays to stack; each must have the same shape. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.
axis
specify the axis along which to stack.
out
unused by JAX
dtype
optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in :ref:type-promotion.

Returns

the stacked result. See also: - :func:jax.numpy.unstack: inverse of stack. - :func:jax.numpy.concatenate: concatenation along existing axes. - :func:jax.numpy.vstack: stack vertically, i.e. along axis 0. - :func:jax.numpy.hstack: stack horizontally, i.e. along axis 1. - :func:jax.numpy.dstack: stack depth-wise, i.e. along axis 2. - :func:jax.numpy.column_stack: stack columns.

Examples

>>> x = jnp.array([1, 2, 3])
>>> y = jnp.array([4, 5, 6])
>>> jnp.stack([x, y])
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.stack([x, y], axis=1)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

:func:~jax.numpy.unstack performs the inverse operation:

>>> arr = jnp.stack([x, y], axis=1)
>>> x, y = jnp.unstack(arr, axis=1)
>>> x
Array([1, 2, 3], dtype=int32)
>>> y
Array([4, 5, 6], dtype=int32)
def staticshape(a: ArrayLike | SupportsShape) ‑> tuple[int, ...]
Expand source code
@export
def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]:
  """Return the shape an array.

  JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function
  raises a :class:`TypeError` if the input is a collection such as a list or
  tuple.

  Args:
    a: array-like object, or any object with a ``shape`` attribute.

  Returns:
    An tuple of integers representing the shape of ``a``.

  Examples:
    Shape for arrays:

    >>> x = jnp.arange(10)
    >>> jnp.shape(x)
    (10,)
    >>> y = jnp.ones((2, 3))
    >>> jnp.shape(y)
    (2, 3)

    This also works for scalars:

    >>> jnp.shape(3.14)
    ()

    For arrays, this can also be accessed via the :attr:`jax.Array.shape` property:

    >>> x.shape
    (10,)
  """
  if hasattr(a, "shape"):
    return a.shape
  # Deprecation warning added 2025-2-20.
  check_arraylike("shape", a, emit_warning=True)
  if hasattr(a, "__jax_array__"):
    a = a.__jax_array__()
  # NumPy dispatches to a.shape if available.
  return np.shape(a)

Return the shape an array.

JAX implementation of :func:numpy.shape. Unlike np.shape, this function raises a :class:TypeError if the input is a collection such as a list or tuple.

Args

a
array-like object, or any object with a shape attribute.

Returns

An tuple of integers representing the shape of a.

Examples

Shape for arrays:

>>> x = jnp.arange(10)
>>> jnp.shape(x)
(10,)
>>> y = jnp.ones((2, 3))
>>> jnp.shape(y)
(2, 3)

This also works for scalars:

>>> jnp.shape(3.14)
()

For arrays, this can also be accessed via the :attr:jax.Array.shape property:

>>> x.shape
(10,)
def stop_gradient(x: T) ‑> ~T
Expand source code
def stop_gradient(x: T) -> T:
  """Stops gradient computation.

  Operationally ``stop_gradient`` is the identity function, that is, it returns
  argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
  gradients during forward or reverse-mode automatic differentiation. If there
  are multiple nested gradient computations, ``stop_gradient`` stops gradients
  for all of them. For some discussion of where this is useful, refer to
  :ref:`stopping-gradients`.

  Args:
    x: array or pytree of arrays

  Returns:
    input value is returned unchanged, but within autodiff will be treated as
    a constant.

  Examples:
    Consider a simple function that returns the square of the input value:

    >>> def f1(x):
    ...   return x ** 2
    >>> x = jnp.float32(3.0)
    >>> f1(x)
    Array(9.0, dtype=float32)
    >>> jax.grad(f1)(x)
    Array(6.0, dtype=float32)

    The same function with ``stop_gradient`` around ``x`` will be equivalent
    under normal evaluation, but return a zero gradient because ``x`` is
    effectively treated as a constant:

    >>> def f2(x):
    ...   return jax.lax.stop_gradient(x) ** 2
    >>> f2(x)
    Array(9.0, dtype=float32)
    >>> jax.grad(f2)(x)
    Array(0.0, dtype=float32)

    This is used in a number of places within the JAX codebase; for example
    :func:`jax.nn.softmax` internally normalizes the input by its maximum
    value, and this maximum value is wrapped in ``stop_gradient`` for
    efficiency. Refer to :ref:`stopping-gradients` for more discussion of
    the applicability of ``stop_gradient``.
  """
  return tree_util.tree_map(_stop_gradient, x)

Stops gradient computation.

Operationally stop_gradient is the identity function, that is, it returns argument x unchanged. However, stop_gradient prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, stop_gradient stops gradients for all of them. For some discussion of where this is useful, refer to :ref:stopping-gradients.

Args

x
array or pytree of arrays

Returns

input value is returned unchanged, but within autodiff will be treated as a constant.

Examples

Consider a simple function that returns the square of the input value:

>>> def f1(x):
...   return x ** 2
>>> x = jnp.float32(3.0)
>>> f1(x)
Array(9.0, dtype=float32)
>>> jax.grad(f1)(x)
Array(6.0, dtype=float32)

The same function with stop_gradient around x will be equivalent under normal evaluation, but return a zero gradient because x is effectively treated as a constant:

>>> def f2(x):
...   return jax.lax.stop_gradient(x) ** 2
>>> f2(x)
Array(9.0, dtype=float32)
>>> jax.grad(f2)(x)
Array(0.0, dtype=float32)

This is used in a number of places within the JAX codebase; for example :func:jax.nn.softmax internally normalizes the input by its maximum value, and this maximum value is wrapped in stop_gradient for efficiency. Refer to :ref:stopping-gradients for more discussion of the applicability of stop_gradient.

def tan(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array
Expand source code
@export
@jit(inline=True)
def tan(x: ArrayLike, /) -> Array:
  """Compute a trigonometric tangent of each element of input.

  JAX implementation of :obj:`numpy.tan`.

  Args:
    x: scalar or array. Angle in radians.

  Returns:
    An array containing the tangent of each element in ``x``, promotes to inexact
    dtype.

  See also:
    - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input.
    - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of
      input.
    - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of
      trigonometric tangent of each element of input.

  Examples:
    >>> pi = jnp.pi
    >>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6])
    >>> with jnp.printoptions(precision=3, suppress=True):
    ...   print(jnp.tan(x))
    [ 0.     0.577  1.    -1.    -0.577]
  """
  out = lax.tan(*promote_args_inexact('tan', x))
  jnp_error._set_error_if_nan(out)
  return out

Compute a trigonometric tangent of each element of input.

JAX implementation of :obj:numpy.tan.

Args

x
scalar or array. Angle in radians.

Returns

An array containing the tangent of each element in x, promotes to inexact dtype. See also: - :func:jax.numpy.sin: Computes a trigonometric sine of each element of input. - :func:jax.numpy.cos: Computes a trigonometric cosine of each element of input. - :func:jax.numpy.arctan and :func:jax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.

Examples

>>> pi = jnp.pi
>>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6])
>>> with jnp.printoptions(precision=3, suppress=True):
...   print(jnp.tan(x))
[ 0.     0.577  1.    -1.    -0.577]
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) ‑> jax.jaxlib._jax.Array
Expand source code
@export
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
  """Construct an array by repeating ``A`` along specified dimensions.

  JAX implementation of :func:`numpy.tile`.

  If ``A`` is an array of shape ``(d1, d2, ..., dn)`` and ``reps`` is a sequence of integers,
  the resulting array will have a shape of ``(reps[0] * d1, reps[1] * d2, ..., reps[n] * dn)``,
  with ``A`` tiled along each dimension.

  Args:
    A: input array to be repeated. Can be of any shape or dimension.
    reps: specifies the number of repetitions along each axis.

  Returns:
    a new array where the input array has been repeated according to ``reps``.

  See also:
    - :func:`jax.numpy.repeat`: Construct an array from repeated elements.
    - :func:`jax.numpy.broadcast_to`: Broadcast an array to a specified shape.

  Examples:
    >>> arr = jnp.array([1, 2])
    >>> jnp.tile(arr, 2)
    Array([1, 2, 1, 2], dtype=int32)
    >>> arr = jnp.array([[1, 2],
    ...                  [3, 4,]])
    >>> jnp.tile(arr, (2, 1))
    Array([[1, 2],
           [3, 4],
           [1, 2],
           [3, 4]], dtype=int32)
  """
  A = util.ensure_arraylike("tile", A)
  try:
    reps_tup = tuple(iter(reps))  # pyrefly: ignore[no-matching-overload]
  except TypeError:
    reps_tup: tuple[DimSize, ...] = (reps,)
  reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
                   for rep in reps_tup)
  # lax.tile expects reps and A.shape to have the same rank.
  reps_tup = (1,) * (A.ndim - len(reps_tup)) + reps_tup
  if len(reps_tup) > np.ndim(A):
    A = lax.expand_dims(
        A, dimensions=tuple(range(len(reps_tup) - np.ndim(A))))
  return lax.tile(A, reps_tup)

Construct an array by repeating A along specified dimensions.

JAX implementation of :func:numpy.tile.

If A is an array of shape (d1, d2, …, dn) and reps is a sequence of integers, the resulting array will have a shape of (reps[0] * d1, reps[1] * d2, ..., reps[n] * dn), with A tiled along each dimension.

Args

A
input array to be repeated. Can be of any shape or dimension.
reps
specifies the number of repetitions along each axis.

Returns

a new array where the input array has been repeated according to reps. See also: - :func:jax.numpy.repeat: Construct an array from repeated elements. - :func:jax.numpy.broadcast_to: Broadcast an array to a specified shape.

Examples

>>> arr = jnp.array([1, 2])
>>> jnp.tile(arr, 2)
Array([1, 2, 1, 2], dtype=int32)
>>> arr = jnp.array([[1, 2],
...                  [3, 4,]])
>>> jnp.tile(arr, (2, 1))
Array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]], dtype=int32)
def transpose(a: ArrayLike, axes: Sequence[int] | None = None) ‑> jax.jaxlib._jax.Array
Expand source code
@export
def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
  """Return a transposed version of an N-dimensional array.

  JAX implementation of :func:`numpy.transpose`, implemented in terms of
  :func:`jax.lax.transpose`.

  Args:
    a: input array
    axes: optionally specify the permutation using a length-`a.ndim` sequence of integers
      ``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e.
      reverses the order of all axes.

  Returns:
    transposed copy of the array.

  See Also:
    - :func:`jax.Array.transpose`: equivalent function via an :class:`~jax.Array` method.
    - :attr:`jax.Array.T`: equivalent function via an :class:`~jax.Array`  property.
    - :func:`jax.numpy.matrix_transpose`: transpose the last two axes of an array. This is
      suitable for working with batched 2D matrices.
    - :func:`jax.numpy.swapaxes`: swap any two axes in an array.
    - :func:`jax.numpy.moveaxis`: move an axis to another position in the array.

  Note:
    Unlike :func:`numpy.transpose`, :func:`jax.numpy.transpose` will return a copy rather
    than a view of the input array. However, under JIT, the compiler will optimize-away
    such copies when possible, so this doesn't have performance impacts in practice.

  Examples:
    For a 1D array, the transpose is the identity:

    >>> x = jnp.array([1, 2, 3, 4])
    >>> jnp.transpose(x)
    Array([1, 2, 3, 4], dtype=int32)

    For a 2D array, the transpose is a matrix transpose:

    >>> x = jnp.array([[1, 2],
    ...                [3, 4]])
    >>> jnp.transpose(x)
    Array([[1, 3],
           [2, 4]], dtype=int32)

    For an N-dimensional array, the transpose reverses the order of the axes:

    >>> x = jnp.zeros(shape=(3, 4, 5))
    >>> jnp.transpose(x).shape
    (5, 4, 3)

    The ``axes`` argument can be specified to change this default behavior:

    >>> jnp.transpose(x, (0, 2, 1)).shape
    (3, 5, 4)

    Since swapping the last two axes is a common operation, it can be done
    via its own API, :func:`jax.numpy.matrix_transpose`:

    >>> jnp.matrix_transpose(x).shape
    (3, 5, 4)

    For convenience, transposes may also be performed using the :meth:`jax.Array.transpose`
    method or the :attr:`jax.Array.T` property:

    >>> x = jnp.array([[1, 2],
    ...                [3, 4]])
    >>> x.transpose()
    Array([[1, 3],
           [2, 4]], dtype=int32)
    >>> x.T
    Array([[1, 3],
           [2, 4]], dtype=int32)
  """
  a = util.ensure_arraylike("transpose", a)
  axes_ = list(range(a.ndim)[::-1]) if axes is None else axes
  axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_]
  return lax.transpose(a, axes_)

Return a transposed version of an N-dimensional array.

JAX implementation of :func:numpy.transpose, implemented in terms of :func:jax.lax.transpose.

Args

a
input array
axes
optionally specify the permutation using a length-a.ndim sequence of integers i satisfying 0 <= i < a.ndim. Defaults to range(a.ndim)[::-1], i.e. reverses the order of all axes.

Returns

transposed copy of the array. See Also: - :func:jax.Array.transpose: equivalent function via an :class:~jax.Array method. - :attr:jax.Array.T: equivalent function via an :class:~jax.Array property. - :func:jax.numpy.matrix_transpose: transpose the last two axes of an array. This is suitable for working with batched 2D matrices. - :func:jax.numpy.swapaxes: swap any two axes in an array. - :func:jax.numpy.moveaxis: move an axis to another position in the array.

Note

Unlike :func:numpy.transpose, :func:jax.numpy.transpose will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn't have performance impacts in practice.

Examples

For a 1D array, the transpose is the identity:

>>> x = jnp.array([1, 2, 3, 4])
>>> jnp.transpose(x)
Array([1, 2, 3, 4], dtype=int32)

For a 2D array, the transpose is a matrix transpose:

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.transpose(x)
Array([[1, 3],
       [2, 4]], dtype=int32)

For an N-dimensional array, the transpose reverses the order of the axes:

>>> x = jnp.zeros(shape=(3, 4, 5))
>>> jnp.transpose(x).shape
(5, 4, 3)

The axes argument can be specified to change this default behavior:

>>> jnp.transpose(x, (0, 2, 1)).shape
(3, 5, 4)

Since swapping the last two axes is a common operation, it can be done via its own API, :func:jax.numpy.matrix_transpose:

>>> jnp.matrix_transpose(x).shape
(3, 5, 4)

For convenience, transposes may also be performed using the :meth:jax.Array.transpose method or the :attr:jax.Array.T property:

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> x.transpose()
Array([[1, 3],
       [2, 4]], dtype=int32)
>>> x.T
Array([[1, 3],
       [2, 4]], dtype=int32)

Methods

def all(self, boolean_tensor, axis=None, keepdims=False)
Expand source code
def all(self, boolean_tensor, axis=None, keepdims=False):
    if isinstance(boolean_tensor, (tuple, list)):
        boolean_tensor = jnp.stack(boolean_tensor)
    return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims)
def allocate_on_device(self, tensor: ~TensorType, device: phiml.backend._backend.ComputeDevice) ‑> ~TensorType
Expand source code
def allocate_on_device(self, tensor: TensorType, device: ComputeDevice) -> TensorType:
    return jax.device_put(tensor, device.ref)

Moves tensor to device. May copy the tensor if it is already on the device.

Args

tensor
Existing tensor native to this backend.
device
Target device, associated with this backend.
def any(self, boolean_tensor, axis=None, keepdims=False)
Expand source code
def any(self, boolean_tensor, axis=None, keepdims=False):
    if isinstance(boolean_tensor, (tuple, list)):
        boolean_tensor = jnp.stack(boolean_tensor)
    return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims)
def argsort(self, x, axis=-1)
Expand source code
def argsort(self, x, axis=-1):
    return jnp.argsort(x, axis)
def as_tensor(self, x, convert_external=True)
Expand source code
def as_tensor(self, x, convert_external=True):
    self._check_float64()
    if self.is_tensor(x, only_native=convert_external):
        array = x
    else:
        array = jnp.array(x)
    # --- Enforce Precision ---
    if not isinstance(array, numbers.Number):
        if self.dtype(array).kind == float:
            array = self.to_float(array)
        elif self.dtype(array).kind == complex:
            array = self.to_complex(array)
    return array

Converts a tensor-like object to the native tensor representation of this backend. If x is a native tensor of this backend, it is returned without modification. If x is a Python number (numbers.Number instance), convert_numbers decides whether to convert it unless the backend cannot handle Python numbers.

Note: There may be objects that are considered tensors by this backend but are not native and thus, will be converted by this method.

Args

x
tensor-like, e.g. list, tuple, Python number, tensor
convert_external
if False and x is a Python number that is understood by this backend, this method returns the number as-is. This can help prevent type clashes like int32 vs int64. (Default value = True)

Returns

tensor representation of x

def batched_gather_1d(self, values, indices)
Expand source code
def batched_gather_1d(self, values, indices):
    batch_size = combined_dim(values.shape[0], indices.shape[0])
    return values[np.arange(batch_size)[:, None], indices]

Args

values
(batch, spatial)
indices
(batch, indices)

Returns

(batch, indices)

def batched_gather_nd(self, values, indices)
Expand source code
def batched_gather_nd(self, values, indices):
    values = self.as_tensor(values)
    indices = self.as_tensor(indices)
    assert indices.shape[-1] == self.ndims(values) - 2
    batch_size = combined_dim(values.shape[0], indices.shape[0])
    indices_list = [indices[..., i] for i in range(indices.shape[-1])]
    batch_range = self.expand_dims(np.arange(batch_size), -1, number=self.ndims(indices) - 2)
    slices = (batch_range, *indices_list)
    return values[slices]

Gathers values from the tensor values at locations indices. The first dimension of values and indices is the batch dimension which must be either equal for both or one for either.

Args

values
tensor of shape (batch, spatial…, channel)
indices
int tensor of shape (batch, any…, multi_index) where the size of multi_index is values.rank - 2.

Returns

Gathered values as tensor of shape (batch, any…, channel)

def bincount(self, x, weights: ~TensorType | None, bins: int, x_sorted=False)
Expand source code
def bincount(self, x, weights: Optional[TensorType], bins: int, x_sorted=False):
    if x_sorted:
        return jax.ops.segment_sum(weights or 1, x, bins, indices_are_sorted=True)
    else:
        return jnp.bincount(x, weights=weights, minlength=bins, length=bins)

Args

x
Bin indices, 1D int tensor.
weights
Weights corresponding to x, 1D tensor. All weights are 1 if weights=None.
bins
Number of bins.
x_sorted
Whether x is sorted from lowest to highest bin.

Returns

bin_counts

def block_until_ready(self, values)
Expand source code
def block_until_ready(self, values):
    if hasattr(values, 'block_until_ready'):
        values.block_until_ready()
    if isinstance(values, (tuple, list)):
        for v in values:
            self.block_until_ready(v)
def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0)
Expand source code
def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0):
    if new_length is None:
        slices = [mask if i == axis else slice(None) for i in range(len(x.shape))]
        return x[tuple(slices)]
    else:
        indices = jnp.argwhere(mask, size=new_length, fill_value=-1)[..., 0]
        valid = indices >= 0
        valid = valid[tuple([slice(None) if i == axis else None for i in range(len(x.shape))])]
        result = self.gather(x, jnp.maximum(0, indices), axis)
        return jnp.where(valid, result, fill_value)

Args

x
tensor with any number of dimensions
mask
1D mask tensor
axis
Axis index >= 0
new_length
Maximum size of the output along axis. This must be set when jit-compiling with Jax.
fill_value
If new_length is larger than the filtered result, the remaining values will be set to fill_value.
def cast(self, x, dtype: phiml.backend._dtype.DType)
Expand source code
def cast(self, x, dtype: DType):
    if self.is_tensor(x, only_native=True) and from_numpy_dtype(x.dtype) == dtype:
        return x
    else:
        return jnp.array(x, to_numpy_dtype(dtype))
def conv(self,
value,
kernel,
strides: Sequence[int],
out_sizes: Sequence[int],
transpose: bool)
Expand source code
def conv(self, value, kernel, strides: Sequence[int], out_sizes: Sequence[int], transpose: bool):
    assert not transpose, "transpose conv not yet supported for Jax"
    assert kernel.shape[0] in (1, value.shape[0])
    assert value.shape[1] == kernel.shape[2], f"value has {value.shape[1]} channels but kernel has {kernel.shape[2]}"
    assert value.ndim + 1 == kernel.ndim
    value, kernel = self.auto_cast(value, kernel, bool_to_int=True)
    ndim = len(value.shape) - 2  # Number of spatial dimensions
    assert len(strides) == ndim, f"Expected {ndim} stride values, got {len(strides)}"
    # --- Determine padding ---
    # default_size = [int(np.ceil((vs - ks + 1) / st)) for vs, ks, st in zip(value.shape[2:], kernel.shape[3:], strides)]  # size if no padding is used
    lr_padding = [max(0, st * (os - 1) - vs + ks) for st, os, vs, ks in zip(strides, out_sizes, value.shape[2:], kernel.shape[3:])]
    padding = [((p+1) // 2, p // 2) for p in lr_padding]
    # --- Run the (transposed) convolution ---
    sp = ''.join(['WHD'[i] for i in range(len(strides))])
    dim_num = jax.lax.conv_dimension_numbers(value.shape, kernel.shape[1:], ('NC'+sp, 'OI'+sp, 'NC'+sp))
    if kernel.shape[0] == 1:
        return jax.lax.conv_general_dilated(value, kernel[0], strides, padding, None, None, dim_num)
    else:
        result = []
        for b in range(kernel.shape[0]):
            result.append(jax.lax.conv_general_dilated(value[b:b + 1], kernel[b], strides, padding, None, None, dim_num))
        return jnp.concatenate(result, 0)

Convolve value with kernel. Depending on the tensor rank, the convolution is either 1D (rank=3), 2D (rank=4) or 3D (rank=5). Higher dimensions may not be supported.

Args

value
tensor of shape (batch_size, in_channel, spatial…)
kernel
tensor of shape (batch_size or 1, out_channel, in_channel, spatial…)
strides
Convolution strides, one int for each spatial dim. For transpose, they act as upsampling factors.
out_sizes
Spatial shape of the output tensor. This determines how much zero-padding or slicing is used.
transpose
If True, performs a transposed convolution, according to PyTorch's definition.

Returns

Convolution result as tensor of shape (batch_size, out_channel, spatial…)

def copy(self, tensor, only_mutable=False)
Expand source code
def copy(self, tensor, only_mutable=False):
    return jnp.array(tensor, copy=True)
def custom_gradient(self,
f: Callable,
gradient: Callable,
get_external_cache: Callable = None,
on_call_skipped: Callable = None) ‑> Callable
Expand source code
def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) -> Callable:
    jax_fun = jax.custom_vjp(f)  # custom vector-Jacobian product (reverse-mode differentiation)

    def forward(*x):
        y = f(*x)
        return y, (x, y)

    def backward(x_y, dy):
        x, y = x_y
        dx = gradient(x, y, dy)
        return tuple(dx)

    jax_fun.defvjp(forward, backward)
    return jax_fun

Creates a function based on f that uses a custom gradient for backprop.

Args

f
Forward function.
gradient
Function for backprop. Will be called as gradient(*d_out) to compute the gradient of f.

Returns

Function with similar signature and return values as f. However, the returned function does not support keyword arguments.

def disassemble(self, x) ‑> Tuple[Callable, Sequence[~TensorType]]
Expand source code
def disassemble(self, x) -> Tuple[Callable, Sequence[TensorType]]:
    if self.is_sparse(x):
        if isinstance(x, COO):
            raise NotImplementedError
            # return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (np.stack([x.row, x.col], -1), x.data)
        if isinstance(x, BCOO):
            return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (x.indices, x.data)
        if isinstance(x, CSR):
            raise NotImplementedError
            # return lambda b, v, i, p: b.csr_matrix(i, p, v, x.shape), (x.data, x.indices, x.indptr)
        elif isinstance(x, CSC):
            raise NotImplementedError
            # return lambda b, v, i, p: b.csc_matrix(p, i, v, x.shape), (x.data, x.indices, x.indptr)
        raise NotImplementedError
    else:
        return lambda b, t: t, (x,)

Disassemble a (sparse) tensor into its individual constituents, such as values and indices.

Args

x
Tensor

Returns

assemble
Function assemble(backend, *constituents) that reassembles x from the constituents.
constituents
Tensors contained in x.
def divide_no_nan(self, x, y)
Expand source code
def divide_no_nan(self, x, y):
    return jnp.where(y == 0, 0, x / y)
    # jnp.nan_to_num(x / y, copy=True, nan=0) covers up NaNs from before

Computes x/y but returns 0 if y=0.

def dtype(self, array) ‑> phiml.backend._dtype.DType
Expand source code
def dtype(self, array) -> DType:
    if isinstance(array, bool):
        return BOOL
    if isinstance(array, int):
        return INT32
    if isinstance(array, float):
        return FLOAT64
    if isinstance(array, complex):
        return COMPLEX128
    if not isinstance(array, jnp.ndarray):
        array = jnp.array(array)
    return from_numpy_dtype(array.dtype)
def eig(self, matrix: ~TensorType) ‑> ~TensorType
Expand source code
def eig(self, matrix: TensorType) -> TensorType:
    return jnp.linalg.eig(matrix)

Args

matrix
(batch…, n, n)

Returns

eigenvalues
(batch…, n,)
eigenvectors
(batch…, n, n)
def eigvals(self, matrix: ~TensorType) ‑> ~TensorType
Expand source code
def eigvals(self, matrix: TensorType) -> TensorType:
    return jnp.linalg.eigvals(matrix)

Args

matrix
(batch…, n, n)

Returns

eigenvalues as (batch…, n,)

def expand_dims(self, a, axis=0, number=1)
Expand source code
def expand_dims(self, a, axis=0, number=1):
    for _i in range(number):
        a = jnp.expand_dims(a, axis)
    return a
def fft(self, x, axes: tuple | list)
Expand source code
def fft(self, x, axes: Union[tuple, list]):
    x = self.to_complex(x)
    if not axes:
        return x
    if len(axes) == 1:
        return jnp.fft.fft(x, axis=axes[0]).astype(x.dtype)
    elif len(axes) == 2:
        return jnp.fft.fft2(x, axes=axes).astype(x.dtype)
    else:
        return jnp.fft.fftn(x, axes=axes).astype(x.dtype)

Computes the n-dimensional FFT along all but the first and last dimensions.

Args

x
tensor of dimension 3 or higher
axes
Along which axes to perform the FFT

Returns

Complex tensor k

def from_dlpack(self, external_array)
Expand source code
def from_dlpack(self, external_array):
    from jax import dlpack
    return dlpack.from_dlpack(external_array)
def from_external_array_dlpack(self, external_array)
Expand source code
def from_external_array_dlpack(self, external_array):
    from jax import dlpack
    return dlpack.from_dlpack(external_array)
def gamma_inc_l(self, a, x)
Expand source code
def gamma_inc_l(self, a, x):
    return scipy.special.gammainc(a, x)

Regularized lower incomplete gamma function.

def gamma_inc_u(self, a, x)
Expand source code
def gamma_inc_u(self, a, x):
    return scipy.special.gammaincc(a, x)

Regularized upper incomplete gamma function.

def gather(self, values, indices, axis: int)
Expand source code
def gather(self, values, indices, axis: int):
    slices = [indices if i == axis else slice(None) for i in range(self.ndims(values))]
    return values[tuple(slices)]

Gathers values from the tensor values at locations indices.

Args

values
tensor
indices
1D tensor
axis
Axis along which to gather slices

Returns

tensor, with size along axis being the length of indices

def get_device(self, tensor: ~TensorType) ‑> phiml.backend._backend.ComputeDevice
Expand source code
def get_device(self, tensor: TensorType) -> ComputeDevice:
    if hasattr(tensor, 'devices'):
        return self.get_device_by_ref(next(iter(tensor.devices())))
    if hasattr(tensor, 'device'):
        return self.get_device_by_ref(tensor.device())
    raise AssertionError(f"tensor {type(tensor)} has no device attribute")

Returns the device tensor is located on.

def get_diagonal(self, matrices, offset=0)
Expand source code
def get_diagonal(self, matrices, offset=0):
    result = jnp.diagonal(matrices, offset=offset, axis1=1, axis2=2)
    return jnp.transpose(result, [0, 2, 1])

Args

matrices
(batch, rows, cols, channels)
offset
0=diagonal, positive=above diagonal, negative=below diagonal

Returns

diagonal
(batch, max(rows,cols), channels)
def get_sparse_format(self, x) ‑> str
Expand source code
def get_sparse_format(self, x) -> str:
    format_names = {
        COO: 'coo',
        BCOO: 'coo',
        CSR: 'csr',
        CSC: 'csc',
    }
    return format_names.get(type(x), 'dense')

Returns lower-case format string, such as 'coo', 'csr', 'csc'

def histogram1d(self, values, weights, bin_edges)
Expand source code
def histogram1d(self, values, weights, bin_edges):
    def unbatched_hist(values, weights, bin_edges):
        hist, _ = jnp.histogram(values, bin_edges, weights=weights)
        return hist
    return jax.vmap(unbatched_hist)(values, weights, bin_edges)

Args

values
(batch, values)
bin_edges
(batch, edges)
weights
(batch, values)

Returns

(batch, edges) with dtype matching weights

def ifft(self, k, axes: tuple | list)
Expand source code
def ifft(self, k, axes: Union[tuple, list]):
    if not axes:
        return k
    if len(axes) == 1:
        return jnp.fft.ifft(k, axis=axes[0]).astype(k.dtype)
    elif len(axes) == 2:
        return jnp.fft.ifft2(k, axes=axes).astype(k.dtype)
    else:
        return jnp.fft.ifftn(k, axes=axes).astype(k.dtype)

Computes the n-dimensional inverse FFT along all but the first and last dimensions.

Args

k
tensor of dimension 3 or higher
axes
Along which axes to perform the inverse FFT

Returns

Complex tensor x

def is_available(self, tensor)
Expand source code
def is_available(self, tensor):
    if isinstance(tensor, JVPTracer):
        tensor = tensor.primal
    return not isinstance(tensor, Tracer)

Tests if the value of the tensor is known and can be read at this point. If true, numpy(tensor) must return a valid NumPy representation of the value.

Tensors are typically available when the backend operates in eager mode.

Args

tensor
backend-compatible tensor

Returns

bool

def is_module(self, obj)
Expand source code
def is_module(self, obj):
    return False

Tests if obj is of a type that is specific to this backend, e.g. a neural network. If True, this backend will be chosen for operations involving obj.

See Also: Backend.is_tensor().

Args

obj
Object to test.
def is_sparse(self, x) ‑> bool
Expand source code
def is_sparse(self, x) -> bool:
    return isinstance(x, (COO, BCOO, CSR, CSC))

Args

x
Tensor native to this Backend.
def is_tensor(self, x, only_native=False)
Expand source code
def is_tensor(self, x, only_native=False):
    if isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray):  # NumPy arrays inherit from Jax arrays
        return True
    if isinstance(x, jnp.bool_) and not isinstance(x, np.bool_):
        return True
    if self.is_sparse(x):
        return True
    # --- Above considered native ---
    if only_native:
        return False
    # --- Non-native types ---
    if isinstance(x, np.ndarray):
        return True
    if isinstance(x, np.bool_):
        return True
    if isinstance(x, (numbers.Number, bool)):
        return True
    if isinstance(x, (tuple, list)):
        return all([self.is_tensor(item, False) for item in x])
    return False

An object is considered a native tensor by a backend if no internal conversion is required by backend methods. An object is considered a tensor (nativer or otherwise) by a backend if it is not a struct (e.g. tuple, list) and all methods of the backend accept it as a tensor argument.

If True, this backend will be chosen for operations involving x.

See Also: Backend.is_module().

Args

x
object to check
only_native
If True, only accepts true native tensor representations, not Python numbers or others that are also supported as tensors (Default value = False)

Returns

bool
whether x is considered a tensor by this backend
def jacobian(self, f, wrt: tuple | list, get_output: bool, is_f_scalar: bool)
Expand source code
def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool):
    if get_output:
        jax_grad_f = jax.value_and_grad(f, argnums=wrt, has_aux=True)
        @wraps(f)
        def unwrap_outputs(*args):
            args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
            (_, output_tuple), grads = jax_grad_f(*args)
            return (*output_tuple, *[jnp.conjugate(g) for g in grads])
        return unwrap_outputs
    else:
        @wraps(f)
        def nonaux_f(*args):
            loss, output = f(*args)
            return loss
        jax_grad = jax.grad(nonaux_f, argnums=wrt, has_aux=False)
        @wraps(f)
        def call_jax_grad(*args):
            args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)]
            grads = jax_grad(*args)
            return tuple([jnp.conjugate(g) for g in grads])
        return call_jax_grad

Args

f
Function to differentiate. Returns a tuple containing (reduced_loss, output)
wrt
Argument indices for which to compute the gradient.
get_output
Whether the derivative function should return the output of f in addition to the gradient.
is_f_scalar
Whether f is guaranteed to return a scalar output.

Returns

A function g with the same arguments as f. If get_output=True, g returns a tuplecontaining the outputs of f followed by the gradients. The gradients retain the dimensions of reduced_loss in order as outer (first) dimensions.

def jit_compile(self, f: Callable) ‑> Callable
Expand source code
def jit_compile(self, f: Callable) -> Callable:
    def run_jit_f(*args):
        # print(jax.make_jaxpr(f)(*args))
        ML_LOGGER.debug(f"JaxBackend: running jit-compiled '{f.__name__}' with shapes {[self.shape(arg) for arg in args]} and dtypes {[self.dtype(arg) for arg in args]}")
        return self.as_registered.call(jit_f, *args, name=f"run jit-compiled '{f.__name__}'")

    run_jit_f.__name__ = f"Jax-Jit({f.__name__})"
    jit_f = jax.jit(f, device=self._default_device.ref)
    return run_jit_f
def linspace(self, start, stop, number)
Expand source code
def linspace(self, start, stop, number):
    self._check_float64()
    return jax.device_put(jnp.linspace(start, stop, number, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)
def linspace_without_last(self, start, stop, number)
Expand source code
def linspace_without_last(self, start, stop, number):
    self._check_float64()
    return jax.device_put(jnp.linspace(start, stop, number, endpoint=False, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref)
def log_gamma(self, x)
Expand source code
def log_gamma(self, x):
    return jax.lax.lgamma(self.to_float(x))
def matrix_rank_dense(self, matrix, hermitian=False)
Expand source code
def matrix_rank_dense(self, matrix, hermitian=False):
    try:
        return jnp.linalg.matrix_rank(matrix)
    except TypeError as err:
        if err.args[0] == "array should have 2 or fewer dimensions":  # this is a Jax bug on some distributions/versions
            warnings.warn("You are using a broken version of JAX. matrix_rank for dense matrices will fall back to NumPy.")
            return self.as_tensor(NUMPY.matrix_rank_dense(self.numpy(matrix), hermitian=hermitian))
        else:
            raise err

Args

matrix
Dense matrix of shape (batch, rows, cols)
hermitian
Whether all matrices are guaranteed to be hermitian.
def matrix_solve_least_squares(self, matrix: ~TensorType, rhs: ~TensorType) ‑> Tuple[~TensorType, ~TensorType, ~TensorType, ~TensorType]
Expand source code
def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tuple[TensorType, TensorType, TensorType, TensorType]:
    solution, residuals, rank, singular_values = lstsq_batched(matrix, rhs)
    return solution, residuals, rank, singular_values

Args

matrix
Shape (batch, vec, constraints)
rhs
Shape (batch, vec, batch_per_matrix)

Returns

solution
Solution vector of Shape (batch, constraints, batch_per_matrix)
residuals
Optional, can be None
rank
Optional, can be None
singular_values
Optional, can be None
def max(self, x, axis=None, keepdims=False)
Expand source code
def max(self, x, axis=None, keepdims=False):
    return jnp.max(x, axis, keepdims=keepdims)
def mean(self, value, axis=None, keepdims=False)
Expand source code
def mean(self, value, axis=None, keepdims=False):
    return jnp.mean(value, axis, keepdims=keepdims)
def meshgrid(self, *coordinates)
Expand source code
def meshgrid(self, *coordinates):
    self._check_float64()
    coordinates = [self.as_tensor(c) for c in coordinates]
    return [jax.device_put(c, self._default_device.ref) for c in jnp.meshgrid(*coordinates, indexing='ij')]
def min(self, x, axis=None, keepdims=False)
Expand source code
def min(self, x, axis=None, keepdims=False):
    return jnp.min(x, axis, keepdims=keepdims)
def mul(self, a, b)
Expand source code
def mul(self, a, b):
    # if scipy.sparse.issparse(a):  # TODO sparse?
    #     return a.multiply(b)
    # elif scipy.sparse.issparse(b):
    #     return b.multiply(a)
    # else:
        return Backend.mul(self, a, b)
def mul_matrix_batched_vector(self, A, b)
Expand source code
def mul_matrix_batched_vector(self, A, b):
    from jax.experimental.sparse import BCOO
    if isinstance(A, BCOO):
        return(A @ b.T).T
    return jnp.stack([A.dot(b[i]) for i in range(b.shape[0])])
def nn_library(self)
Expand source code
def nn_library(self):
    from . import stax_nets
    return stax_nets
def nonzero(self, values, length=None, fill_value=-1)
Expand source code
def nonzero(self, values, length=None, fill_value=-1):
    result = jnp.nonzero(values, size=length, fill_value=fill_value)
    return jnp.stack(result, -1)

Args

values
Tensor with only spatial dimensions
length
(Optional) Length of the resulting array. If specified, the result array will be padded with fill_value or trimmed.

Returns

non-zero multi-indices as tensor of shape (nnz/length, vector)

def numpy(self, tensor)
Expand source code
def numpy(self, tensor):
    if isinstance(tensor, JVPTracer):
        tensor = tensor.primal
    if self.is_sparse(tensor):
        assemble, parts = self.disassemble(tensor)
        return assemble(NUMPY, *[self.numpy(t) for t in parts])
    return np.array(tensor)

Returns a NumPy representation of the given tensor. If tensor is already a NumPy array, it is returned without modification.

This method raises an error if the value of the tensor is not known at this point, e.g. because it represents a node in a graph. Use is_available(tensor) to check if the value can be represented as a NumPy array.

Args

tensor
backend-compatible tensor or sparse tensor

Returns

NumPy representation of the values stored in the tensor

def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args)
Expand source code
def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args):
    if all([self.is_available(arg) for arg in args]):
        args = [self.numpy(arg) for arg in args]
        output = f(*args, **aux_args)
        result = map_structure(self.as_tensor, output)
        return result
    @dataclasses.dataclass
    class OutputTensor:
        shape: Tuple[int]
        dtype: np.dtype
    output_specs = map_structure(lambda t, s: OutputTensor(s, to_numpy_dtype(t)), output_dtypes, output_shapes)
    if hasattr(jax, 'pure_callback'):
        def aux_f(*args):
            return f(*args, **aux_args)
        return jax.pure_callback(aux_f, output_specs, *args)
    else:
        def aux_f(args):
            if isinstance(args, tuple):
                return f(*args, **aux_args)
            else:
                return f(args, **aux_args)
        from jax.experimental.host_callback import call
        return call(aux_f, args, result_shape=output_specs)

This call can be used in jit-compiled code but is not differentiable.

Args

f
Function operating on numpy arrays.
output_shapes
Single shape tuple or tuple of shapes declaring the shapes of the tensors returned by f.
output_dtypes
Single DType or tuple of DTypes declaring the dtypes of the tensors returned by f.
*args
Tensor arguments to be converted to NumPy arrays and then passed to f.
**aux_args
Keyword arguments to be passed to f without conversion.

Returns

Returned arrays of f converted to tensors.

def ones(self, shape, dtype: phiml.backend._dtype.DType = None)
Expand source code
def ones(self, shape, dtype: DType = None):
    self._check_float64()
    return jax.device_put(jnp.ones(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)
def ones_like(self, tensor)
Expand source code
def ones_like(self, tensor):
    return jax.device_put(jnp.ones_like(tensor), self._default_device.ref)
def pad(self, value, pad_width, mode='constant', constant_values=0)
Expand source code
def pad(self, value, pad_width, mode='constant', constant_values=0):
    assert mode in ('constant', 'symmetric', 'periodic', 'reflect', 'boundary'), mode
    if mode == 'constant':
        constant_values = jnp.array(constant_values, dtype=value.dtype)
        return jnp.pad(value, pad_width, 'constant', constant_values=constant_values)
    else:
        if mode in ('periodic', 'boundary'):
            mode = {'periodic': 'wrap', 'boundary': 'edge'}[mode]
        return jnp.pad(value, pad_width, mode)

Pad a tensor with values as specified by mode and constant_values.

If the mode is not supported, returns NotImplemented.

Args

value
tensor
pad_width
2D tensor specifying the number of values padded to the edges of each axis in the form [[axis 0 lower, axis 0 upper], …] including batch and component axes.
mode
constant', 'boundary', 'periodic', 'symmetric', 'reflect'
constant_values
Scalar value used for out-of-bounds points if mode='constant'. Must be a Python primitive type or scalar tensor.
mode
str: (Default value = 'constant')

Returns

padded tensor or NotImplemented

def prefers_channels_last(self) ‑> bool
Expand source code
def prefers_channels_last(self) -> bool:
    return True
def prod(self, value, axis=None)
Expand source code
def prod(self, value, axis=None):
    if not isinstance(value, jnp.ndarray):
        value = jnp.array(value)
    if value.dtype == bool:
        return jnp.all(value, axis=axis)
    return jnp.prod(value, axis=axis)
def quantile(self, x, quantiles)
Expand source code
def quantile(self, x, quantiles):
    return jnp.quantile(x, quantiles, axis=-1)

Reduces the last / inner axis of x.

Args

x
Tensor
quantiles
List or 1D tensor of quantiles to compute.

Returns

Tensor with shape (quantiles, *x.shape[:-1])

def random_normal(self, shape, dtype: phiml.backend._dtype.DType)
Expand source code
def random_normal(self, shape, dtype: DType):
    self._check_float64()
    self.rnd_key, subkey = jax.random.split(self.rnd_key)
    dtype = dtype or self.float_type
    return jax.device_put(random.normal(subkey, shape, dtype=to_numpy_dtype(dtype)), self._default_device.ref)

Float tensor of selected precision containing random values sampled from a normal distribution with mean 0 and std 1.

def random_permutations(self, permutations: int, n: int)
Expand source code
def random_permutations(self, permutations: int, n: int):
    self.rnd_key, subkey = jax.random.split(self.rnd_key)
    result = jnp.stack([jax.random.permutation(subkey, n) for _ in range(permutations)])
    return jax.device_put(result, self._default_device.ref)

Generate permutations stacked arrays of shuffled integers between 0 and n.

def random_uniform(self, shape, low, high, dtype: phiml.backend._dtype.DType | None)
Expand source code
def random_uniform(self, shape, low, high, dtype: Union[DType, None]):
    self._check_float64()
    self.rnd_key, subkey = jax.random.split(self.rnd_key)

    dtype = dtype or self.float_type
    jdt = to_numpy_dtype(dtype)
    if dtype.kind == float:
        tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt)
    elif dtype.kind == complex:
        real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType.by_precision(float, dtype.precision)))
        imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType.by_precision(float, dtype.precision)))
        return real + 1j * imag
    elif dtype.kind == int:
        tensor = random.randint(subkey, shape, low, high, dtype=jdt)
        if tensor.dtype != jdt:
            warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning)
    else:
        raise ValueError(dtype)
    return jax.device_put(tensor, self._default_device.ref)

Float tensor of selected precision containing random values in the range [0, 1)

def range(self, start, limit=None, delta=1, dtype: phiml.backend._dtype.DType = int32)
Expand source code
def range(self, start, limit=None, delta=1, dtype: DType = INT32):
    if limit is None:
        start, limit = 0, start
    return jnp.arange(start, limit, delta, to_numpy_dtype(dtype))
def ravel_multi_index(self, multi_index, shape, mode: str | int = 'undefined')
Expand source code
def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined'):
    if not self.is_available(shape):
        return Backend.ravel_multi_index(self, multi_index, shape, mode)
    mode = mode if isinstance(mode, int) else {'undefined': 'clip', 'periodic': 'wrap', 'clamp': 'clip'}[mode]
    idx_first = jnp.transpose(multi_index, (self.ndims(multi_index)-1,) + tuple(range(self.ndims(multi_index)-1)))
    result = jnp.ravel_multi_index(idx_first, shape, mode='wrap' if isinstance(mode, int) else mode)
    if isinstance(mode, int):
        outside = self.any((multi_index < 0) | (multi_index >= jnp.asarray(shape, dtype=multi_index.dtype)), -1)
        result = self.where(outside, mode, result)
    return result

Args

multi_index
(batch…, index_dim)
shape
1D tensor or tuple/list
mode
'undefined', 'periodic', 'clamp' or an int to use for all invalid indices.

Returns

Integer tensor of shape (batch…) of same dtype as multi_index.

def repeat(self, x, repeats, axis: int, new_length=None)
Expand source code
def repeat(self, x, repeats, axis: int, new_length=None):
    return jnp.repeat(x, self.as_tensor(repeats), axis, total_repeat_length=new_length)

Repeats the elements along axis repeats times.

Args

x
Tensor
repeats
How often to repeat each element. 1D tensor of length x.shape[axis]
axis
Which axis to repeat elements along
new_length
Set the length of axis after repeating. This is required for jit compilation with Jax.

Returns

repeated Tensor

def requires_fixed_shapes_when_tracing(self) ‑> bool
Expand source code
def requires_fixed_shapes_when_tracing(self) -> bool:
    return True
def reshape(self, value, shape)
Expand source code
def reshape(self, value, shape):
    return jnp.reshape(value, shape)
def scatter(self, base_grid, indices, values, mode: str)
Expand source code
def scatter(self, base_grid, indices, values, mode: str):
    out_kind = self.combine_types(self.dtype(base_grid), self.dtype(values)).kind
    base_grid, values = self.auto_cast(base_grid, values, bool_to_int=True)
    batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0])
    spatial_dims = tuple(range(base_grid.ndim - 2))
    dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,),  # channel dim of updates (batch dim removed)
                                            inserted_window_dims=spatial_dims,  # no idea what this does but spatial_dims seems to work
                                            scatter_dims_to_operand_dims=spatial_dims)  # spatial dims of base_grid (batch dim removed)
    scatter = {'add': jax.lax.scatter_add, 'update': jax.lax.scatter, 'max': jax.lax.scatter_max, 'min': jax.lax.scatter_min}[mode]
    def scatter_single(base_grid, indices, values):
        return scatter(base_grid, indices, values, dnums)
    if self.staticshape(indices)[0] == 1:
        indices = self.tile(indices, [batch_size, 1, 1])
    if self.staticshape(values)[0] == 1:
        values = self.tile(values, [batch_size, 1, 1])
    result = self.vectorized_call(scatter_single, base_grid, indices, values)
    if self.dtype(result).kind != out_kind:
        if out_kind == bool:
            result = self.cast(result, BOOL)
    return result

Batched n-dimensional scatter.

Args

base_grid
Tensor into which scatter values are inserted at indices. Tensor of shape (batch_size, spatial…, channels)
indices
Tensor of shape (batch_size or 1, update_count, index_vector)
values
Values to scatter at indices. Tensor of shape (batch_size or 1, update_count or 1, channels or 1)
mode
One of ('update', 'add', 'max', 'min')

Returns

Copy of base_grid with values at indices updated by values.

def searchsorted(self, sorted_sequence, search_values, side: str, dtype=int32)
Expand source code
def searchsorted(self, sorted_sequence, search_values, side: str, dtype=INT32):
    if self.ndims(sorted_sequence) == 1:
        return jnp.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype))
    else:
        return jax.vmap(partial(self.searchsorted, side=side, dtype=dtype))(sorted_sequence, search_values)
def seed(self, seed: int)
Expand source code
def seed(self, seed: int):
    self.rnd_key = jax.random.PRNGKey(seed)
def sizeof(self, tensor) ‑> int
Expand source code
def sizeof(self, tensor) -> int:
    return tensor.nbytes

Returns the size in bytes

def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool)
Expand source code
def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool):
    matrix, rhs = self.auto_cast(matrix, rhs, int_to_float=True, bool_to_int=True)
    x = jax.lax.linalg.triangular_solve(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal, left_side=True)
    return x

Args

matrix
(batch_size, rows, cols)
rhs
(batch_size, cols)

lower: unit_diagonal:

Returns

(batch_size, cols)

def sort(self, x, axis=-1)
Expand source code
def sort(self, x, axis=-1):
    return jnp.sort(x, axis)
def sparse_coo_tensor(self, indices: tuple | list, values, shape: tuple)
Expand source code
def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple):
    return BCOO((values, indices), shape=shape)

Create a sparse matrix in coordinate list (COO) format.

Optional feature.

See Also: Backend.csr_matrix(), Backend.csc_matrix().

Args

indices
2D tensor of shape (nnz, dims).
values
1D values tensor matching indices
shape
Shape of the sparse matrix

Returns

Native representation of the sparse matrix

def std(self, x, axis=None, keepdims=False)
Expand source code
def std(self, x, axis=None, keepdims=False):
    return jnp.std(x, axis, keepdims=keepdims)
def sum(self, value, axis=None, keepdims=False)
Expand source code
def sum(self, value, axis=None, keepdims=False):
    if isinstance(value, (tuple, list)):
        assert axis == 0
        return sum(value[1:], value[0])
    return jnp.sum(value, axis=axis, keepdims=keepdims)
def svd(self, matrix: ~TensorType, full_matrices=True) ‑> Tuple[~TensorType, ~TensorType, ~TensorType]
Expand source code
def svd(self, matrix: TensorType, full_matrices=True) -> Tuple[TensorType, TensorType, TensorType]:
    result = jnp.linalg.svd(matrix, full_matrices=full_matrices)
    return result[0], result[1], result[2]

Args

matrix
(batch…, m, n)

Returns

eigenvalues
(batch…, n,)
eigenvectors
(batch…, n, n)
def tensordot(self, a, a_axes: tuple | list, b, b_axes: tuple | list)
Expand source code
def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list]):
    return jnp.tensordot(a, b, (a_axes, b_axes))

Multiply-sum-reduce a_axes of a with b_axes of b.

def to_dlpack(self, tensor)
Expand source code
def to_dlpack(self, tensor):
    if version.parse(jax.__version__) < version.parse("0.7.0"):
        from jax import dlpack
        return dlpack.to_dlpack(tensor)
    else:
        return tensor.__dlpack__()
def unique(self, x: ~TensorType, return_inverse: bool, return_counts: bool, axis: int) ‑> Tuple[~TensorType, ...]
Expand source code
def unique(self, x: TensorType, return_inverse: bool, return_counts: bool, axis: int) -> Tuple[TensorType, ...]:
    return jnp.unique(x, return_inverse=return_inverse, return_counts=return_counts, axis=axis)

Args

x
n-dimensional int array. Will compare axis-slices of x for multidimensional x.
return_inverse
Whether to return the inverse
return_counts
Whether to return the counts.
axis
Axis along which slices of x should be compared.

Returns

unique_slices
Sorted unique slices of x
unique_inverse
(optional) index of the unique slice for each slice of x
unique_counts
Number of occurrences of each unique slices
def unravel_index(self, flat_index, shape)
Expand source code
def unravel_index(self, flat_index, shape):
    return jnp.stack(jnp.unravel_index(flat_index, shape), -1)
def vectorized_call(self, f, *args, output_dtypes=None, **aux_args)
Expand source code
def vectorized_call(self, f, *args, output_dtypes=None, **aux_args):
    batch_size = self.determine_size(args, 0)
    args = [self.tile_to(t, 0, batch_size) for t in args]
    def f_positional(*args):
        return f(*args, **aux_args)
    vec_f = jax.vmap(f_positional, 0, 0)
    return vec_f(*args)

Args

f
Function with only positional tensor argument, returning one or multiple tensors.
*args
Batched inputs for f. The first dimension of all args is vectorized. All tensors in args must have the same size or 1 in their first dimension.
output_dtypes
Single DType or tuple of DTypes declaring the dtypes of the tensors returned by f.
**aux_args
Non-vectorized keyword arguments to be passed to f.
def where(self, condition, x=None, y=None)
Expand source code
def where(self, condition, x=None, y=None):
    if x is None or y is None:
        return jnp.argwhere(condition)
    return jnp.where(condition, x, y)
def while_loop(self, loop: Callable, values: tuple, max_iter: int | Tuple[int, ...] | List[int])
Expand source code
def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]]):
    if all(self.is_available(t) for t in values):
        return self.stop_gradient_tree(Backend.while_loop(self, loop, values, max_iter))
    if isinstance(max_iter, (tuple, list)):  # stack traced trajectory, unroll until max_iter
        values = self.stop_gradient_tree(values)
        trj = [values] if 0 in max_iter else []
        for i in range(1, max(max_iter) + 1):
            values = loop(*values)
            if i in max_iter:
                trj.append(values)  # values are not mutable so no need to copy
        return self.stop_gradient_tree(self.stack_leaves(trj))
    else:
        if max_iter is None:
            cond = lambda vals: jnp.any(vals[0])
            body = lambda vals: loop(*vals)
            return jax.lax.while_loop(cond, body, values)
        else:
            cond = lambda vals: jnp.any(vals[1][0]) & (vals[0] < max_iter)
            body = lambda vals: (vals[0] + 1, loop(*vals[1]))
            return jax.lax.while_loop(cond, body, (self.as_tensor(0), values))[1]

If max_iter is None, runs

while any(values[0]):
    values = loop(*values)
return values

This operation does not support backpropagation.

Args

loop
Loop function, must return a tuple with entries equal to values in shape and data type.
values
Initial values of loop variables.
max_iter
Maximum number of iterations to run, single int or sequence of integers.

Returns

Loop variables upon loop completion if max_iter is a single integer. If max_iter is a sequence, stacks the variables after each entry in max_iter, adding an outer dimension of size <= len(max_iter). If the condition is fulfilled before the maximum max_iter is reached, the loop may be broken or not, depending on the implementation. If the loop is broken, the values returned by the last loop are expected to be constant and filled.

def zeros(self, shape, dtype: phiml.backend._dtype.DType = None)
Expand source code
def zeros(self, shape, dtype: DType = None):
    self._check_float64()
    return jax.device_put(jnp.zeros(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref)
def zeros_like(self, tensor)
Expand source code
def zeros_like(self, tensor):
    return jax.device_put(jnp.zeros_like(tensor), self._default_device.ref)