• 🌐 ΦML • 📖 Documentation • 🔗 API • ▶ Videos
All examples listed in this notebook can be run with any of the three supported ML backends: jax
, torch
and tensorflow
.
You can select your preferred one by changing the below math.use()
call.
# !pip install phiml
from phiml import math, nn
from phiml.math import channel, spatial, batch, instance
math.use('jax')
from matplotlib import pyplot as plt
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
net = nn.mlp(1, 1, layers=[128, 128, 128], activation='ReLU')
optimizer = nn.adam(net, learning_rate=1e-3)
data_x = math.random_uniform(batch(batch=128), low=-2, high=2)
data_y = math.sin(data_x) + math.random_normal(batch(batch=128)) * .2
def loss_function(x, y):
return math.l2_loss(y - math.native_call(net, x))
for i in range(10):
loss = nn.update_weights(net, optimizer, loss_function, data_x, data_y)
print(loss)
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/jax/stax_nets.py:96: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead. xs, _ = tree_flatten(net.parameters)
(batchᵇ=128) 0.370 ± 0.261 (2e-04...1e+00) (batchᵇ=128) 0.312 ± 0.223 (9e-05...1e+00) (batchᵇ=128) 0.248 ± 0.184 (3e-05...8e-01) (batchᵇ=128) 0.189 ± 0.149 (6e-07...6e-01) (batchᵇ=128) 0.134 ± 0.116 (2e-05...5e-01) (batchᵇ=128) 0.088 ± 0.088 (1e-04...4e-01) (batchᵇ=128) 0.053 ± 0.062 (1e-06...3e-01) (batchᵇ=128) 0.032 ± 0.039 (9e-06...2e-01) (batchᵇ=128) 0.031 ± 0.040 (3e-06...2e-01) (batchᵇ=128) 0.047 ± 0.069 (2e-05...3e-01)
We didn't even have to import torch
in this example since all calls were routed through ΦML.
See the introduction to network training for more details.
The heat equation $\partial_t u = \nu \nabla^2 u$ describes heat convection, where $\nu$ denotes the diffusivity and $\nabla^2$ is the laplace operator.
Discretizing $\partial_t u = (u_{t+1} - u_t) / \Delta t$ yields $u_{t+1} = u_t + \Delta t \cdot \nu \nabla^2 u$.
Decorating the function with math.jit_compile_linear
makes it run faster as it is compiled to a matrix.
@math.jit_compile_linear
def explicit_heat_step(u, dt, dx, diffusivity):
return u + dt * diffusivity * math.laplace(u, dx, 'periodic')
Rewriting the above equation to implicit form, we have $u_{t+1} - \Delta t \cdot \nu \nabla^2 u = u_t$. Notice how this equation looks like a time-reversed explicit step with $u_t$ and $u_{t+1}$ swapped. This important property is always true for differential equations.
Rewriting the equation to $(1 - \Delta t \cdot \nu \nabla^2) u_{t+1} = u_t$ yields a linear system of equations for $u_{t+1}$.
Advancing the simulation then becomes solving this linear system where the matrix encodes the time-reversed explicit step.
While an implicit step is technically linear, the corresponding matrix would be dense (and typically very large). Therefore, we use math.jit_compile
here.
@math.jit_compile
def implicit_heat_step(u, dx, dt, diffusivity):
return math.solve_linear(explicit_heat_step, u, math.Solve(x0=u), dt=-dt, dx=dx, diffusivity=diffusivity)
Let's run the simulation in 1D.
x = math.linspace(0, 1, spatial(x=100))
u0 = math.exp(-.5 * (x-.5) ** 2 / .1 ** 2)
u_trj = math.iterate(implicit_heat_step, batch(t=8), u0, dx=1/100, dt=.01, diffusivity=1.)
plt.plot(u_trj.numpy('x,t'));
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:631: UserWarning: Possible rank deficiency detected. Matrix might be singular which can lead to convergence problems. Please specify using Solve(rank_deficiency=...). warnings.warn("Possible rank deficiency detected. Matrix might be singular which can lead to convergence problems. Please specify using Solve(rank_deficiency=...).")
Let's run the simulation in 2D. In fact, the simulation works in n dimensions.
x = math.linspace(0, 1, spatial(x=100, y=100))
u0 = math.exp(-.5 * math.vec_squared(x-.5) / .1 ** 2)
u_trj = math.iterate(implicit_heat_step, batch(t=8), u0, dx=1/100, dt=.01, diffusivity=1.)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.imshow(u0.numpy('y,x'))
ax2.imshow(u_trj.t[-1].numpy('y,x'));
/tmp/ipykernel_2266/2300467920.py:2: DeprecationWarning: phiml.math.vec_squared is deprecated in favor of phiml.math.squared_norm u0 = math.exp(-.5 * math.vec_squared(x-.5) / .1 ** 2) /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:631: UserWarning: Possible rank deficiency detected. Matrix might be singular which can lead to convergence problems. Please specify using Solve(rank_deficiency=...). warnings.warn("Possible rank deficiency detected. Matrix might be singular which can lead to convergence problems. Please specify using Solve(rank_deficiency=...).")
Burgers' equation $\partial_t u = \nu \nabla^2 u - u \nabla u$ is a convection-diffusion equation. Notice the non-linear convection term $u \nabla u$. When writing Burgers' equation as a matrix equation, we need to linearize it. This can be done by using the previous state for the first $u$ term (Picard 1).
As with the heat equation above, we first write an explicit Euler step. We use an upwind scheme.
@math.jit_compile_linear
def explicit_burgers(u, u_prev, dt, diffusivity, boundary='periodic'):
dx = 1 / math.wrap(spatial(u), channel(u))
left, right = math.shift(u_prev, (-1, 1), padding=boundary, stack_dim=channel('adv'))
positive_winds = math.maximum(0, (u_prev + left) / 2)
negative_winds = math.minimum(0, (u_prev + right) / 2)
forward_diff = math.spatial_gradient(u, dx, 'forward', padding=boundary, stack_dim=channel('adv'))
backward_diff = math.spatial_gradient(u, dx, 'backward', padding=boundary, stack_dim=channel('adv'))
advection = math.sum(backward_diff * positive_winds + forward_diff * negative_winds, 'adv') # sum required for n-d
diffusion = diffusivity * math.laplace(u, padding=boundary)
return u + dt * (diffusion - advection)
@math.jit_compile
def implicit_burgers(u_prev, dt, diffusivity=1e-4):
return math.solve_linear(explicit_burgers, u_prev, math.Solve(x0=u_prev), u_prev=u_prev, dt=-dt, diffusivity=diffusivity)
Let's run the implicit simulation in 1D.
x = math.linspace(0, 1, spatial(x=100))
u0 = math.sin(x * 2 * math.PI) + 0.3
u_trj = math.iterate(implicit_burgers, batch(t=80), u0, dt=0.01)
plt.plot(u_trj.t[::10].numpy('x,t'));
The following function takes a (possibly batched) tensor of positions and computes the distance matrix.
from phiml import math # uses NumPy by default
def pairwise_distances(x: math.Tensor):
dx = math.rename_dims(x, 'points', 'others') - x
return math.vec_length(dx)
x = math.random_normal(math.instance(points=3), math.channel(vector="x,y"))
math.print(pairwise_distances(x))
[[0. , 3.3728998, 1.9318242], [3.3728998, 0. , 1.8491019], [1.9318242, 1.8491019, 0. ]]
/tmp/ipykernel_2266/2042670748.py:5: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm return math.vec_length(dx)
Inside pairwise_distances
, we rename 'points' to 'others'.
When taking the difference, ΦML automatically expands both operands by the missing dimensions, adding 'points' to the first argument and 'others' to the second.
An explanation of this automatic reshaping is given here.
from phiml import math
math.use('jax')
def function(x, y):
return x ** 2 * y
gradient_x = math.gradient(function, wrt='x', get_output=False)
print(gradient_x(2, 1))
4.0
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_functional.py:628: RuntimeWarning: Using jax for gradient computation because numpy does not support jacobian() warnings.warn(f"Using {math.default_backend()} for gradient computation because {key.backend} does not support jacobian()", RuntimeWarning)
ΦML provides two types of JIT compilation: the generic jit_compile
calls the corresponding library function while jit_compile_linear
builds an explicit representation for linear functions.
from phiml import math
math.use('tensorflow')
@math.jit_compile(auxiliary_args='divide_by_y')
def function(x, y, divide_by_y=False):
if divide_by_y:
return x ** 2 / y
else:
return x ** 2 * y
function(math.tensor(2), 2, False)
2024-12-19 17:09:35.449165: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
8
Here, we declare divide_by_y
as an auxiliary argument to force the function to be re-traced when its value changes.
Otherwise, its concrete value would not be available inside the function and could not be used within an if
clause.
JIT compilation of linear functions is also supported on NumPy.
from phiml import math
math.use('numpy')
@math.jit_compile_linear(auxiliary_args='compute_laplace')
def optional_sp_grad(x, compute_gradient):
if compute_gradient:
return math.spatial_gradient(x)
else:
return -x
optional_sp_grad(math.linspace(0, 1, math.spatial(x=10)), True)
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) Cell In[13], line 11 8 else: 9 return -x ---> 11 optional_sp_grad(math.linspace(0, 1, math.spatial(x=10)), True) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_functional.py:453, in LinearFunction.__call__(self, *args, **kwargs) 451 else: 452 return self.nl_jit(*args, **kwargs) --> 453 matrix, bias, (out_tree, out_tensors) = self._get_or_trace(key, args, aux_kwargs) 454 result = matrix @ tensors[0] + bias 455 out_tensors = list(out_tensors) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_functional.py:424, in LinearFunction._get_or_trace(self, key, args, f_kwargs) 422 _TRACING_LINEAR.append(self) 423 try: --> 424 matrix, bias, raw_out = matrix_from_function(self.f, *args, **f_kwargs, auto_compress=True, _return_raw_output=True) 425 finally: 426 assert _TRACING_LINEAR.pop(-1) is self File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_trace.py:726, in matrix_from_function(f, auxiliary_args, auto_compress, sparsify_batch, separate_independent, _return_raw_output, *args, **kwargs) 724 matrix, bias = to_sparse_tracer(tracer, None)._get_matrix(sparsify_batch), tracer._bias 725 else: --> 726 matrix, bias = tracer_to_coo(tracer, sparsify_batch, separate_independent) 727 # --- Compress --- 728 if auto_compress and matrix.default_backend.supports(Backend.mul_csr_dense) and target_backend.supports(Backend.mul_csr_dense) and isinstance(matrix, SparseCoordinateTensor): File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_trace.py:771, in tracer_to_coo(tracer, sparsify_batch, separate_independent) 769 mask = np.sum(abs(native_shift_values), 0) # only 0 where no batch entry has a non-zero value 770 out_idx = numpy.nonzero(mask) --> 771 src_idx = [(component + shift_[dim]) % typed_src_shape.get_size(dim) for component, dim in zip(out_idx, original_out_names)] 772 values.append(native_shift_values[(slice(None), *out_idx)]) 773 else: # add full stencil tensor File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_trace.py:771, in <listcomp>(.0) 769 mask = np.sum(abs(native_shift_values), 0) # only 0 where no batch entry has a non-zero value 770 out_idx = numpy.nonzero(mask) --> 771 src_idx = [(component + shift_[dim]) % typed_src_shape.get_size(dim) for component, dim in zip(out_idx, original_out_names)] 772 values.append(native_shift_values[(slice(None), *out_idx)]) 773 else: # add full stencil tensor File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_shape.py:237, in Shape.get_size(self, dim, default) 235 if dim not in self.names: 236 if default is None: --> 237 raise KeyError(f"get_size() failed because '{dim}' is not part of Shape {self} and no default value was provided") 238 else: 239 return default KeyError: "get_size() failed because 'gradient' is not part of Shape (xˢ=10) and no default value was provided"
Here, an explicit sparse matrix representation of optional_sp_grad
is computed each time a new value of compute_gradient
is passed.
ΦML supports solving dense as well as sparse linear systems and can build an explicit matrix representation from linear Python functions in order to compute preconditioners.
We recommend using ΦML's tensors, but you can pass native tensors to solve_linear()
as well.
The following example solves the 1D Poisson problem ∇x = b with b=1 with incomplete LU decomposition.
from phiml import math
import numpy as np
def laplace_1d(x):
return math.pad(x[1:], (0, 1)) + math.pad(x[:-1], (1, 0)) - 2 * x
b = np.ones((6,))
solve = math.Solve('scipy-CG', rel_tol=1e-5, x0=0*b, preconditioner='ilu')
sol = math.solve_linear(math.jit_compile_linear(laplace_1d), b, solve)
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/scipy/sparse/linalg/_dsolve/linsolve.py:640: SparseEfficiencyWarning: CSR matrix format is required. Converting to CSR matrix. warn('CSR matrix format is required. Converting to CSR matrix.',
Decorating the linear function with math.jit_compile_linear
lets ΦML compute the sparse matrix inside solve_linear()
. In this example, the matrix is a tridiagonal band matrix.
Note that if you JIT-compile the math.solve_linear()
call, the sparsity pattern and incomplete LU preconditioner are computed at JIT time.
The L and U matrices then enter the computational graph as constants and are not recomputed every time the function is called.
With ΦML, you can write functions that work in any number of spatial dimensions.
The function neighbor_mean
computes the average of all direct neighbor cells (2 in 1D, 4 in 2D, 6 in 3D).
We can inspect the derived factors by printing the function as a matrix.
def neighbor_mean(grid):
left, right = math.shift(grid, (-1, 1), padding=math.extrapolation.PERIODIC)
return math.mean([left, right], math.non_spatial)
math.print(math.matrix_from_function(neighbor_mean, math.zeros(math.spatial(x=5)))[0])
x=0 0. 0.5 0. 0. 0.5 along ~x x=1 0.5 0. 0.5 0. 0. along ~x x=2 0. 0.5 0. 0.5 0. along ~x x=3 0. 0. 0.5 0. 0.5 along ~x x=4 0.5 0. 0. 0.5 0. along ~x
Check out the 🚀 quickstart guide for an introduction into tensors and dimensions.
🌐 ΦML • 📖 Documentation • 🔗 API • ▶ Videos