Having control over the floating-point (FP) precision is essential for many scientific applications. For example, some linear systems of equations are only solvable with FP64, even if the desired tolerance lies within FP32 territory. To accommodate these requirements, ΦML provides custom precision management tools that differ from the common machine learning libraries.
%%capture
import numpy as np
import torch
import tensorflow as tf
import jax
from jax import numpy as jnp
from phiml import math
2024-04-15 12:15:43.980061: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used. 2024-04-15 12:15:44.018580: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-04-15 12:15:44.774381: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
First, let's look at the behavior of the backend libraries that ΦML supports.
Tensor creation:
Consider creating a float tensor from primitive floats. Can you guess what the data type will be for tensor(1.)
(or the analogous operations) in NumPy, PyTorch, TensorFlow and Jax?
print(f"NumPy: {np.array(1.).dtype}\nPyTorch: {torch.tensor(1.).dtype}\nTensorFlow: {tf.constant(1.).dtype}\nJax: {jnp.array(1.).dtype}")
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
NumPy: float64 PyTorch: torch.float32 TensorFlow: <dtype: 'float32'> Jax: float32
IF you guessed float64
for NumPy, float32
for PyTorch and TensorFlow, and depends on the configuration for Jax, you are correct!
Yes, Jax disables FP64 by default! Let's repeat that with FP64 enabled.
jax.config.update("jax_enable_x64", True)
print(f"Jax: {jnp.array(1.).dtype}")
Jax: float64
Now, Jax behaves like NumPy! Or does it...?
Combining different precisions: What do you think will happen in each of the base libraries if we sum a FP64 and FP32 tensor? Let's try it!
(np.array(1., dtype=np.float32) + np.array(1., dtype=np.float64)).dtype # NumPy
dtype('float64')
(torch.tensor(1., dtype=torch.float32) + torch.tensor(1., dtype=torch.float64)).dtype # PyTorch
torch.float64
NumPy and PyTorch automatically upgrade to the highest precision.
However, unlike NumPy, PyTorch does not upgrade its dtype
when adding a primitive float
.
(np.array(1., dtype=np.float32) + 1.).dtype # NumPy
dtype('float64')
(torch.tensor(1., dtype=torch.float32) + 1.).dtype
torch.float32
Let's look at TensorFlow and Jax next.
try:
(tf.constant(1., dtype=tf.float32) + tf.constant(1., dtype=tf.float64)).dtype # TensorFlow
except tf.errors.InvalidArgumentError as err:
print(err)
cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a double tensor [Op:AddV2] name:
TensorFlow outright refuses to mix different precisions and requires manual casting.
This is not the case when passing a primitive float
which is also FP64. Here, TensorFlow keeps the tensor dtype
.
(tf.constant(1., dtype=tf.float32) + 1.).dtype # TensorFlow
tf.float32
At first glance, Jax seems to upgrade the different precisions like NumPy.
(jnp.array(1., dtype=jnp.float32) + jnp.array(1., dtype=jnp.float64)).dtype # Jax
dtype('float64')
Let's modify the expression a bit.
t64 = jnp.array(1.)
print(t64.dtype)
(jnp.array(1., dtype=jnp.float32) + t64).dtype
float64
dtype('float32')
Here we also add a float64
to a float32
tensor but the result now is float32
.
Jax remembers that we did not explicitly specify the type of the t64
tensor and treats it differently.
Also, Jax does not upgrade the precision when adding a float
.
(jnp.array(1., dtype=jnp.float32) + 1.).dtype # Jax
dtype('float32')
Converting integer tensors:
Let's look at the behavior when combining a float32
and an int
tensor in the different libraries. Can you guess what the result type will be?
(np.array(1., dtype=np.float32) + np.array(1)).dtype # NumPy
dtype('float64')
(torch.tensor(1., dtype=torch.float32) + torch.tensor(1)).dtype # PyTorch
torch.float32
try:
(tf.constant(1., dtype=tf.float32) + tf.constant(1)).dtype # TensorFlow
except tf.errors.InvalidArgumentError as err:
print(err)
cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a int32 tensor [Op:AddV2] name:
(jnp.array(1., dtype=jnp.float32) + jnp.array(1)).dtype # Jax
dtype('float32')
We see that NumPy upgrades to 64 bit while PyTorch and Jax keep 32. Like before, TensorFlow refuses to combine different types.
When adding a primitive int
instead, TensorFlow can perform the operation, however.
(tf.constant(1., dtype=tf.float32) + 1).dtype
tf.float32
We have seen that there is no consistent type handling between the four libraries. In fact no two libraries behave the same.
float64
and upgrades when combining tensors and primitives, including int
.float32
and upgrades only for float tensors, not primitives or integer tensors.float32
but requires all tensors to have the same precision, except for Python primitives.Library | f32+f64 |
f32 + primitive f64 |
f32+i32 |
---|---|---|---|
NumPy | f64 |
f64 |
f64 |
PyTorch | f64 |
f32 |
f32 |
TensorFlow | Error | f32 |
Error |
Jax | Depends | f32 |
f32 |
These inconsistencies indicate that there is not one obvious correct way to handle precision with the data type system these libraries employ, i.e. where the output dtype
is determined solely by the input types.
In ΦML the operation / output precision is independent of the inputs. Instead, it can be set globally or by context. The default precision is FP32.
Tensor creation: Let's create a tensor like above. Can you guess the resulting dtype
?
math.tensor(1.).dtype
float32
Since we have not changed the precision, ΦML creates an FP32 tensor.
Combining different precisions:
Can you guess what will happen if we add a float32
and float64
tensor?
(math.ones(dtype=(float, 32)) + math.ones(dtype=(float, 64))).dtype
float32
The precision is still set to float32
so that's what we get.
Of course this also applies to adding Python primitives or int
tensors.
(math.ones(dtype=(float, 32)) + math.ones(dtype=int)).dtype
float32
If we want to use FP64, we can either set the global precision or execute the code within a precision context. The following line sets the global precision to 64 bit.
math.set_global_precision(64)
Executing the above cells now yields float64
in all cases.
Likewise, the precision can be set to 16 bit. In that case we get float16
even when adding a float32
and float64
tensor.
As you can see, this system is much simpler and more predictable than the alternatives. It also makes writing code much easier. Upgrading a script that was written for FP32 to FP64 is as simple as setting the global precision, and executing parts of your code with a different precision is as simple as embedding it into a precision block (see example below).
Let's look at a simple application where we want to run operations with both FP32 and PF64, specifically iterate the map 35 (1-cos(x))^2
. The operation 1-cos
is much more sensitive to rounding errors than multiplication, so we wish to compute it using FP64.
The expected values after 5 iterations are: 0.2659 (FP64), 0.2663 (FP32), 0.2657 (mixed).
Here's the ΦML code. We use a precision
context to execute the inner part with FP64.
math.set_global_precision(32) # reset precision to 32 bit
x = math.tensor(.5)
for i in range(5):
with math.precision(64):
x = 1 - math.cos(x)
x = x ** 2 * 35
x
0.265725
Next, let's implement this using PyTorch. Here we need to manually convert x
between FP32 and PF64.
x = torch.tensor(.5)
for i in range(5):
x = x.double()
x = 1 - torch.cos(x)
x = x.float()
x = x ** 2 * 35
x
tensor(0.2657)
These conversions seem relatively tame here, but imagine we had a bunch of variables to keep track of! Making sure they all have the correct precision can be a time sink, especially when one variable with a too-high precision can upgrade all following intermediate results. The danger of this going unnoticed is why TensorFlow and Jax have taken the extreme measures of banning operations with mixed inputs and disabling FP64 by default, respectively.