ΦML Examples¶

Colab   •   🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos

All examples listed in this notebook can be run with any of the three supported ML backends: jax, torch and tensorflow. You can select your preferred one by changing the below math.use() call.

In [1]:
# !pip install phiml
from phiml import math, nn
from phiml.math import channel, spatial, batch, instance

math.use('jax')

from matplotlib import pyplot as plt
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Training an MLP¶

The following script trains an MLP with three hidden layers to learn a noisy 1D sine function in the range [-2, 2].

In [2]:
net = nn.mlp(1, 1, layers=[128, 128, 128], activation='ReLU')
optimizer = nn.adam(net, learning_rate=1e-3)

data_x = math.random_uniform(batch(batch=128), low=-2, high=2)
data_y = math.sin(data_x) + math.random_normal(batch(batch=128)) * .2

def loss_function(x, y):
    return math.l2_loss(y - math.native_call(net, x))

for i in range(10):
    loss = nn.update_weights(net, optimizer, loss_function, data_x, data_y)
    print(loss)
(batchᵇ=128) 0.370 ± 0.261 (2e-04...1e+00)
(batchᵇ=128) 0.312 ± 0.223 (9e-05...1e+00)
(batchᵇ=128) 0.248 ± 0.184 (3e-05...8e-01)
(batchᵇ=128) 0.189 ± 0.149 (6e-07...6e-01)
(batchᵇ=128) 0.134 ± 0.116 (2e-05...5e-01)
(batchᵇ=128) 0.088 ± 0.088 (1e-04...4e-01)
(batchᵇ=128) 0.053 ± 0.062 (1e-06...3e-01)
(batchᵇ=128) 0.032 ± 0.039 (9e-06...2e-01)
(batchᵇ=128) 0.031 ± 0.040 (3e-06...2e-01)
(batchᵇ=128) 0.047 ± 0.069 (2e-05...3e-01)

We didn't even have to import torch in this example since all calls were routed through ΦML. See the introduction to network training for more details.

Simulating the Heat Equation¶

The heat equation $\partial_t u = \nu \nabla^2 u$ describes heat convection, where $\nu$ denotes the diffusivity and $\nabla^2$ is the laplace operator. Discretizing $\partial_t u = (u_{t+1} - u_t) / \Delta t$ yields $u_{t+1} = u_t + \Delta t \cdot \nu \nabla^2 u$. Decorating the function with math.jit_compile_linear makes it run faster as it is compiled to a matrix.

In [3]:
@math.jit_compile_linear
def explicit_heat_step(u, dt, dx, diffusivity):
    return u + dt * diffusivity * math.laplace(u, dx, 'periodic')

Rewriting the above equation to implicit form, we have $u_{t+1} - \Delta t \cdot \nu \nabla^2 u = u_t$. Notice how this equation looks like a time-reversed explicit step with $u_t$ and $u_{t+1}$ swapped. This important property is always true for differential equations.

Rewriting the equation to $(1 - \Delta t \cdot \nu \nabla^2) u_{t+1} = u_t$ yields a linear system of equations for $u_{t+1}$. Advancing the simulation then becomes solving this linear system where the matrix encodes the time-reversed explicit step. While an implicit step is technically linear, the corresponding matrix would be dense (and typically very large). Therefore, we use math.jit_compile here.

In [4]:
@math.jit_compile
def implicit_heat_step(u, dx, dt, diffusivity):
    return math.solve_linear(explicit_heat_step, u, math.Solve(x0=u), dt=-dt, dx=dx, diffusivity=diffusivity)

Let's run the simulation in 1D.

In [5]:
x = math.linspace(0, 1, spatial(x=100))
u0 = math.exp(-.5 * (x-.5) ** 2 / .1 ** 2)
u_trj = math.iterate(implicit_heat_step, batch(t=8), u0, dx=1/100, dt=.01, diffusivity=1.)
plt.plot(u_trj.numpy('x,t'));
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:640: UserWarning: Possible rank deficiency detected. Matrix might be singular which can lead to convergence problems. Please specify using Solve(rank_deficiency=...).
  warnings.warn("Possible rank deficiency detected. Matrix might be singular which can lead to convergence problems. Please specify using Solve(rank_deficiency=...).")

Let's run the simulation in 2D. In fact, the simulation works in n dimensions.

In [6]:
x = math.linspace(0, 1, spatial(x=100, y=100))
u0 = math.exp(-.5 * math.vec_squared(x-.5) / .1 ** 2)
u_trj = math.iterate(implicit_heat_step, batch(t=8), u0, dx=1/100, dt=.01, diffusivity=1.)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.imshow(u0.numpy('y,x'))
ax2.imshow(u_trj.t[-1].numpy('y,x'));
/tmp/ipykernel_2488/2300467920.py:2: DeprecationWarning: phiml.math.vec_squared is deprecated in favor of phiml.math.squared_norm
  u0 = math.exp(-.5 * math.vec_squared(x-.5) / .1 ** 2)
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:640: UserWarning: Possible rank deficiency detected. Matrix might be singular which can lead to convergence problems. Please specify using Solve(rank_deficiency=...).
  warnings.warn("Possible rank deficiency detected. Matrix might be singular which can lead to convergence problems. Please specify using Solve(rank_deficiency=...).")

Simulating Burgers' Equation¶

Burgers' equation $\partial_t u = \nu \nabla^2 u - u \nabla u$ is a convection-diffusion equation. Notice the non-linear convection term $u \nabla u$. When writing Burgers' equation as a matrix equation, we need to linearize it. This can be done by using the previous state for the first $u$ term (Picard 1).

As with the heat equation above, we first write an explicit Euler step. We use an upwind scheme.

In [7]:
@math.jit_compile_linear
def explicit_burgers(u, u_prev, dt, diffusivity, boundary='periodic'):
    dx = 1 / math.wrap(spatial(u), channel(u))
    left, right = math.shift(u_prev, (-1, 1), padding=boundary, stack_dim=channel('adv'))
    positive_winds = math.maximum(0, (u_prev + left) / 2)
    negative_winds = math.minimum(0, (u_prev + right) / 2)
    forward_diff = math.spatial_gradient(u, dx, 'forward', padding=boundary, stack_dim=channel('adv'))
    backward_diff = math.spatial_gradient(u, dx, 'backward', padding=boundary, stack_dim=channel('adv'))
    advection = math.sum(backward_diff * positive_winds + forward_diff * negative_winds, 'adv')  # sum required for n-d
    diffusion = diffusivity * math.laplace(u, padding=boundary)
    return u + dt * (diffusion - advection)
In [8]:
@math.jit_compile
def implicit_burgers(u_prev, dt, diffusivity=1e-4):
    return math.solve_linear(explicit_burgers, u_prev, math.Solve(x0=u_prev), u_prev=u_prev, dt=-dt, diffusivity=diffusivity)

Let's run the implicit simulation in 1D.

In [9]:
x = math.linspace(0, 1, spatial(x=100))
u0 = math.sin(x * 2 * math.PI) + 0.3
u_trj = math.iterate(implicit_burgers, batch(t=80), u0, dt=0.01)
plt.plot(u_trj.t[::10].numpy('x,t'));

Pairwise Distances¶

The following function takes a (possibly batched) tensor of positions and computes the distance matrix.

In [10]:
from phiml import math  # uses NumPy by default

def pairwise_distances(x: math.Tensor):
    dx = math.rename_dims(x, 'points', 'others') - x
    return math.vec_length(dx)

x = math.random_normal(math.instance(points=3), math.channel(vector="x,y"))
math.print(pairwise_distances(x))
[[0.       , 3.3728998, 1.9318242],
 [3.3728998, 0.       , 1.8491019],
 [1.9318242, 1.8491019, 0.       ]]
/tmp/ipykernel_2488/2042670748.py:5: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm
  return math.vec_length(dx)

Inside pairwise_distances, we rename 'points' to 'others'. When taking the difference, ΦML automatically expands both operands by the missing dimensions, adding 'points' to the first argument and 'others' to the second. An explanation of this automatic reshaping is given here.

Automatic Differentiation¶

Next, let's compute the gradient of some function of (x,y) w.r.t. x.

In [11]:
from phiml import math
math.use('jax')

def function(x, y):
    return x ** 2 * y

gradient_x = math.gradient(function, wrt='x', get_output=False)
print(gradient_x(2, 1))
4.0
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_functional.py:630: RuntimeWarning: Using jax for gradient computation because numpy does not support jacobian()
  warnings.warn(f"Using {math.default_backend()} for gradient computation because {key.backend} does not support jacobian()", RuntimeWarning)

JIT compilation¶

ΦML provides two types of JIT compilation: the generic jit_compile calls the corresponding library function while jit_compile_linear builds an explicit representation for linear functions.

In [12]:
from phiml import math
math.use('tensorflow')

@math.jit_compile(auxiliary_args='divide_by_y')
def function(x, y, divide_by_y=False):
    if divide_by_y:
        return x ** 2 / y
    else:
        return x ** 2 * y

function(math.tensor(2), 2, False)
2025-05-06 16:14:29.363344: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Out[12]:
8

Here, we declare divide_by_y as an auxiliary argument to force the function to be re-traced when its value changes. Otherwise, its concrete value would not be available inside the function and could not be used within an if clause.

JIT compilation of linear functions is also supported on NumPy.

In [13]:
from phiml import math
math.use('numpy')

@math.jit_compile_linear(auxiliary_args='compute_laplace')
def optional_sp_grad(x, compute_gradient):
    if compute_gradient:
        return math.spatial_gradient(x)
    else:
        return -x

optional_sp_grad(math.linspace(0, 1, math.spatial(x=10)), True)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[13], line 11
      8     else:
      9         return -x
---> 11 optional_sp_grad(math.linspace(0, 1, math.spatial(x=10)), True)

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_functional.py:455, in LinearFunction.__call__(self, *args, **kwargs)
    453     else:
    454         return self.nl_jit(*args, **kwargs)
