%%capture
!pip install phiml
ΦML can execute your instructions using Jax, PyTorch, TensorFlow or NumPy. Which library is used generally depends on what tensors you pass to it. Calling a function with PyTorch tensors will always invoke the corresponding PyTorch routine.
Let's first look at the function math.use()
which lets you set a global default backend.
from phiml import math
math.use('jax')
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax
From now on, new tensors created by ΦML will be backed by Jax arrays.
math.random_normal().default_backend
jax
jax_tensor = math.tensor([0, 1])
jax_tensor.default_backend
jax
However, passing a tensor backed by a different backend will not automatically convert it to a Jax tensor.
numpy_tensor = math.wrap([0, 1])
math.sin(numpy_tensor).default_backend
numpy
When calling a function with both NumPy and non-NumPy tensors, the NumPy tensors will be automatically converted to the ML backend.
(numpy_tensor + math.tensor([2, 3])).default_backend
jax
However, conversion between ML backends needs to be performed by the user.
import torch
torch_tensor = math.wrap(torch.tensor([2, 3]))
torch_tensor.default_backend
torch
try:
(jax_tensor + torch_tensor).default_backend
except Exception as exc:
print(exc)
Could not resolve backend for native types ['ArrayImpl', 'Tensor']
We can move tensors between different backends using math.convert()
.
This will use DLPack under-the-hood when converting between ML backends.
For the target backend, you can pass the module, module name or Backend object.
import torch
math.convert(jax_tensor, torch).default_backend
torch
math.convert(jax_tensor, 'torch').default_backend
torch
from phiml.backend.torch import TORCH
from phiml.backend.jax import JAX
from phiml.backend.tensorflow import TENSORFLOW
math.convert(jax_tensor, TORCH).default_backend
2024-10-18 15:39:53.379507: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
torch
If you change the global default backend or use a with backend:
block, convert()
will default to the default.
with TORCH:
jax_to_torch = math.convert(jax_tensor)
jax_to_torch.default_backend
torch
NumPy functions are not differentiable but it nevertheless plays an important role in representing constants in your code.
🌐 ΦML • 📖 Documentation • 🔗 API • ▶ Videos • Examples