Module phiml.backend.jax
Sub-modules
phiml.backend.jax.stax_nets-
Stax implementation of the unified machine learning API. Equivalent functions also exist for the other frameworks …
Classes
class JaxBackend-
Expand source code
class JaxBackend(Backend): def __init__(self): devices = [] for device_type in ['cpu', 'gpu', 'tpu']: try: for jax_dev in jax.devices(device_type): devices.append(ComputeDevice(self, device_type.upper(), jax_dev.platform.upper(), -1, -1, f"id={jax_dev.id}", jax_dev)) except RuntimeError as err: pass # this is just Jax not finding anything. jaxlib.xla_client._get_local_backends() could help but isn't currently available on GitHub actions Backend.__init__(self, 'jax', devices, devices[-1]) try: self.rnd_key = jax.random.PRNGKey(seed=0) except RuntimeError as err: warnings.warn(f"{err}", RuntimeWarning) self.rnd_key = None def prefers_channels_last(self) -> bool: return True def requires_fixed_shapes_when_tracing(self) -> bool: return True def nn_library(self): from . import stax_nets return stax_nets def _check_float64(self): if self.precision == 64: if not jax.config.read('jax_enable_x64'): jax.config.update('jax_enable_x64', True) assert jax.config.read('jax_enable_x64'), "FP64 is disabled for Jax." def seed(self, seed: int): self.rnd_key = jax.random.PRNGKey(seed) def as_tensor(self, x, convert_external=True): self._check_float64() if self.is_tensor(x, only_native=convert_external): array = x else: array = jnp.array(x) # --- Enforce Precision --- if not isinstance(array, numbers.Number): if self.dtype(array).kind == float: array = self.to_float(array) elif self.dtype(array).kind == complex: array = self.to_complex(array) return array def is_module(self, obj): return False def is_tensor(self, x, only_native=False): if isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray): # NumPy arrays inherit from Jax arrays return True if isinstance(x, jnp.bool_) and not isinstance(x, np.bool_): return True if self.is_sparse(x): return True # --- Above considered native --- if only_native: return False # --- Non-native types --- if isinstance(x, np.ndarray): return True if isinstance(x, np.bool_): return True if isinstance(x, (numbers.Number, bool)): return True if isinstance(x, (tuple, list)): return all([self.is_tensor(item, False) for item in x]) return False def sizeof(self, tensor) -> int: return tensor.nbytes def is_sparse(self, x) -> bool: return isinstance(x, (COO, BCOO, CSR, CSC)) def get_sparse_format(self, x) -> str: format_names = { COO: 'coo', BCOO: 'coo', CSR: 'csr', CSC: 'csc', } return format_names.get(type(x), 'dense') def is_available(self, tensor): if isinstance(tensor, JVPTracer): tensor = tensor.primal return not isinstance(tensor, Tracer) def numpy(self, tensor): if isinstance(tensor, JVPTracer): tensor = tensor.primal if self.is_sparse(tensor): assemble, parts = self.disassemble(tensor) return assemble(NUMPY, *[self.numpy(t) for t in parts]) return np.array(tensor) def disassemble(self, x) -> Tuple[Callable, Sequence[TensorType]]: if self.is_sparse(x): if isinstance(x, COO): raise NotImplementedError # return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (np.stack([x.row, x.col], -1), x.data) if isinstance(x, BCOO): return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (x.indices, x.data) if isinstance(x, CSR): raise NotImplementedError # return lambda b, v, i, p: b.csr_matrix(i, p, v, x.shape), (x.data, x.indices, x.indptr) elif isinstance(x, CSC): raise NotImplementedError # return lambda b, v, i, p: b.csc_matrix(p, i, v, x.shape), (x.data, x.indices, x.indptr) raise NotImplementedError else: return lambda b, t: t, (x,) def to_dlpack(self, tensor): if version.parse(jax.__version__) < version.parse("0.7.0"): from jax import dlpack return dlpack.to_dlpack(tensor) else: return tensor.__dlpack__() def from_dlpack(self, external_array): from jax import dlpack return dlpack.from_dlpack(external_array) if version.parse(jax.__version__) >= version.parse("0.7.2"): def from_external_array_dlpack(self, external_array): from jax import dlpack return dlpack.from_dlpack(external_array) def copy(self, tensor, only_mutable=False): return jnp.array(tensor, copy=True) def get_device(self, tensor: TensorType) -> ComputeDevice: if hasattr(tensor, 'devices'): return self.get_device_by_ref(next(iter(tensor.devices()))) if hasattr(tensor, 'device'): return self.get_device_by_ref(tensor.device()) raise AssertionError(f"tensor {type(tensor)} has no device attribute") def allocate_on_device(self, tensor: TensorType, device: ComputeDevice) -> TensorType: return jax.device_put(tensor, device.ref) sqrt = staticmethod(jnp.sqrt) exp = staticmethod(jnp.exp) erf = staticmethod(scipy.special.erf) softplus = staticmethod(jax.nn.softplus) sin = staticmethod(jnp.sin) arcsin = staticmethod(jnp.arcsin) cos = staticmethod(jnp.cos) arccos = staticmethod(jnp.arccos) tan = staticmethod(jnp.tan) arctan = staticmethod(np.arctan) arctan2 = staticmethod(np.arctan2) sinh = staticmethod(np.sinh) arcsinh = staticmethod(np.arcsinh) cosh = staticmethod(np.cosh) arccosh = staticmethod(np.arccosh) tanh = staticmethod(np.tanh) arctanh = staticmethod(np.arctanh) log = staticmethod(jnp.log) log2 = staticmethod(jnp.log2) log10 = staticmethod(jnp.log10) isfinite = staticmethod(jnp.isfinite) isnan = staticmethod(jnp.isnan) isinf = staticmethod(jnp.isinf) abs = staticmethod(jnp.abs) sign = staticmethod(jnp.sign) round = staticmethod(jnp.round) ceil = staticmethod(jnp.ceil) floor = staticmethod(jnp.floor) flip = staticmethod(jnp.flip) stop_gradient = staticmethod(jax.lax.stop_gradient) transpose = staticmethod(jnp.transpose) equal = staticmethod(jnp.equal) tile = staticmethod(jnp.tile) stack = staticmethod(jnp.stack) concat = staticmethod(jnp.concatenate) maximum = staticmethod(jnp.maximum) minimum = staticmethod(jnp.minimum) clip = staticmethod(jnp.clip) argmax = staticmethod(np.argmax) argmin = staticmethod(np.argmin) shape = staticmethod(jnp.shape) staticshape = staticmethod(jnp.shape) imag = staticmethod(jnp.imag) real = staticmethod(jnp.real) conj = staticmethod(jnp.conjugate) einsum = staticmethod(jnp.einsum) cumsum = staticmethod(jnp.cumsum) def nonzero(self, values, length=None, fill_value=-1): result = jnp.nonzero(values, size=length, fill_value=fill_value) return jnp.stack(result, -1) def vectorized_call(self, f, *args, output_dtypes=None, **aux_args): batch_size = self.determine_size(args, 0) args = [self.tile_to(t, 0, batch_size) for t in args] def f_positional(*args): return f(*args, **aux_args) vec_f = jax.vmap(f_positional, 0, 0) return vec_f(*args) def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args): if all([self.is_available(arg) for arg in args]): args = [self.numpy(arg) for arg in args] output = f(*args, **aux_args) result = map_structure(self.as_tensor, output) return result @dataclasses.dataclass class OutputTensor: shape: Tuple[int] dtype: np.dtype output_specs = map_structure(lambda t, s: OutputTensor(s, to_numpy_dtype(t)), output_dtypes, output_shapes) if hasattr(jax, 'pure_callback'): def aux_f(*args): return f(*args, **aux_args) return jax.pure_callback(aux_f, output_specs, *args) else: def aux_f(args): if isinstance(args, tuple): return f(*args, **aux_args) else: return f(args, **aux_args) from jax.experimental.host_callback import call return call(aux_f, args, result_shape=output_specs) def jit_compile(self, f: Callable) -> Callable: def run_jit_f(*args): # print(jax.make_jaxpr(f)(*args)) ML_LOGGER.debug(f"JaxBackend: running jit-compiled '{f.__name__}' with shapes {[self.shape(arg) for arg in args]} and dtypes {[self.dtype(arg) for arg in args]}") return self.as_registered.call(jit_f, *args, name=f"run jit-compiled '{f.__name__}'") run_jit_f.__name__ = f"Jax-Jit({f.__name__})" jit_f = jax.jit(f, device=self._default_device.ref) return run_jit_f def block_until_ready(self, values): if hasattr(values, 'block_until_ready'): values.block_until_ready() if isinstance(values, (tuple, list)): for v in values: self.block_until_ready(v) def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool): if get_output: jax_grad_f = jax.value_and_grad(f, argnums=wrt, has_aux=True) @wraps(f) def unwrap_outputs(*args): args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)] (_, output_tuple), grads = jax_grad_f(*args) return (*output_tuple, *[jnp.conjugate(g) for g in grads]) return unwrap_outputs else: @wraps(f) def nonaux_f(*args): loss, output = f(*args) return loss jax_grad = jax.grad(nonaux_f, argnums=wrt, has_aux=False) @wraps(f) def call_jax_grad(*args): args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)] grads = jax_grad(*args) return tuple([jnp.conjugate(g) for g in grads]) return call_jax_grad def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) -> Callable: jax_fun = jax.custom_vjp(f) # custom vector-Jacobian product (reverse-mode differentiation) def forward(*x): y = f(*x) return y, (x, y) def backward(x_y, dy): x, y = x_y dx = gradient(x, y, dy) return tuple(dx) jax_fun.defvjp(forward, backward) return jax_fun def divide_no_nan(self, x, y): return jnp.where(y == 0, 0, x / y) # jnp.nan_to_num(x / y, copy=True, nan=0) covers up NaNs from before def random_uniform(self, shape, low, high, dtype: Union[DType, None]): self._check_float64() self.rnd_key, subkey = jax.random.split(self.rnd_key) dtype = dtype or self.float_type jdt = to_numpy_dtype(dtype) if dtype.kind == float: tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt) elif dtype.kind == complex: real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType.by_precision(float, dtype.precision))) imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType.by_precision(float, dtype.precision))) return real + 1j * imag elif dtype.kind == int: tensor = random.randint(subkey, shape, low, high, dtype=jdt) if tensor.dtype != jdt: warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning) else: raise ValueError(dtype) return jax.device_put(tensor, self._default_device.ref) def random_normal(self, shape, dtype: DType): self._check_float64() self.rnd_key, subkey = jax.random.split(self.rnd_key) dtype = dtype or self.float_type return jax.device_put(random.normal(subkey, shape, dtype=to_numpy_dtype(dtype)), self._default_device.ref) def random_permutations(self, permutations: int, n: int): self.rnd_key, subkey = jax.random.split(self.rnd_key) result = jnp.stack([jax.random.permutation(subkey, n) for _ in range(permutations)]) return jax.device_put(result, self._default_device.ref) def range(self, start, limit=None, delta=1, dtype: DType = INT32): if limit is None: start, limit = 0, start return jnp.arange(start, limit, delta, to_numpy_dtype(dtype)) def pad(self, value, pad_width, mode='constant', constant_values=0): assert mode in ('constant', 'symmetric', 'periodic', 'reflect', 'boundary'), mode if mode == 'constant': constant_values = jnp.array(constant_values, dtype=value.dtype) return jnp.pad(value, pad_width, 'constant', constant_values=constant_values) else: if mode in ('periodic', 'boundary'): mode = {'periodic': 'wrap', 'boundary': 'edge'}[mode] return jnp.pad(value, pad_width, mode) def reshape(self, value, shape): return jnp.reshape(value, shape) def sum(self, value, axis=None, keepdims=False): if isinstance(value, (tuple, list)): assert axis == 0 return sum(value[1:], value[0]) return jnp.sum(value, axis=axis, keepdims=keepdims) def prod(self, value, axis=None): if not isinstance(value, jnp.ndarray): value = jnp.array(value) if value.dtype == bool: return jnp.all(value, axis=axis) return jnp.prod(value, axis=axis) def where(self, condition, x=None, y=None): if x is None or y is None: return jnp.argwhere(condition) return jnp.where(condition, x, y) def zeros(self, shape, dtype: DType = None): self._check_float64() return jax.device_put(jnp.zeros(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref) def zeros_like(self, tensor): return jax.device_put(jnp.zeros_like(tensor), self._default_device.ref) def ones(self, shape, dtype: DType = None): self._check_float64() return jax.device_put(jnp.ones(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref) def ones_like(self, tensor): return jax.device_put(jnp.ones_like(tensor), self._default_device.ref) def meshgrid(self, *coordinates): self._check_float64() coordinates = [self.as_tensor(c) for c in coordinates] return [jax.device_put(c, self._default_device.ref) for c in jnp.meshgrid(*coordinates, indexing='ij')] def linspace(self, start, stop, number): self._check_float64() return jax.device_put(jnp.linspace(start, stop, number, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref) def linspace_without_last(self, start, stop, number): self._check_float64() return jax.device_put(jnp.linspace(start, stop, number, endpoint=False, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref) def mean(self, value, axis=None, keepdims=False): return jnp.mean(value, axis, keepdims=keepdims) def log_gamma(self, x): return jax.lax.lgamma(self.to_float(x)) def gamma_inc_l(self, a, x): return scipy.special.gammainc(a, x) def gamma_inc_u(self, a, x): return scipy.special.gammaincc(a, x) def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list]): return jnp.tensordot(a, b, (a_axes, b_axes)) def mul(self, a, b): # if scipy.sparse.issparse(a): # TODO sparse? # return a.multiply(b) # elif scipy.sparse.issparse(b): # return b.multiply(a) # else: return Backend.mul(self, a, b) def mul_matrix_batched_vector(self, A, b): from jax.experimental.sparse import BCOO if isinstance(A, BCOO): return(A @ b.T).T return jnp.stack([A.dot(b[i]) for i in range(b.shape[0])]) def get_diagonal(self, matrices, offset=0): result = jnp.diagonal(matrices, offset=offset, axis1=1, axis2=2) return jnp.transpose(result, [0, 2, 1]) def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]]): if all(self.is_available(t) for t in values): return self.stop_gradient_tree(Backend.while_loop(self, loop, values, max_iter)) if isinstance(max_iter, (tuple, list)): # stack traced trajectory, unroll until max_iter values = self.stop_gradient_tree(values) trj = [values] if 0 in max_iter else [] for i in range(1, max(max_iter) + 1): values = loop(*values) if i in max_iter: trj.append(values) # values are not mutable so no need to copy return self.stop_gradient_tree(self.stack_leaves(trj)) else: if max_iter is None: cond = lambda vals: jnp.any(vals[0]) body = lambda vals: loop(*vals) return jax.lax.while_loop(cond, body, values) else: cond = lambda vals: jnp.any(vals[1][0]) & (vals[0] < max_iter) body = lambda vals: (vals[0] + 1, loop(*vals[1])) return jax.lax.while_loop(cond, body, (self.as_tensor(0), values))[1] def max(self, x, axis=None, keepdims=False): return jnp.max(x, axis, keepdims=keepdims) def min(self, x, axis=None, keepdims=False): return jnp.min(x, axis, keepdims=keepdims) def conv(self, value, kernel, strides: Sequence[int], out_sizes: Sequence[int], transpose: bool): assert not transpose, "transpose conv not yet supported for Jax" assert kernel.shape[0] in (1, value.shape[0]) assert value.shape[1] == kernel.shape[2], f"value has {value.shape[1]} channels but kernel has {kernel.shape[2]}" assert value.ndim + 1 == kernel.ndim value, kernel = self.auto_cast(value, kernel, bool_to_int=True) ndim = len(value.shape) - 2 # Number of spatial dimensions assert len(strides) == ndim, f"Expected {ndim} stride values, got {len(strides)}" # --- Determine padding --- # default_size = [int(np.ceil((vs - ks + 1) / st)) for vs, ks, st in zip(value.shape[2:], kernel.shape[3:], strides)] # size if no padding is used lr_padding = [max(0, st * (os - 1) - vs + ks) for st, os, vs, ks in zip(strides, out_sizes, value.shape[2:], kernel.shape[3:])] padding = [((p+1) // 2, p // 2) for p in lr_padding] # --- Run the (transposed) convolution --- sp = ''.join(['WHD'[i] for i in range(len(strides))]) dim_num = jax.lax.conv_dimension_numbers(value.shape, kernel.shape[1:], ('NC'+sp, 'OI'+sp, 'NC'+sp)) if kernel.shape[0] == 1: return jax.lax.conv_general_dilated(value, kernel[0], strides, padding, None, None, dim_num) else: result = [] for b in range(kernel.shape[0]): result.append(jax.lax.conv_general_dilated(value[b:b + 1], kernel[b], strides, padding, None, None, dim_num)) return jnp.concatenate(result, 0) def expand_dims(self, a, axis=0, number=1): for _i in range(number): a = jnp.expand_dims(a, axis) return a def cast(self, x, dtype: DType): if self.is_tensor(x, only_native=True) and from_numpy_dtype(x.dtype) == dtype: return x else: return jnp.array(x, to_numpy_dtype(dtype)) def unravel_index(self, flat_index, shape): return jnp.stack(jnp.unravel_index(flat_index, shape), -1) def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined'): if not self.is_available(shape): return Backend.ravel_multi_index(self, multi_index, shape, mode) mode = mode if isinstance(mode, int) else {'undefined': 'clip', 'periodic': 'wrap', 'clamp': 'clip'}[mode] idx_first = jnp.transpose(multi_index, (self.ndims(multi_index)-1,) + tuple(range(self.ndims(multi_index)-1))) result = jnp.ravel_multi_index(idx_first, shape, mode='wrap' if isinstance(mode, int) else mode) if isinstance(mode, int): outside = self.any((multi_index < 0) | (multi_index >= jnp.asarray(shape, dtype=multi_index.dtype)), -1) result = self.where(outside, mode, result) return result def gather(self, values, indices, axis: int): slices = [indices if i == axis else slice(None) for i in range(self.ndims(values))] return values[tuple(slices)] def batched_gather_nd(self, values, indices): values = self.as_tensor(values) indices = self.as_tensor(indices) assert indices.shape[-1] == self.ndims(values) - 2 batch_size = combined_dim(values.shape[0], indices.shape[0]) indices_list = [indices[..., i] for i in range(indices.shape[-1])] batch_range = self.expand_dims(np.arange(batch_size), -1, number=self.ndims(indices) - 2) slices = (batch_range, *indices_list) return values[slices] def batched_gather_1d(self, values, indices): batch_size = combined_dim(values.shape[0], indices.shape[0]) return values[np.arange(batch_size)[:, None], indices] def repeat(self, x, repeats, axis: int, new_length=None): return jnp.repeat(x, self.as_tensor(repeats), axis, total_repeat_length=new_length) def std(self, x, axis=None, keepdims=False): return jnp.std(x, axis, keepdims=keepdims) def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0): if new_length is None: slices = [mask if i == axis else slice(None) for i in range(len(x.shape))] return x[tuple(slices)] else: indices = jnp.argwhere(mask, size=new_length, fill_value=-1)[..., 0] valid = indices >= 0 valid = valid[tuple([slice(None) if i == axis else None for i in range(len(x.shape))])] result = self.gather(x, jnp.maximum(0, indices), axis) return jnp.where(valid, result, fill_value) def any(self, boolean_tensor, axis=None, keepdims=False): if isinstance(boolean_tensor, (tuple, list)): boolean_tensor = jnp.stack(boolean_tensor) return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims) def all(self, boolean_tensor, axis=None, keepdims=False): if isinstance(boolean_tensor, (tuple, list)): boolean_tensor = jnp.stack(boolean_tensor) return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims) def scatter(self, base_grid, indices, values, mode: str): out_kind = self.combine_types(self.dtype(base_grid), self.dtype(values)).kind base_grid, values = self.auto_cast(base_grid, values, bool_to_int=True) batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0]) spatial_dims = tuple(range(base_grid.ndim - 2)) dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,), # channel dim of updates (batch dim removed) inserted_window_dims=spatial_dims, # no idea what this does but spatial_dims seems to work scatter_dims_to_operand_dims=spatial_dims) # spatial dims of base_grid (batch dim removed) scatter = {'add': jax.lax.scatter_add, 'update': jax.lax.scatter, 'max': jax.lax.scatter_max, 'min': jax.lax.scatter_min}[mode] def scatter_single(base_grid, indices, values): return scatter(base_grid, indices, values, dnums) if self.staticshape(indices)[0] == 1: indices = self.tile(indices, [batch_size, 1, 1]) if self.staticshape(values)[0] == 1: values = self.tile(values, [batch_size, 1, 1]) result = self.vectorized_call(scatter_single, base_grid, indices, values) if self.dtype(result).kind != out_kind: if out_kind == bool: result = self.cast(result, BOOL) return result def histogram1d(self, values, weights, bin_edges): def unbatched_hist(values, weights, bin_edges): hist, _ = jnp.histogram(values, bin_edges, weights=weights) return hist return jax.vmap(unbatched_hist)(values, weights, bin_edges) def bincount(self, x, weights: Optional[TensorType], bins: int, x_sorted=False): if x_sorted: return jax.ops.segment_sum(weights or 1, x, bins, indices_are_sorted=True) else: return jnp.bincount(x, weights=weights, minlength=bins, length=bins) def unique(self, x: TensorType, return_inverse: bool, return_counts: bool, axis: int) -> Tuple[TensorType, ...]: return jnp.unique(x, return_inverse=return_inverse, return_counts=return_counts, axis=axis) def quantile(self, x, quantiles): return jnp.quantile(x, quantiles, axis=-1) def argsort(self, x, axis=-1): return jnp.argsort(x, axis) def sort(self, x, axis=-1): return jnp.sort(x, axis) def searchsorted(self, sorted_sequence, search_values, side: str, dtype=INT32): if self.ndims(sorted_sequence) == 1: return jnp.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype)) else: return jax.vmap(partial(self.searchsorted, side=side, dtype=dtype))(sorted_sequence, search_values) def fft(self, x, axes: Union[tuple, list]): x = self.to_complex(x) if not axes: return x if len(axes) == 1: return jnp.fft.fft(x, axis=axes[0]).astype(x.dtype) elif len(axes) == 2: return jnp.fft.fft2(x, axes=axes).astype(x.dtype) else: return jnp.fft.fftn(x, axes=axes).astype(x.dtype) def ifft(self, k, axes: Union[tuple, list]): if not axes: return k if len(axes) == 1: return jnp.fft.ifft(k, axis=axes[0]).astype(k.dtype) elif len(axes) == 2: return jnp.fft.ifft2(k, axes=axes).astype(k.dtype) else: return jnp.fft.ifftn(k, axes=axes).astype(k.dtype) def dtype(self, array) -> DType: if isinstance(array, bool): return BOOL if isinstance(array, int): return INT32 if isinstance(array, float): return FLOAT64 if isinstance(array, complex): return COMPLEX128 if not isinstance(array, jnp.ndarray): array = jnp.array(array) return from_numpy_dtype(array.dtype) def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tuple[TensorType, TensorType, TensorType, TensorType]: solution, residuals, rank, singular_values = lstsq_batched(matrix, rhs) return solution, residuals, rank, singular_values def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool): matrix, rhs = self.auto_cast(matrix, rhs, int_to_float=True, bool_to_int=True) x = jax.lax.linalg.triangular_solve(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal, left_side=True) return x def matrix_rank_dense(self, matrix, hermitian=False): try: return jnp.linalg.matrix_rank(matrix) except TypeError as err: if err.args[0] == "array should have 2 or fewer dimensions": # this is a Jax bug on some distributions/versions warnings.warn("You are using a broken version of JAX. matrix_rank for dense matrices will fall back to NumPy.") return self.as_tensor(NUMPY.matrix_rank_dense(self.numpy(matrix), hermitian=hermitian)) else: raise err def eigvals(self, matrix: TensorType) -> TensorType: return jnp.linalg.eigvals(matrix) def eig(self, matrix: TensorType) -> TensorType: return jnp.linalg.eig(matrix) def svd(self, matrix: TensorType, full_matrices=True) -> Tuple[TensorType, TensorType, TensorType]: result = jnp.linalg.svd(matrix, full_matrices=full_matrices) return result[0], result[1], result[2] def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple): return BCOO((values, indices), shape=shape)Backends delegate low-level operations to a ML or numerics library or emulate them. The methods of
Backendform a comprehensive list of available operations.To support a library, subclass
Backendand register it by adding it toBACKENDS.Args
name- Human-readable string
default_deviceComputeDevicebeing used by default
Ancestors
- phiml.backend._backend.Backend
Class variables
var arccoshvar arcsinhvar arctanvar arctan2var arctanhvar coshvar maximumvar minimumvar sinhvar tanh
Static methods
def abs(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def abs(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.absolute`.""" return absolute(x)Alias of :func:
jax.numpy.absolute. def arccos(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def arccos(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric cosine of input. JAX implementation of :obj:`numpy.arccos`. Args: x: input array or scalar. Returns: An array containing the inverse trigonometric cosine of each element of ``x`` in radians in the range ``[0, pi]``, promoting to inexact dtype. Note: - ``jnp.arccos`` returns ``nan`` when ``x`` is real-valued and not in the closed interval ``[-1, 1]``. - ``jnp.arccos`` follows the branch cut convention of :obj:`numpy.arccos` for complex inputs. See also: - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of input. - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of trigonometric sine of each element of input. - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of trigonometric tangent of each element of input. Examples: >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arccos(x) Array([ nan, 3.142, 2.094, 1.571, 1.047, 0. , nan], dtype=float32) For complex inputs: >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arccos(4-1j) Array(0.252+2.097j, dtype=complex64, weak_type=True) """ out = lax.acos(*promote_args_inexact('arccos', x)) jnp_error._set_error_if_nan(out) return outCompute element-wise inverse of trigonometric cosine of input.
JAX implementation of :obj:
numpy.arccos.Args
x- input array or scalar.
Returns
An array containing the inverse trigonometric cosine of each element of
xin radians in the range[0, pi], promoting to inexact dtype.Note
jnp.arccosreturnsnanwhenxis real-valued and not in the closed interval[-1, 1].jnp.arccosfollows the branch cut convention of :obj:numpy.arccosfor complex inputs.
See also: - :func:
jax.numpy.cos: Computes a trigonometric cosine of each element of input. - :func:jax.numpy.arcsinand :func:jax.numpy.asin: Computes the inverse of trigonometric sine of each element of input. - :func:jax.numpy.arctanand :func:jax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.Examples
>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arccos(x) Array([ nan, 3.142, 2.094, 1.571, 1.047, 0. , nan], dtype=float32)For complex inputs:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arccos(4-1j) Array(0.252+2.097j, dtype=complex64, weak_type=True) def arcsin(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def arcsin(x: ArrayLike, /) -> Array: r"""Compute element-wise inverse of trigonometric sine of input. JAX implementation of :obj:`numpy.arcsin`. Args: x: input array or scalar. Returns: An array containing the inverse trigonometric sine of each element of ``x`` in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype. Note: - ``jnp.arcsin`` returns ``nan`` when ``x`` is real-valued and not in the closed interval ``[-1, 1]``. - ``jnp.arcsin`` follows the branch cut convention of :obj:`numpy.arcsin` for complex inputs. See also: - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. - :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of trigonometric cosine of each element of input. - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of trigonometric tangent of each element of input. Examples: >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arcsin(x) Array([ nan, -1.571, -0.524, 0. , 0.524, 1.571, nan], dtype=float32) For complex-valued inputs: >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arcsin(3+4j) Array(0.634+2.306j, dtype=complex64, weak_type=True) """ out = lax.asin(*promote_args_inexact('arcsin', x)) jnp_error._set_error_if_nan(out) return outCompute element-wise inverse of trigonometric sine of input.
JAX implementation of :obj:
numpy.arcsin.Args
x- input array or scalar.
Returns
An array containing the inverse trigonometric sine of each element of
xin radians in the range[-pi/2, pi/2], promoting to inexact dtype.Note
jnp.arcsinreturnsnanwhenxis real-valued and not in the closed interval[-1, 1].jnp.arcsinfollows the branch cut convention of :obj:numpy.arcsinfor complex inputs.
See also: - :func:
jax.numpy.sin: Computes a trigonometric sine of each element of input. - :func:jax.numpy.arccosand :func:jax.numpy.acos: Computes the inverse of trigonometric cosine of each element of input. - :func:jax.numpy.arctanand :func:jax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.Examples
>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arcsin(x) Array([ nan, -1.571, -0.524, 0. , 0.524, 1.571, nan], dtype=float32)For complex-valued inputs:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arcsin(3+4j) Array(0.634+2.306j, dtype=complex64, weak_type=True) def argmax(a, axis=None, out=None, *, keepdims=<no value>)-
Expand source code
@array_function_dispatch(_argmax_dispatcher) def argmax(a, axis=None, out=None, *, keepdims=np._NoValue): """ Returns the indices of the maximum values along an axis. Parameters ---------- a : array_like Input array. axis : int, optional By default, the index is into the flattened array, otherwise along the specified axis. out : array, optional If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype. keepdims : bool, optional If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array. .. versionadded:: 1.22.0 Returns ------- index_array : ndarray of ints Array of indices into the array. It has the same shape as ``a.shape`` with the dimension along `axis` removed. If `keepdims` is set to True, then the size of `axis` will be 1 with the resulting array having same shape as ``a.shape``. See Also -------- ndarray.argmax, argmin amax : The maximum value along a given axis. unravel_index : Convert a flat index into an index tuple. take_along_axis : Apply ``np.expand_dims(index_array, axis)`` from argmax to an array as if by calling max. Notes ----- In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned. Examples -------- >>> import numpy as np >>> a = np.arange(6).reshape(2,3) + 10 >>> a array([[10, 11, 12], [13, 14, 15]]) >>> np.argmax(a) 5 >>> np.argmax(a, axis=0) array([1, 1, 1]) >>> np.argmax(a, axis=1) array([2, 2]) Indexes of the maximal elements of a N-dimensional array: >>> a.flat[np.argmax(a)] 15 >>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape) >>> ind (1, 2) >>> a[ind] 15 >>> b = np.arange(6) >>> b[1] = 5 >>> b array([0, 5, 2, 3, 4, 5]) >>> np.argmax(b) # Only the first occurrence is returned. 1 >>> x = np.array([[4,2,3], [1,0,3]]) >>> index_array = np.argmax(x, axis=-1) >>> # Same as np.amax(x, axis=-1, keepdims=True) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1) array([[4], [3]]) >>> # Same as np.amax(x, axis=-1) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), ... axis=-1).squeeze(axis=-1) array([4, 3]) Setting `keepdims` to `True`, >>> x = np.arange(24).reshape((2, 3, 4)) >>> res = np.argmax(x, axis=1, keepdims=True) >>> res.shape (2, 1, 4) """ kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {} return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)Returns the indices of the maximum values along an axis.
Parameters
a:array_like- Input array.
axis:int, optional- By default, the index is into the flattened array, otherwise along the specified axis.
out:array, optional- If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims:bool, optional-
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.
Added in version: 1.22.0
Returns
index_array:ndarrayofints- Array of indices into the array. It has the same shape as
a.shapewith the dimension alongaxisremoved. Ifkeepdimsis set to True, then the size ofaxiswill be 1 with the resulting array having same shape asa.shape.
See Also
ndarray.argmax,argminamax : The maximum value along a given axis.unravel_index : Convert a flat index into an index tuple.take_along_axis : Apply ``np.expand_dims(index_array,axis)`` from argmax to an array as if by calling max.Notes
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
Examples
>>> import numpy as np >>> a = np.arange(6).reshape(2,3) + 10 >>> a array([[10, 11, 12], [13, 14, 15]]) >>> np.argmax(a) 5 >>> np.argmax(a, axis=0) array([1, 1, 1]) >>> np.argmax(a, axis=1) array([2, 2])Indexes of the maximal elements of a N-dimensional array:
>>> a.flat[np.argmax(a)] 15 >>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape) >>> ind (1, 2) >>> a[ind] 15>>> b = np.arange(6) >>> b[1] = 5 >>> b array([0, 5, 2, 3, 4, 5]) >>> np.argmax(b) # Only the first occurrence is returned. 1>>> x = np.array([[4,2,3], [1,0,3]]) >>> index_array = np.argmax(x, axis=-1) >>> # Same as np.amax(x, axis=-1, keepdims=True) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1) array([[4], [3]]) >>> # Same as np.amax(x, axis=-1) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), ... axis=-1).squeeze(axis=-1) array([4, 3])Setting
keepdimstoTrue,>>> x = np.arange(24).reshape((2, 3, 4)) >>> res = np.argmax(x, axis=1, keepdims=True) >>> res.shape (2, 1, 4) def argmin(a, axis=None, out=None, *, keepdims=<no value>)-
Expand source code
@array_function_dispatch(_argmin_dispatcher) def argmin(a, axis=None, out=None, *, keepdims=np._NoValue): """ Returns the indices of the minimum values along an axis. Parameters ---------- a : array_like Input array. axis : int, optional By default, the index is into the flattened array, otherwise along the specified axis. out : array, optional If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype. keepdims : bool, optional If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array. .. versionadded:: 1.22.0 Returns ------- index_array : ndarray of ints Array of indices into the array. It has the same shape as `a.shape` with the dimension along `axis` removed. If `keepdims` is set to True, then the size of `axis` will be 1 with the resulting array having same shape as `a.shape`. See Also -------- ndarray.argmin, argmax amin : The minimum value along a given axis. unravel_index : Convert a flat index into an index tuple. take_along_axis : Apply ``np.expand_dims(index_array, axis)`` from argmin to an array as if by calling min. Notes ----- In case of multiple occurrences of the minimum values, the indices corresponding to the first occurrence are returned. Examples -------- >>> import numpy as np >>> a = np.arange(6).reshape(2,3) + 10 >>> a array([[10, 11, 12], [13, 14, 15]]) >>> np.argmin(a) 0 >>> np.argmin(a, axis=0) array([0, 0, 0]) >>> np.argmin(a, axis=1) array([0, 0]) Indices of the minimum elements of a N-dimensional array: >>> a.flat[np.argmin(a)] 10 >>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape) >>> ind (0, 0) >>> a[ind] 10 >>> b = np.arange(6) + 10 >>> b[4] = 10 >>> b array([10, 11, 12, 13, 10, 15]) >>> np.argmin(b) # Only the first occurrence is returned. 0 >>> x = np.array([[4,2,3], [1,0,3]]) >>> index_array = np.argmin(x, axis=-1) >>> # Same as np.amin(x, axis=-1, keepdims=True) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1) array([[2], [0]]) >>> # Same as np.amax(x, axis=-1) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), ... axis=-1).squeeze(axis=-1) array([2, 0]) Setting `keepdims` to `True`, >>> x = np.arange(24).reshape((2, 3, 4)) >>> res = np.argmin(x, axis=1, keepdims=True) >>> res.shape (2, 1, 4) """ kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {} return _wrapfunc(a, 'argmin', axis=axis, out=out, **kwds)Returns the indices of the minimum values along an axis.
Parameters
a:array_like- Input array.
axis:int, optional- By default, the index is into the flattened array, otherwise along the specified axis.
out:array, optional- If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims:bool, optional-
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.
Added in version: 1.22.0
Returns
index_array:ndarrayofints- Array of indices into the array. It has the same shape as
a.shapewith the dimension alongaxisremoved. Ifkeepdimsis set to True, then the size ofaxiswill be 1 with the resulting array having same shape asa.shape.
See Also
ndarray.argmin,argmaxamin : The minimum value along a given axis.unravel_index : Convert a flat index into an index tuple.take_along_axis : Apply ``np.expand_dims(index_array,axis)`` from argmin to an array as if by calling min.Notes
In case of multiple occurrences of the minimum values, the indices corresponding to the first occurrence are returned.
Examples
>>> import numpy as np >>> a = np.arange(6).reshape(2,3) + 10 >>> a array([[10, 11, 12], [13, 14, 15]]) >>> np.argmin(a) 0 >>> np.argmin(a, axis=0) array([0, 0, 0]) >>> np.argmin(a, axis=1) array([0, 0])Indices of the minimum elements of a N-dimensional array:
>>> a.flat[np.argmin(a)] 10 >>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape) >>> ind (0, 0) >>> a[ind] 10>>> b = np.arange(6) + 10 >>> b[4] = 10 >>> b array([10, 11, 12, 13, 10, 15]) >>> np.argmin(b) # Only the first occurrence is returned. 0>>> x = np.array([[4,2,3], [1,0,3]]) >>> index_array = np.argmin(x, axis=-1) >>> # Same as np.amin(x, axis=-1, keepdims=True) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1) array([[2], [0]]) >>> # Same as np.amax(x, axis=-1) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), ... axis=-1).squeeze(axis=-1) array([2, 0])Setting
keepdimstoTrue,>>> x = np.arange(24).reshape((2, 3, 4)) >>> res = np.argmin(x, axis=1, keepdims=True) >>> res.shape (2, 1, 4) def ceil(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def ceil(x: ArrayLike, /) -> Array: """Round input to the nearest integer upwards. JAX implementation of :obj:`numpy.ceil`. Args: x: input array or scalar. Must not have complex dtype. Returns: An array with same shape and dtype as ``x`` containing the values rounded to the nearest integer that is greater than or equal to the value itself. See also: - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. Examples: >>> key = jax.random.key(1) >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) >>> with jnp.printoptions(precision=2, suppress=True): ... print(x) [[-0.61 0.34 -0.54] [-0.62 3.97 0.59] [ 4.84 3.42 -1.14]] >>> jnp.ceil(x) Array([[-0., 1., -0.], [-0., 4., 1.], [ 5., 4., -1.]], dtype=float32) """ x = ensure_arraylike('ceil', x) if dtypes.isdtype(x.dtype, ('integral', 'bool')): return lax.asarray(x) return lax.ceil(*promote_args_inexact('ceil', x))Round input to the nearest integer upwards.
JAX implementation of :obj:
numpy.ceil.Args
x- input array or scalar. Must not have complex dtype.
Returns
An array with same shape and dtype as
xcontaining the values rounded to the nearest integer that is greater than or equal to the value itself. See also: - :func:jax.numpy.fix: Rounds the input to the nearest integer towards zero. - :func:jax.numpy.trunc: Rounds the input to the nearest integer towards zero. - :func:jax.numpy.floor: Rounds the input down to the nearest integer.Examples
>>> key = jax.random.key(1) >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) >>> with jnp.printoptions(precision=2, suppress=True): ... print(x) [[-0.61 0.34 -0.54] [-0.62 3.97 0.59] [ 4.84 3.42 -1.14]] >>> jnp.ceil(x) Array([[-0., 1., -0.], [-0., 4., 1.], [ 5., 4., -1.]], dtype=float32) def clip(arr: ArrayLike | None = None,
/,
min: ArrayLike | None = None,
max: ArrayLike | None = None) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @api.jit def clip( arr: ArrayLike | None = None, /, min: ArrayLike | None = None, max: ArrayLike | None = None, ) -> Array: """Clip array values to a specified range. JAX implementation of :func:`numpy.clip`. Args: arr: N-dimensional array to be clipped. min: optional minimum value of the clipped range; if ``None`` (default) then result will not be clipped to any minimum value. If specified, it should be broadcast-compatible with ``arr`` and ``max``. max: optional maximum value of the clipped range; if ``None`` (default) then result will not be clipped to any maximum value. If specified, it should be broadcast-compatible with ``arr`` and ``min``. Returns: An array containing values from ``arr``, with values smaller than ``min`` set to ``min``, and values larger than ``max`` set to ``max``. Wherever ``min`` is larger than ``max``, the value of ``max`` is returned. See also: - :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays. - :func:`jax.numpy.maximum`: Compute the element-wise maximum value of two arrays. Examples: >>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) >>> jnp.clip(arr, 2, 5) Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32) """ if arr is None: raise ValueError("No input was provided to the clip function.") util.check_arraylike("clip", arr) if any(iscomplexobj(t) for t in (arr, min, max)): raise ValueError( "Clip received a complex value either through the input or the min/max " "keywords. Complex values have no ordering and cannot be clipped. " "Please convert to a real value or array by taking the real or " "imaginary components via jax.numpy.real/imag respectively.") if min is not None: arr = ufuncs.maximum(min, arr) if max is not None: arr = ufuncs.minimum(max, arr) return asarray(arr)Clip array values to a specified range.
JAX implementation of :func:
numpy.clip.Args
arr- N-dimensional array to be clipped.
min- optional minimum value of the clipped range; if
None(default) then result will not be clipped to any minimum value. If specified, it should be broadcast-compatible witharrandmax. max- optional maximum value of the clipped range; if
None(default) then result will not be clipped to any maximum value. If specified, it should be broadcast-compatible witharrandmin.
Returns
An array containing values from
arr, with values smaller thanminset tomin, and values larger thanmaxset tomax. Whereverminis larger thanmax, the value ofmaxis returned. See also: - :func:jax.numpy.minimum: Compute the element-wise minimum value of two arrays. - :func:jax.numpy.maximum: Compute the element-wise maximum value of two arrays.Examples
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) >>> jnp.clip(arr, 2, 5) Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32) def concat(arrays: np.ndarray | Array | Sequence[ArrayLike],
axis: int | None = 0,
dtype: DTypeLike | None = None) ‑> jax.jaxlib._jax.Array-
Expand source code
@export def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: """Join arrays along an existing axis. JAX implementation of :func:`numpy.concatenate`. Args: arrays: a sequence of arrays to concatenate; each must have the same shape except along the specified axis. If a single array is given it will be treated equivalently to `arrays = unstack(arrays)`, but the implementation will avoid explicit unstacking. axis: specify the axis along which to concatenate. If None, the arrays are flattened before concatenation. dtype: optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in :ref:`type-promotion`. Returns: the concatenated result. See also: - :func:`jax.lax.concatenate`: XLA concatenation API. - :func:`jax.numpy.concat`: Array API version of this function. - :func:`jax.numpy.stack`: concatenate arrays along a new axis. Examples: One-dimensional concatenation: >>> x = jnp.arange(3) >>> y = jnp.zeros(3, dtype=int) >>> jnp.concatenate([x, y]) Array([0, 1, 2, 0, 0, 0], dtype=int32) Two-dimensional concatenation: >>> x = jnp.ones((2, 3)) >>> y = jnp.zeros((2, 1)) >>> jnp.concatenate([x, y], axis=1) Array([[1., 1., 1., 0.], [1., 1., 1., 0.]], dtype=float32) """ if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) arrays = util.ensure_arraylike_tuple("concatenate", arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if axis is None: return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype) if np.ndim(arrays[0]) == 0: raise ValueError("Zero-dimensional arrays cannot be concatenated.") axis = _canonicalize_axis(axis, np.ndim(arrays[0])) if dtype is None: arrays_out = util.promote_dtypes(*arrays) else: arrays_out = [asarray(arr, dtype=dtype) for arr in arrays] # lax.concatenate can be slow to compile for wide concatenations, so form a # tree of concatenations as a workaround especially for op-by-op mode. # (https://github.com/jax-ml/jax/issues/653). k = 16 while len(arrays_out) > 1: arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) for i in range(0, len(arrays_out), k)] return arrays_out[0]Join arrays along an existing axis.
JAX implementation of :func:
numpy.concatenate.Args
arrays- a sequence of arrays to concatenate; each must have the same shape
except along the specified axis. If a single array is given it will be
treated equivalently to
arrays = unstack(arrays), but the implementation will avoid explicit unstacking. axis- specify the axis along which to concatenate. If None, the arrays are flattened before concatenation.
dtype- optional dtype of the resulting array. If not specified, the dtype
will be determined via type promotion rules described in :ref:
type-promotion.
Returns
the concatenated result. See also: - :func:
jax.lax.concatenate: XLA concatenation API. - :func:jax.numpy.concat: Array API version of this function. - :func:jax.numpy.stack: concatenate arrays along a new axis.Examples
One-dimensional concatenation:
>>> x = jnp.arange(3) >>> y = jnp.zeros(3, dtype=int) >>> jnp.concatenate([x, y]) Array([0, 1, 2, 0, 0, 0], dtype=int32)Two-dimensional concatenation:
>>> x = jnp.ones((2, 3)) >>> y = jnp.zeros((2, 1)) >>> jnp.concatenate([x, y], axis=1) Array([[1., 1., 1., 0.], [1., 1., 1., 0.]], dtype=float32) def conj(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def conjugate(x: ArrayLike, /) -> Array: """Return element-wise complex-conjugate of the input. JAX implementation of :obj:`numpy.conjugate`. Args: x: inpuat array or scalar. Returns: An array containing the complex-conjugate of ``x``. See also: - :func:`jax.numpy.real`: Returns the element-wise real part of the complex argument. - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the complex argument. Examples: >>> jnp.conjugate(3) Array(3, dtype=int32, weak_type=True) >>> x = jnp.array([2-1j, 3+5j, 7]) >>> jnp.conjugate(x) Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64) """ x = ensure_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)Return element-wise complex-conjugate of the input.
JAX implementation of :obj:
numpy.conjugate.Args
x- inpuat array or scalar.
Returns
An array containing the complex-conjugate of
x. See also: - :func:jax.numpy.real: Returns the element-wise real part of the complex argument. - :func:jax.numpy.imag: Returns the element-wise imaginary part of the complex argument.Examples
>>> jnp.conjugate(3) Array(3, dtype=int32, weak_type=True) >>> x = jnp.array([2-1j, 3+5j, 7]) >>> jnp.conjugate(x) Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64) def cos(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def cos(x: ArrayLike, /) -> Array: """Compute a trigonometric cosine of each element of input. JAX implementation of :obj:`numpy.cos`. Args: x: scalar or array. Angle in radians. Returns: An array containing the cosine of each element in ``x``, promotes to inexact dtype. See also: - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of input. - :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of trigonometric cosine of each element of input. Examples: >>> pi = jnp.pi >>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.cos(x)) [ 0.707 -0. -0.707 -0.866] """ out = lax.cos(*promote_args_inexact('cos', x)) jnp_error._set_error_if_nan(out) return outCompute a trigonometric cosine of each element of input.
JAX implementation of :obj:
numpy.cos.Args
x- scalar or array. Angle in radians.
Returns
An array containing the cosine of each element in
x, promotes to inexact dtype. See also: - :func:jax.numpy.sin: Computes a trigonometric sine of each element of input. - :func:jax.numpy.tan: Computes a trigonometric tangent of each element of input. - :func:jax.numpy.arccosand :func:jax.numpy.acos: Computes the inverse of trigonometric cosine of each element of input.Examples
>>> pi = jnp.pi >>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.cos(x)) [ 0.707 -0. -0.707 -0.866] def cumsum(a: ArrayLike,
axis: int | None = None,
dtype: DTypeLike | None = None,
out: None = None) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @api.jit(static_argnames=('axis', 'dtype')) def cumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: """Cumulative sum of elements along an axis. JAX implementation of :func:`numpy.cumsum`. Args: a: N-dimensional array to be accumulated. axis: integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis. dtype: optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype. out: unused by JAX Returns: An array containing the accumulated sum along the given axis. See also: - :func:`jax.numpy.cumulative_sum`: cumulative sum via the array API standard. - :meth:`jax.numpy.add.accumulate`: cumulative sum via ufunc methods. - :func:`jax.numpy.nancumsum`: cumulative sum ignoring NaN values. - :func:`jax.numpy.sum`: sum along axis Examples: >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumsum(x) # flattened cumulative sum Array([ 1, 3, 6, 10, 15, 21], dtype=int32) >>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1 Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32) """ return _cumulative_reduction("cumsum", control_flow.cumsum, a, axis, dtype, out)Cumulative sum of elements along an axis.
JAX implementation of :func:
numpy.cumsum.Args
a- N-dimensional array to be accumulated.
axis- integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.
dtype- optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.
out- unused by JAX
Returns
An array containing the accumulated sum along the given axis. See also: - :func:
jax.numpy.cumulative_sum: cumulative sum via the array API standard. - :meth:jax.numpy.add.accumulate: cumulative sum via ufunc methods. - :func:jax.numpy.nancumsum: cumulative sum ignoring NaN values. - :func:jax.numpy.sum: sum along axisExamples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumsum(x) # flattened cumulative sum Array([ 1, 3, 6, 10, 15, 21], dtype=int32) >>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1 Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32) def einsum(subscripts,
/,
*operands,
out: None = None,
optimize: str | bool | list[tuple[int, ...]] = 'auto',
precision: str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision] | jax._src.lax.lax.DotAlgorithm | jax._src.lax.lax.DotAlgorithmPreset | None = None,
preferred_element_type: str | type[typing.Any] | numpy.dtype | jax._src.typing.SupportsDType | None = None,
out_sharding=None) ‑> jax.jaxlib._jax.Array-
Expand source code
@export def einsum( subscripts, /, *operands, out: None = None, optimize: str | bool | list[tuple[int, ...]] = "auto", precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, out_sharding=None, ) -> Array: """Einstein summation JAX implementation of :func:`numpy.einsum`. ``einsum`` is a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. It has a somewhat complicated overloaded API; the arguments below reflect the most common calling convention. The Examples section below demonstrates some of the alternative calling conventions. Args: subscripts: string containing axes names separated by commas. *operands: sequence of one or more arrays corresponding to the subscripts. optimize: specify how to optimize the order of computation. In JAX this defaults to ``"auto"`` which produces optimized expressions via the opt_einsum_ package. Other options are ``True`` (same as ``"optimal"``), ``False`` (unoptimized), or any string supported by ``opt_einsum``, which includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also be a pre-computed path (see :func:`~jax.numpy.einsum_path`). precision: either ``None`` (default), which means the default precision for the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``). preferred_element_type: either ``None`` (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. out: unsupported by JAX _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. This parameter is experimental, and may be removed without warning at any time. Returns: array containing the result of the einstein summation. See also: :func:`jax.numpy.einsum_path` Examples: The mechanics of ``einsum`` are perhaps best demonstrated by example. Here we show how to use ``einsum`` to compute a number of quantities from one or more arrays. For more discussion and examples of ``einsum``, see the documentation of :func:`numpy.einsum`. >>> M = jnp.arange(16).reshape(4, 4) >>> x = jnp.arange(4) >>> y = jnp.array([5, 4, 3, 2]) **Vector product** >>> jnp.einsum('i,i', x, y) Array(16, dtype=int32) >>> jnp.vecdot(x, y) Array(16, dtype=int32) Here are some alternative ``einsum`` calling conventions to compute the same result: >>> jnp.einsum('i,i->', x, y) # explicit form Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices Array(16, dtype=int32) **Matrix product** >>> jnp.einsum('ij,j->i', M, x) # explicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.matmul(M, x) Array([14, 38, 62, 86], dtype=int32) Here are some alternative ``einsum`` calling conventions to compute the same result: >>> jnp.einsum('ij,j', M, x) # implicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices Array([14, 38, 62, 86], dtype=int32) **Outer product** >>> jnp.einsum("i,j->ij", x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.outer(x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) Some other ways of computing outer products: >>> jnp.einsum("i,j", x, y) # implicit form Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) **1D array sum** >>> jnp.einsum("i->", x) # requires explicit form Array(6, dtype=int32) >>> jnp.einsum(x, (0,), ()) # explicit form via indices Array(6, dtype=int32) >>> jnp.sum(x) Array(6, dtype=int32) **Sum along an axis** >>> jnp.einsum("...j->...", M) # requires explicit form Array([ 6, 22, 38, 54], dtype=int32) >>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices Array([ 6, 22, 38, 54], dtype=int32) >>> M.sum(-1) Array([ 6, 22, 38, 54], dtype=int32) **Matrix transpose** >>> y = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.einsum("ij->ji", y) # explicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum("ji", y) # implicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (1, 0)) # implicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.transpose(y) Array([[1, 4], [2, 5], [3, 6]], dtype=int32) **Matrix diagonal** >>> jnp.einsum("ii->i", M) Array([ 0, 5, 10, 15], dtype=int32) >>> jnp.diagonal(M) Array([ 0, 5, 10, 15], dtype=int32) **Matrix trace** >>> jnp.einsum("ii", M) Array(30, dtype=int32) >>> jnp.trace(M) Array(30, dtype=int32) **Tensor products** >>> x = jnp.arange(30).reshape(2, 3, 5) >>> y = jnp.arange(60).reshape(3, 4, 5) >>> jnp.einsum('ijk,jlk->il', x, y) # explicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)]) Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum('ijk,jlk', x, y) # implicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) **Chained dot products** >>> w = jnp.arange(5, 9).reshape(2, 2) >>> x = jnp.arange(6).reshape(2, 3) >>> y = jnp.arange(-2, 4).reshape(3, 2) >>> z = jnp.array([[2, 4, 6], [3, 5, 7]]) >>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> w @ x @ y @ z # direct chain of matmuls Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.linalg.multi_dot([w, x, y, z]) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) .. _opt_einsum: https://github.com/dgasmith/opt_einsum """ operands = (subscripts, *operands) if out is not None: raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") spec = operands[0] if isinstance(operands[0], str) else None path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize # Extract __jax_array__ before passing to contract_path() operands = tuple(op.__jax_array__() if hasattr(op, "__jax_array__") else op for op in operands) # Allow handling of shape polymorphism non_constant_dim_types = { type(d) for op in operands if not isinstance(op, str) for d in np.shape(op) if not core.is_constant_dim(d) } if not non_constant_dim_types: contract_path: Any = opt_einsum.contract_path else: ty = next(iter(non_constant_dim_types)) contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) # using einsum_call=True here is an internal api for opt_einsum... sorry operands, contractions = contract_path( *operands, einsum_call=True, use_blas=True, optimize=path_type) contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) num_contractions = len(contractions) out_sharding = canonicalize_sharding(out_sharding, 'einsum') if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError( "`out_sharding` argument of `einsum` only supports NamedSharding" " instances.") jit_einsum = api.jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True) if spec is not None: jit_einsum = api.named_call(jit_einsum, name=spec) operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands)) if num_contractions > 1 and out_sharding is not None: # TODO(yashkatariya): If the out_sharding is unreduced, figure out a way to # run the dot_general unreduced_rule on these einsums because right now we # drop into Auto mode skipping the checks happening in the rule. return auto_axes( jit_einsum, axes=out_sharding.mesh.explicit_axes, out_sharding=out_sharding, )(operand_arrays, contractions=contractions, precision=precision, preferred_element_type=preferred_element_type, _dot_general=_dot_general, out_sharding=None) else: return jit_einsum(operand_arrays, contractions, precision, preferred_element_type, _dot_general, out_sharding)Einstein summation
JAX implementation of :func:
numpy.einsum.einsumis a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. It has a somewhat complicated overloaded API; the arguments below reflect the most common calling convention. The Examples section below demonstrates some of the alternative calling conventions.Args
subscripts- string containing axes names separated by commas.
*operands- sequence of one or more arrays corresponding to the subscripts.
optimize- specify how to optimize the order of computation. In JAX this defaults
to
"auto"which produces optimized expressions via the opt_einsum_ package. Other options areTrue(same as"optimal"),False(unoptimized), or any string supported byopt_einsum, which includes"optimal","greedy","eager", and others. It may also be a pre-computed path (see :func:~jax.numpy.einsum_path). precision- either
None(default), which means the default precision for the backend, a :class:~jax.lax.Precisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST). preferred_element_type- either
None(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. out- unsupported by JAX
_dot_general- optionally override the
dot_generalcallable used byeinsum. This parameter is experimental, and may be removed without warning at any time.
Returns
array containing the result of the einstein summation. See also: :func:
jax.numpy.einsum_pathExamples
The mechanics of
einsumare perhaps best demonstrated by example. Here we show how to useeinsumto compute a number of quantities from one or more arrays. For more discussion and examples ofeinsum, see the documentation of :func:numpy.einsum.>>> M = jnp.arange(16).reshape(4, 4) >>> x = jnp.arange(4) >>> y = jnp.array([5, 4, 3, 2])Vector product
>>> jnp.einsum('i,i', x, y) Array(16, dtype=int32) >>> jnp.vecdot(x, y) Array(16, dtype=int32)Here are some alternative
einsumcalling conventions to compute the same result:>>> jnp.einsum('i,i->', x, y) # explicit form Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices Array(16, dtype=int32)Matrix product
>>> jnp.einsum('ij,j->i', M, x) # explicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.matmul(M, x) Array([14, 38, 62, 86], dtype=int32)Here are some alternative
einsumcalling conventions to compute the same result:>>> jnp.einsum('ij,j', M, x) # implicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices Array([14, 38, 62, 86], dtype=int32)Outer product
>>> jnp.einsum("i,j->ij", x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.outer(x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)Some other ways of computing outer products:
>>> jnp.einsum("i,j", x, y) # implicit form Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)1D array sum
>>> jnp.einsum("i->", x) # requires explicit form Array(6, dtype=int32) >>> jnp.einsum(x, (0,), ()) # explicit form via indices Array(6, dtype=int32) >>> jnp.sum(x) Array(6, dtype=int32)Sum along an axis
>>> jnp.einsum("...j->...", M) # requires explicit form Array([ 6, 22, 38, 54], dtype=int32) >>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices Array([ 6, 22, 38, 54], dtype=int32) >>> M.sum(-1) Array([ 6, 22, 38, 54], dtype=int32)Matrix transpose
>>> y = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.einsum("ij->ji", y) # explicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum("ji", y) # implicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (1, 0)) # implicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.transpose(y) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)Matrix diagonal
>>> jnp.einsum("ii->i", M) Array([ 0, 5, 10, 15], dtype=int32) >>> jnp.diagonal(M) Array([ 0, 5, 10, 15], dtype=int32)Matrix trace
>>> jnp.einsum("ii", M) Array(30, dtype=int32) >>> jnp.trace(M) Array(30, dtype=int32)Tensor products
>>> x = jnp.arange(30).reshape(2, 3, 5) >>> y = jnp.arange(60).reshape(3, 4, 5) >>> jnp.einsum('ijk,jlk->il', x, y) # explicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)]) Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum('ijk,jlk', x, y) # implicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)Chained dot products
>>> w = jnp.arange(5, 9).reshape(2, 2) >>> x = jnp.arange(6).reshape(2, 3) >>> y = jnp.arange(-2, 4).reshape(3, 2) >>> z = jnp.array([[2, 4, 6], [3, 5, 7]]) >>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> w @ x @ y @ z # direct chain of matmuls Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.linalg.multi_dot([w, x, y, z]) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32).. _opt_einsum: https://github.com/dgasmith/opt_einsum
def equal(x: ArrayLike, y: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x == y``. JAX implementation of :obj:`numpy.equal`. This function provides the implementation of the ``==`` operator for JAX arrays. Args: x: input array or scalar. y: input array or scalar. ``x`` and ``y`` should either have same shape or be broadcast compatible. Returns: A boolean array containing ``True`` where the elements of ``x == y`` and ``False`` otherwise. See also: - :func:`jax.numpy.not_equal`: Returns element-wise truth value of ``x != y``. - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of ``x >= y``. - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``. - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``. - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``. Examples: >>> jnp.equal(0., -0.) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(1, 1.) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(5, jnp.array(5)) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(2, -2) Array(False, dtype=bool, weak_type=True) >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> y = jnp.array([1, 5, 9]) >>> jnp.equal(x, y) Array([[ True, False, False], [False, True, False], [False, False, True]], dtype=bool) >>> x == y Array([[ True, False, False], [False, True, False], [False, False, True]], dtype=bool) """ return lax.eq(*promote_args("equal", x, y))Returns element-wise truth value of
x == y.JAX implementation of :obj:
numpy.equal. This function provides the implementation of the==operator for JAX arrays.Args
x- input array or scalar.
y- input array or scalar.
xandyshould either have same shape or be broadcast compatible.
Returns
A boolean array containing
Truewhere the elements ofx == yandFalseotherwise. See also: - :func:jax.numpy.not_equal: Returns element-wise truth value ofx != y. - :func:jax.numpy.greater_equal: Returns element-wise truth value ofx >= y. - :func:jax.numpy.less_equal: Returns element-wise truth value ofx <= y. - :func:jax.numpy.greater: Returns element-wise truth value ofx > y. - :func:jax.numpy.less: Returns element-wise truth value ofx < y.Examples
>>> jnp.equal(0., -0.) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(1, 1.) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(5, jnp.array(5)) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(2, -2) Array(False, dtype=bool, weak_type=True) >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> y = jnp.array([1, 5, 9]) >>> jnp.equal(x, y) Array([[ True, False, False], [False, True, False], [False, False, True]], dtype=bool) >>> x == y Array([[ True, False, False], [False, True, False], [False, False, True]], dtype=bool) def erf(x: ArrayLike) ‑> jax.jaxlib._jax.Array-
Expand source code
def erf(x: ArrayLike) -> Array: r"""The error function JAX implementation of :obj:`scipy.special.erf`. .. math:: \mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t Args: x: arraylike, real-valued. Returns: array containing values of the error function. Notes: The JAX version only supports real-valued inputs. See also: - :func:`jax.scipy.special.erfc` - :func:`jax.scipy.special.erfinv` """ x, = promote_args_inexact("erf", x) return lax.erf(x)The error function
JAX implementation of :obj:
scipy.special.erf.[ \mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t ]
Args
x- arraylike, real-valued.
Returns
array containing values of the error function.
Notes
The JAX version only supports real-valued inputs.
See also: - :func:
jax.scipy.special.erfc- :func:jax.scipy.special.erfinv def exp(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def exp(x: ArrayLike, /) -> Array: """Calculate element-wise exponential of the input. JAX implementation of :obj:`numpy.exp`. Args: x: input array or scalar Returns: An array containing the exponential of each element in ``x``, promotes to inexact dtype. See also: - :func:`jax.numpy.log`: Calculates element-wise logarithm of the input. - :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the input. - :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of the input. Examples: ``jnp.exp`` follows the properties of exponential such as :math:`e^{(a+b)} = e^a * e^b`. >>> x1 = jnp.array([2, 4, 3, 1]) >>> x2 = jnp.array([1, 3, 2, 3]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x1+x2)) [ 20.09 1096.63 148.41 54.6 ] >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x1)*jnp.exp(x2)) [ 20.09 1096.63 148.41 54.6 ] This property holds for complex input also: >>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j)) Array(True, dtype=bool) """ return lax.exp(*promote_args_inexact('exp', x))Calculate element-wise exponential of the input.
JAX implementation of :obj:
numpy.exp.Args
x- input array or scalar
Returns
An array containing the exponential of each element in
x, promotes to inexact dtype. See also: - :func:jax.numpy.log: Calculates element-wise logarithm of the input. - :func:jax.numpy.expm1: Calculates :math:e^x-1of each element of the input. - :func:jax.numpy.exp2: Calculates base-2 exponential of each element of the input.Examples
jnp.expfollows the properties of exponential such as :math:e^{(a+b)} = e^a * e^b.>>> x1 = jnp.array([2, 4, 3, 1]) >>> x2 = jnp.array([1, 3, 2, 3]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x1+x2)) [ 20.09 1096.63 148.41 54.6 ] >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x1)*jnp.exp(x2)) [ 20.09 1096.63 148.41 54.6 ]This property holds for complex input also:
>>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j)) Array(True, dtype=bool) def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) ‑> jax.jaxlib._jax.Array-
Expand source code
@export def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Reverse the order of elements of an array along the given axis. JAX implementation of :func:`numpy.flip`. Args: m: Array. axis: integer or sequence of integers. Specifies along which axis or axes should the array elements be reversed. Default is ``None``, which flips along all axes. Returns: An array with the elements in reverse order along ``axis``. See Also: - :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right) - :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down) Examples: >>> x1 = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.flip(x1) Array([[4, 3], [2, 1]], dtype=int32) If ``axis`` is specified with an integer, then ``jax.numpy.flip`` reverses the array along that particular axis only. >>> jnp.flip(x1, axis=1) Array([[2, 1], [4, 3]], dtype=int32) >>> x2 = jnp.arange(1, 9).reshape(2, 2, 2) >>> x2 Array([[[1, 2], [3, 4]], <BLANKLINE> [[5, 6], [7, 8]]], dtype=int32) >>> jnp.flip(x2) Array([[[8, 7], [6, 5]], <BLANKLINE> [[4, 3], [2, 1]]], dtype=int32) When ``axis`` is specified with a sequence of integers, then ``jax.numpy.flip`` reverses the array along the specified axes. >>> jnp.flip(x2, axis=[1, 2]) Array([[[4, 3], [2, 1]], <BLANKLINE> [[8, 7], [6, 5]]], dtype=int32) """ arr = util.ensure_arraylike("flip", m) return _flip(arr, reductions._ensure_optional_axes(axis))Reverse the order of elements of an array along the given axis.
JAX implementation of :func:
numpy.flip.Args
m- Array.
axis- integer or sequence of integers. Specifies along which axis or axes
should the array elements be reversed. Default is
None, which flips along all axes.
Returns
An array with the elements in reverse order along
axis. See Also: - :func:jax.numpy.fliplr: reverse the order along axis 1 (left/right) - :func:jax.numpy.flipud: reverse the order along axis 0 (up/down)Examples
>>> x1 = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.flip(x1) Array([[4, 3], [2, 1]], dtype=int32)If
axisis specified with an integer, thenjax.numpy.flipreverses the array along that particular axis only.>>> jnp.flip(x1, axis=1) Array([[2, 1], [4, 3]], dtype=int32)>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2) >>> x2 Array([[[1, 2], [3, 4]], <BLANKLINE> [[5, 6], [7, 8]]], dtype=int32) >>> jnp.flip(x2) Array([[[8, 7], [6, 5]], <BLANKLINE> [[4, 3], [2, 1]]], dtype=int32)When
axisis specified with a sequence of integers, thenjax.numpy.flipreverses the array along the specified axes.>>> jnp.flip(x2, axis=[1, 2]) Array([[[4, 3], [2, 1]], <BLANKLINE> [[8, 7], [6, 5]]], dtype=int32) def floor(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def floor(x: ArrayLike, /) -> Array: """Round input to the nearest integer downwards. JAX implementation of :obj:`numpy.floor`. Args: x: input array or scalar. Must not have complex dtype. Returns: An array with same shape and dtype as ``x`` containing the values rounded to the nearest integer that is less than or equal to the value itself. See also: - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. Examples: >>> key = jax.random.key(42) >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) >>> with jnp.printoptions(precision=2, suppress=True): ... print(x) [[-0.11 1.8 1.16] [ 0.61 -0.49 0.86] [-4.25 2.75 1.99]] >>> jnp.floor(x) Array([[-1., 1., 1.], [ 0., -1., 0.], [-5., 2., 1.]], dtype=float32) """ x = ensure_arraylike('floor', x) if dtypes.isdtype(x.dtype, ('integral', 'bool')): return x return lax.floor(*promote_args_inexact('floor', x))Round input to the nearest integer downwards.
JAX implementation of :obj:
numpy.floor.Args
x- input array or scalar. Must not have complex dtype.
Returns
An array with same shape and dtype as
xcontaining the values rounded to the nearest integer that is less than or equal to the value itself. See also: - :func:jax.numpy.fix: Rounds the input to the nearest integer towards zero. - :func:jax.numpy.trunc: Rounds the input to the nearest integer towards zero. - :func:jax.numpy.ceil: Rounds the input up to the nearest integer.Examples
>>> key = jax.random.key(42) >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) >>> with jnp.printoptions(precision=2, suppress=True): ... print(x) [[-0.11 1.8 1.16] [ 0.61 -0.49 0.86] [-4.25 2.75 1.99]] >>> jnp.floor(x) Array([[-1., 1., 1.], [ 0., -1., 0.], [-5., 2., 1.]], dtype=float32) def imag(val: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def imag(val: ArrayLike, /) -> Array: """Return element-wise imaginary of part of the complex argument. JAX implementation of :obj:`numpy.imag`. Args: val: input array or scalar. Returns: An array containing the imaginary part of the elements of ``val``. See also: - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise complex-conjugate of the input. - :func:`jax.numpy.real`: Returns the element-wise real part of the complex argument. Examples: >>> jnp.imag(4) Array(0, dtype=int32, weak_type=True) >>> jnp.imag(5j) Array(5., dtype=float32, weak_type=True) >>> x = jnp.array([2+3j, 5-1j, -3]) >>> jnp.imag(x) Array([ 3., -1., 0.], dtype=float32) """ val = ensure_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)Return element-wise imaginary of part of the complex argument.
JAX implementation of :obj:
numpy.imag.Args
val- input array or scalar.
Returns
An array containing the imaginary part of the elements of
val. See also: - :func:jax.numpy.conjugateand :func:jax.numpy.conj: Returns the element-wise complex-conjugate of the input. - :func:jax.numpy.real: Returns the element-wise real part of the complex argument.Examples
>>> jnp.imag(4) Array(0, dtype=int32, weak_type=True) >>> jnp.imag(5j) Array(5., dtype=float32, weak_type=True) >>> x = jnp.array([2+3j, 5-1j, -3]) >>> jnp.imag(x) Array([ 3., -1., 0.], dtype=float32) def isfinite(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def isfinite(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is finite. JAX implementation of :obj:`numpy.isfinite`. Args: x: input array or scalar. Returns: A boolean array of same shape as ``x`` containing ``True`` where ``x`` is not ``inf``, ``-inf``, or ``NaN``, and ``False`` otherwise. See also: - :func:`jax.numpy.isinf`: Returns a boolean array indicating whether each element of input is either positive or negative infinity. - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each element of input is positive infinity. - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each element of input is negative infinity. - :func:`jax.numpy.isnan`: Returns a boolean array indicating whether each element of input is not a number (``NaN``). Examples: >>> x = jnp.array([-1, 3, jnp.inf, jnp.nan]) >>> jnp.isfinite(x) Array([ True, True, False, False], dtype=bool) >>> jnp.isfinite(3-4j) Array(True, dtype=bool, weak_type=True) """ x = ensure_arraylike("isfinite", x) dtype = x.dtype if dtypes.issubdtype(dtype, np.floating): return lax.is_finite(x) elif dtypes.issubdtype(dtype, np.complexfloating): return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) else: return lax.full_like(x, True, dtype=np.bool_)Return a boolean array indicating whether each element of input is finite.
JAX implementation of :obj:
numpy.isfinite.Args
x- input array or scalar.
Returns
A boolean array of same shape as
xcontainingTruewherexis notinf,-inf, orNaN, andFalseotherwise. See also: - :func:jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity. - :func:jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity. - :func:jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity. - :func:jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).Examples
>>> x = jnp.array([-1, 3, jnp.inf, jnp.nan]) >>> jnp.isfinite(x) Array([ True, True, False, False], dtype=bool) >>> jnp.isfinite(3-4j) Array(True, dtype=bool, weak_type=True) def isinf(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit def isinf(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is infinite. JAX implementation of :obj:`numpy.isinf`. Args: x: input array or scalar. Returns: A boolean array of same shape as ``x`` containing ``True`` where ``x`` is ``inf`` or ``-inf``, and ``False`` otherwise. See also: - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each element of input is positive infinity. - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each element of input is negative infinity. - :func:`jax.numpy.isfinite`: Returns a boolean array indicating whether each element of input is finite. - :func:`jax.numpy.isnan`: Returns a boolean array indicating whether each element of input is not a number (``NaN``). Examples: >>> jnp.isinf(jnp.inf) Array(True, dtype=bool) >>> x = jnp.array([2+3j, -jnp.inf, 6, jnp.inf, jnp.nan]) >>> jnp.isinf(x) Array([False, True, False, True, False], dtype=bool) """ x = ensure_arraylike("isinf", x) dtype = x.dtype if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) elif dtypes.issubdtype(dtype, np.complexfloating): re = lax.real(x) im = lax.imag(x) return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)), lax.eq(lax.abs(im), _constant_like(im, np.inf))) else: return lax.full_like(x, False, dtype=np.bool_)Return a boolean array indicating whether each element of input is infinite.
JAX implementation of :obj:
numpy.isinf.Args
x- input array or scalar.
Returns
A boolean array of same shape as
xcontainingTruewherexisinfor-inf, andFalseotherwise. See also: - :func:jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity. - :func:jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity. - :func:jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite. - :func:jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).Examples
>>> jnp.isinf(jnp.inf) Array(True, dtype=bool) >>> x = jnp.array([2+3j, -jnp.inf, 6, jnp.inf, jnp.nan]) >>> jnp.isinf(x) Array([False, True, False, True, False], dtype=bool) def isnan(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def isnan(x: ArrayLike, /) -> Array: """Returns a boolean array indicating whether each element of input is ``NaN``. JAX implementation of :obj:`numpy.isnan`. Args: x: input array or scalar. Returns: A boolean array of same shape as ``x`` containing ``True`` where ``x`` is not a number (i.e. ``NaN``) and ``False`` otherwise. See also: - :func:`jax.numpy.isfinite`: Returns a boolean array indicating whether each element of input is finite. - :func:`jax.numpy.isinf`: Returns a boolean array indicating whether each element of input is either positive or negative infinity. - :func:`jax.numpy.isposinf`: Returns a boolean array indicating whether each element of input is positive infinity. - :func:`jax.numpy.isneginf`: Returns a boolean array indicating whether each element of input is negative infinity. Examples: >>> jnp.isnan(6) Array(False, dtype=bool, weak_type=True) >>> x = jnp.array([2, 1+4j, jnp.inf, jnp.nan]) >>> jnp.isnan(x) Array([False, False, False, True], dtype=bool) """ x = ensure_arraylike("isnan", x) return lax.ne(x, x)Returns a boolean array indicating whether each element of input is
NaN.JAX implementation of :obj:
numpy.isnan.Args
x- input array or scalar.
Returns
A boolean array of same shape as
xcontainingTruewherexis not a number (i.e.NaN) andFalseotherwise. See also: - :func:jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite. - :func:jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity. - :func:jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity. - :func:jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.Examples
>>> jnp.isnan(6) Array(False, dtype=bool, weak_type=True) >>> x = jnp.array([2, 1+4j, jnp.inf, jnp.nan]) >>> jnp.isnan(x) Array([False, False, False, True], dtype=bool) def log(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def log(x: ArrayLike, /) -> Array: """Calculate element-wise natural logarithm of the input. JAX implementation of :obj:`numpy.log`. Args: x: input array or scalar. Returns: An array containing the logarithm of each element in ``x``, promotes to inexact dtype. See also: - :func:`jax.numpy.exp`: Calculates element-wise exponential of the input. - :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input. - :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input. Examples: ``jnp.log`` and ``jnp.exp`` are inverse functions of each other. Applying ``jnp.log`` on the result of ``jnp.exp(x)`` yields the original input ``x``. >>> x = jnp.array([2, 3, 4, 5]) >>> jnp.log(jnp.exp(x)) Array([2., 3., 4., 5.], dtype=float32) Using ``jnp.log`` we can demonstrate well-known properties of logarithms, such as :math:`log(a*b) = log(a)+log(b)`. >>> x1 = jnp.array([2, 1, 3, 1]) >>> x2 = jnp.array([1, 3, 2, 4]) >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) Array(True, dtype=bool) """ out = lax.log(*promote_args_inexact('log', x)) jnp_error._set_error_if_nan(out) return outCalculate element-wise natural logarithm of the input.
JAX implementation of :obj:
numpy.log.Args
x- input array or scalar.
Returns
An array containing the logarithm of each element in
x, promotes to inexact dtype. See also: - :func:jax.numpy.exp: Calculates element-wise exponential of the input. - :func:jax.numpy.log2: Calculates base-2 logarithm of each element of input. - :func:jax.numpy.log1p: Calculates element-wise logarithm of one plus input.Examples
jnp.logandjnp.expare inverse functions of each other. Applyingjnp.logon the result ofjnp.exp(x)yields the original inputx.>>> x = jnp.array([2, 3, 4, 5]) >>> jnp.log(jnp.exp(x)) Array([2., 3., 4., 5.], dtype=float32)Using
jnp.logwe can demonstrate well-known properties of logarithms, such as :math:log(a*b) = log(a)+log(b).>>> x1 = jnp.array([2, 1, 3, 1]) >>> x2 = jnp.array([1, 3, 2, 4]) >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) Array(True, dtype=bool) def log10(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise JAX implementation of :obj:`numpy.log10`. Args: x: Input array Returns: An array containing the base-10 logarithm of each element in ``x``, promotes to inexact dtype. Examples: >>> x1 = jnp.array([0.01, 0.1, 1, 10, 100, 1000]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.log10(x1)) [-2. -1. 0. 1. 2. 3.] """ x, = promote_args_inexact("log10", x) one_over_log10 = np.array(0.4342944819032518, # exact value of 1 / log(10) dtype=dtypes.finfo(x.dtype).dtype) if dtypes.issubdtype(x.dtype, np.complexfloating): r = lax.log(x) re = lax.real(r) im = lax.imag(r) return lax.complex(lax.mul(re, one_over_log10), lax.mul(im, one_over_log10)) out = lax.mul(lax.log(x), one_over_log10) jnp_error._set_error_if_nan(out) return outCalculates the base-10 logarithm of x element-wise
JAX implementation of :obj:
numpy.log10.Args
x- Input array
Returns
An array containing the base-10 logarithm of each element in
x, promotes to inexact dtype.Examples
>>> x1 = jnp.array([0.01, 0.1, 1, 10, 100, 1000]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.log10(x1)) [-2. -1. 0. 1. 2. 3.] def log2(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def log2(x: ArrayLike, /) -> Array: """Calculates the base-2 logarithm of ``x`` element-wise. JAX implementation of :obj:`numpy.log2`. Args: x: Input array Returns: An array containing the base-2 logarithm of each element in ``x``, promotes to inexact dtype. Examples: >>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8]) >>> jnp.log2(x1) Array([-2., -1., 0., 1., 2., 3.], dtype=float32) """ x, = promote_args_inexact("log2", x) if dtypes.issubdtype(x.dtype, np.complexfloating): r = lax.log(x) re = lax.real(r) im = lax.imag(r) ln2 = lax.log(_constant_like(re, 2)) return lax.complex(lax.div(re, ln2), lax.div(im, ln2)) out = lax.div(lax.log(x), lax.log(_constant_like(x, 2))) jnp_error._set_error_if_nan(out) return outCalculates the base-2 logarithm of
xelement-wise.JAX implementation of :obj:
numpy.log2.Args
x- Input array
Returns
An array containing the base-2 logarithm of each element in
x, promotes to inexact dtype.Examples
>>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8]) >>> jnp.log2(x1) Array([-2., -1., 0., 1., 2., 3.], dtype=float32) def real(val: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def real(val: ArrayLike, /) -> Array: """Return element-wise real part of the complex argument. JAX implementation of :obj:`numpy.real`. Args: val: input array or scalar. Returns: An array containing the real part of the elements of ``val``. See also: - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise complex-conjugate of the input. - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the complex argument. Examples: >>> jnp.real(5) Array(5, dtype=int32, weak_type=True) >>> jnp.real(2j) Array(0., dtype=float32, weak_type=True) >>> x = jnp.array([3-2j, 4+7j, -2j]) >>> jnp.real(x) Array([ 3., 4., -0.], dtype=float32) """ val = ensure_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)Return element-wise real part of the complex argument.
JAX implementation of :obj:
numpy.real.Args
val- input array or scalar.
Returns
An array containing the real part of the elements of
val. See also: - :func:jax.numpy.conjugateand :func:jax.numpy.conj: Returns the element-wise complex-conjugate of the input. - :func:jax.numpy.imag: Returns the element-wise imaginary part of the complex argument.Examples
>>> jnp.real(5) Array(5, dtype=int32, weak_type=True) >>> jnp.real(2j) Array(0., dtype=float32, weak_type=True) >>> x = jnp.array([3-2j, 4+7j, -2j]) >>> jnp.real(x) Array([ 3., 4., -0.], dtype=float32) def round(a: ArrayLike, decimals: int = 0, out: None = None) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @api.jit(static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Round input evenly to the given number of decimals. JAX implementation of :func:`numpy.round`. Args: a: input array or scalar. decimals: int, default=0. Number of decimal points to which the input needs to be rounded. It must be specified statically. Not implemented for ``decimals < 0``. out: Unused by JAX. Returns: An array containing the rounded values to the specified ``decimals`` with same shape and dtype as ``a``. Note: ``jnp.round`` rounds to the nearest even integer for the values exactly halfway between rounded decimal values. See also: - :func:`jax.numpy.floor`: Rounds the input to the nearest integer downwards. - :func:`jax.numpy.ceil`: Rounds the input to the nearest integer upwards. - :func:`jax.numpy.fix` and :func:numpy.trunc`: Rounds the input to the nearest integer towards zero. Examples: >>> x = jnp.array([1.532, 3.267, 6.149]) >>> jnp.round(x) Array([2., 3., 6.], dtype=float32) >>> jnp.round(x, decimals=2) Array([1.53, 3.27, 6.15], dtype=float32) For values exactly halfway between rounded values: >>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5]) >>> jnp.round(x1) Array([10., 22., 12., 32.], dtype=float32) """ a = util.ensure_arraylike("round", a) decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round") if out is not None: raise NotImplementedError("The 'out' argument to jnp.round is not supported.") dtype = a.dtype if issubdtype(dtype, np.integer): if decimals < 0: raise NotImplementedError( "integer np.round not implemented for decimals < 0") return a # no-op on integer types def _round_float(x: ArrayLike) -> Array: if decimals == 0: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) # TODO(phawkins): the strategy of rescaling the value isn't necessarily a # good one since we may be left with an incorrectly rounded value at the # end due to precision problems. As a workaround for float16, convert to # float32, x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x factor = lax._const(x, 10 ** decimals) out = lax.div(lax.round(lax.mul(x, factor), lax.RoundingMethod.TO_NEAREST_EVEN), factor) return lax.convert_element_type(out, dtype) if dtype == np.float16 else out if decimals > np.log10(dtypes.finfo(dtype).max): # Rounding beyond the input precision is a no-op. return lax.asarray(a) if issubdtype(dtype, np.complexfloating): return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a))) else: return _round_float(a)Round input evenly to the given number of decimals.
JAX implementation of :func:
numpy.round.Args
a- input array or scalar.
decimals- int, default=0. Number of decimal points to which the input needs
to be rounded. It must be specified statically. Not implemented for
decimals < 0. out- Unused by JAX.
Returns
An array containing the rounded values to the specified
decimalswith same shape and dtype asa.Note
jnp.roundrounds to the nearest even integer for the values exactly halfway between rounded decimal values.See also: - :func:
jax.numpy.floor: Rounds the input to the nearest integer downwards. - :func:jax.numpy.ceil: Rounds the input to the nearest integer upwards. - :func:jax.numpy.fixand :func:numpy.trunc`: Rounds the input to the nearest integer towards zero.Examples
>>> x = jnp.array([1.532, 3.267, 6.149]) >>> jnp.round(x) Array([2., 3., 6.], dtype=float32) >>> jnp.round(x, decimals=2) Array([1.53, 3.27, 6.15], dtype=float32)For values exactly halfway between rounded values:
>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5]) >>> jnp.round(x1) Array([10., 22., 12., 32.], dtype=float32) def shape(a: ArrayLike | SupportsShape) ‑> tuple[int, ...]-
Expand source code
@export def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: """Return the shape an array. JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function raises a :class:`TypeError` if the input is a collection such as a list or tuple. Args: a: array-like object, or any object with a ``shape`` attribute. Returns: An tuple of integers representing the shape of ``a``. Examples: Shape for arrays: >>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3) This also works for scalars: >>> jnp.shape(3.14) () For arrays, this can also be accessed via the :attr:`jax.Array.shape` property: >>> x.shape (10,) """ if hasattr(a, "shape"): return a.shape # Deprecation warning added 2025-2-20. check_arraylike("shape", a, emit_warning=True) if hasattr(a, "__jax_array__"): a = a.__jax_array__() # NumPy dispatches to a.shape if available. return np.shape(a)Return the shape an array.
JAX implementation of :func:
numpy.shape. Unlikenp.shape, this function raises a :class:TypeErrorif the input is a collection such as a list or tuple.Args
a- array-like object, or any object with a
shapeattribute.
Returns
An tuple of integers representing the shape of
a.Examples
Shape for arrays:
>>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3)This also works for scalars:
>>> jnp.shape(3.14) ()For arrays, this can also be accessed via the :attr:
jax.Array.shapeproperty:>>> x.shape (10,) def sign(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def sign(x: ArrayLike, /) -> Array: r"""Return an element-wise indication of sign of the input. JAX implementation of :obj:`numpy.sign`. The sign of ``x`` for real-valued input is: .. math:: \mathrm{sign}(x) = \begin{cases} 1, & x > 0\\ 0, & x = 0\\ -1, & x < 0 \end{cases} For complex valued input, ``jnp.sign`` returns a unit vector representing the phase. For generalized case, the sign of ``x`` is given by: .. math:: \mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\\ 0, & x = 0 \end{cases} Args: x: input array or scalar. Returns: An array with same shape and dtype as ``x`` containing the sign indication. See also: - :func:`jax.numpy.positive`: Returns element-wise positive values of the input. - :func:`jax.numpy.negative`: Returns element-wise negative values of the input. Examples: For Real-valued inputs: >>> x = jnp.array([0., -3., 7.]) >>> jnp.sign(x) Array([ 0., -1., 1.], dtype=float32) For complex-inputs: >>> x1 = jnp.array([1, 3+4j, 5j]) >>> jnp.sign(x1) Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64) """ return lax.sign(*promote_args('sign', x))Return an element-wise indication of sign of the input.
JAX implementation of :obj:
numpy.sign.The sign of
xfor real-valued input is:[ \mathrm{sign}(x) = \begin{cases} 1, & x > 0\ 0, & x = 0\ -1, & x < 0 \end{cases} ]
For complex valued input,
jnp.signreturns a unit vector representing the phase. For generalized case, the sign ofxis given by:[ \mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\ 0, & x = 0 \end{cases} ]
Args
x- input array or scalar.
Returns
An array with same shape and dtype as
xcontaining the sign indication. See also: - :func:jax.numpy.positive: Returns element-wise positive values of the input. - :func:jax.numpy.negative: Returns element-wise negative values of the input.Examples
For Real-valued inputs:
>>> x = jnp.array([0., -3., 7.]) >>> jnp.sign(x) Array([ 0., -1., 1.], dtype=float32)For complex-inputs:
>>> x1 = jnp.array([1, 3+4j, 5j]) >>> jnp.sign(x1) Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64) def sin(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def sin(x: ArrayLike, /) -> Array: """Compute a trigonometric sine of each element of input. JAX implementation of :obj:`numpy.sin`. Args: x: array or scalar. Angle in radians. Returns: An array containing the sine of each element in ``x``, promotes to inexact dtype. See also: - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of input. - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of input. - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of trigonometric sine of each element of input. Examples: >>> pi = jnp.pi >>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.sin(x)) [ 0.707 1. 0.707 -0. ] """ out = lax.sin(*promote_args_inexact('sin', x)) jnp_error._set_error_if_nan(out) return outCompute a trigonometric sine of each element of input.
JAX implementation of :obj:
numpy.sin.Args
x- array or scalar. Angle in radians.
Returns
An array containing the sine of each element in
x, promotes to inexact dtype. See also: - :func:jax.numpy.cos: Computes a trigonometric cosine of each element of input. - :func:jax.numpy.tan: Computes a trigonometric tangent of each element of input. - :func:jax.numpy.arcsinand :func:jax.numpy.asin: Computes the inverse of trigonometric sine of each element of input.Examples
>>> pi = jnp.pi >>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.sin(x)) [ 0.707 1. 0.707 -0. ] def softplus(x: ArrayLike) ‑> jax.jaxlib._jax.Array-
Expand source code
@api.jit def softplus(x: ArrayLike) -> Array: r"""Softplus activation function. Computes the element-wise function .. math:: \mathrm{softplus}(x) = \log(1 + e^x) Args: x : input array """ return jnp.logaddexp(x, 0)Softplus activation function.
Computes the element-wise function
[ \mathrm{softplus}(x) = \log(1 + e^x) ]
Args
x : input array
def sqrt(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def sqrt(x: ArrayLike, /) -> Array: """Calculates element-wise non-negative square root of the input array. JAX implementation of :obj:`numpy.sqrt`. Args: x: input array or scalar. Returns: An array containing the non-negative square root of the elements of ``x``. Note: - For real-valued negative inputs, ``jnp.sqrt`` produces a ``nan`` output. - For complex-valued negative inputs, ``jnp.sqrt`` produces a ``complex`` output. See also: - :func:`jax.numpy.square`: Calculates the element-wise square of the input. - :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential of ``x2``. Examples: >>> x = jnp.array([-8-6j, 1j, 4]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sqrt(x) Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True) """ out = lax.sqrt(*promote_args_inexact('sqrt', x)) jnp_error._set_error_if_nan(out) return outCalculates element-wise non-negative square root of the input array.
JAX implementation of :obj:
numpy.sqrt.Args
x- input array or scalar.
Returns
An array containing the non-negative square root of the elements of
x.Note
- For real-valued negative inputs,
jnp.sqrtproduces ananoutput. - For complex-valued negative inputs,
jnp.sqrtproduces acomplexoutput.
See also: - :func:
jax.numpy.square: Calculates the element-wise square of the input. - :func:jax.numpy.power: Calculates the element-wise basex1exponential ofx2.Examples
>>> x = jnp.array([-8-6j, 1j, 4]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sqrt(x) Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True) def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
axis: int = 0,
out: None = None,
dtype: DTypeLike | None = None) ‑> jax.jaxlib._jax.Array-
Expand source code
@export def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: """Join arrays along a new axis. JAX implementation of :func:`numpy.stack`. Args: arrays: a sequence of arrays to stack; each must have the same shape. If a single array is given it will be treated equivalently to `arrays = unstack(arrays)`, but the implementation will avoid explicit unstacking. axis: specify the axis along which to stack. out: unused by JAX dtype: optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in :ref:`type-promotion`. Returns: the stacked result. See also: - :func:`jax.numpy.unstack`: inverse of ``stack``. - :func:`jax.numpy.concatenate`: concatenation along existing axes. - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2. - :func:`jax.numpy.column_stack`: stack columns. Examples: >>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32) :func:`~jax.numpy.unstack` performs the inverse operation: >>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32) """ if not len(arrays): raise ValueError("Need at least one array to stack.") if out is not None: raise NotImplementedError("The 'out' argument to jnp.stack is not supported.") if isinstance(arrays, (np.ndarray, Array)): axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: arrays = util.ensure_arraylike_tuple("stack", arrays) shape0 = np.shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] for a in arrays: if np.shape(a) != shape0: raise ValueError("All input arrays must have the same shape.") new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype)Join arrays along a new axis.
JAX implementation of :func:
numpy.stack.Args
arrays- a sequence of arrays to stack; each must have the same shape. If a
single array is given it will be treated equivalently to
arrays = unstack(arrays), but the implementation will avoid explicit unstacking. axis- specify the axis along which to stack.
out- unused by JAX
dtype- optional dtype of the resulting array. If not specified, the dtype
will be determined via type promotion rules described in :ref:
type-promotion.
Returns
the stacked result. See also: - :func:
jax.numpy.unstack: inverse ofstack. - :func:jax.numpy.concatenate: concatenation along existing axes. - :func:jax.numpy.vstack: stack vertically, i.e. along axis 0. - :func:jax.numpy.hstack: stack horizontally, i.e. along axis 1. - :func:jax.numpy.dstack: stack depth-wise, i.e. along axis 2. - :func:jax.numpy.column_stack: stack columns.Examples
>>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32):func:
~jax.numpy.unstackperforms the inverse operation:>>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32) def staticshape(a: ArrayLike | SupportsShape) ‑> tuple[int, ...]-
Expand source code
@export def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: """Return the shape an array. JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function raises a :class:`TypeError` if the input is a collection such as a list or tuple. Args: a: array-like object, or any object with a ``shape`` attribute. Returns: An tuple of integers representing the shape of ``a``. Examples: Shape for arrays: >>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3) This also works for scalars: >>> jnp.shape(3.14) () For arrays, this can also be accessed via the :attr:`jax.Array.shape` property: >>> x.shape (10,) """ if hasattr(a, "shape"): return a.shape # Deprecation warning added 2025-2-20. check_arraylike("shape", a, emit_warning=True) if hasattr(a, "__jax_array__"): a = a.__jax_array__() # NumPy dispatches to a.shape if available. return np.shape(a)Return the shape an array.
JAX implementation of :func:
numpy.shape. Unlikenp.shape, this function raises a :class:TypeErrorif the input is a collection such as a list or tuple.Args
a- array-like object, or any object with a
shapeattribute.
Returns
An tuple of integers representing the shape of
a.Examples
Shape for arrays:
>>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3)This also works for scalars:
>>> jnp.shape(3.14) ()For arrays, this can also be accessed via the :attr:
jax.Array.shapeproperty:>>> x.shape (10,) def stop_gradient(x: T) ‑> ~T-
Expand source code
def stop_gradient(x: T) -> T: """Stops gradient computation. Operationally ``stop_gradient`` is the identity function, that is, it returns argument `x` unchanged. However, ``stop_gradient`` prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, ``stop_gradient`` stops gradients for all of them. For some discussion of where this is useful, refer to :ref:`stopping-gradients`. Args: x: array or pytree of arrays Returns: input value is returned unchanged, but within autodiff will be treated as a constant. Examples: Consider a simple function that returns the square of the input value: >>> def f1(x): ... return x ** 2 >>> x = jnp.float32(3.0) >>> f1(x) Array(9.0, dtype=float32) >>> jax.grad(f1)(x) Array(6.0, dtype=float32) The same function with ``stop_gradient`` around ``x`` will be equivalent under normal evaluation, but return a zero gradient because ``x`` is effectively treated as a constant: >>> def f2(x): ... return jax.lax.stop_gradient(x) ** 2 >>> f2(x) Array(9.0, dtype=float32) >>> jax.grad(f2)(x) Array(0.0, dtype=float32) This is used in a number of places within the JAX codebase; for example :func:`jax.nn.softmax` internally normalizes the input by its maximum value, and this maximum value is wrapped in ``stop_gradient`` for efficiency. Refer to :ref:`stopping-gradients` for more discussion of the applicability of ``stop_gradient``. """ return tree_util.tree_map(_stop_gradient, x)Stops gradient computation.
Operationally
stop_gradientis the identity function, that is, it returns argumentxunchanged. However,stop_gradientprevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradientstops gradients for all of them. For some discussion of where this is useful, refer to :ref:stopping-gradients.Args
x- array or pytree of arrays
Returns
input value is returned unchanged, but within autodiff will be treated as a constant.
Examples
Consider a simple function that returns the square of the input value:
>>> def f1(x): ... return x ** 2 >>> x = jnp.float32(3.0) >>> f1(x) Array(9.0, dtype=float32) >>> jax.grad(f1)(x) Array(6.0, dtype=float32)The same function with
stop_gradientaroundxwill be equivalent under normal evaluation, but return a zero gradient becausexis effectively treated as a constant:>>> def f2(x): ... return jax.lax.stop_gradient(x) ** 2 >>> f2(x) Array(9.0, dtype=float32) >>> jax.grad(f2)(x) Array(0.0, dtype=float32)This is used in a number of places within the JAX codebase; for example :func:
jax.nn.softmaxinternally normalizes the input by its maximum value, and this maximum value is wrapped instop_gradientfor efficiency. Refer to :ref:stopping-gradientsfor more discussion of the applicability ofstop_gradient. def tan(x: ArrayLike, /) ‑> jax.jaxlib._jax.Array-
Expand source code
@export @jit(inline=True) def tan(x: ArrayLike, /) -> Array: """Compute a trigonometric tangent of each element of input. JAX implementation of :obj:`numpy.tan`. Args: x: scalar or array. Angle in radians. Returns: An array containing the tangent of each element in ``x``, promotes to inexact dtype. See also: - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of input. - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of trigonometric tangent of each element of input. Examples: >>> pi = jnp.pi >>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.tan(x)) [ 0. 0.577 1. -1. -0.577] """ out = lax.tan(*promote_args_inexact('tan', x)) jnp_error._set_error_if_nan(out) return outCompute a trigonometric tangent of each element of input.
JAX implementation of :obj:
numpy.tan.Args
x- scalar or array. Angle in radians.
Returns
An array containing the tangent of each element in
x, promotes to inexact dtype. See also: - :func:jax.numpy.sin: Computes a trigonometric sine of each element of input. - :func:jax.numpy.cos: Computes a trigonometric cosine of each element of input. - :func:jax.numpy.arctanand :func:jax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.Examples
>>> pi = jnp.pi >>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.tan(x)) [ 0. 0.577 1. -1. -0.577] def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) ‑> jax.jaxlib._jax.Array-
Expand source code
@export def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: """Construct an array by repeating ``A`` along specified dimensions. JAX implementation of :func:`numpy.tile`. If ``A`` is an array of shape ``(d1, d2, ..., dn)`` and ``reps`` is a sequence of integers, the resulting array will have a shape of ``(reps[0] * d1, reps[1] * d2, ..., reps[n] * dn)``, with ``A`` tiled along each dimension. Args: A: input array to be repeated. Can be of any shape or dimension. reps: specifies the number of repetitions along each axis. Returns: a new array where the input array has been repeated according to ``reps``. See also: - :func:`jax.numpy.repeat`: Construct an array from repeated elements. - :func:`jax.numpy.broadcast_to`: Broadcast an array to a specified shape. Examples: >>> arr = jnp.array([1, 2]) >>> jnp.tile(arr, 2) Array([1, 2, 1, 2], dtype=int32) >>> arr = jnp.array([[1, 2], ... [3, 4,]]) >>> jnp.tile(arr, (2, 1)) Array([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=int32) """ A = util.ensure_arraylike("tile", A) try: reps_tup = tuple(iter(reps)) # pyrefly: ignore[no-matching-overload] except TypeError: reps_tup: tuple[DimSize, ...] = (reps,) reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep for rep in reps_tup) # lax.tile expects reps and A.shape to have the same rank. reps_tup = (1,) * (A.ndim - len(reps_tup)) + reps_tup if len(reps_tup) > np.ndim(A): A = lax.expand_dims( A, dimensions=tuple(range(len(reps_tup) - np.ndim(A)))) return lax.tile(A, reps_tup)Construct an array by repeating
Aalong specified dimensions.JAX implementation of :func:
numpy.tile.If
Ais an array of shape(d1, d2, …, dn)andrepsis a sequence of integers, the resulting array will have a shape of(reps[0] * d1, reps[1] * d2, ..., reps[n] * dn), withAtiled along each dimension.Args
A- input array to be repeated. Can be of any shape or dimension.
reps- specifies the number of repetitions along each axis.
Returns
a new array where the input array has been repeated according to
reps. See also: - :func:jax.numpy.repeat: Construct an array from repeated elements. - :func:jax.numpy.broadcast_to: Broadcast an array to a specified shape.Examples
>>> arr = jnp.array([1, 2]) >>> jnp.tile(arr, 2) Array([1, 2, 1, 2], dtype=int32) >>> arr = jnp.array([[1, 2], ... [3, 4,]]) >>> jnp.tile(arr, (2, 1)) Array([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=int32) def transpose(a: ArrayLike, axes: Sequence[int] | None = None) ‑> jax.jaxlib._jax.Array-
Expand source code
@export def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. JAX implementation of :func:`numpy.transpose`, implemented in terms of :func:`jax.lax.transpose`. Args: a: input array axes: optionally specify the permutation using a length-`a.ndim` sequence of integers ``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e. reverses the order of all axes. Returns: transposed copy of the array. See Also: - :func:`jax.Array.transpose`: equivalent function via an :class:`~jax.Array` method. - :attr:`jax.Array.T`: equivalent function via an :class:`~jax.Array` property. - :func:`jax.numpy.matrix_transpose`: transpose the last two axes of an array. This is suitable for working with batched 2D matrices. - :func:`jax.numpy.swapaxes`: swap any two axes in an array. - :func:`jax.numpy.moveaxis`: move an axis to another position in the array. Note: Unlike :func:`numpy.transpose`, :func:`jax.numpy.transpose` will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn't have performance impacts in practice. Examples: For a 1D array, the transpose is the identity: >>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32) For a 2D array, the transpose is a matrix transpose: >>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32) For an N-dimensional array, the transpose reverses the order of the axes: >>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3) The ``axes`` argument can be specified to change this default behavior: >>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4) Since swapping the last two axes is a common operation, it can be done via its own API, :func:`jax.numpy.matrix_transpose`: >>> jnp.matrix_transpose(x).shape (3, 5, 4) For convenience, transposes may also be performed using the :meth:`jax.Array.transpose` method or the :attr:`jax.Array.T` property: >>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32) """ a = util.ensure_arraylike("transpose", a) axes_ = list(range(a.ndim)[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_] return lax.transpose(a, axes_)Return a transposed version of an N-dimensional array.
JAX implementation of :func:
numpy.transpose, implemented in terms of :func:jax.lax.transpose.Args
a- input array
axes- optionally specify the permutation using a length-
a.ndimsequence of integersisatisfying0 <= i < a.ndim. Defaults torange(a.ndim)[::-1], i.e. reverses the order of all axes.
Returns
transposed copy of the array. See Also: - :func:
jax.Array.transpose: equivalent function via an :class:~jax.Arraymethod. - :attr:jax.Array.T: equivalent function via an :class:~jax.Arrayproperty. - :func:jax.numpy.matrix_transpose: transpose the last two axes of an array. This is suitable for working with batched 2D matrices. - :func:jax.numpy.swapaxes: swap any two axes in an array. - :func:jax.numpy.moveaxis: move an axis to another position in the array.Note
Unlike :func:
numpy.transpose, :func:jax.numpy.transposewill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn't have performance impacts in practice.Examples
For a 1D array, the transpose is the identity:
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)For a 2D array, the transpose is a matrix transpose:
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)For an N-dimensional array, the transpose reverses the order of the axes:
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)The
axesargument can be specified to change this default behavior:>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)Since swapping the last two axes is a common operation, it can be done via its own API, :func:
jax.numpy.matrix_transpose:>>> jnp.matrix_transpose(x).shape (3, 5, 4)For convenience, transposes may also be performed using the :meth:
jax.Array.transposemethod or the :attr:jax.Array.Tproperty:>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)
Methods
def all(self, boolean_tensor, axis=None, keepdims=False)-
Expand source code
def all(self, boolean_tensor, axis=None, keepdims=False): if isinstance(boolean_tensor, (tuple, list)): boolean_tensor = jnp.stack(boolean_tensor) return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims) def allocate_on_device(self, tensor: ~TensorType, device: phiml.backend._backend.ComputeDevice) ‑> ~TensorType-
Expand source code
def allocate_on_device(self, tensor: TensorType, device: ComputeDevice) -> TensorType: return jax.device_put(tensor, device.ref)Moves
tensortodevice. May copy the tensor if it is already on the device.Args
tensor- Existing tensor native to this backend.
device- Target device, associated with this backend.
def any(self, boolean_tensor, axis=None, keepdims=False)-
Expand source code
def any(self, boolean_tensor, axis=None, keepdims=False): if isinstance(boolean_tensor, (tuple, list)): boolean_tensor = jnp.stack(boolean_tensor) return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims) def argsort(self, x, axis=-1)-
Expand source code
def argsort(self, x, axis=-1): return jnp.argsort(x, axis) def as_tensor(self, x, convert_external=True)-
Expand source code
def as_tensor(self, x, convert_external=True): self._check_float64() if self.is_tensor(x, only_native=convert_external): array = x else: array = jnp.array(x) # --- Enforce Precision --- if not isinstance(array, numbers.Number): if self.dtype(array).kind == float: array = self.to_float(array) elif self.dtype(array).kind == complex: array = self.to_complex(array) return arrayConverts a tensor-like object to the native tensor representation of this backend. If x is a native tensor of this backend, it is returned without modification. If x is a Python number (numbers.Number instance),
convert_numbersdecides whether to convert it unless the backend cannot handle Python numbers.Note: There may be objects that are considered tensors by this backend but are not native and thus, will be converted by this method.
Args
x- tensor-like, e.g. list, tuple, Python number, tensor
convert_external- if False and
xis a Python number that is understood by this backend, this method returns the number as-is. This can help prevent type clashes like int32 vs int64. (Default value = True)
Returns
tensor representation of
x def batched_gather_1d(self, values, indices)-
Expand source code
def batched_gather_1d(self, values, indices): batch_size = combined_dim(values.shape[0], indices.shape[0]) return values[np.arange(batch_size)[:, None], indices]Args
values- (batch, spatial)
indices- (batch, indices)
Returns
(batch, indices)
def batched_gather_nd(self, values, indices)-
Expand source code
def batched_gather_nd(self, values, indices): values = self.as_tensor(values) indices = self.as_tensor(indices) assert indices.shape[-1] == self.ndims(values) - 2 batch_size = combined_dim(values.shape[0], indices.shape[0]) indices_list = [indices[..., i] for i in range(indices.shape[-1])] batch_range = self.expand_dims(np.arange(batch_size), -1, number=self.ndims(indices) - 2) slices = (batch_range, *indices_list) return values[slices]Gathers values from the tensor
valuesat locationsindices. The first dimension ofvaluesandindicesis the batch dimension which must be either equal for both or one for either.Args
values- tensor of shape (batch, spatial…, channel)
indices- int tensor of shape (batch, any…, multi_index) where the size of multi_index is values.rank - 2.
Returns
Gathered values as tensor of shape (batch, any…, channel)
def bincount(self, x, weights: ~TensorType | None, bins: int, x_sorted=False)-
Expand source code
def bincount(self, x, weights: Optional[TensorType], bins: int, x_sorted=False): if x_sorted: return jax.ops.segment_sum(weights or 1, x, bins, indices_are_sorted=True) else: return jnp.bincount(x, weights=weights, minlength=bins, length=bins)Args
x- Bin indices, 1D int tensor.
weights- Weights corresponding to
x, 1D tensor. All weights are 1 ifweights=None. bins- Number of bins.
x_sorted- Whether
xis sorted from lowest to highest bin.
Returns
bin_counts
def block_until_ready(self, values)-
Expand source code
def block_until_ready(self, values): if hasattr(values, 'block_until_ready'): values.block_until_ready() if isinstance(values, (tuple, list)): for v in values: self.block_until_ready(v) def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0)-
Expand source code
def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0): if new_length is None: slices = [mask if i == axis else slice(None) for i in range(len(x.shape))] return x[tuple(slices)] else: indices = jnp.argwhere(mask, size=new_length, fill_value=-1)[..., 0] valid = indices >= 0 valid = valid[tuple([slice(None) if i == axis else None for i in range(len(x.shape))])] result = self.gather(x, jnp.maximum(0, indices), axis) return jnp.where(valid, result, fill_value)Args
x- tensor with any number of dimensions
mask- 1D mask tensor
axis- Axis index >= 0
new_length- Maximum size of the output along
axis. This must be set when jit-compiling with Jax. fill_value- If
new_lengthis larger than the filtered result, the remaining values will be set tofill_value.
def cast(self, x, dtype: phiml.backend._dtype.DType)-
Expand source code
def cast(self, x, dtype: DType): if self.is_tensor(x, only_native=True) and from_numpy_dtype(x.dtype) == dtype: return x else: return jnp.array(x, to_numpy_dtype(dtype)) def conv(self,
value,
kernel,
strides: Sequence[int],
out_sizes: Sequence[int],
transpose: bool)-
Expand source code
def conv(self, value, kernel, strides: Sequence[int], out_sizes: Sequence[int], transpose: bool): assert not transpose, "transpose conv not yet supported for Jax" assert kernel.shape[0] in (1, value.shape[0]) assert value.shape[1] == kernel.shape[2], f"value has {value.shape[1]} channels but kernel has {kernel.shape[2]}" assert value.ndim + 1 == kernel.ndim value, kernel = self.auto_cast(value, kernel, bool_to_int=True) ndim = len(value.shape) - 2 # Number of spatial dimensions assert len(strides) == ndim, f"Expected {ndim} stride values, got {len(strides)}" # --- Determine padding --- # default_size = [int(np.ceil((vs - ks + 1) / st)) for vs, ks, st in zip(value.shape[2:], kernel.shape[3:], strides)] # size if no padding is used lr_padding = [max(0, st * (os - 1) - vs + ks) for st, os, vs, ks in zip(strides, out_sizes, value.shape[2:], kernel.shape[3:])] padding = [((p+1) // 2, p // 2) for p in lr_padding] # --- Run the (transposed) convolution --- sp = ''.join(['WHD'[i] for i in range(len(strides))]) dim_num = jax.lax.conv_dimension_numbers(value.shape, kernel.shape[1:], ('NC'+sp, 'OI'+sp, 'NC'+sp)) if kernel.shape[0] == 1: return jax.lax.conv_general_dilated(value, kernel[0], strides, padding, None, None, dim_num) else: result = [] for b in range(kernel.shape[0]): result.append(jax.lax.conv_general_dilated(value[b:b + 1], kernel[b], strides, padding, None, None, dim_num)) return jnp.concatenate(result, 0)Convolve value with kernel. Depending on the tensor rank, the convolution is either 1D (rank=3), 2D (rank=4) or 3D (rank=5). Higher dimensions may not be supported.
Args
value- tensor of shape (batch_size, in_channel, spatial…)
kernel- tensor of shape (batch_size or 1, out_channel, in_channel, spatial…)
strides- Convolution strides, one
intfor each spatial dim. For transpose, they act as upsampling factors. out_sizes- Spatial shape of the output tensor. This determines how much zero-padding or slicing is used.
transpose- If
True, performs a transposed convolution, according to PyTorch's definition.
Returns
Convolution result as tensor of shape (batch_size, out_channel, spatial…)
def copy(self, tensor, only_mutable=False)-
Expand source code
def copy(self, tensor, only_mutable=False): return jnp.array(tensor, copy=True) def custom_gradient(self,
f: Callable,
gradient: Callable,
get_external_cache: Callable = None,
on_call_skipped: Callable = None) ‑> Callable-
Expand source code
def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) -> Callable: jax_fun = jax.custom_vjp(f) # custom vector-Jacobian product (reverse-mode differentiation) def forward(*x): y = f(*x) return y, (x, y) def backward(x_y, dy): x, y = x_y dx = gradient(x, y, dy) return tuple(dx) jax_fun.defvjp(forward, backward) return jax_funCreates a function based on
fthat uses a custom gradient for backprop.Args
f- Forward function.
gradient- Function for backprop. Will be called as
gradient(*d_out)to compute the gradient off.
Returns
Function with similar signature and return values as
f. However, the returned function does not support keyword arguments. def disassemble(self, x) ‑> Tuple[Callable, Sequence[~TensorType]]-
Expand source code
def disassemble(self, x) -> Tuple[Callable, Sequence[TensorType]]: if self.is_sparse(x): if isinstance(x, COO): raise NotImplementedError # return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (np.stack([x.row, x.col], -1), x.data) if isinstance(x, BCOO): return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (x.indices, x.data) if isinstance(x, CSR): raise NotImplementedError # return lambda b, v, i, p: b.csr_matrix(i, p, v, x.shape), (x.data, x.indices, x.indptr) elif isinstance(x, CSC): raise NotImplementedError # return lambda b, v, i, p: b.csc_matrix(p, i, v, x.shape), (x.data, x.indices, x.indptr) raise NotImplementedError else: return lambda b, t: t, (x,)Disassemble a (sparse) tensor into its individual constituents, such as values and indices.
Args
x- Tensor
Returns
assemble- Function
assemble(backend, *constituents)that reassemblesxfrom the constituents. constituents- Tensors contained in
x.
def divide_no_nan(self, x, y)-
Expand source code
def divide_no_nan(self, x, y): return jnp.where(y == 0, 0, x / y) # jnp.nan_to_num(x / y, copy=True, nan=0) covers up NaNs from beforeComputes x/y but returns 0 if y=0.
def dtype(self, array) ‑> phiml.backend._dtype.DType-
Expand source code
def dtype(self, array) -> DType: if isinstance(array, bool): return BOOL if isinstance(array, int): return INT32 if isinstance(array, float): return FLOAT64 if isinstance(array, complex): return COMPLEX128 if not isinstance(array, jnp.ndarray): array = jnp.array(array) return from_numpy_dtype(array.dtype) def eig(self, matrix: ~TensorType) ‑> ~TensorType-
Expand source code
def eig(self, matrix: TensorType) -> TensorType: return jnp.linalg.eig(matrix)Args
matrix- (batch…, n, n)
Returns
eigenvalues- (batch…, n,)
eigenvectors- (batch…, n, n)
def eigvals(self, matrix: ~TensorType) ‑> ~TensorType-
Expand source code
def eigvals(self, matrix: TensorType) -> TensorType: return jnp.linalg.eigvals(matrix)Args
matrix- (batch…, n, n)
Returns
eigenvalues as (batch…, n,)
def expand_dims(self, a, axis=0, number=1)-
Expand source code
def expand_dims(self, a, axis=0, number=1): for _i in range(number): a = jnp.expand_dims(a, axis) return a def fft(self, x, axes: tuple | list)-
Expand source code
def fft(self, x, axes: Union[tuple, list]): x = self.to_complex(x) if not axes: return x if len(axes) == 1: return jnp.fft.fft(x, axis=axes[0]).astype(x.dtype) elif len(axes) == 2: return jnp.fft.fft2(x, axes=axes).astype(x.dtype) else: return jnp.fft.fftn(x, axes=axes).astype(x.dtype)Computes the n-dimensional FFT along all but the first and last dimensions.
Args
x- tensor of dimension 3 or higher
axes- Along which axes to perform the FFT
Returns
Complex tensor
k def from_dlpack(self, external_array)-
Expand source code
def from_dlpack(self, external_array): from jax import dlpack return dlpack.from_dlpack(external_array) def from_external_array_dlpack(self, external_array)-
Expand source code
def from_external_array_dlpack(self, external_array): from jax import dlpack return dlpack.from_dlpack(external_array) def gamma_inc_l(self, a, x)-
Expand source code
def gamma_inc_l(self, a, x): return scipy.special.gammainc(a, x)Regularized lower incomplete gamma function.
def gamma_inc_u(self, a, x)-
Expand source code
def gamma_inc_u(self, a, x): return scipy.special.gammaincc(a, x)Regularized upper incomplete gamma function.
def gather(self, values, indices, axis: int)-
Expand source code
def gather(self, values, indices, axis: int): slices = [indices if i == axis else slice(None) for i in range(self.ndims(values))] return values[tuple(slices)]Gathers values from the tensor
valuesat locationsindices.Args
values- tensor
indices- 1D tensor
axis- Axis along which to gather slices
Returns
tensor, with size along
axisbeing the length ofindices def get_device(self, tensor: ~TensorType) ‑> phiml.backend._backend.ComputeDevice-
Expand source code
def get_device(self, tensor: TensorType) -> ComputeDevice: if hasattr(tensor, 'devices'): return self.get_device_by_ref(next(iter(tensor.devices()))) if hasattr(tensor, 'device'): return self.get_device_by_ref(tensor.device()) raise AssertionError(f"tensor {type(tensor)} has no device attribute")Returns the device
tensoris located on. def get_diagonal(self, matrices, offset=0)-
Expand source code
def get_diagonal(self, matrices, offset=0): result = jnp.diagonal(matrices, offset=offset, axis1=1, axis2=2) return jnp.transpose(result, [0, 2, 1])Args
matrices- (batch, rows, cols, channels)
offset- 0=diagonal, positive=above diagonal, negative=below diagonal
Returns
diagonal- (batch, max(rows,cols), channels)
def get_sparse_format(self, x) ‑> str-
Expand source code
def get_sparse_format(self, x) -> str: format_names = { COO: 'coo', BCOO: 'coo', CSR: 'csr', CSC: 'csc', } return format_names.get(type(x), 'dense')Returns lower-case format string, such as 'coo', 'csr', 'csc'
def histogram1d(self, values, weights, bin_edges)-
Expand source code
def histogram1d(self, values, weights, bin_edges): def unbatched_hist(values, weights, bin_edges): hist, _ = jnp.histogram(values, bin_edges, weights=weights) return hist return jax.vmap(unbatched_hist)(values, weights, bin_edges)Args
values- (batch, values)
bin_edges- (batch, edges)
weights- (batch, values)
Returns
(batch, edges) with dtype matching weights
def ifft(self, k, axes: tuple | list)-
Expand source code
def ifft(self, k, axes: Union[tuple, list]): if not axes: return k if len(axes) == 1: return jnp.fft.ifft(k, axis=axes[0]).astype(k.dtype) elif len(axes) == 2: return jnp.fft.ifft2(k, axes=axes).astype(k.dtype) else: return jnp.fft.ifftn(k, axes=axes).astype(k.dtype)Computes the n-dimensional inverse FFT along all but the first and last dimensions.
Args
k- tensor of dimension 3 or higher
axes- Along which axes to perform the inverse FFT
Returns
Complex tensor
x def is_available(self, tensor)-
Expand source code
def is_available(self, tensor): if isinstance(tensor, JVPTracer): tensor = tensor.primal return not isinstance(tensor, Tracer)Tests if the value of the tensor is known and can be read at this point. If true,
numpy(tensor)must return a valid NumPy representation of the value.Tensors are typically available when the backend operates in eager mode.
Args
tensor- backend-compatible tensor
Returns
bool
def is_module(self, obj)-
Expand source code
def is_module(self, obj): return FalseTests if
objis of a type that is specific to this backend, e.g. a neural network. IfTrue, this backend will be chosen for operations involvingobj.See Also:
Backend.is_tensor().Args
obj- Object to test.
def is_sparse(self, x) ‑> bool-
Expand source code
def is_sparse(self, x) -> bool: return isinstance(x, (COO, BCOO, CSR, CSC))Args
x- Tensor native to this
Backend.
def is_tensor(self, x, only_native=False)-
Expand source code
def is_tensor(self, x, only_native=False): if isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray): # NumPy arrays inherit from Jax arrays return True if isinstance(x, jnp.bool_) and not isinstance(x, np.bool_): return True if self.is_sparse(x): return True # --- Above considered native --- if only_native: return False # --- Non-native types --- if isinstance(x, np.ndarray): return True if isinstance(x, np.bool_): return True if isinstance(x, (numbers.Number, bool)): return True if isinstance(x, (tuple, list)): return all([self.is_tensor(item, False) for item in x]) return FalseAn object is considered a native tensor by a backend if no internal conversion is required by backend methods. An object is considered a tensor (nativer or otherwise) by a backend if it is not a struct (e.g. tuple, list) and all methods of the backend accept it as a tensor argument.
If
True, this backend will be chosen for operations involvingx.See Also:
Backend.is_module().Args
x- object to check
only_native- If True, only accepts true native tensor representations, not Python numbers or others that are also supported as tensors (Default value = False)
Returns
bool- whether
xis considered a tensor by this backend
def jacobian(self, f, wrt: tuple | list, get_output: bool, is_f_scalar: bool)-
Expand source code
def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool): if get_output: jax_grad_f = jax.value_and_grad(f, argnums=wrt, has_aux=True) @wraps(f) def unwrap_outputs(*args): args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)] (_, output_tuple), grads = jax_grad_f(*args) return (*output_tuple, *[jnp.conjugate(g) for g in grads]) return unwrap_outputs else: @wraps(f) def nonaux_f(*args): loss, output = f(*args) return loss jax_grad = jax.grad(nonaux_f, argnums=wrt, has_aux=False) @wraps(f) def call_jax_grad(*args): args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)] grads = jax_grad(*args) return tuple([jnp.conjugate(g) for g in grads]) return call_jax_gradArgs
f- Function to differentiate. Returns a tuple containing
(reduced_loss, output) wrt- Argument indices for which to compute the gradient.
get_output- Whether the derivative function should return the output of
fin addition to the gradient. is_f_scalar- Whether
fis guaranteed to return a scalar output.
Returns
A function
gwith the same arguments asf. Ifget_output=True,greturns atuplecontaining the outputs offfollowed by the gradients. The gradients retain the dimensions ofreduced_lossin order as outer (first) dimensions. def jit_compile(self, f: Callable) ‑> Callable-
Expand source code
def jit_compile(self, f: Callable) -> Callable: def run_jit_f(*args): # print(jax.make_jaxpr(f)(*args)) ML_LOGGER.debug(f"JaxBackend: running jit-compiled '{f.__name__}' with shapes {[self.shape(arg) for arg in args]} and dtypes {[self.dtype(arg) for arg in args]}") return self.as_registered.call(jit_f, *args, name=f"run jit-compiled '{f.__name__}'") run_jit_f.__name__ = f"Jax-Jit({f.__name__})" jit_f = jax.jit(f, device=self._default_device.ref) return run_jit_f def linspace(self, start, stop, number)-
Expand source code
def linspace(self, start, stop, number): self._check_float64() return jax.device_put(jnp.linspace(start, stop, number, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref) def linspace_without_last(self, start, stop, number)-
Expand source code
def linspace_without_last(self, start, stop, number): self._check_float64() return jax.device_put(jnp.linspace(start, stop, number, endpoint=False, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref) def log_gamma(self, x)-
Expand source code
def log_gamma(self, x): return jax.lax.lgamma(self.to_float(x)) def matrix_rank_dense(self, matrix, hermitian=False)-
Expand source code
def matrix_rank_dense(self, matrix, hermitian=False): try: return jnp.linalg.matrix_rank(matrix) except TypeError as err: if err.args[0] == "array should have 2 or fewer dimensions": # this is a Jax bug on some distributions/versions warnings.warn("You are using a broken version of JAX. matrix_rank for dense matrices will fall back to NumPy.") return self.as_tensor(NUMPY.matrix_rank_dense(self.numpy(matrix), hermitian=hermitian)) else: raise errArgs
matrix- Dense matrix of shape (batch, rows, cols)
hermitian- Whether all matrices are guaranteed to be hermitian.
def matrix_solve_least_squares(self, matrix: ~TensorType, rhs: ~TensorType) ‑> Tuple[~TensorType, ~TensorType, ~TensorType, ~TensorType]-
Expand source code
def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tuple[TensorType, TensorType, TensorType, TensorType]: solution, residuals, rank, singular_values = lstsq_batched(matrix, rhs) return solution, residuals, rank, singular_valuesArgs
matrix- Shape (batch, vec, constraints)
rhs- Shape (batch, vec, batch_per_matrix)
Returns
solution- Solution vector of Shape (batch, constraints, batch_per_matrix)
residuals- Optional, can be
None rank- Optional, can be
None singular_values- Optional, can be
None
def max(self, x, axis=None, keepdims=False)-
Expand source code
def max(self, x, axis=None, keepdims=False): return jnp.max(x, axis, keepdims=keepdims) def mean(self, value, axis=None, keepdims=False)-
Expand source code
def mean(self, value, axis=None, keepdims=False): return jnp.mean(value, axis, keepdims=keepdims) def meshgrid(self, *coordinates)-
Expand source code
def meshgrid(self, *coordinates): self._check_float64() coordinates = [self.as_tensor(c) for c in coordinates] return [jax.device_put(c, self._default_device.ref) for c in jnp.meshgrid(*coordinates, indexing='ij')] def min(self, x, axis=None, keepdims=False)-
Expand source code
def min(self, x, axis=None, keepdims=False): return jnp.min(x, axis, keepdims=keepdims) def mul(self, a, b)-
Expand source code
def mul(self, a, b): # if scipy.sparse.issparse(a): # TODO sparse? # return a.multiply(b) # elif scipy.sparse.issparse(b): # return b.multiply(a) # else: return Backend.mul(self, a, b) def mul_matrix_batched_vector(self, A, b)-
Expand source code
def mul_matrix_batched_vector(self, A, b): from jax.experimental.sparse import BCOO if isinstance(A, BCOO): return(A @ b.T).T return jnp.stack([A.dot(b[i]) for i in range(b.shape[0])]) def nn_library(self)-
Expand source code
def nn_library(self): from . import stax_nets return stax_nets def nonzero(self, values, length=None, fill_value=-1)-
Expand source code
def nonzero(self, values, length=None, fill_value=-1): result = jnp.nonzero(values, size=length, fill_value=fill_value) return jnp.stack(result, -1)Args
values- Tensor with only spatial dimensions
length- (Optional) Length of the resulting array. If specified, the result array will be padded with
fill_valueor trimmed.
Returns
non-zero multi-indices as tensor of shape (nnz/length, vector)
def numpy(self, tensor)-
Expand source code
def numpy(self, tensor): if isinstance(tensor, JVPTracer): tensor = tensor.primal if self.is_sparse(tensor): assemble, parts = self.disassemble(tensor) return assemble(NUMPY, *[self.numpy(t) for t in parts]) return np.array(tensor)Returns a NumPy representation of the given tensor. If
tensoris already a NumPy array, it is returned without modification.This method raises an error if the value of the tensor is not known at this point, e.g. because it represents a node in a graph. Use
is_available(tensor)to check if the value can be represented as a NumPy array.Args
tensor- backend-compatible tensor or sparse tensor
Returns
NumPy representation of the values stored in the tensor
def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args)-
Expand source code
def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args): if all([self.is_available(arg) for arg in args]): args = [self.numpy(arg) for arg in args] output = f(*args, **aux_args) result = map_structure(self.as_tensor, output) return result @dataclasses.dataclass class OutputTensor: shape: Tuple[int] dtype: np.dtype output_specs = map_structure(lambda t, s: OutputTensor(s, to_numpy_dtype(t)), output_dtypes, output_shapes) if hasattr(jax, 'pure_callback'): def aux_f(*args): return f(*args, **aux_args) return jax.pure_callback(aux_f, output_specs, *args) else: def aux_f(args): if isinstance(args, tuple): return f(*args, **aux_args) else: return f(args, **aux_args) from jax.experimental.host_callback import call return call(aux_f, args, result_shape=output_specs)This call can be used in jit-compiled code but is not differentiable.
Args
f- Function operating on numpy arrays.
output_shapes- Single shape
tupleor tuple of shapes declaring the shapes of the tensors returned byf. output_dtypes- Single
DTypeor tuple of DTypes declaring the dtypes of the tensors returned byf. *args- Tensor arguments to be converted to NumPy arrays and then passed to
f. **aux_args- Keyword arguments to be passed to
fwithout conversion.
Returns
Returned arrays of
fconverted to tensors. def ones(self, shape, dtype: phiml.backend._dtype.DType = None)-
Expand source code
def ones(self, shape, dtype: DType = None): self._check_float64() return jax.device_put(jnp.ones(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref) def ones_like(self, tensor)-
Expand source code
def ones_like(self, tensor): return jax.device_put(jnp.ones_like(tensor), self._default_device.ref) def pad(self, value, pad_width, mode='constant', constant_values=0)-
Expand source code
def pad(self, value, pad_width, mode='constant', constant_values=0): assert mode in ('constant', 'symmetric', 'periodic', 'reflect', 'boundary'), mode if mode == 'constant': constant_values = jnp.array(constant_values, dtype=value.dtype) return jnp.pad(value, pad_width, 'constant', constant_values=constant_values) else: if mode in ('periodic', 'boundary'): mode = {'periodic': 'wrap', 'boundary': 'edge'}[mode] return jnp.pad(value, pad_width, mode)Pad a tensor with values as specified by
modeandconstant_values.If the mode is not supported, returns NotImplemented.
Args
value- tensor
pad_width- 2D tensor specifying the number of values padded to the edges of each axis in the form [[axis 0 lower, axis 0 upper], …] including batch and component axes.
mode- constant', 'boundary', 'periodic', 'symmetric', 'reflect'
constant_values- Scalar value used for out-of-bounds points if mode='constant'. Must be a Python primitive type or scalar tensor.
mode- str: (Default value = 'constant')
Returns
padded tensor or
NotImplemented def prefers_channels_last(self) ‑> bool-
Expand source code
def prefers_channels_last(self) -> bool: return True def prod(self, value, axis=None)-
Expand source code
def prod(self, value, axis=None): if not isinstance(value, jnp.ndarray): value = jnp.array(value) if value.dtype == bool: return jnp.all(value, axis=axis) return jnp.prod(value, axis=axis) def quantile(self, x, quantiles)-
Expand source code
def quantile(self, x, quantiles): return jnp.quantile(x, quantiles, axis=-1)Reduces the last / inner axis of x.
Args
x- Tensor
quantiles- List or 1D tensor of quantiles to compute.
Returns
Tensor with shape (quantiles, *x.shape[:-1])
def random_normal(self, shape, dtype: phiml.backend._dtype.DType)-
Expand source code
def random_normal(self, shape, dtype: DType): self._check_float64() self.rnd_key, subkey = jax.random.split(self.rnd_key) dtype = dtype or self.float_type return jax.device_put(random.normal(subkey, shape, dtype=to_numpy_dtype(dtype)), self._default_device.ref)Float tensor of selected precision containing random values sampled from a normal distribution with mean 0 and std 1.
def random_permutations(self, permutations: int, n: int)-
Expand source code
def random_permutations(self, permutations: int, n: int): self.rnd_key, subkey = jax.random.split(self.rnd_key) result = jnp.stack([jax.random.permutation(subkey, n) for _ in range(permutations)]) return jax.device_put(result, self._default_device.ref)Generate
permutationsstacked arrays of shuffled integers between0andn. def random_uniform(self, shape, low, high, dtype: phiml.backend._dtype.DType | None)-
Expand source code
def random_uniform(self, shape, low, high, dtype: Union[DType, None]): self._check_float64() self.rnd_key, subkey = jax.random.split(self.rnd_key) dtype = dtype or self.float_type jdt = to_numpy_dtype(dtype) if dtype.kind == float: tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt) elif dtype.kind == complex: real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType.by_precision(float, dtype.precision))) imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType.by_precision(float, dtype.precision))) return real + 1j * imag elif dtype.kind == int: tensor = random.randint(subkey, shape, low, high, dtype=jdt) if tensor.dtype != jdt: warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning) else: raise ValueError(dtype) return jax.device_put(tensor, self._default_device.ref)Float tensor of selected precision containing random values in the range [0, 1)
def range(self, start, limit=None, delta=1, dtype: phiml.backend._dtype.DType = int32)-
Expand source code
def range(self, start, limit=None, delta=1, dtype: DType = INT32): if limit is None: start, limit = 0, start return jnp.arange(start, limit, delta, to_numpy_dtype(dtype)) def ravel_multi_index(self, multi_index, shape, mode: str | int = 'undefined')-
Expand source code
def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined'): if not self.is_available(shape): return Backend.ravel_multi_index(self, multi_index, shape, mode) mode = mode if isinstance(mode, int) else {'undefined': 'clip', 'periodic': 'wrap', 'clamp': 'clip'}[mode] idx_first = jnp.transpose(multi_index, (self.ndims(multi_index)-1,) + tuple(range(self.ndims(multi_index)-1))) result = jnp.ravel_multi_index(idx_first, shape, mode='wrap' if isinstance(mode, int) else mode) if isinstance(mode, int): outside = self.any((multi_index < 0) | (multi_index >= jnp.asarray(shape, dtype=multi_index.dtype)), -1) result = self.where(outside, mode, result) return resultArgs
multi_index- (batch…, index_dim)
shape- 1D tensor or tuple/list
mode'undefined','periodic','clamp'or anintto use for all invalid indices.
Returns
Integer tensor of shape (batch…) of same dtype as
multi_index. def repeat(self, x, repeats, axis: int, new_length=None)-
Expand source code
def repeat(self, x, repeats, axis: int, new_length=None): return jnp.repeat(x, self.as_tensor(repeats), axis, total_repeat_length=new_length)Repeats the elements along
axisrepeatstimes.Args
x- Tensor
repeats- How often to repeat each element. 1D tensor of length x.shape[axis]
axis- Which axis to repeat elements along
new_length- Set the length of
axisafter repeating. This is required for jit compilation with Jax.
Returns
repeated Tensor
def requires_fixed_shapes_when_tracing(self) ‑> bool-
Expand source code
def requires_fixed_shapes_when_tracing(self) -> bool: return True def reshape(self, value, shape)-
Expand source code
def reshape(self, value, shape): return jnp.reshape(value, shape) def scatter(self, base_grid, indices, values, mode: str)-
Expand source code
def scatter(self, base_grid, indices, values, mode: str): out_kind = self.combine_types(self.dtype(base_grid), self.dtype(values)).kind base_grid, values = self.auto_cast(base_grid, values, bool_to_int=True) batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0]) spatial_dims = tuple(range(base_grid.ndim - 2)) dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,), # channel dim of updates (batch dim removed) inserted_window_dims=spatial_dims, # no idea what this does but spatial_dims seems to work scatter_dims_to_operand_dims=spatial_dims) # spatial dims of base_grid (batch dim removed) scatter = {'add': jax.lax.scatter_add, 'update': jax.lax.scatter, 'max': jax.lax.scatter_max, 'min': jax.lax.scatter_min}[mode] def scatter_single(base_grid, indices, values): return scatter(base_grid, indices, values, dnums) if self.staticshape(indices)[0] == 1: indices = self.tile(indices, [batch_size, 1, 1]) if self.staticshape(values)[0] == 1: values = self.tile(values, [batch_size, 1, 1]) result = self.vectorized_call(scatter_single, base_grid, indices, values) if self.dtype(result).kind != out_kind: if out_kind == bool: result = self.cast(result, BOOL) return resultBatched n-dimensional scatter.
Args
base_grid- Tensor into which scatter values are inserted at indices. Tensor of shape (batch_size, spatial…, channels)
indices- Tensor of shape (batch_size or 1, update_count, index_vector)
values- Values to scatter at indices. Tensor of shape (batch_size or 1, update_count or 1, channels or 1)
mode- One of ('update', 'add', 'max', 'min')
Returns
Copy of base_grid with values at
indicesupdated byvalues. def searchsorted(self, sorted_sequence, search_values, side: str, dtype=int32)-
Expand source code
def searchsorted(self, sorted_sequence, search_values, side: str, dtype=INT32): if self.ndims(sorted_sequence) == 1: return jnp.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype)) else: return jax.vmap(partial(self.searchsorted, side=side, dtype=dtype))(sorted_sequence, search_values) def seed(self, seed: int)-
Expand source code
def seed(self, seed: int): self.rnd_key = jax.random.PRNGKey(seed) def sizeof(self, tensor) ‑> int-
Expand source code
def sizeof(self, tensor) -> int: return tensor.nbytesReturns the size in bytes
def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool)-
Expand source code
def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool): matrix, rhs = self.auto_cast(matrix, rhs, int_to_float=True, bool_to_int=True) x = jax.lax.linalg.triangular_solve(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal, left_side=True) return xArgs
matrix- (batch_size, rows, cols)
rhs- (batch_size, cols)
lower: unit_diagonal:
Returns
(batch_size, cols)
def sort(self, x, axis=-1)-
Expand source code
def sort(self, x, axis=-1): return jnp.sort(x, axis) def sparse_coo_tensor(self, indices: tuple | list, values, shape: tuple)-
Expand source code
def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple): return BCOO((values, indices), shape=shape)Create a sparse matrix in coordinate list (COO) format.
Optional feature.
See Also:
Backend.csr_matrix(),Backend.csc_matrix().Args
indices- 2D tensor of shape
(nnz, dims). values- 1D values tensor matching
indices shape- Shape of the sparse matrix
Returns
Native representation of the sparse matrix
def std(self, x, axis=None, keepdims=False)-
Expand source code
def std(self, x, axis=None, keepdims=False): return jnp.std(x, axis, keepdims=keepdims) def sum(self, value, axis=None, keepdims=False)-
Expand source code
def sum(self, value, axis=None, keepdims=False): if isinstance(value, (tuple, list)): assert axis == 0 return sum(value[1:], value[0]) return jnp.sum(value, axis=axis, keepdims=keepdims) def svd(self, matrix: ~TensorType, full_matrices=True) ‑> Tuple[~TensorType, ~TensorType, ~TensorType]-
Expand source code
def svd(self, matrix: TensorType, full_matrices=True) -> Tuple[TensorType, TensorType, TensorType]: result = jnp.linalg.svd(matrix, full_matrices=full_matrices) return result[0], result[1], result[2]Args
matrix- (batch…, m, n)
Returns
eigenvalues- (batch…, n,)
eigenvectors- (batch…, n, n)
def tensordot(self, a, a_axes: tuple | list, b, b_axes: tuple | list)-
Expand source code
def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list]): return jnp.tensordot(a, b, (a_axes, b_axes))Multiply-sum-reduce a_axes of a with b_axes of b.
def to_dlpack(self, tensor)-
Expand source code
def to_dlpack(self, tensor): if version.parse(jax.__version__) < version.parse("0.7.0"): from jax import dlpack return dlpack.to_dlpack(tensor) else: return tensor.__dlpack__() def unique(self, x: ~TensorType, return_inverse: bool, return_counts: bool, axis: int) ‑> Tuple[~TensorType, ...]-
Expand source code
def unique(self, x: TensorType, return_inverse: bool, return_counts: bool, axis: int) -> Tuple[TensorType, ...]: return jnp.unique(x, return_inverse=return_inverse, return_counts=return_counts, axis=axis)Args
x- n-dimensional int array. Will compare
axis-slices ofxfor multidimensionalx. return_inverse- Whether to return the inverse
return_counts- Whether to return the counts.
axis- Axis along which slices of
xshould be compared.
Returns
unique_slices- Sorted unique slices of
x unique_inverse- (optional) index of the unique slice for each slice of
x unique_counts- Number of occurrences of each unique slices
def unravel_index(self, flat_index, shape)-
Expand source code
def unravel_index(self, flat_index, shape): return jnp.stack(jnp.unravel_index(flat_index, shape), -1) def vectorized_call(self, f, *args, output_dtypes=None, **aux_args)-
Expand source code
def vectorized_call(self, f, *args, output_dtypes=None, **aux_args): batch_size = self.determine_size(args, 0) args = [self.tile_to(t, 0, batch_size) for t in args] def f_positional(*args): return f(*args, **aux_args) vec_f = jax.vmap(f_positional, 0, 0) return vec_f(*args)Args
f- Function with only positional tensor argument, returning one or multiple tensors.
*args- Batched inputs for
f. The first dimension of allargsis vectorized. All tensors inargsmust have the same size or1in their first dimension. output_dtypes- Single
DTypeor tuple of DTypes declaring the dtypes of the tensors returned byf. **aux_args- Non-vectorized keyword arguments to be passed to
f.
def where(self, condition, x=None, y=None)-
Expand source code
def where(self, condition, x=None, y=None): if x is None or y is None: return jnp.argwhere(condition) return jnp.where(condition, x, y) def while_loop(self, loop: Callable, values: tuple, max_iter: int | Tuple[int, ...] | List[int])-
Expand source code
def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]]): if all(self.is_available(t) for t in values): return self.stop_gradient_tree(Backend.while_loop(self, loop, values, max_iter)) if isinstance(max_iter, (tuple, list)): # stack traced trajectory, unroll until max_iter values = self.stop_gradient_tree(values) trj = [values] if 0 in max_iter else [] for i in range(1, max(max_iter) + 1): values = loop(*values) if i in max_iter: trj.append(values) # values are not mutable so no need to copy return self.stop_gradient_tree(self.stack_leaves(trj)) else: if max_iter is None: cond = lambda vals: jnp.any(vals[0]) body = lambda vals: loop(*vals) return jax.lax.while_loop(cond, body, values) else: cond = lambda vals: jnp.any(vals[1][0]) & (vals[0] < max_iter) body = lambda vals: (vals[0] + 1, loop(*vals[1])) return jax.lax.while_loop(cond, body, (self.as_tensor(0), values))[1]If
max_iter is None, runswhile any(values[0]): values = loop(*values) return valuesThis operation does not support backpropagation.
Args
loop- Loop function, must return a
tuplewith entries equal tovaluesin shape and data type. values- Initial values of loop variables.
max_iter- Maximum number of iterations to run, single
intor sequence of integers.
Returns
Loop variables upon loop completion if
max_iteris a single integer. Ifmax_iteris a sequence, stacks the variables after each entry inmax_iter, adding an outer dimension of size<= len(max_iter). If the condition is fulfilled before the maximum max_iter is reached, the loop may be broken or not, depending on the implementation. If the loop is broken, the values returned by the last loop are expected to be constant and filled. def zeros(self, shape, dtype: phiml.backend._dtype.DType = None)-
Expand source code
def zeros(self, shape, dtype: DType = None): self._check_float64() return jax.device_put(jnp.zeros(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref) def zeros_like(self, tensor)-
Expand source code
def zeros_like(self, tensor): return jax.device_put(jnp.zeros_like(tensor), self._default_device.ref)