--> 455 matrix, bias, (out_tree, out_tensors) = self._get_or_trace(key, args, aux_kwargs)
    456 result = matrix @ tensors[0] + bias
    457 out_tensors = list(out_tensors)

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_functional.py:426, in LinearFunction._get_or_trace(self, key, args, f_kwargs)
    424 _TRACING_LINEAR.append(self)
    425 try:
--> 426     matrix, bias, raw_out = matrix_from_function(self.f, *args, **f_kwargs, auto_compress=True, _return_raw_output=True)
    427 finally:
    428     assert _TRACING_LINEAR.pop(-1) is self

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_trace.py:729, in matrix_from_function(f, auxiliary_args, auto_compress, sparsify_batch, separate_independent, _return_raw_output, *args, **kwargs)
    727     matrix, bias = to_sparse_tracer(tracer, None)._get_matrix(sparsify), tracer._bias
    728 else:
--> 729     matrix, bias = tracer_to_coo(tracer, sparsify, separate_independent)
    730 # --- Compress ---
    731 if matrix.backend.name == 'torch' and matrix._values._native.requires_grad:

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_trace.py:773, in tracer_to_coo(tracer, sparsify, separate_independent)
    771     mask = np.sum(abs(native_shift_values), 0)  # only 0 where no batch entry has a non-zero value
    772     out_idx = numpy.nonzero(mask)
--> 773     src_idx = [(component + shift_[dim]) % typed_src_shape.get_size(dim) for component, dim in zip(out_idx, original_out_names)]
    774     values.append(native_shift_values[(slice(None), *out_idx)])
    775 else:  # add full stencil tensor

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_trace.py:773, in <listcomp>(.0)
    771     mask = np.sum(abs(native_shift_values), 0)  # only 0 where no batch entry has a non-zero value
    772     out_idx = numpy.nonzero(mask)
--> 773     src_idx = [(component + shift_[dim]) % typed_src_shape.get_size(dim) for component, dim in zip(out_idx, original_out_names)]
    774     values.append(native_shift_values[(slice(None), *out_idx)])
    775 else:  # add full stencil tensor

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_shape.py:1007, in Dim.get_size(self, dim, default)
   1005 if default is not None:
   1006     return default
-> 1007 raise KeyError(f"get_size() failed because '{dim}' is not part of {self} and no default value was provided")

KeyError: "get_size() failed because 'gradient' is not part of (xˢ=10) and no default value was provided"

Here, an explicit sparse matrix representation of optional_sp_grad is computed each time a new value of compute_gradient is passed.

Solving a sparse linear system with preconditioners¶

ΦML supports solving dense as well as sparse linear systems and can build an explicit matrix representation from linear Python functions in order to compute preconditioners. We recommend using ΦML's tensors, but you can pass native tensors to solve_linear() as well. The following example solves the 1D Poisson problem ∇x = b with b=1 with incomplete LU decomposition.

In [14]:
from phiml import math
import numpy as np

def laplace_1d(x):
    return math.pad(x[1:], (0, 1)) + math.pad(x[:-1], (1, 0)) - 2 * x

b = np.ones((6,))
solve = math.Solve('scipy-CG', rel_tol=1e-5, x0=0*b, preconditioner='ilu')
sol = math.solve_linear(math.jit_compile_linear(laplace_1d), b, solve)
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/scipy/sparse/linalg/_dsolve/linsolve.py:640: SparseEfficiencyWarning: CSR matrix format is required. Converting to CSR matrix.
  warn('CSR matrix format is required. Converting to CSR matrix.',

Decorating the linear function with math.jit_compile_linear lets ΦML compute the sparse matrix inside solve_linear(). In this example, the matrix is a tridiagonal band matrix. Note that if you JIT-compile the math.solve_linear() call, the sparsity pattern and incomplete LU preconditioner are computed at JIT time. The L and U matrices then enter the computational graph as constants and are not recomputed every time the function is called.

Average of Neighbor Cells¶

With ΦML, you can write functions that work in any number of spatial dimensions. The function neighbor_mean computes the average of all direct neighbor cells (2 in 1D, 4 in 2D, 6 in 3D). We can inspect the derived factors by printing the function as a matrix.

In [15]:
def neighbor_mean(grid):
    left, right = math.shift(grid, (-1, 1), padding=math.extrapolation.PERIODIC)
    return math.mean([left, right], math.non_spatial)

math.print(math.matrix_from_function(neighbor_mean, math.zeros(math.spatial(x=5)))[0])
x=0     0.   0.5  0.   0.   0.5  along ~x
x=1     0.5  0.   0.5  0.   0.   along ~x
x=2     0.   0.5  0.   0.5  0.   along ~x
x=3     0.   0.   0.5  0.   0.5  along ~x
x=4     0.5  0.   0.   0.5  0.   along ~x

Further Reading¶

Check out the 🚀 quickstart guide for an introduction into tensors and dimensions.

🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos