This notebook reproduces the billiards example described in the DiffTaichi paper using ΦFlow.
# !pip install phiflow
The import statement decides whether to use TensorFlow, PyTorch or Jax.
# from phi.tf.flow import *
from phi.torch.flow import *
# from phi.jax.stax.flow import *
First, we create the typical billiards triangle with four layers and a ball radius of 0.03.
We use the ΦFlow geometry class Sphere
to represent the balls.
def billiards_triangle(billiard_layers=4, radius=.03):
coords = []
for i in range(billiard_layers):
for j in range(i + 1):
coords.append(vec(x=i * 2 * radius + 0.5, y=j * 2 * radius + 0.5 - i * radius * 0.7))
return Sphere(stack(coords, instance('balls')), radius=radius)
plot(billiards_triangle())
Next, we define the dynamics consisting of linear movement and collisions with an elasticity of 0.8.
From here on, we will represent the system using a phi.field.PointCloud
object storing the spheres and their velocities.
Below, we test the step with a simple pair collision.
def physics_step(v: PointCloud, dt: float, elasticity=0.8):
v_next = advect.points(v, v, dt)
dist = v_next.points - rename_dims(v_next.points, 'balls', 'others')
dist_norm = math.vec_length(dist, eps=1e-4) # eps to avoid NaN during backprop of sqrt
rel_v = v.values - rename_dims(v.values, 'balls', 'others')
dist_dir = math.safe_div(dist, dist_norm)
projected_v = dist_dir.vector * rel_v.vector
has_impact = (projected_v < 0) & (dist_norm < 2 * v.geometry.radius)
impulse = -(1 + elasticity) * .5 * projected_v * dist_dir
radius_sum = v.geometry.radius + rename_dims(v.geometry.radius, 'balls', 'others')
impact_time = math.safe_div(dist_norm - radius_sum, projected_v)
x_inc_contrib = math.sum(math.where(has_impact, math.minimum(impact_time - dt, 0) * impulse, 0), 'others')
v = v.with_elements(v.geometry.shifted(x_inc_contrib))
v += math.sum(math.where(has_impact, impulse, 0), 'others')
return advect.points(v, v, dt)
# Simple animated test for physics_step
balls = Sphere(tensor([(0, .03), (.5, 0)], instance('balls'), channel(vector='x,y')), radius=.03)
ball_v = PointCloud(balls, tensor([(1., 0), (0, 0)], shape(balls)))
trj = iterate(physics_step, batch(t=20), ball_v, f_kwargs={'dt': .05})
plot(trj.geometry, animate='t')
/tmp/ipykernel_2003/1958517515.py:4: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm dist_norm = math.vec_length(dist, eps=1e-4) # eps to avoid NaN during backprop of sqrt
<Figure size 640x480 with 0 Axes>
Now, we can define our objective. We stick to the definition of the DiffTaichi paper, desiring the last ball in the triangle to reach the target position (x=0.9, y=0.75).
The loss function sets up our system consisting of the 10 balls in triangle formation plus the cue ball. The initial position and velocity of the cue ball is specified via the parameters x0
and v0
.
1024 time steps of the simulation are performed before the distance $L^2$ loss is computed and returned.
def loss_function(x0: Tensor, v0: Tensor, goal=vec(x=0.9, y=0.75), steps=1024):
triangle_balls = PointCloud(billiards_triangle()) * (0, 0)
controllable_ball = PointCloud(Sphere(expand(x0, instance(triangle_balls).with_size(1)), radius=triangle_balls.geometry.radius)) * v0
all_balls = controllable_ball & triangle_balls
trj = iterate(physics_step, batch(t=steps), all_balls, f_kwargs={'dt': 0.003})
return math.l2_loss(trj.t[-1].balls[-1] - goal), trj
Let's plot one simulated trajectory! We only show every 16th frame.
plot(loss_function(x0=vec(x=.1, y=.5), v0=vec(x=.3, y=0))[1].t[::16].geometry, animate='t')
/tmp/ipykernel_2003/1958517515.py:4: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm dist_norm = math.vec_length(dist, eps=1e-4) # eps to avoid NaN during backprop of sqrt
<Figure size 640x480 with 0 Axes>
Now, we can plot the loss landscape depending on the angle $\alpha$ of v0
. We create a scan plot by sampling a CenteredGrid
from a function.
The resulting plot matches Figure 11 from the DiffTaichi paper with small deviations.
We noticed the loss landscape is highly dependent on the number of simulation steps.
x0 = vec(x=.1, y=.5)
v0 = lambda alpha: vec(x=0.3 * math.cos(alpha), y=0.3 * math.sin(alpha))
plot(CenteredGrid(lambda alpha: loss_function(x0, v0(alpha))[0], alpha=1000, bounds=Box(alpha=(-PI/4, PI/4))))
/tmp/ipykernel_2003/1958517515.py:4: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm dist_norm = math.vec_length(dist, eps=1e-4) # eps to avoid NaN during backprop of sqrt
Next, we perform simple gradient descent optimization on the loss landscape. However, due to the chaotic loss landscape, this performs rather poorly. We plot the found trajectory after 5 iterations along with the goal point (X).
loss_grad = math.gradient(loss_function, 'x0,v0')
x0 = vec(x=.1, y=.5)
v0 = vec(x=.3, y=0)
learning_rate = .01
for iter in range(5):
(loss, trj), (dx0, dv0) = loss_grad(x0, v0)
print(f"Iter={iter} loss={loss:.3f} x0={x0} ∇={dx0} v0={v0} ∇={dv0}")
x0 -= learning_rate * dx0
v0 -= learning_rate * dv0
final_loss, trj = loss_function(x0, v0)
print(f"Final loss: {final_loss}")
plot(vis.overlay(trj.t[::16].geometry, vec(x=0.9, y=0.75)), animate='t')
/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/phiml/math/_functional.py:628: RuntimeWarning: Using torch for gradient computation because numpy does not support jacobian() warnings.warn(f"Using {math.default_backend()} for gradient computation because {key.backend} does not support jacobian()", RuntimeWarning) /tmp/ipykernel_2003/1958517515.py:4: DeprecationWarning: phiml.math.length is deprecated in favor of phiml.math.norm dist_norm = math.vec_length(dist, eps=1e-4) # eps to avoid NaN during backprop of sqrt
Iter=0 loss=0.6501343 x0=(x=0.100, y=0.500) ∇=(x=-0.654, y=1.529) v0=(x=0.300, y=0.000) ∇=(x=-0.889, y=1.734) Iter=1 loss=0.6250768 x0=(x=0.107, y=0.485) ∇=(x=-0.656, y=-0.973) v0=(x=0.309, y=-0.017) ∇=(x=-0.980, y=-1.262) Iter=2 loss=0.63563234 x0=(x=0.113, y=0.494) ∇=(x=-1.158, y=0.959) v0=(x=0.319, y=-0.005) ∇=(x=-1.381, y=0.953) Iter=3 loss=0.6203982 x0=(x=0.125, y=0.485) ∇=(x=-1.096, y=-0.586) v0=(x=0.333, y=-0.014) ∇=(x=-1.306, y=-0.703) Iter=4 loss=0.6252239 x0=(x=0.136, y=0.491) ∇=(x=-1.408, y=0.618) v0=(x=0.346, y=-0.007) ∇=(x=-1.462, y=0.491) Final loss: 0.6168283
<Figure size 640x480 with 0 Axes>