Module phi.jax.stax.flow
Standard import for Jax + Stax mode.
Extends the import from phi.flow import *
by Jax-related functions and modules.
The following Jax modules are included: jax
, jax.numpy
as jnp
, jax.scipy
as jsp
.
Importing this module registers the Jax backend as the default backend unless called within a backend context.
New tensors created via phiml.math
functions will be backed by Jax tensors.
See phi.flow
, phi.torch.flow
, phi.tf.flow
.