JIT Compilation in ΦML¶

Colab   •   🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples

Just-in-time (JIT) compilation can drastically speed up your code as Python-related overheads are eliminated and optimizations can be performed.

In [1]:
%%capture
!pip install phiml

from phiml import math

math.use('jax')

In ΦML, you can JIT-compile a function using the math.jit_compile() decorator.

In [2]:
@math.jit_compile
def fun(x):
    print(f"Tracing fun with x = {x}")
    return 2 * x

The first time the function is called with new arguments, it is traced, i.e. all tensor operations are recorded. Then, the passed arguments have concrete shapes but no concrete values. Consequently, traced tensors cannot be used in control flow, such as if or loop conditions. Replace if statements by math.where().

Depending on the used backend, the function may be called multiple times during tracing.

In [3]:
fun(math.tensor(1.))
Tracing fun with x = () float32 jax tracer
Out[3]:
2.0

Whenever the function is called with similar arguments to a previous call, the compiled version of the function is evaluated without calling the Python code. Instead, the previously recorded tensor operations are performed again on the new input.

In [4]:
fun(math.tensor(1.))
Out[4]:
2.0

Note that the print statement was not executed since fun was not actually called. If we call the function with different shapes or dtypes, it will be traced again.

In [5]:
fun(math.tensor([1, 2]))
Tracing fun with x = (vectorᶜ=2) int64 jax tracer
Out[5]:
(2, 4) int64

NumPy Operations¶

All NumPy operations are performed at JIT-compile time and will not be executed once the function is compiled, similar to the print statement. NumPy-backed tensors always have concrete values and can be used in if statements as well as loop conditions.

In [6]:
@math.jit_compile
def fun(x):
    print(f"Tracing fun with x = {x}")
    y = math.wrap(2)
    z = math.sin(y ** 2)
    print(f"z = {z}")
    if z > 1:
        return z * x
    else:
        return z / x

fun(math.tensor(1.))
Tracing fun with x = () float32 jax tracer
z = float64 -0.7568024953079282
Out[6]:
-0.7568025

Here, the control flow can depend on z since it is a NumPy array.

Auxiliary Arguments¶

If we want the control flow to depend on a parameter, we must pass it as an auxiliary argument.

In [7]:
@math.jit_compile(auxiliary_args='y')
def fun(x, y):
    print(f"Tracing fun with x = {x}, y = {y}")
    z = math.sin(y ** 2)
    print(f"z = {z}")
    if (z > 1).all:
        return z * x
    else:
        return z / x

fun(math.tensor(1.), math.wrap(2))
Tracing fun with x = () float32 jax tracer, y = 2
z = float64 -0.7568024953079282
Out[7]:
-0.7568025

The function always needs to be re-traced if an auxiliary argument changes in any way.

You can check whether a function would have to be traced using [`math.trace_check()](phiml/math#phiml.math.trace_check).

In [8]:
math.trace_check(fun, math.tensor(1.), math.wrap(2))
Out[8]:
(True, '')
In [9]:
math.trace_check(fun, math.tensor(1.), math.wrap(-1))
Out[9]:
(False, 'Auxiliary arguments do not match')

Further Reading¶

🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples