%%capture
!pip install phiml
from phiml import math
When you create a new Tensor in ΦML without specifying what backend to use, it will create a NumPy tensor.
You can check the corresponding backend for a Tensor
using .default_backend
.
math.random_uniform().default_backend
numpy
The same is true if you wrap
a NumPy array, even after specifying the default backend.
math.use('jax')
import numpy
math.wrap(numpy.asarray([0, 1, 2])).default_backend
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
numpy
Tensors backed by NumPy are not differentiable, only run on the CPU, and functions acting on NumPy tensors cannot be JIT-compiled. So why would you ever want to use NumPy?
Put simply, tensors backed by NumPy represent constants in your computational graph. If you JIT-compile a function and call it with PyTorch tensors, all recorded PyTorch calls will be executed each time. All NumPy calls, however, will only be executed during tracing and never again (unless the function needs to be re-traced).
Here's an example:
@math.jit_compile
def fun(a):
b = math.wrap([1, 2])
print(f"Tracing with a = {a}, b = {b}")
return a + math.sin(b ** 2)
fun(math.tensor([1, 2]))
Tracing with a = (vectorᶜ=2) int64 jax tracer, b = (1, 2) int64
(1.841, 1.243)
Since fun
is JIT-compiled, the Python function will always be called with tracer objects for a
.
However, since b
is a wrapped NumPy array, b ** 2
can be computed during JIT-compile time using NumPy.
Consequently, the computational graph generated by Jax or PyTorch or TensorFlow only contains one multiplication.
As NumPy tensors are represented the same way in ΦML as ML tensors, we can change its dependency on variables later without modifying the later code.
@math.jit_compile
def fun(a):
b = a
print(f"Tracing with a = {a}, b = {b}")
return a + math.sin(b ** 2)
fun(math.tensor([1, 2]))
Tracing with a = (vectorᶜ=2) int64 jax tracer, b = (vectorᶜ=2) int64 jax tracer
(1.841, 1.243)
Now, b
is a JIT-compiled tensor that is tracked in the computational graph.
Importantly, this principle also applies to complex functions, such as simulations.
Say you have a simulation sim(x)
.
Then for a fixed label
, the loss |sim(prediction)-sim(label)|
only needs to compute sim(prediction)
while the result of sim(label)
is pre-computed while tracing the function.
NumPy also plays an important role in tracing linear functions to obtain an explicit sparse matrix representation. If the linear function does not depend on any ML tensors, it will be represented as a NumPy array or SciPy sparse matrix, even when the default backend is not NumPy. New tensors created inside the linear function will also default to NumPy, overriding the global default backend.
def lin(x):
two = math.tensor(2)
return two * x.x[:-1]
matrix = math.matrix_from_function(lin, math.zeros(math.spatial(x=3)))[0]
math.print(matrix, f"Backend: {matrix.default_backend}")
matrix.native()
Backend: numpy x=0 2. 0. 0. along ~x x=1 0. 2. 0. along ~x
<2x3 sparse matrix of type '<class 'numpy.float32'>' with 2 stored elements in COOrdinate format>
This allows preconditioners to be computed at JIT-compile time. If the linear function depends on ML tensors, the matrix will be represented as a corresponding sparse tensor.
def lin(x, a):
return a * x.x[:-1]
matrix = math.matrix_from_function(lin, math.zeros(math.spatial(x=3)), math.tensor(2))[0]
math.print(matrix, f"Backend: {matrix.default_backend}")
matrix.native()
Backend: jax x=0 2. 0. 0. along ~x x=1 0. 2. 0. along ~x
BCOO(float32[2, 3], nse=2)