Billiards in ΦFlow¶

This notebook reproduces the billiards example described in the DiffTaichi paper using ΦFlow.

In [1]:
# !pip install phiflow

The import statement decides whether to use TensorFlow, PyTorch or Jax.

In [2]:
# from phi.tf.flow import *
from phi.torch.flow import *
# from phi.jax.stax.flow import *

First, we create the typical billiards triangle with four layers and a ball radius of 0.03. We use the ΦFlow geometry class Sphere to represent the balls.

In [3]:
def billiards_triangle(billiard_layers=4, radius=.03):
    coords = []
    for i in range(billiard_layers):
        for j in range(i + 1):
            coords.append(vec(x=i * 2 * radius + 0.5, y=j * 2 * radius + 0.5 - i * radius * 0.7))
    return Sphere(stack(coords, instance('balls')), radius=radius)
  
plot(billiards_triangle())
Out[3]:

Next, we define the dynamics consisting of linear movement and collisions with an elasticity of 0.8. From here on, we will represent the system using a phi.field.PointCloud object storing the spheres and their velocities. Below, we test the step with a simple pair collision.

In [4]:
def physics_step(v: PointCloud, dt: float, elasticity=0.8):
    v_next = advect.points(v, v, dt)
    dist = v_next.points - rename_dims(v_next.points, 'balls', 'others')
    dist_norm = math.vec_length(dist, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
    rel_v = v.values - rename_dims(v.values, 'balls', 'others')
    dist_dir = math.safe_div(dist, dist_norm)
    projected_v = dist_dir.vector * rel_v.vector
    has_impact = (projected_v < 0) & (dist_norm < 2 * v.geometry.radius)
    impulse = -(1 + elasticity) * .5 * projected_v * dist_dir
    radius_sum = v.geometry.radius + rename_dims(v.geometry.radius, 'balls', 'others')
    impact_time = math.safe_div(dist_norm - radius_sum, projected_v)
    x_inc_contrib = math.sum(math.where(has_impact, math.minimum(impact_time - dt, 0) * impulse, 0), 'others')
    v = v.with_elements(v.geometry.shifted(x_inc_contrib))
    v += math.sum(math.where(has_impact, impulse, 0), 'others')
    return advect.points(v, v, dt)

# Simple animated test for physics_step
balls = Sphere(tensor([(0, .03), (.5, 0)], instance('balls'), channel(vector='x,y')), radius=.03)
ball_v = PointCloud(balls, tensor([(1., 0), (0, 0)], shape(balls)))
trj = iterate(physics_step, batch(t=20), ball_v, f_kwargs={'dt': .05})
plot(trj.geometry, animate='t')
/tmp/ipykernel_2099/1958517515.py:4: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm
  dist_norm = math.vec_length(dist, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
Out[4]:
Your browser does not support the video tag.
<Figure size 640x480 with 0 Axes>

Now, we can define our objective. We stick to the definition of the DiffTaichi paper, desiring the last ball in the triangle to reach the target position (x=0.9, y=0.75). The loss function sets up our system consisting of the 10 balls in triangle formation plus the cue ball. The initial position and velocity of the cue ball is specified via the parameters x0 and v0. 1024 time steps of the simulation are performed before the distance $L^2$ loss is computed and returned.

In [5]:
def loss_function(x0: Tensor, v0: Tensor, goal=vec(x=0.9, y=0.75), steps=1024):
    triangle_balls = PointCloud(billiards_triangle()) * (0, 0)
    controllable_ball = PointCloud(Sphere(expand(x0, instance(triangle_balls).with_size(1)), radius=triangle_balls.geometry.radius)) * v0
    all_balls = controllable_ball & triangle_balls
    trj = iterate(physics_step, batch(t=steps), all_balls, f_kwargs={'dt': 0.003})
    return math.l2_loss(trj.t[-1].balls[-1] - goal), trj

Let's plot one simulated trajectory! We only show every 16th frame.

In [6]:
plot(loss_function(x0=vec(x=.1, y=.5), v0=vec(x=.3, y=0))[1].t[::16].geometry, animate='t')
/tmp/ipykernel_2099/1958517515.py:4: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm
  dist_norm = math.vec_length(dist, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
Out[6]:
Your browser does not support the video tag.
<Figure size 640x480 with 0 Axes>

Now, we can plot the loss landscape depending on the angle $\alpha$ of v0. We create a scan plot by sampling a CenteredGrid from a function. The resulting plot matches Figure 11 from the DiffTaichi paper with small deviations. We noticed the loss landscape is highly dependent on the number of simulation steps.

In [7]:
x0 = vec(x=.1, y=.5)
v0 = lambda alpha: vec(x=0.3 * math.cos(alpha), y=0.3 * math.sin(alpha))
plot(CenteredGrid(lambda alpha: loss_function(x0, v0(alpha))[0], alpha=1000, bounds=Box(alpha=(-PI/4, PI/4))))
/tmp/ipykernel_2099/1958517515.py:4: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm
  dist_norm = math.vec_length(dist, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
Out[7]:

Next, we perform simple gradient descent optimization on the loss landscape. However, due to the chaotic loss landscape, this performs rather poorly. We plot the found trajectory after 5 iterations along with the goal point (X).

In [8]:
loss_grad = math.gradient(loss_function, 'x0,v0')
x0 = vec(x=.1, y=.5)
v0 = vec(x=.3, y=0)
learning_rate = .01
for iter in range(5):
    (loss, trj), (dx0, dv0) = loss_grad(x0, v0)
    print(f"Iter={iter}  loss={loss:.3f}  x0={x0}  ∇={dx0}  v0={v0}  ∇={dv0}")
    x0 -= learning_rate * dx0
    v0 -= learning_rate * dv0
final_loss, trj = loss_function(x0, v0)
print(f"Final loss: {final_loss}")
plot(vis.overlay(trj.t[::16].geometry, vec(x=0.9, y=0.75)), animate='t')
/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/phiml/math/_functional.py:630: RuntimeWarning: Using torch for gradient computation because numpy does not support jacobian()
  warnings.warn(f"Using {math.default_backend()} for gradient computation because {key.backend} does not support jacobian()", RuntimeWarning)
/tmp/ipykernel_2099/1958517515.py:4: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm
  dist_norm = math.vec_length(dist, eps=1e-4)  # eps to avoid NaN during backprop of sqrt
Iter=0  loss=0.6501343  x0=(x=0.100, y=0.500) float64  ∇=(x=-0.654, y=1.529)  v0=(x=0.300, y=0.000)  ∇=(x=-0.889, y=1.734)
Iter=1  loss=0.6250768  x0=(x=0.107, y=0.485)  ∇=(x=-0.656, y=-0.973)  v0=(x=0.309, y=-0.017)  ∇=(x=-0.980, y=-1.262)
Iter=2  loss=0.63563234  x0=(x=0.113, y=0.494)  ∇=(x=-1.158, y=0.959)  v0=(x=0.319, y=-0.005)  ∇=(x=-1.381, y=0.953)
Iter=3  loss=0.6203982  x0=(x=0.125, y=0.485)  ∇=(x=-1.096, y=-0.586)  v0=(x=0.333, y=-0.014)  ∇=(x=-1.306, y=-0.703)
Iter=4  loss=0.6252239  x0=(x=0.136, y=0.491)  ∇=(x=-1.408, y=0.618)  v0=(x=0.346, y=-0.007)  ∇=(x=-1.462, y=0.491)
Final loss: 0.6168283
Out[8]:
Your browser does not support the video tag.
<Figure size 640x480 with 0 Axes>