Module phi.torch.flow
Standard import for PyTorch mode.
Extends the import from phi.flow import *
by PyTorch-related functions and modules.
The following PyTorch modules are included: torch
, torch.nn.functional as torchf
, optim
.
Importing this module registers the PyTorch backend as the default backend unless called within a backend context.
New tensors created via phiml.math
functions will be backed by PyTorch tensors.
See phi.flow
, phi.tf.flow
, phi.jax.flow
.
Expand source code
# pylint: disable-msg = wildcard-import, unused-wildcard-import, unused-import
"""
Standard import for PyTorch mode.
Extends the import `from phi.flow import *` by PyTorch-related functions and modules.
The following PyTorch modules are included: `torch`, *torch.nn.functional* as `torchf`, `optim`.
Importing this module registers the PyTorch backend as the default backend unless called within a backend context.
New tensors created via `phiml.math` functions will be backed by PyTorch tensors.
See `phi.flow`, `phi.tf.flow`, `phi.jax.flow`.
"""
from phi.flow import *
from . import TORCH
from . import nets
from .nets import parameter_count, get_parameters, save_state, load_state, dense_net, u_net, update_weights, adam, conv_net, res_net, sgd, sgd as SGD, rmsprop, adagrad, conv_classifier, invertible_net
import torch
import torch.nn.functional as torchf
import torch.optim as optim
if not backend.context_backend():
backend.set_global_default_backend(TORCH)
else:
from phiml.backend import ML_LOGGER as _LOGGER
_LOGGER.warning(f"Importing '{__name__}' within a backend context will not set the default backend.")