Module phiml.backend.jax
Sub-modules
phiml.backend.jax.stax_nets
-
Stax implementation of the unified machine learning API. Equivalent functions also exist for the other frameworks …
Classes
class JaxBackend
-
Backends delegate low-level operations to a ML or numerics library or emulate them. The methods of
Backend
form a comprehensive list of available operations.To support a library, subclass
Backend
and register it by adding it toBACKENDS
.Args
name
- Human-readable string
default_device
ComputeDevice
being used by default
Expand source code
class JaxBackend(Backend): def __init__(self): devices = [] for device_type in ['cpu', 'gpu', 'tpu']: try: for jax_dev in jax.devices(device_type): devices.append(ComputeDevice(self, device_type.upper(), jax_dev.platform.upper(), -1, -1, f"id={jax_dev.id}", jax_dev)) except RuntimeError as err: pass # this is just Jax not finding anything. jaxlib.xla_client._get_local_backends() could help but isn't currently available on GitHub actions Backend.__init__(self, 'jax', devices, devices[-1]) try: self.rnd_key = jax.random.PRNGKey(seed=0) except RuntimeError as err: warnings.warn(f"{err}", RuntimeWarning) self.rnd_key = None def prefers_channels_last(self) -> bool: return True def requires_fixed_shapes_when_tracing(self) -> bool: return True def nn_library(self): from . import stax_nets return stax_nets def _check_float64(self): if self.precision == 64: if not jax.config.read('jax_enable_x64'): jax.config.update('jax_enable_x64', True) assert jax.config.read('jax_enable_x64'), "FP64 is disabled for Jax." def seed(self, seed: int): self.rnd_key = jax.random.PRNGKey(seed) def as_tensor(self, x, convert_external=True): self._check_float64() if self.is_tensor(x, only_native=convert_external): array = x else: array = jnp.array(x) # --- Enforce Precision --- if not isinstance(array, numbers.Number): if self.dtype(array).kind == float: array = self.to_float(array) elif self.dtype(array).kind == complex: array = self.to_complex(array) return array def is_module(self, obj): return False def is_tensor(self, x, only_native=False): if isinstance(x, jnp.ndarray) and not isinstance(x, np.ndarray): # NumPy arrays inherit from Jax arrays return True if isinstance(x, jnp.bool_) and not isinstance(x, np.bool_): return True if self.is_sparse(x): return True # --- Above considered native --- if only_native: return False # --- Non-native types --- if isinstance(x, np.ndarray): return True if isinstance(x, np.bool_): return True if isinstance(x, (numbers.Number, bool)): return True if isinstance(x, (tuple, list)): return all([self.is_tensor(item, False) for item in x]) return False def is_sparse(self, x) -> bool: return isinstance(x, (COO, BCOO, CSR, CSC)) def get_sparse_format(self, x) -> str: format_names = { COO: 'coo', BCOO: 'coo', CSR: 'csr', CSC: 'csc', } return format_names.get(type(x), 'dense') def is_available(self, tensor): if isinstance(tensor, JVPTracer): tensor = tensor.primal return not isinstance(tensor, Tracer) def numpy(self, tensor): if isinstance(tensor, JVPTracer): tensor = tensor.primal if self.is_sparse(tensor): assemble, parts = self.disassemble(tensor) return assemble(NUMPY, *[self.numpy(t) for t in parts]) return np.array(tensor) def disassemble(self, x) -> Tuple[Callable, Sequence[TensorType]]: if self.is_sparse(x): if isinstance(x, COO): raise NotImplementedError # return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (np.stack([x.row, x.col], -1), x.data) if isinstance(x, BCOO): return lambda b, i, v: b.sparse_coo_tensor(i, v, x.shape), (x.indices, x.data) if isinstance(x, CSR): raise NotImplementedError # return lambda b, v, i, p: b.csr_matrix(i, p, v, x.shape), (x.data, x.indices, x.indptr) elif isinstance(x, CSC): raise NotImplementedError # return lambda b, v, i, p: b.csc_matrix(p, i, v, x.shape), (x.data, x.indices, x.indptr) raise NotImplementedError else: return lambda b, t: t, (x,) def to_dlpack(self, tensor): from jax import dlpack return dlpack.to_dlpack(tensor) def from_dlpack(self, capsule): from jax import dlpack return dlpack.from_dlpack(capsule) def copy(self, tensor, only_mutable=False): return jnp.array(tensor, copy=True) def get_device(self, tensor: TensorType) -> ComputeDevice: if hasattr(tensor, 'devices'): return self.get_device_by_ref(next(iter(tensor.devices()))) if hasattr(tensor, 'device'): return self.get_device_by_ref(tensor.device()) raise AssertionError(f"tensor {type(tensor)} has no device attribute") def allocate_on_device(self, tensor: TensorType, device: ComputeDevice) -> TensorType: return jax.device_put(tensor, device.ref) sqrt = staticmethod(jnp.sqrt) exp = staticmethod(jnp.exp) erf = staticmethod(scipy.special.erf) softplus = staticmethod(jax.nn.softplus) sin = staticmethod(jnp.sin) arcsin = staticmethod(jnp.arcsin) cos = staticmethod(jnp.cos) arccos = staticmethod(jnp.arccos) tan = staticmethod(jnp.tan) arctan = staticmethod(np.arctan) arctan2 = staticmethod(np.arctan2) sinh = staticmethod(np.sinh) arcsinh = staticmethod(np.arcsinh) cosh = staticmethod(np.cosh) arccosh = staticmethod(np.arccosh) tanh = staticmethod(np.tanh) arctanh = staticmethod(np.arctanh) log = staticmethod(jnp.log) log2 = staticmethod(jnp.log2) log10 = staticmethod(jnp.log10) isfinite = staticmethod(jnp.isfinite) isnan = staticmethod(jnp.isnan) isinf = staticmethod(jnp.isinf) abs = staticmethod(jnp.abs) sign = staticmethod(jnp.sign) round = staticmethod(jnp.round) ceil = staticmethod(jnp.ceil) floor = staticmethod(jnp.floor) flip = staticmethod(jnp.flip) stop_gradient = staticmethod(jax.lax.stop_gradient) transpose = staticmethod(jnp.transpose) equal = staticmethod(jnp.equal) tile = staticmethod(jnp.tile) stack = staticmethod(jnp.stack) concat = staticmethod(jnp.concatenate) maximum = staticmethod(jnp.maximum) minimum = staticmethod(jnp.minimum) clip = staticmethod(jnp.clip) argmax = staticmethod(np.argmax) argmin = staticmethod(np.argmin) shape = staticmethod(jnp.shape) staticshape = staticmethod(jnp.shape) imag = staticmethod(jnp.imag) real = staticmethod(jnp.real) conj = staticmethod(jnp.conjugate) einsum = staticmethod(jnp.einsum) cumsum = staticmethod(jnp.cumsum) def nonzero(self, values, length=None, fill_value=-1): result = jnp.nonzero(values, size=length, fill_value=fill_value) return jnp.stack(result, -1) def vectorized_call(self, f, *args, output_dtypes=None, **aux_args): batch_size = self.determine_size(args, 0) args = [self.tile_to(t, 0, batch_size) for t in args] def f_positional(*args): return f(*args, **aux_args) vec_f = jax.vmap(f_positional, 0, 0) return vec_f(*args) def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args): if all([self.is_available(arg) for arg in args]): args = [self.numpy(arg) for arg in args] output = f(*args, **aux_args) result = map_structure(self.as_tensor, output) return result @dataclasses.dataclass class OutputTensor: shape: Tuple[int] dtype: np.dtype output_specs = map_structure(lambda t, s: OutputTensor(s, to_numpy_dtype(t)), output_dtypes, output_shapes) if hasattr(jax, 'pure_callback'): def aux_f(*args): return f(*args, **aux_args) return jax.pure_callback(aux_f, output_specs, *args) else: def aux_f(args): if isinstance(args, tuple): return f(*args, **aux_args) else: return f(args, **aux_args) from jax.experimental.host_callback import call return call(aux_f, args, result_shape=output_specs) def jit_compile(self, f: Callable) -> Callable: def run_jit_f(*args): # print(jax.make_jaxpr(f)(*args)) ML_LOGGER.debug(f"JaxBackend: running jit-compiled '{f.__name__}' with shapes {[self.shape(arg) for arg in args]} and dtypes {[self.dtype(arg) for arg in args]}") return self.as_registered.call(jit_f, *args, name=f"run jit-compiled '{f.__name__}'") run_jit_f.__name__ = f"Jax-Jit({f.__name__})" jit_f = jax.jit(f, device=self._default_device.ref) return run_jit_f def block_until_ready(self, values): if hasattr(values, 'block_until_ready'): values.block_until_ready() if isinstance(values, (tuple, list)): for v in values: self.block_until_ready(v) def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool): if get_output: jax_grad_f = jax.value_and_grad(f, argnums=wrt, has_aux=True) @wraps(f) def unwrap_outputs(*args): args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)] (_, output_tuple), grads = jax_grad_f(*args) return (*output_tuple, *[jnp.conjugate(g) for g in grads]) return unwrap_outputs else: @wraps(f) def nonaux_f(*args): loss, output = f(*args) return loss jax_grad = jax.grad(nonaux_f, argnums=wrt, has_aux=False) @wraps(f) def call_jax_grad(*args): args = [self.to_float(arg) if self.dtype(arg).kind in (bool, int) and i in wrt else arg for i, arg in enumerate(args)] grads = jax_grad(*args) return tuple([jnp.conjugate(g) for g in grads]) return call_jax_grad def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) -> Callable: jax_fun = jax.custom_vjp(f) # custom vector-Jacobian product (reverse-mode differentiation) def forward(*x): y = f(*x) return y, (x, y) def backward(x_y, dy): x, y = x_y dx = gradient(x, y, dy) return tuple(dx) jax_fun.defvjp(forward, backward) return jax_fun def divide_no_nan(self, x, y): return jnp.where(y == 0, 0, x / y) # jnp.nan_to_num(x / y, copy=True, nan=0) covers up NaNs from before def random_uniform(self, shape, low, high, dtype: Union[DType, None]): self._check_float64() self.rnd_key, subkey = jax.random.split(self.rnd_key) dtype = dtype or self.float_type jdt = to_numpy_dtype(dtype) if dtype.kind == float: tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt) elif dtype.kind == complex: real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType(float, dtype.precision))) imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType(float, dtype.precision))) return real + 1j * imag elif dtype.kind == int: tensor = random.randint(subkey, shape, low, high, dtype=jdt) if tensor.dtype != jdt: warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning) else: raise ValueError(dtype) return jax.device_put(tensor, self._default_device.ref) def random_normal(self, shape, dtype: DType): self._check_float64() self.rnd_key, subkey = jax.random.split(self.rnd_key) dtype = dtype or self.float_type return jax.device_put(random.normal(subkey, shape, dtype=to_numpy_dtype(dtype)), self._default_device.ref) def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)): if limit is None: start, limit = 0, start return jnp.arange(start, limit, delta, to_numpy_dtype(dtype)) def pad(self, value, pad_width, mode='constant', constant_values=0): assert mode in ('constant', 'symmetric', 'periodic', 'reflect', 'boundary'), mode if mode == 'constant': constant_values = jnp.array(constant_values, dtype=value.dtype) return jnp.pad(value, pad_width, 'constant', constant_values=constant_values) else: if mode in ('periodic', 'boundary'): mode = {'periodic': 'wrap', 'boundary': 'edge'}[mode] return jnp.pad(value, pad_width, mode) def reshape(self, value, shape): return jnp.reshape(value, shape) def sum(self, value, axis=None, keepdims=False): if isinstance(value, (tuple, list)): assert axis == 0 return sum(value[1:], value[0]) return jnp.sum(value, axis=axis, keepdims=keepdims) def prod(self, value, axis=None): if not isinstance(value, jnp.ndarray): value = jnp.array(value) if value.dtype == bool: return jnp.all(value, axis=axis) return jnp.prod(value, axis=axis) def where(self, condition, x=None, y=None): if x is None or y is None: return jnp.argwhere(condition) return jnp.where(condition, x, y) def zeros(self, shape, dtype: DType = None): self._check_float64() return jax.device_put(jnp.zeros(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref) def zeros_like(self, tensor): return jax.device_put(jnp.zeros_like(tensor), self._default_device.ref) def ones(self, shape, dtype: DType = None): self._check_float64() return jax.device_put(jnp.ones(shape, dtype=to_numpy_dtype(dtype or self.float_type)), self._default_device.ref) def ones_like(self, tensor): return jax.device_put(jnp.ones_like(tensor), self._default_device.ref) def meshgrid(self, *coordinates): self._check_float64() coordinates = [self.as_tensor(c) for c in coordinates] return [jax.device_put(c, self._default_device.ref) for c in jnp.meshgrid(*coordinates, indexing='ij')] def linspace(self, start, stop, number): self._check_float64() return jax.device_put(jnp.linspace(start, stop, number, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref) def linspace_without_last(self, start, stop, number): self._check_float64() return jax.device_put(jnp.linspace(start, stop, number, endpoint=False, dtype=to_numpy_dtype(self.float_type)), self._default_device.ref) def mean(self, value, axis=None, keepdims=False): return jnp.mean(value, axis, keepdims=keepdims) def log_gamma(self, x): return jax.lax.lgamma(self.to_float(x)) def gamma_inc_l(self, a, x): return scipy.special.gammainc(a, x) def gamma_inc_u(self, a, x): return scipy.special.gammaincc(a, x) def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list]): return jnp.tensordot(a, b, (a_axes, b_axes)) def mul(self, a, b): # if scipy.sparse.issparse(a): # TODO sparse? # return a.multiply(b) # elif scipy.sparse.issparse(b): # return b.multiply(a) # else: return Backend.mul(self, a, b) def mul_matrix_batched_vector(self, A, b): from jax.experimental.sparse import BCOO if isinstance(A, BCOO): return(A @ b.T).T return jnp.stack([A.dot(b[i]) for i in range(b.shape[0])]) def get_diagonal(self, matrices, offset=0): result = jnp.diagonal(matrices, offset=offset, axis1=1, axis2=2) return jnp.transpose(result, [0, 2, 1]) def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]]): if all(self.is_available(t) for t in values): return self.stop_gradient_tree(Backend.while_loop(self, loop, values, max_iter)) if isinstance(max_iter, (tuple, list)): # stack traced trajectory, unroll until max_iter values = self.stop_gradient_tree(values) trj = [values] if 0 in max_iter else [] for i in range(1, max(max_iter) + 1): values = loop(*values) if i in max_iter: trj.append(values) # values are not mutable so no need to copy return self.stop_gradient_tree(self.stack_leaves(trj)) else: if max_iter is None: cond = lambda vals: jnp.any(vals[0]) body = lambda vals: loop(*vals) return jax.lax.while_loop(cond, body, values) else: cond = lambda vals: jnp.any(vals[1][0]) & (vals[0] < max_iter) body = lambda vals: (vals[0] + 1, loop(*vals[1])) return jax.lax.while_loop(cond, body, (self.as_tensor(0), values))[1] def max(self, x, axis=None, keepdims=False): return jnp.max(x, axis, keepdims=keepdims) def min(self, x, axis=None, keepdims=False): return jnp.min(x, axis, keepdims=keepdims) def conv(self, value, kernel, zero_padding=True): assert kernel.shape[0] in (1, value.shape[0]) assert value.shape[1] == kernel.shape[2], f"value has {value.shape[1]} channels but kernel has {kernel.shape[2]}" assert value.ndim + 1 == kernel.ndim value, kernel = self.auto_cast(value, kernel, bool_to_int=True) # AutoDiff may require jax.lax.conv_general_dilated result = [] for b in range(value.shape[0]): b_kernel = kernel[min(b, kernel.shape[0] - 1)] result_b = [] for o in range(kernel.shape[1]): result_b.append(0) for i in range(value.shape[1]): # result.at[b, o, ...].set(scipy.signal.correlate(value[b, i, ...], b_kernel[o, i, ...], mode='same' if zero_padding else 'valid')) result_b[-1] += scipy.signal.correlate(value[b, i, ...], b_kernel[o, i, ...], mode='same' if zero_padding else 'valid') result.append(jnp.stack(result_b, 0)) return jnp.stack(result, 0) def expand_dims(self, a, axis=0, number=1): for _i in range(number): a = jnp.expand_dims(a, axis) return a def cast(self, x, dtype: DType): if self.is_tensor(x, only_native=True) and from_numpy_dtype(x.dtype) == dtype: return x else: return jnp.array(x, to_numpy_dtype(dtype)) def unravel_index(self, flat_index, shape): return jnp.stack(jnp.unravel_index(flat_index, shape), -1) def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined'): if not self.is_available(shape): return Backend.ravel_multi_index(self, multi_index, shape, mode) mode = mode if isinstance(mode, int) else {'undefined': 'clip', 'periodic': 'wrap', 'clamp': 'clip'}[mode] idx_first = jnp.transpose(multi_index, (self.ndims(multi_index)-1,) + tuple(range(self.ndims(multi_index)-1))) result = jnp.ravel_multi_index(idx_first, shape, mode='wrap' if isinstance(mode, int) else mode) if isinstance(mode, int): outside = self.any((multi_index < 0) | (multi_index >= jnp.asarray(shape, dtype=multi_index.dtype)), -1) result = self.where(outside, mode, result) return result def gather(self, values, indices, axis: int): slices = [indices if i == axis else slice(None) for i in range(self.ndims(values))] return values[tuple(slices)] def batched_gather_nd(self, values, indices): values = self.as_tensor(values) indices = self.as_tensor(indices) assert indices.shape[-1] == self.ndims(values) - 2 batch_size = combined_dim(values.shape[0], indices.shape[0]) indices_list = [indices[..., i] for i in range(indices.shape[-1])] batch_range = self.expand_dims(np.arange(batch_size), -1, number=self.ndims(indices) - 2) slices = (batch_range, *indices_list) return values[slices] def batched_gather_1d(self, values, indices): batch_size = combined_dim(values.shape[0], indices.shape[0]) return values[np.arange(batch_size)[:, None], indices] def repeat(self, x, repeats, axis: int, new_length=None): return jnp.repeat(x, self.as_tensor(repeats), axis, total_repeat_length=new_length) def std(self, x, axis=None, keepdims=False): return jnp.std(x, axis, keepdims=keepdims) def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0): if new_length is None: slices = [mask if i == axis else slice(None) for i in range(len(x.shape))] return x[tuple(slices)] else: indices = jnp.argwhere(mask, size=new_length, fill_value=-1)[..., 0] valid = indices >= 0 valid = valid[tuple([slice(None) if i == axis else None for i in range(len(x.shape))])] result = self.gather(x, jnp.maximum(0, indices), axis) return jnp.where(valid, result, fill_value) def any(self, boolean_tensor, axis=None, keepdims=False): if isinstance(boolean_tensor, (tuple, list)): boolean_tensor = jnp.stack(boolean_tensor) return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims) def all(self, boolean_tensor, axis=None, keepdims=False): if isinstance(boolean_tensor, (tuple, list)): boolean_tensor = jnp.stack(boolean_tensor) return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims) def scatter(self, base_grid, indices, values, mode: str): out_kind = self.combine_types(self.dtype(base_grid), self.dtype(values)).kind base_grid, values = self.auto_cast(base_grid, values, bool_to_int=True) batch_size = combined_dim(combined_dim(indices.shape[0], values.shape[0]), base_grid.shape[0]) spatial_dims = tuple(range(base_grid.ndim - 2)) dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,), # channel dim of updates (batch dim removed) inserted_window_dims=spatial_dims, # no idea what this does but spatial_dims seems to work scatter_dims_to_operand_dims=spatial_dims) # spatial dims of base_grid (batch dim removed) scatter = {'add': jax.lax.scatter_add, 'update': jax.lax.scatter, 'max': jax.lax.scatter_max, 'min': jax.lax.scatter_min}[mode] def scatter_single(base_grid, indices, values): return scatter(base_grid, indices, values, dnums) if self.staticshape(indices)[0] == 1: indices = self.tile(indices, [batch_size, 1, 1]) if self.staticshape(values)[0] == 1: values = self.tile(values, [batch_size, 1, 1]) result = self.vectorized_call(scatter_single, base_grid, indices, values) if self.dtype(result).kind != out_kind: if out_kind == bool: result = self.cast(result, DType(bool)) return result def histogram1d(self, values, weights, bin_edges): def unbatched_hist(values, weights, bin_edges): hist, _ = jnp.histogram(values, bin_edges, weights=weights) return hist return jax.vmap(unbatched_hist)(values, weights, bin_edges) def bincount(self, x, weights: Optional[TensorType], bins: int, x_sorted=False): if x_sorted: return jax.ops.segment_sum(weights or 1, x, bins, indices_are_sorted=True) else: return jnp.bincount(x, weights=weights, minlength=bins, length=bins) def unique(self, x: TensorType, return_inverse: bool, return_counts: bool, axis: int) -> Tuple[TensorType, ...]: return jnp.unique(x, return_inverse=return_inverse, return_counts=return_counts, axis=axis) def quantile(self, x, quantiles): return jnp.quantile(x, quantiles, axis=-1) def argsort(self, x, axis=-1): return jnp.argsort(x, axis) def sort(self, x, axis=-1): return jnp.sort(x, axis) def searchsorted(self, sorted_sequence, search_values, side: str, dtype=DType(int, 32)): if self.ndims(sorted_sequence) == 1: return jnp.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype)) else: return jax.vmap(partial(self.searchsorted, side=side, dtype=dtype))(sorted_sequence, search_values) def fft(self, x, axes: Union[tuple, list]): x = self.to_complex(x) if not axes: return x if len(axes) == 1: return jnp.fft.fft(x, axis=axes[0]).astype(x.dtype) elif len(axes) == 2: return jnp.fft.fft2(x, axes=axes).astype(x.dtype) else: return jnp.fft.fftn(x, axes=axes).astype(x.dtype) def ifft(self, k, axes: Union[tuple, list]): if not axes: return k if len(axes) == 1: return jnp.fft.ifft(k, axis=axes[0]).astype(k.dtype) elif len(axes) == 2: return jnp.fft.ifft2(k, axes=axes).astype(k.dtype) else: return jnp.fft.ifftn(k, axes=axes).astype(k.dtype) def dtype(self, array) -> DType: if isinstance(array, bool): return DType(bool) if isinstance(array, int): return DType(int, 32) if isinstance(array, float): return DType(float, 64) if isinstance(array, complex): return DType(complex, 128) if not isinstance(array, jnp.ndarray): array = jnp.array(array) return from_numpy_dtype(array.dtype) def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tuple[TensorType, TensorType, TensorType, TensorType]: solution, residuals, rank, singular_values = lstsq_batched(matrix, rhs) return solution, residuals, rank, singular_values def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool): matrix, rhs = self.auto_cast(matrix, rhs, int_to_float=True, bool_to_int=True) x = jax.lax.linalg.triangular_solve(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal, left_side=True) return x def matrix_rank_dense(self, matrix, hermitian=False): try: return jnp.linalg.matrix_rank(matrix) except TypeError as err: if err.args[0] == "array should have 2 or fewer dimensions": # this is a Jax bug on some distributions/versions warnings.warn("You are using a broken version of JAX. matrix_rank for dense matrices will fall back to NumPy.") return self.as_tensor(NUMPY.matrix_rank_dense(self.numpy(matrix), hermitian=hermitian)) else: raise err def eigvals(self, matrix: TensorType) -> TensorType: return jnp.linalg.eigvals(matrix) def eig(self, matrix: TensorType) -> TensorType: return jnp.linalg.eig(matrix) def svd(self, matrix: TensorType, full_matrices=True) -> Tuple[TensorType, TensorType, TensorType]: result = jnp.linalg.svd(matrix, full_matrices=full_matrices) return result[0], result[1], result[2] def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple): return BCOO((values, indices), shape=shape)
Ancestors
- phiml.backend._backend.Backend
Class variables
var arccosh
var arcsinh
var arctan
var arctan2
var arctanh
var cosh
var sinh
var tanh
Static methods
def abs(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def arccos(x, /)
def arcsin(x, /)
def argmax(a, axis=None, out=None, *, keepdims=<no value>)
-
Returns the indices of the maximum values along an axis.
Parameters
a
:array_like
- Input array.
axis
:int
, optional- By default, the index is into the flattened array, otherwise along the specified axis.
out
:array
, optional- If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims
:bool
, optional-
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.
Added in version: 1.22.0
Returns
index_array
:ndarray
ofints
- Array of indices into the array. It has the same shape as
a.shape
with the dimension alongaxis
removed. Ifkeepdims
is set to True, then the size ofaxis
will be 1 with the resulting array having same shape asa.shape
.
See Also
ndarray.argmax
,argmin
amax : The maximum value along a given axis.
unravel_index : Convert a flat index into an index tuple.
take_along_axis : Apply ``np.expand_dims(index_array
,axis)`` from argmax to an array as if by calling max.
Notes
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
Examples
>>> a = np.arange(6).reshape(2,3) + 10 >>> a array([[10, 11, 12], [13, 14, 15]]) >>> np.argmax(a) 5 >>> np.argmax(a, axis=0) array([1, 1, 1]) >>> np.argmax(a, axis=1) array([2, 2])
Indexes of the maximal elements of a N-dimensional array:
>>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape) >>> ind (1, 2) >>> a[ind] 15
>>> b = np.arange(6) >>> b[1] = 5 >>> b array([0, 5, 2, 3, 4, 5]) >>> np.argmax(b) # Only the first occurrence is returned. 1
>>> x = np.array([[4,2,3], [1,0,3]]) >>> index_array = np.argmax(x, axis=-1) >>> # Same as np.amax(x, axis=-1, keepdims=True) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1) array([[4], [3]]) >>> # Same as np.amax(x, axis=-1) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1) array([4, 3])
Setting
keepdims
toTrue
,>>> x = np.arange(24).reshape((2, 3, 4)) >>> res = np.argmax(x, axis=1, keepdims=True) >>> res.shape (2, 1, 4)
def argmin(a, axis=None, out=None, *, keepdims=<no value>)
-
Returns the indices of the minimum values along an axis.
Parameters
a
:array_like
- Input array.
axis
:int
, optional- By default, the index is into the flattened array, otherwise along the specified axis.
out
:array
, optional- If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims
:bool
, optional-
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.
Added in version: 1.22.0
Returns
index_array
:ndarray
ofints
- Array of indices into the array. It has the same shape as
a.shape
with the dimension alongaxis
removed. Ifkeepdims
is set to True, then the size ofaxis
will be 1 with the resulting array having same shape asa.shape
.
See Also
ndarray.argmin
,argmax
amin : The minimum value along a given axis.
unravel_index : Convert a flat index into an index tuple.
take_along_axis : Apply ``np.expand_dims(index_array
,axis)`` from argmin to an array as if by calling min.
Notes
In case of multiple occurrences of the minimum values, the indices corresponding to the first occurrence are returned.
Examples
>>> a = np.arange(6).reshape(2,3) + 10 >>> a array([[10, 11, 12], [13, 14, 15]]) >>> np.argmin(a) 0 >>> np.argmin(a, axis=0) array([0, 0, 0]) >>> np.argmin(a, axis=1) array([0, 0])
Indices of the minimum elements of a N-dimensional array:
>>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape) >>> ind (0, 0) >>> a[ind] 10
>>> b = np.arange(6) + 10 >>> b[4] = 10 >>> b array([10, 11, 12, 13, 10, 15]) >>> np.argmin(b) # Only the first occurrence is returned. 0
>>> x = np.array([[4,2,3], [1,0,3]]) >>> index_array = np.argmin(x, axis=-1) >>> # Same as np.amin(x, axis=-1, keepdims=True) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1) array([[2], [0]]) >>> # Same as np.amax(x, axis=-1) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1) array([2, 0])
Setting
keepdims
toTrue
,>>> x = np.arange(24).reshape((2, 3, 4)) >>> res = np.argmin(x, axis=1, keepdims=True) >>> res.shape (2, 1, 4)
def ceil(x, /)
def clip(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], a_min: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, None] = None, a_max: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, None] = None, out: None = None) ‑> jax.Array
def concat(arrays: Union[numpy.ndarray, jax.Array, Sequence[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]]], axis: Optional[int] = 0, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array
-
Join a sequence of arrays along an existing axis.
LAX-backend implementation of :func:
numpy.concatenate
.Original docstring below.
Parameters
axis
:int
, optional- The axis along which the arrays will be joined. If axis is None, arrays are flattened before use. Default is 0.
dtype
:str
ordtype
- If provided, the destination array will have this dtype. Cannot be
provided together with
out
.
Returns
res
:ndarray
- The concatenated array.
def conj(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def cos(x, /)
def cumsum(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axis: Union[None, int, Sequence[int]] = None, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType] = None, out: None = None) ‑> jax.Array
-
Return the cumulative sum of the elements along a given axis.
LAX-backend implementation of :func:
numpy.cumsum
.Original docstring below.
Parameters
a
:array_like
- Input array.
axis
:int
, optional- Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
dtype
:dtype
, optional- Type of the returned array and of the accumulator in which the
elements are summed.
If
dtype
is not specified, it defaults to the dtype ofa
, unlessa
has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.
Returns
cumsum_along_axis : ndarray. A new array holding the result is returned unless
out
is specified, in which case a reference toout
is returned. The result has the same size asa
, and the same shape asa
ifaxis
is not None ora
is a 1-d array. def einsum(subscripts, /, *operands, out: None = None, optimize: str = 'optimal', precision: Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]] = None, preferred_element_type: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array
-
Evaluates the Einstein summation convention on the operands.
LAX-backend implementation of :func:
numpy.einsum
.In addition to the original NumPy arguments listed below, also supports
precision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, a :class:~jax.lax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of two :class:~jax.lax.Precision
enums indicating separate precision for each argument. A tupleprecision
does not necessarily map to multiple arguments ofeinsum()
; rather, the specifiedprecision
is forwarded to eachdot_general
call used in the implementation.Original docstring below.
Using the Einstein summation convention, many common multi-dimensional, linear algebraic array operations can be represented in a simple fashion. In implicit mode
einsum
computes these values.In explicit mode,
einsum
provides further flexibility to compute other array operations that might not be considered classical Einstein summation operations, by disabling, or forcing summation over specified subscript labels.See the notes and examples for clarification.
Parameters
subscripts
:str
- Specifies the subscripts for summation as comma separated list of subscript labels. An implicit (classical Einstein summation) calculation is performed unless the explicit indicator '->' is included as well as subscript labels of the precise output form.
operands
:list
ofarray_like
- These are the arrays for the operation.
optimize
:{False, True, 'greedy', 'optimal'}
, optional- Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Also accepts an explicit contraction list from the
np.einsum_path
function. Seenp.einsum_path
for more details. Defaults to False.
Returns
output
:ndarray
- The calculation based on the Einstein summation convention.
def equal(x1, x2, /)
def erf(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]) ‑> jax.Array
-
Returns the error function of complex argument.
LAX-backend implementation of :func:
scipy.special.erf
.Note that the JAX version does not support complex inputs.
Original docstring below.
It is defined as
2/sqrt(pi)*integral(exp(-t**2), t=0..z)
.Parameters
x
:ndarray
- Input array.
Returns
res
:scalar
orndarray
- The values of the error function at the given points
x
.
References
.. [1] https://en.wikipedia.org/wiki/Error_function .. [2] Milton Abramowitz and Irene A. Stegun, eds. Handbook of Mathematical Functions with Formulas, Graphs, and Mathematical Tables. New York: Dover, 1972. http://www.math.sfu.ca/~cbm/aands/page_297.htm .. [3] Steven G. Johnson, Faddeeva W function implementation. http://ab-initio.mit.edu/Faddeeva
def exp(x, /)
def flip(m: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axis: Union[int, Tuple[int, ...], None] = None) ‑> jax.Array
-
Reverse the order of elements in an array along the given axis.
LAX-backend implementation of :func:
numpy.flip
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
The shape of the array is preserved, but the elements are reordered.
Added in version: 1.12.0
Parameters
m
:array_like
- Input array.
axis
:None
orint
ortuple
ofints
, optional-
Axis or axes along which to flip over. The default, axis=None, will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple.
!!! versionchanged "Changed in version: 1.15.0" None and tuples of axes are supported
Returns
out
:array_like
- A view of
m
with the entries of axis reversed. Since a view is returned, this operation is done in constant time.
def floor(x, /)
def imag(val: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def isfinite(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def isinf(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def isnan(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def log(x, /)
def log10(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def log2(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def maximum(x1, x2, /)
def minimum(x1, x2, /)
def real(val: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def round(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], decimals: int = 0, out: None = None) ‑> jax.Array
def shape(a)
-
Return the shape of an array.
Parameters
a
:array_like
- Input array.
Returns
shape
:tuple
ofints
- The elements of the shape tuple give the lengths of the corresponding array dimensions.
See Also
len
len(a)
is equivalent tonp.shape(a)[0]
for N-D arrays withN>=1
.ndarray.shape
- Equivalent array method.
Examples
>>> np.shape(np.eye(3)) (3, 3) >>> np.shape([[1, 3]]) (1, 2) >>> np.shape([0]) (1,) >>> np.shape(0) ()
>>> a = np.array([(1, 2), (3, 4), (5, 6)], ... dtype=[('x', 'i4'), ('y', 'i4')]) >>> np.shape(a) (3,) >>> a.shape (3,)
def sign(x: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], /) ‑> jax.Array
def sin(x, /)
def softplus(x: Any) ‑> Any
-
Softplus activation function.
Computes the element-wise function
[ \mathrm{softplus}(x) = \log(1 + e^x) ]
Args
x : input array
def sqrt(x, /)
def stack(arrays: Union[numpy.ndarray, jax.Array, Sequence[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]]], axis: int = 0, out: None = None, dtype: Union[Any, str, numpy.dtype, jax._src.SupportsDType, None] = None) ‑> jax.Array
-
Join a sequence of arrays along a new axis.
LAX-backend implementation of :func:
numpy.stack
.Original docstring below.
The
axis
parameter specifies the index of the new axis in the dimensions of the result. For example, ifaxis=0
it will be the first dimension and ifaxis=-1
it will be the last dimension.Added in version: 1.10.0
Parameters
arrays
:sequence
ofarray_like
- Each array must have the same shape.
axis
:int
, optional- The axis in the result array along which the input arrays are stacked.
dtype
:str
ordtype
- If provided, the destination array will have this dtype. Cannot be
provided together with
out
.
Returns
stacked
:ndarray
- The stacked array has one more dimension than the input arrays.
def staticshape(a)
-
Return the shape of an array.
Parameters
a
:array_like
- Input array.
Returns
shape
:tuple
ofints
- The elements of the shape tuple give the lengths of the corresponding array dimensions.
See Also
len
len(a)
is equivalent tonp.shape(a)[0]
for N-D arrays withN>=1
.ndarray.shape
- Equivalent array method.
Examples
>>> np.shape(np.eye(3)) (3, 3) >>> np.shape([[1, 3]]) (1, 2) >>> np.shape([0]) (1,) >>> np.shape(0) ()
>>> a = np.array([(1, 2), (3, 4), (5, 6)], ... dtype=[('x', 'i4'), ('y', 'i4')]) >>> np.shape(a) (3,) >>> a.shape (3,)
def stop_gradient(x: ~T) ‑> ~T
-
Stops gradient computation.
Operationally
stop_gradient
is the identity function, that is, it returns argumentx
unchanged. However,stop_gradient
prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradient
stops gradients for all of them.For example:
>>> jax.grad(lambda x: x**2)(3.) Array(6., dtype=float32, weak_type=True) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) Array(0., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: x**2))(3.) Array(2., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) Array(0., dtype=float32, weak_type=True)
def tan(x, /)
def tile(A: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], reps: Union[int, Any, Sequence[Union[int, Any]]]) ‑> jax.Array
-
Construct an array by repeating A the number of times given by reps.
LAX-backend implementation of :func:
numpy.tile
.Original docstring below.
If
reps
has lengthd
, the result will have dimension ofmax(d, A.ndim)
.If
A.ndim < d
,A
is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not the desired behavior, promoteA
to d-dimensions manually before calling this function.If
A.ndim > d
,reps
is promoted toA
.ndim by pre-pending 1's to it. Thus for anA
of shape (2, 3, 4, 5), areps
of (2, 2) is treated as (1, 1, 2, 2).Note : Although tile may be used for broadcasting, it is strongly recommended to use numpy's broadcasting operations and functions.
Parameters
A
:array_like
- The input array.
reps
:array_like
- The number of repetitions of
A
along each axis.
Returns
c
:ndarray
- The tiled output array.
def transpose(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], axes: Optional[Sequence[int]] = None) ‑> jax.Array
-
Returns an array with axes transposed.
LAX-backend implementation of :func:
numpy.transpose
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
For a 1-D array, this returns an unchanged view of the original array, as a transposed vector is simply the same vector. To convert a 1-D array into a 2-D column vector, an additional dimension must be added, e.g.,
np.atleast2d(a).T
achieves this, as doesa[:, np.newaxis]
. For a 2-D array, this is the standard matrix transpose. For an n-D array, if axes are given, their order indicates how the axes are permuted (see Examples). If axes are not provided, thentranspose(a).shape == a.shape[::-1]
.Parameters
a
:array_like
- Input array.
axes
:tuple
orlist
ofints
, optional- If specified, it must be a tuple or list which contains a permutation
of [0,1,…,N-1] where N is the number of axes of
a
. Thei
'th axis of the returned array will correspond to the axis numberedaxes[i]
of the input. If not specified, defaults torange(a.ndim)[::-1]
, which reverses the order of the axes.
Returns
p
:ndarray
a
with its axes permuted. A view is returned whenever possible.
Methods
def all(self, boolean_tensor, axis=None, keepdims=False)
def allocate_on_device(self, tensor: ~TensorType, device: phiml.backend._backend.ComputeDevice) ‑> ~TensorType
-
Moves
tensor
todevice
. May copy the tensor if it is already on the device.Args
tensor
- Existing tensor native to this backend.
device
- Target device, associated with this backend.
def any(self, boolean_tensor, axis=None, keepdims=False)
def argsort(self, x, axis=-1)
def as_tensor(self, x, convert_external=True)
-
Converts a tensor-like object to the native tensor representation of this backend. If x is a native tensor of this backend, it is returned without modification. If x is a Python number (numbers.Number instance),
convert_numbers
decides whether to convert it unless the backend cannot handle Python numbers.Note: There may be objects that are considered tensors by this backend but are not native and thus, will be converted by this method.
Args
x
- tensor-like, e.g. list, tuple, Python number, tensor
convert_external
- if False and
x
is a Python number that is understood by this backend, this method returns the number as-is. This can help prevent type clashes like int32 vs int64. (Default value = True)
Returns
tensor representation of
x
def batched_gather_1d(self, values, indices)
-
Args
values
- (batch, spatial)
indices
- (batch, indices)
Returns
(batch, indices)
def batched_gather_nd(self, values, indices)
-
Gathers values from the tensor
values
at locationsindices
. The first dimension ofvalues
andindices
is the batch dimension which must be either equal for both or one for either.Args
values
- tensor of shape (batch, spatial…, channel)
indices
- int tensor of shape (batch, any…, multi_index) where the size of multi_index is values.rank - 2.
Returns
Gathered values as tensor of shape (batch, any…, channel)
def bincount(self, x, weights: Optional[~TensorType], bins: int, x_sorted=False)
-
Args
x
- Bin indices, 1D int tensor.
weights
- Weights corresponding to
x
, 1D tensor. All weights are 1 ifweights=None
. bins
- Number of bins.
x_sorted
- Whether
x
is sorted from lowest to highest bin.
Returns
bin_counts
def block_until_ready(self, values)
def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0)
-
Args
x
- tensor with any number of dimensions
mask
- 1D mask tensor
axis
- Axis index >= 0
new_length
- Maximum size of the output along
axis
. This must be set when jit-compiling with Jax. fill_value
- If
new_length
is larger than the filtered result, the remaining values will be set tofill_value
.
def cast(self, x, dtype: phiml.backend._dtype.DType)
def conv(self, value, kernel, zero_padding=True)
-
Convolve value with kernel. Depending on the tensor rank, the convolution is either 1D (rank=3), 2D (rank=4) or 3D (rank=5). Higher dimensions may not be supported.
Args
value
- tensor of shape (batch_size, in_channel, spatial…)
kernel
- tensor of shape (batch_size or 1, out_channel, in_channel, spatial…)
zero_padding
- If True, pads the edges of
value
with zeros so that the result has the same shape asvalue
.
Returns
Convolution result as tensor of shape (batch_size, out_channel, spatial…)
def copy(self, tensor, only_mutable=False)
def custom_gradient(self, f: Callable, gradient: Callable, get_external_cache: Callable = None, on_call_skipped: Callable = None) ‑> Callable
-
Creates a function based on
f
that uses a custom gradient for backprop.Args
f
- Forward function.
gradient
- Function for backprop. Will be called as
gradient(*d_out)
to compute the gradient off
.
Returns
Function with similar signature and return values as
f
. However, the returned function does not support keyword arguments. def disassemble(self, x) ‑> Tuple[Callable, Sequence[~TensorType]]
-
Disassemble a (sparse) tensor into its individual constituents, such as values and indices.
Args
x
- Tensor
Returns
assemble
- Function
assemble(backend, *constituents)
that reassemblesx
from the constituents. constituents
- Tensors contained in
x
.
def divide_no_nan(self, x, y)
-
Computes x/y but returns 0 if y=0.
def dtype(self, array) ‑> phiml.backend._dtype.DType
def eig(self, matrix: ~TensorType) ‑> ~TensorType
-
Args
matrix
- (batch…, n, n)
Returns
eigenvalues
- (batch…, n,)
eigenvectors
- (batch…, n, n)
def eigvals(self, matrix: ~TensorType) ‑> ~TensorType
-
Args
matrix
- (batch…, n, n)
Returns
eigenvalues as (batch…, n,)
def expand_dims(self, a, axis=0, number=1)
def fft(self, x, axes: Union[tuple, list])
-
Computes the n-dimensional FFT along all but the first and last dimensions.
Args
x
- tensor of dimension 3 or higher
axes
- Along which axes to perform the FFT
Returns
Complex tensor
k
def from_dlpack(self, capsule)
def gamma_inc_l(self, a, x)
-
Regularized lower incomplete gamma function.
def gamma_inc_u(self, a, x)
-
Regularized upper incomplete gamma function.
def gather(self, values, indices, axis: int)
-
Gathers values from the tensor
values
at locationsindices
.Args
values
- tensor
indices
- 1D tensor
axis
- Axis along which to gather slices
Returns
tensor, with size along
axis
being the length ofindices
def get_device(self, tensor: ~TensorType) ‑> phiml.backend._backend.ComputeDevice
-
Returns the device
tensor
is located on. def get_diagonal(self, matrices, offset=0)
-
Args
matrices
- (batch, rows, cols, channels)
offset
- 0=diagonal, positive=above diagonal, negative=below diagonal
Returns
diagonal
- (batch, max(rows,cols), channels)
def get_sparse_format(self, x) ‑> str
-
Returns lower-case format string, such as 'coo', 'csr', 'csc'
def histogram1d(self, values, weights, bin_edges)
-
Args
values
- (batch, values)
bin_edges
- (batch, edges)
weights
- (batch, values)
Returns
(batch, edges) with dtype matching weights
def ifft(self, k, axes: Union[tuple, list])
-
Computes the n-dimensional inverse FFT along all but the first and last dimensions.
Args
k
- tensor of dimension 3 or higher
axes
- Along which axes to perform the inverse FFT
Returns
Complex tensor
x
def is_available(self, tensor)
-
Tests if the value of the tensor is known and can be read at this point. If true,
numpy(tensor)
must return a valid NumPy representation of the value.Tensors are typically available when the backend operates in eager mode.
Args
tensor
- backend-compatible tensor
Returns
bool
def is_module(self, obj)
-
Tests if
obj
is of a type that is specific to this backend, e.g. a neural network. IfTrue
, this backend will be chosen for operations involvingobj
.See Also:
Backend.is_tensor()
.Args
obj
- Object to test.
def is_sparse(self, x) ‑> bool
-
Args
x
- Tensor native to this
Backend
.
def is_tensor(self, x, only_native=False)
-
An object is considered a native tensor by a backend if no internal conversion is required by backend methods. An object is considered a tensor (nativer or otherwise) by a backend if it is not a struct (e.g. tuple, list) and all methods of the backend accept it as a tensor argument.
If
True
, this backend will be chosen for operations involvingx
.See Also:
Backend.is_module()
.Args
x
- object to check
only_native
- If True, only accepts true native tensor representations, not Python numbers or others that are also supported as tensors (Default value = False)
Returns
bool
- whether
x
is considered a tensor by this backend
def jacobian(self, f, wrt: Union[tuple, list], get_output: bool, is_f_scalar: bool)
-
Args
f
- Function to differentiate. Returns a tuple containing
(reduced_loss, output)
wrt
- Argument indices for which to compute the gradient.
get_output
- Whether the derivative function should return the output of
f
in addition to the gradient. is_f_scalar
- Whether
f
is guaranteed to return a scalar output.
Returns
A function
g
with the same arguments asf
. Ifget_output=True
,g
returns atuple
containing the outputs off
followed by the gradients. The gradients retain the dimensions ofreduced_loss
in order as outer (first) dimensions. def jit_compile(self, f: Callable) ‑> Callable
def linspace(self, start, stop, number)
def linspace_without_last(self, start, stop, number)
def log_gamma(self, x)
def matrix_rank_dense(self, matrix, hermitian=False)
-
Args
matrix
- Dense matrix of shape (batch, rows, cols)
hermitian
- Whether all matrices are guaranteed to be hermitian.
def matrix_solve_least_squares(self, matrix: ~TensorType, rhs: ~TensorType) ‑> Tuple[~TensorType, ~TensorType, ~TensorType, ~TensorType]
-
Args
matrix
- Shape (batch, vec, constraints)
rhs
- Shape (batch, vec, batch_per_matrix)
Returns
solution
- Solution vector of Shape (batch, constraints, batch_per_matrix)
residuals
- Optional, can be
None
rank
- Optional, can be
None
singular_values
- Optional, can be
None
def max(self, x, axis=None, keepdims=False)
def mean(self, value, axis=None, keepdims=False)
def meshgrid(self, *coordinates)
def min(self, x, axis=None, keepdims=False)
def mul(self, a, b)
def mul_matrix_batched_vector(self, A, b)
def nn_library(self)
def nonzero(self, values, length=None, fill_value=-1)
-
Args
values
- Tensor with only spatial dimensions
length
- (Optional) Length of the resulting array. If specified, the result array will be padded with
fill_value
or trimmed.
Returns
non-zero multi-indices as tensor of shape (nnz/length, vector)
def numpy(self, tensor)
-
Returns a NumPy representation of the given tensor. If
tensor
is already a NumPy array, it is returned without modification.This method raises an error if the value of the tensor is not known at this point, e.g. because it represents a node in a graph. Use
is_available(tensor)
to check if the value can be represented as a NumPy array.Args
tensor
- backend-compatible tensor or sparse tensor
Returns
NumPy representation of the values stored in the tensor
def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args)
-
This call can be used in jit-compiled code but is not differentiable.
Args
f
- Function operating on numpy arrays.
output_shapes
- Single shape
tuple
or tuple of shapes declaring the shapes of the tensors returned byf
. output_dtypes
- Single
DType
or tuple of DTypes declaring the dtypes of the tensors returned byf
. *args
- Tensor arguments to be converted to NumPy arrays and then passed to
f
. **aux_args
- Keyword arguments to be passed to
f
without conversion.
Returns
Returned arrays of
f
converted to tensors. def ones(self, shape, dtype: phiml.backend._dtype.DType = None)
def ones_like(self, tensor)
def pad(self, value, pad_width, mode='constant', constant_values=0)
-
Pad a tensor with values as specified by
mode
andconstant_values
.If the mode is not supported, returns NotImplemented.
Args
value
- tensor
pad_width
- 2D tensor specifying the number of values padded to the edges of each axis in the form [[axis 0 lower, axis 0 upper], …] including batch and component axes.
mode
- constant', 'boundary', 'periodic', 'symmetric', 'reflect'
constant_values
- Scalar value used for out-of-bounds points if mode='constant'. Must be a Python primitive type or scalar tensor.
mode
- str: (Default value = 'constant')
Returns
padded tensor or
NotImplemented
def prefers_channels_last(self) ‑> bool
def prod(self, value, axis=None)
def quantile(self, x, quantiles)
-
Reduces the last / inner axis of x.
Args
x
- Tensor
quantiles
- List or 1D tensor of quantiles to compute.
Returns
Tensor with shape (quantiles, *x.shape[:-1])
def random_normal(self, shape, dtype: phiml.backend._dtype.DType)
-
Float tensor of selected precision containing random values sampled from a normal distribution with mean 0 and std 1.
def random_uniform(self, shape, low, high, dtype: Optional[phiml.backend._dtype.DType])
-
Float tensor of selected precision containing random values in the range [0, 1)
def range(self, start, limit=None, delta=1, dtype: phiml.backend._dtype.DType = int32)
def ravel_multi_index(self, multi_index, shape, mode: Union[str, int] = 'undefined')
-
Args
multi_index
- (batch…, index_dim)
shape
- 1D tensor or tuple/list
mode
'undefined'
,'periodic'
,'clamp'
or anint
to use for all invalid indices.
Returns
Integer tensor of shape (batch…) of same dtype as
multi_index
. def repeat(self, x, repeats, axis: int, new_length=None)
-
Repeats the elements along
axis
repeats
times.Args
x
- Tensor
repeats
- How often to repeat each element. 1D tensor of length x.shape[axis]
axis
- Which axis to repeat elements along
new_length
- Set the length of
axis
after repeating. This is required for jit compilation with Jax.
Returns
repeated Tensor
def requires_fixed_shapes_when_tracing(self) ‑> bool
def reshape(self, value, shape)
def scatter(self, base_grid, indices, values, mode: str)
-
Batched n-dimensional scatter.
Args
base_grid
- Tensor into which scatter values are inserted at indices. Tensor of shape (batch_size, spatial…, channels)
indices
- Tensor of shape (batch_size or 1, update_count, index_vector)
values
- Values to scatter at indices. Tensor of shape (batch_size or 1, update_count or 1, channels or 1)
mode
- One of ('update', 'add', 'max', 'min')
Returns
Copy of base_grid with values at
indices
updated byvalues
. def searchsorted(self, sorted_sequence, search_values, side: str, dtype=int32)
def seed(self, seed: int)
def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool)
-
Args
matrix
- (batch_size, rows, cols)
rhs
- (batch_size, cols)
lower: unit_diagonal:
Returns
(batch_size, cols)
def sort(self, x, axis=-1)
def sparse_coo_tensor(self, indices: Union[tuple, list], values, shape: tuple)
-
Create a sparse matrix in coordinate list (COO) format.
Optional feature.
See Also:
Backend.csr_matrix()
,Backend.csc_matrix()
.Args
indices
- 2D tensor of shape
(nnz, dims)
. values
- 1D values tensor matching
indices
shape
- Shape of the sparse matrix
Returns
Native representation of the sparse matrix
def std(self, x, axis=None, keepdims=False)
def sum(self, value, axis=None, keepdims=False)
def svd(self, matrix: ~TensorType, full_matrices=True) ‑> Tuple[~TensorType, ~TensorType, ~TensorType]
-
Args
matrix
- (batch…, m, n)
Returns
eigenvalues
- (batch…, n,)
eigenvectors
- (batch…, n, n)
def tensordot(self, a, a_axes: Union[tuple, list], b, b_axes: Union[tuple, list])
-
Multiply-sum-reduce a_axes of a with b_axes of b.
def to_dlpack(self, tensor)
def unique(self, x: ~TensorType, return_inverse: bool, return_counts: bool, axis: int) ‑> Tuple[~TensorType, ...]
-
Args
x
- n-dimensional int array. Will compare
axis
-slices ofx
for multidimensionalx
. return_inverse
- Whether to return the inverse
return_counts
- Whether to return the counts.
axis
- Axis along which slices of
x
should be compared.
Returns
unique_slices
- Sorted unique slices of
x
unique_inverse
- (optional) index of the unique slice for each slice of
x
unique_counts
- Number of occurrences of each unique slices
def unravel_index(self, flat_index, shape)
def vectorized_call(self, f, *args, output_dtypes=None, **aux_args)
-
Args
f
- Function with only positional tensor argument, returning one or multiple tensors.
*args
- Batched inputs for
f
. The first dimension of allargs
is vectorized. All tensors inargs
must have the same size or1
in their first dimension. output_dtypes
- Single
DType
or tuple of DTypes declaring the dtypes of the tensors returned byf
. **aux_args
- Non-vectorized keyword arguments to be passed to
f
.
def where(self, condition, x=None, y=None)
def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[int, ...], List[int]])
-
If
max_iter is None
, runswhile any(values[0]): values = loop(*values) return values
This operation does not support backpropagation.
Args
loop
- Loop function, must return a
tuple
with entries equal tovalues
in shape and data type. values
- Initial values of loop variables.
max_iter
- Maximum number of iterations to run, single
int
or sequence of integers.
Returns
Loop variables upon loop completion if
max_iter
is a single integer. Ifmax_iter
is a sequence, stacks the variables after each entry inmax_iter
, adding an outer dimension of size<= len(max_iter)
. If the condition is fulfilled before the maximum max_iter is reached, the loop may be broken or not, depending on the implementation. If the loop is broken, the values returned by the last loop are expected to be constant and filled. def zeros(self, shape, dtype: phiml.backend._dtype.DType = None)
def zeros_like(self, tensor)