%%capture
!pip install phiml
from phiml import math
math.use('jax')
In ΦML, you can JIT-compile a function using the math.jit_compile()
decorator.
@math.jit_compile
def fun(x):
print(f"Tracing fun with x = {x}")
return 2 * x
The first time the function is called with new arguments, it is traced, i.e. all tensor operations are recorded.
Then, the passed arguments have concrete shapes but no concrete values.
Consequently, traced tensors cannot be used in control flow, such as if
or loop conditions.
Replace if
statements by math.where()
.
Depending on the used backend, the function may be called multiple times during tracing.
fun(math.tensor(1.))
Tracing fun with x = () float32 jax tracer
2.0
Whenever the function is called with similar arguments to a previous call, the compiled version of the function is evaluated without calling the Python code. Instead, the previously recorded tensor operations are performed again on the new input.
fun(math.tensor(1.))
2.0
Note that the print
statement was not executed since fun
was not actually called.
If we call the function with different shapes or dtypes, it will be traced again.
fun(math.tensor([1, 2]))
Tracing fun with x = (vectorᶜ=2) int64 jax tracer
(2, 4) int64
All NumPy operations are performed at JIT-compile time and will not be executed once the function is compiled, similar to the print
statement.
NumPy-backed tensors always have concrete values and can be used in if
statements as well as loop conditions.
@math.jit_compile
def fun(x):
print(f"Tracing fun with x = {x}")
y = math.wrap(2)
z = math.sin(y ** 2)
print(f"z = {z}")
if z > 1:
return z * x
else:
return z / x
fun(math.tensor(1.))
Tracing fun with x = () float32 jax tracer z = float64 -0.7568025
-0.7568025
Here, the control flow can depend on z
since it is a NumPy array.
If we want the control flow to depend on a parameter, we must pass it as an auxiliary argument.
@math.jit_compile(auxiliary_args='y')
def fun(x, y):
print(f"Tracing fun with x = {x}, y = {y}")
z = math.sin(y ** 2)
print(f"z = {z}")
if (z > 1).all:
return z * x
else:
return z / x
fun(math.tensor(1.), math.wrap(2))
Tracing fun with x = () float32 jax tracer, y = 2 z = float64 -0.7568025
-0.7568025
The function always needs to be re-traced if an auxiliary argument changes in any way.
You can check whether a function would have to be traced using [`math.trace_check()](phiml/math#phiml.math.trace_check).
math.trace_check(fun, math.tensor(1.), math.wrap(2))
(True, '')
math.trace_check(fun, math.tensor(1.), math.wrap(-1))
(False, 'Auxiliary arguments do not match')