ΦML Performance¶

Colab   •   🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples

In [1]:
%%capture
!pip install phiml

Performance and JIT-compilation¶

ΦML provides many conveniences to make your code more concise and expressive. It also includes various performance optimization checks. All of this comes at the cost of added overhead for dimension matching, under-the-hood reshaping, etc. However, this overhead gets eliminated once your code is JIT-compiled with PyTorch, TensorFlow or Jax.

In this notebook, we measure the differences on a rigid body simulation inspired by Billiards. Spheres are moving on a table and can collide with the boundary as well as with other spheres. The below function defines the physics step using the Φ-ML API.

In [2]:
import time
from phiml.math import math, rename_dims, instance, dual, iterate, batch, channel


@math.jit_compile
def physics_step(x, v, dt: float, elasticity=0.8, radius=.03):
    x_next = x + v * dt
    deltas = -math.pairwise_distances(x_next)
    dist = math.vec_length(deltas, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
    rel_v = -math.pairwise_distances(v)
    dist_dir = math.safe_div(deltas, dist)
    projected_v = dist_dir.vector * rel_v.vector
    has_impact = (projected_v < 0) & (dist < 2 * radius)
    impulse = -(1 + elasticity) * .5 * projected_v * dist_dir
    radius_sum = radius + rename_dims(radius, instance, dual)
    impact_time = math.safe_div(dist - radius_sum, projected_v)
    x_inc_contrib = math.sum(math.where(has_impact, math.minimum(impact_time - dt, 0) * impulse, 0), dual)
    x += x_inc_contrib
    v += math.sum(math.where(has_impact, impulse, 0), dual)
    v = math.where((x < 0) | (x > 2), -v, v)
    return x + v * dt, v


for backend in ['torch', 'jax', 'tensorflow']:
    math.use(backend)
    x0 = math.random_uniform(instance(points=1000), channel(vector='x,y'), high=2)
    v0 = math.random_normal(x0.shape)
    (x_trj, v_trj), dt = iterate(physics_step, batch(t=200), x0, v0, f_kwargs={'dt': .05}, measure=time.perf_counter)
    print(f"Φ-ML + {backend} JIT compilation: {dt.t[0]}")
    print(f"Φ-ML + {backend} execution average: {dt.t[2:].mean} +- {dt.t[2:].std}")
/tmp/ipykernel_3363/4125686289.py:9: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm
  dist = math.vec_length(deltas, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
/tmp/ipykernel_3363/4125686289.py:9: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm
  dist = math.vec_length(deltas, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
Φ-ML + torch JIT compilation: float64 0.17295930900002077
Φ-ML + torch execution average: 0.03068203723737388 +- 0.0015217601792174532
/tmp/ipykernel_3363/4125686289.py:9: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm
  dist = math.vec_length(deltas, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
Φ-ML + jax JIT compilation: float64 0.1507311049999771
Φ-ML + jax execution average: 0.010426139373737327 +- 0.004407840839214921
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1778411760.504172    3363 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1778411761.758507    3363 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
I0000 00:00:1778411761.758830    3363 cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
E0000 00:00:1778411762.319543    3363 cuda_platform.cc:52] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
Φ-ML + tensorflow JIT compilation: float64 14.91661241600002
Φ-ML + tensorflow execution average: 0.07933342476262613 +- 0.0007065608764509045

The performance numbers printed above were measured on GitHub Actions and may fluctuate a lot, depending on the allotted CPU processing power. We provide reference GPU numbers below.

Native Implementations¶

Next, let's implement the same simulation natively, without using PhiML to compare the performance.

In [3]:
import time
import jax
from jax import numpy as np


def safe_div(x, y):
    return np.where(y == 0, 0, x / y)


@jax.jit
def physics_step(x, v, dt: float, elasticity=0.8, radius=.03):
    x_next = x + v * dt
    deltas = x_next[..., None, :] - x_next[..., None, :, :]
    dist = np.sqrt(np.maximum(np.sum(deltas ** 2, -1), 1e-4))  # eps=1e-4 to avoid NaN during backprop of sqrt
    rel_v = v[..., None, :] - v[..., None, :, :]
    dist_dir = safe_div(deltas, dist[..., None])
    projected_v = np.sum(dist_dir * rel_v, -1)
    has_impact = (projected_v < 0) & (dist < 2 * radius)
    impulse = -(1 + elasticity) * .5 * projected_v[..., None] * dist_dir
    radius_sum = radius + radius  # this only supports equal radii
    impact_time = safe_div(dist - radius_sum, projected_v)
    x_inc_contrib = np.sum(np.where(has_impact[..., None], np.minimum(impact_time[..., None] - dt, 0) * impulse, 0), -2)
    x += x_inc_contrib
    v += np.sum(np.where(has_impact[..., None], impulse, 0), -2)
    v = np.where((x < 0) | (x > 2), -v, v)
    return x + v * dt, v


x0 = jax.random.uniform(jax.random.PRNGKey(0), shape=(1000, 2)) * 2
v0 = jax.random.normal(jax.random.PRNGKey(1), shape=x0.shape)
x_trj = [x0]
v_trj = [v0]
dt_jax = []
t0 = time.perf_counter()
for i in range(200):
    x, v = physics_step(x_trj[-1], v_trj[-1], dt=.05)
    x_trj.append(x)
    v_trj.append(v)
    t = time.perf_counter()
    dt_jax.append(t - t0)
    t0 = t
x_trj = np.stack(x_trj)
v_trj = np.stack(v_trj)
print(f"jax JIT compilation: {dt_jax[0]}")
print(f"jax execution average: {np.mean(np.asarray(dt_jax[2:]))}")
jax JIT compilation: 0.13584853800000474
jax execution average: 0.009756988762626062
In [4]:
import time
import torch


def safe_div(x, y):
    return torch.where(y == 0, torch.zeros_like(x), x / y)


def physics_step_torch(x, v, dt: float = .05, elasticity=0.8, radius=0.03):
    x_next = x + v * dt
    deltas = x_next.unsqueeze(-2) - x_next.unsqueeze(-3)
    dist = torch.sqrt(torch.maximum(torch.sum(deltas ** 2, -1), torch.tensor(1e-4)))  # eps=1e-4 to avoid NaN during backprop of sqrt
    rel_v = v.unsqueeze(-2) - v.unsqueeze(-3)
    dist_dir = safe_div(deltas, dist.unsqueeze(-1))
    projected_v = torch.sum(dist_dir * rel_v, -1)
    has_impact = (projected_v < 0) & (dist < 2 * radius)
    impulse = -(1 + elasticity) * 0.5 * projected_v.unsqueeze(-1) * dist_dir
    radius_sum = radius + radius  # this only supports equal radii
    impact_time = safe_div(dist - radius_sum, projected_v)
    x_inc_contrib = torch.sum(torch.where(has_impact.unsqueeze(-1), torch.minimum(impact_time.unsqueeze(-1) - dt, torch.tensor(0.0)) * impulse, torch.tensor(0.0)), -2)
    x += x_inc_contrib
    v += torch.sum(torch.where(has_impact.unsqueeze(-1), impulse, torch.tensor(0.0)), -2)
    v = torch.where((x < 0) | (x > 2), -v, v)
    return x + v * dt, v


x0 = torch.rand((1000, 2)) * 2
v0 = torch.randn_like(x0)
physics_step_torch = torch.jit.trace(physics_step_torch, [x0, v0])
x_trj = [x0]
v_trj = [v0]
dt_torch = []
t0 = time.perf_counter()
for i in range(200):
    x, v = physics_step_torch(x_trj[-1], v_trj[-1])
    x_trj.append(x)
    v_trj.append(v)
    t = time.perf_counter()
    dt_torch.append(t - t0)
    t0 = t
x_trj = torch.stack(x_trj)
v_trj = torch.stack(v_trj)
dt_torch = torch.tensor(dt_torch)
print(f"torch JIT compilation: {dt_torch[0]}")
print(f"torch execution average: {torch.mean(torch.tensor(dt_torch[2:]))}")
/tmp/ipykernel_3363/3571425526.py:12: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  dist = torch.sqrt(torch.maximum(torch.sum(deltas ** 2, -1), torch.tensor(1e-4)))  # eps=1e-4 to avoid NaN during backprop of sqrt
/tmp/ipykernel_3363/3571425526.py:20: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  x_inc_contrib = torch.sum(torch.where(has_impact.unsqueeze(-1), torch.minimum(impact_time.unsqueeze(-1) - dt, torch.tensor(0.0)) * impulse, torch.tensor(0.0)), -2)
/tmp/ipykernel_3363/3571425526.py:22: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  v += torch.sum(torch.where(has_impact.unsqueeze(-1), impulse, torch.tensor(0.0)), -2)
torch JIT compilation: 0.03101467154920101
torch execution average: 0.026800286024808884
/tmp/ipykernel_3363/3571425526.py:45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  print(f"torch execution average: {torch.mean(torch.tensor(dt_torch[2:]))}")
In [5]:
import time
import tensorflow as tf

@tf.function
def physics_step(x, v, dt: float, elasticity=0.8, radius=0.03):
    x_next = x + v * dt
    deltas = tf.expand_dims(x_next, -2) - tf.expand_dims(x_next, -3)
    dist = tf.sqrt(tf.maximum(tf.reduce_sum(deltas ** 2, -1), 1e-4))  # eps=1e-4 to avoid NaN during backprop of sqrt
    rel_v = tf.expand_dims(v, -2) - tf.expand_dims(v, -3)
    dist_dir = tf.math.divide_no_nan(deltas, tf.expand_dims(dist, -1))
    projected_v = tf.reduce_sum(dist_dir * rel_v, -1)
    has_impact = tf.logical_and(projected_v < 0, dist < 2 * radius)
    impulse = -(1 + elasticity) * 0.5 * tf.expand_dims(projected_v, -1) * dist_dir
    radius_sum = radius + radius  # this only supports equal radii
    impact_time = tf.math.divide_no_nan(dist - radius_sum, projected_v)
    x_inc_contrib = tf.reduce_sum(tf.where(tf.expand_dims(has_impact, -1), tf.minimum(tf.expand_dims(impact_time, -1) - dt, 0) * impulse, 0), -2)
    x += x_inc_contrib
    v += tf.reduce_sum(tf.where(tf.expand_dims(has_impact, -1), impulse, 0), -2)
    v = tf.where(tf.logical_or(x < 0, x > 2), -v, v)
    return x + v * dt, v

x0 = tf.random.uniform((1000, 2)) * 2
v0 = tf.random.normal(shape=x0.shape)
x_trj = [x0]
v_trj = [v0]
dt_tf = []
t0 = time.perf_counter()
for i in range(200):
    x, v = physics_step(x_trj[-1], v_trj[-1], dt=.05)
    x_trj.append(x)
    v_trj.append(v)
    t = time.perf_counter()
    dt_tf.append(t - t0)
    t0 = t
x_trj = tf.stack(x_trj)
v_trj = tf.stack(v_trj)
dt_tf = tf.constant(dt_tf)
print(f"tensorflow JIT compilation: {dt_tf[0]}")
print(f"tensorflow execution average: {tf.reduce_mean(dt_tf)}")
tensorflow JIT compilation: 0.17361828684806824
tensorflow execution average: 0.04744492471218109

Summary¶

Let's compare the JIT compilation performance first. The following numbers were captured on a 12-core AMD processor.

Library Native Φ-ML
PyTorch 59 123
Jax 152 165
TensorFlow 187 938

Evidently, Φ-ML adds a small overhead to the JIT compilation in PyTorch and Jax, due to the additional shape handling operations. The overhead with TensorFlow is much larger. This is due to the way @tf.function is implemented. It does not simply trace the function with placeholder tensors but utilizes a custom Python code interpreter. While this allows it to capture control flow (if/else/for) more accurately, this results in a large part of the Φ-ML codebase being executed at a much reduced speed. Importantly, this compilation usually needs to performed only once. All later function invocations can call the already-compiled code.

For the execution performance, we measured the following numbers on an NVIDIA RTX 2070 SUPER (ms/step).

Library Native Φ-ML
PyTorch 45.2 5.0
Jax 19.5 16.1
TensorFlow 23.6 24.3

Here, Φ-ML beats the performance of our native PyTorch implementation and is on-par with TensorFlow and Jax. As discussed above, all extra code of Φ-ML is completely optimized out during JIT compilation, resulting in similar compiled code.

The fact that Φ-ML is faster than PyTorch may be down to some inefficiency in the native PyTorch implementation above. This was largely generated by ChatGPT based on our function, but we had to make some changes in order for the function to output the correct result. If you find a problem with the code, please let us know!

Further Reading¶

As discussed, optimal performance requires the use of just-in-time compilation.

ΦML is compatible with PyTorch, TensorFlow, and Jax. Which backend and device you use has a major impact on performance.

🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples