%%capture
!pip install phiml
ΦML provides many conveniences to make your code more concise and expressive. It also includes various performance optimization checks. All of this comes at the cost of added overhead for dimension matching, under-the-hood reshaping, etc. However, this overhead gets eliminated once your code is JIT-compiled with PyTorch, TensorFlow or Jax.
In this notebook, we measure the differences on a rigid body simulation inspired by Billiards. Spheres are moving on a table and can collide with the boundary as well as with other spheres. The below function defines the physics step using the Φ-ML API.
import time
from phiml.math import math, rename_dims, instance, dual, iterate, batch, channel
@math.jit_compile
def physics_step(x, v, dt: float, elasticity=0.8, radius=.03):
x_next = x + v * dt
deltas = -math.pairwise_distances(x_next)
dist = math.vec_length(deltas, eps=1e-4) # eps to avoid NaN during backprop of sqrt
rel_v = -math.pairwise_distances(v)
dist_dir = math.safe_div(deltas, dist)
projected_v = dist_dir.vector * rel_v.vector
has_impact = (projected_v < 0) & (dist < 2 * radius)
impulse = -(1 + elasticity) * .5 * projected_v * dist_dir
radius_sum = radius + rename_dims(radius, instance, dual)
impact_time = math.safe_div(dist - radius_sum, projected_v)
x_inc_contrib = math.sum(math.where(has_impact, math.minimum(impact_time - dt, 0) * impulse, 0), dual)
x += x_inc_contrib
v += math.sum(math.where(has_impact, impulse, 0), dual)
v = math.where((x < 0) | (x > 2), -v, v)
return x + v * dt, v
for backend in ['torch', 'jax', 'tensorflow']:
math.use(backend)
x0 = math.random_uniform(instance(points=1000), channel(vector='x,y'), high=2)
v0 = math.random_normal(x0.shape)
(x_trj, v_trj), dt = iterate(physics_step, batch(t=200), x0, v0, f_kwargs={'dt': .05}, measure=time.perf_counter)
print(f"Φ-ML + {backend} JIT compilation: {dt.t[0]}")
print(f"Φ-ML + {backend} execution average: {dt.t[2:].mean} +- {dt.t[2:].std}")
/tmp/ipykernel_2724/4125686289.py:9: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm dist = math.vec_length(deltas, eps=1e-4) # eps to avoid NaN during backprop of sqrt /tmp/ipykernel_2724/4125686289.py:9: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm dist = math.vec_length(deltas, eps=1e-4) # eps to avoid NaN during backprop of sqrt
Φ-ML + torch JIT compilation: float64 0.2283915 Φ-ML + torch execution average: 0.035401053726673126 +- 0.004196503199636936
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /tmp/ipykernel_2724/4125686289.py:9: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm dist = math.vec_length(deltas, eps=1e-4) # eps to avoid NaN during backprop of sqrt
Φ-ML + jax JIT compilation: float64 0.1806284 Φ-ML + jax execution average: 0.012228960171341896 +- 0.001105412608012557
2024-12-19 17:10:40.843211: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Φ-ML + tensorflow JIT compilation: float64 17.477774 Φ-ML + tensorflow execution average: 0.05659555643796921 +- 0.002101015066727996
The performance numbers printed above were measured on GitHub Actions and may fluctuate a lot, depending on the allotted CPU processing power. We provide reference GPU numbers below.
Next, let's implement the same simulation natively, without using PhiML to compare the performance.
import time
import jax
from jax import numpy as np
def safe_div(x, y):
return np.where(y == 0, 0, x / y)
@jax.jit
def physics_step(x, v, dt: float, elasticity=0.8, radius=.03):
x_next = x + v * dt
deltas = x_next[..., None, :] - x_next[..., None, :, :]
dist = np.sqrt(np.maximum(np.sum(deltas ** 2, -1), 1e-4)) # eps=1e-4 to avoid NaN during backprop of sqrt
rel_v = v[..., None, :] - v[..., None, :, :]
dist_dir = safe_div(deltas, dist[..., None])
projected_v = np.sum(dist_dir * rel_v, -1)
has_impact = (projected_v < 0) & (dist < 2 * radius)
impulse = -(1 + elasticity) * .5 * projected_v[..., None] * dist_dir
radius_sum = radius + radius # this only supports equal radii
impact_time = safe_div(dist - radius_sum, projected_v)
x_inc_contrib = np.sum(np.where(has_impact[..., None], np.minimum(impact_time[..., None] - dt, 0) * impulse, 0), -2)
x += x_inc_contrib
v += np.sum(np.where(has_impact[..., None], impulse, 0), -2)
v = np.where((x < 0) | (x > 2), -v, v)
return x + v * dt, v
x0 = jax.random.uniform(jax.random.PRNGKey(0), shape=(1000, 2)) * 2
v0 = jax.random.normal(jax.random.PRNGKey(1), shape=x0.shape)
x_trj = [x0]
v_trj = [v0]
dt_jax = []
t0 = time.perf_counter()
for i in range(200):
x, v = physics_step(x_trj[-1], v_trj[-1], dt=.05)
x_trj.append(x)
v_trj.append(v)
t = time.perf_counter()
dt_jax.append(t - t0)
t0 = t
x_trj = np.stack(x_trj)
v_trj = np.stack(v_trj)
print(f"jax JIT compilation: {dt_jax[0]}")
print(f"jax execution average: {np.mean(np.asarray(dt_jax[2:]))}")
jax JIT compilation: 0.14002356800006055 jax execution average: 0.010168181111110916
import time
import torch
def safe_div(x, y):
return torch.where(y == 0, torch.zeros_like(x), x / y)
def physics_step_torch(x, v, dt: float = .05, elasticity=0.8, radius=0.03):
x_next = x + v * dt
deltas = x_next.unsqueeze(-2) - x_next.unsqueeze(-3)
dist = torch.sqrt(torch.maximum(torch.sum(deltas ** 2, -1), torch.tensor(1e-4))) # eps=1e-4 to avoid NaN during backprop of sqrt
rel_v = v.unsqueeze(-2) - v.unsqueeze(-3)
dist_dir = safe_div(deltas, dist.unsqueeze(-1))
projected_v = torch.sum(dist_dir * rel_v, -1)
has_impact = (projected_v < 0) & (dist < 2 * radius)
impulse = -(1 + elasticity) * 0.5 * projected_v.unsqueeze(-1) * dist_dir
radius_sum = radius + radius # this only supports equal radii
impact_time = safe_div(dist - radius_sum, projected_v)
x_inc_contrib = torch.sum(torch.where(has_impact.unsqueeze(-1), torch.minimum(impact_time.unsqueeze(-1) - dt, torch.tensor(0.0)) * impulse, torch.tensor(0.0)), -2)
x += x_inc_contrib
v += torch.sum(torch.where(has_impact.unsqueeze(-1), impulse, torch.tensor(0.0)), -2)
v = torch.where((x < 0) | (x > 2), -v, v)
return x + v * dt, v
x0 = torch.rand((1000, 2)) * 2
v0 = torch.randn_like(x0)
physics_step_torch = torch.jit.trace(physics_step_torch, [x0, v0])
x_trj = [x0]
v_trj = [v0]
dt_torch = []
t0 = time.perf_counter()
for i in range(200):
x, v = physics_step_torch(x_trj[-1], v_trj[-1])
x_trj.append(x)
v_trj.append(v)
t = time.perf_counter()
dt_torch.append(t - t0)
t0 = t
x_trj = torch.stack(x_trj)
v_trj = torch.stack(v_trj)
dt_torch = torch.tensor(dt_torch)
print(f"torch JIT compilation: {dt_torch[0]}")
print(f"torch execution average: {torch.mean(torch.tensor(dt_torch[2:]))}")
/tmp/ipykernel_2724/3571425526.py:12: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. dist = torch.sqrt(torch.maximum(torch.sum(deltas ** 2, -1), torch.tensor(1e-4))) # eps=1e-4 to avoid NaN during backprop of sqrt /tmp/ipykernel_2724/3571425526.py:20: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. x_inc_contrib = torch.sum(torch.where(has_impact.unsqueeze(-1), torch.minimum(impact_time.unsqueeze(-1) - dt, torch.tensor(0.0)) * impulse, torch.tensor(0.0)), -2) /tmp/ipykernel_2724/3571425526.py:22: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. v += torch.sum(torch.where(has_impact.unsqueeze(-1), impulse, torch.tensor(0.0)), -2)
torch JIT compilation: 0.0365150049328804 torch execution average: 0.033097729086875916
/tmp/ipykernel_2724/3571425526.py:45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). print(f"torch execution average: {torch.mean(torch.tensor(dt_torch[2:]))}")
import time
import tensorflow as tf
@tf.function
def physics_step(x, v, dt: float, elasticity=0.8, radius=0.03):
x_next = x + v * dt
deltas = tf.expand_dims(x_next, -2) - tf.expand_dims(x_next, -3)
dist = tf.sqrt(tf.maximum(tf.reduce_sum(deltas ** 2, -1), 1e-4)) # eps=1e-4 to avoid NaN during backprop of sqrt
rel_v = tf.expand_dims(v, -2) - tf.expand_dims(v, -3)
dist_dir = tf.math.divide_no_nan(deltas, tf.expand_dims(dist, -1))
projected_v = tf.reduce_sum(dist_dir * rel_v, -1)
has_impact = tf.logical_and(projected_v < 0, dist < 2 * radius)
impulse = -(1 + elasticity) * 0.5 * tf.expand_dims(projected_v, -1) * dist_dir
radius_sum = radius + radius # this only supports equal radii
impact_time = tf.math.divide_no_nan(dist - radius_sum, projected_v)
x_inc_contrib = tf.reduce_sum(tf.where(tf.expand_dims(has_impact, -1), tf.minimum(tf.expand_dims(impact_time, -1) - dt, 0) * impulse, 0), -2)
x += x_inc_contrib
v += tf.reduce_sum(tf.where(tf.expand_dims(has_impact, -1), impulse, 0), -2)
v = tf.where(tf.logical_or(x < 0, x > 2), -v, v)
return x + v * dt, v
x0 = tf.random.uniform((1000, 2)) * 2
v0 = tf.random.normal(shape=x0.shape)
x_trj = [x0]
v_trj = [v0]
dt_tf = []
t0 = time.perf_counter()
for i in range(200):
x, v = physics_step(x_trj[-1], v_trj[-1], dt=.05)
x_trj.append(x)
v_trj.append(v)
t = time.perf_counter()
dt_tf.append(t - t0)
t0 = t
x_trj = tf.stack(x_trj)
v_trj = tf.stack(v_trj)
dt_tf = tf.constant(dt_tf)
print(f"tensorflow JIT compilation: {dt_tf[0]}")
print(f"tensorflow execution average: {tf.reduce_mean(dt_tf)}")
tensorflow JIT compilation: 0.20959796011447906 tensorflow execution average: 0.040268540382385254
Let's compare the JIT compilation performance first. The following numbers were captured on a 12-core AMD processor.
Library | Native | Φ-ML |
---|---|---|
PyTorch | 59 | 123 |
Jax | 152 | 165 |
TensorFlow | 187 | 938 |
Evidently, Φ-ML adds a small overhead to the JIT compilation in PyTorch and Jax, due to the additional shape handling operations.
The overhead with TensorFlow is much larger. This is due to the way @tf.function
is implemented. It does not simply trace the function with placeholder tensors but utilizes a custom Python code interpreter. While this allows it to capture control flow (if/else/for
) more accurately, this results in a large part of the Φ-ML codebase being executed at a much reduced speed.
Importantly, this compilation usually needs to performed only once. All later function invocations can call the already-compiled code.
For the execution performance, we measured the following numbers on an NVIDIA RTX 2070 SUPER (ms/step).
Library | Native | Φ-ML |
---|---|---|
PyTorch | 45.2 | 5.0 |
Jax | 19.5 | 16.1 |
TensorFlow | 23.6 | 24.3 |
Here, Φ-ML beats the performance of our native PyTorch implementation and is on-par with TensorFlow and Jax. As discussed above, all extra code of Φ-ML is completely optimized out during JIT compilation, resulting in similar compiled code.
The fact that Φ-ML is faster than PyTorch may be down to some inefficiency in the native PyTorch implementation above. This was largely generated by ChatGPT based on our function, but we had to make some changes in order for the function to output the correct result. If you find a problem with the code, please let us know!
As discussed, optimal performance requires the use of just-in-time compilation.
ΦML is compatible with PyTorch, TensorFlow, and Jax. Which backend and device you use has a major impact on performance.
🌐 ΦML • 📖 Documentation • 🔗 API • ▶ Videos • Examples