Using Multiple Backends via ΦML¶

Colab   •   🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples

In [1]:
%%capture
!pip install phiml

How Backends are Chosen¶

ΦML can execute your instructions using Jax, PyTorch, TensorFlow or NumPy. Which library is used generally depends on what tensors you pass to it. Calling a function with PyTorch tensors will always invoke the corresponding PyTorch routine.

Let's first look at the function math.use() which lets you set a global default backend.

In [2]:
from phiml import math

math.use('jax')
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Out[2]:
jax

From now on, new tensors created by ΦML will be backed by Jax arrays.

In [3]:
math.random_normal().default_backend
Out[3]:
jax
In [4]:
jax_tensor = math.tensor([0, 1])
jax_tensor.default_backend
Out[4]:
jax

However, passing a tensor backed by a different backend will not automatically convert it to a Jax tensor.

In [5]:
numpy_tensor = math.wrap([0, 1])
math.sin(numpy_tensor).default_backend
Out[5]:
numpy

When calling a function with both NumPy and non-NumPy tensors, the NumPy tensors will be automatically converted to the ML backend.

In [6]:
(numpy_tensor + math.tensor([2, 3])).default_backend
Out[6]:
jax

However, conversion between ML backends needs to be performed by the user.

In [7]:
import torch

torch_tensor = math.wrap(torch.tensor([2, 3]))
torch_tensor.default_backend
Out[7]:
torch
In [8]:
try:
    (jax_tensor + torch_tensor).default_backend
except Exception as exc:
    print(exc)
Could not resolve backend for native types ['ArrayImpl', 'Tensor']

Converting Tensors¶

We can move tensors between different backends using math.convert(). This will use DLPack under-the-hood when converting between ML backends.

For the target backend, you can pass the module, module name or Backend object.

In [9]:
import torch

math.convert(jax_tensor, torch).default_backend
Out[9]:
torch
In [10]:
math.convert(jax_tensor, 'torch').default_backend
Out[10]:
torch
In [11]:
from phiml.backend.torch import TORCH
from phiml.backend.jax import JAX
from phiml.backend.tensorflow import TENSORFLOW

math.convert(jax_tensor, TORCH).default_backend
2025-05-06 16:14:06.300205: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Out[11]:
torch

If you change the global default backend or use a with backend: block, convert() will default to the default.

In [12]:
with TORCH:
    jax_to_torch = math.convert(jax_tensor)
jax_to_torch.default_backend
Out[12]:
torch

Further Reading¶

NumPy functions are not differentiable but it nevertheless plays an important role in representing constants in your code.

🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples