Why ΦML has Precision Management¶

Having control over the floating-point (FP) precision is essential for many scientific applications. For example, some linear systems of equations are only solvable with FP64, even if the desired tolerance lies within FP32 territory. To accommodate these requirements, ΦML provides custom precision management tools that differ from the common machine learning libraries.

In [1]:
%%capture
import numpy as np
import torch
import tensorflow as tf
import jax
from jax import numpy as jnp
from phiml import math
2025-05-06 16:13:53.335830: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-06 16:13:53.374269: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-06 16:13:54.234235: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Precision in ML libraries¶

First, let's look at the behavior of the backend libraries that ΦML supports.

Tensor creation: Consider creating a float tensor from primitive floats. Can you guess what the data type will be for tensor(1.) (or the analogous operations) in NumPy, PyTorch, TensorFlow and Jax?

In [2]:
print(f"NumPy: {np.array(1.).dtype}\nPyTorch: {torch.tensor(1.).dtype}\nTensorFlow: {tf.constant(1.).dtype}\nJax: {jnp.array(1.).dtype}")
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
NumPy: float64
PyTorch: torch.float32
TensorFlow: <dtype: 'float32'>
Jax: float32

IF you guessed float64 for NumPy, float32 for PyTorch and TensorFlow, and depends on the configuration for Jax, you are correct! Yes, Jax disables FP64 by default! Let's repeat that with FP64 enabled.

In [3]:
jax.config.update("jax_enable_x64", True)
print(f"Jax: {jnp.array(1.).dtype}")
Jax: float64

Now, Jax behaves like NumPy! Or does it...?

Combining different precisions: What do you think will happen in each of the base libraries if we sum a FP64 and FP32 tensor? Let's try it!

In [4]:
(np.array(1., dtype=np.float32) + np.array(1., dtype=np.float64)).dtype  # NumPy
Out[4]:
dtype('float64')
In [5]:
(torch.tensor(1., dtype=torch.float32) + torch.tensor(1., dtype=torch.float64)).dtype  # PyTorch
Out[5]:
torch.float64

NumPy and PyTorch automatically upgrade to the highest precision. However, unlike NumPy, PyTorch does not upgrade its dtype when adding a primitive float.

In [6]:
(np.array(1., dtype=np.float32) + 1.).dtype  # NumPy
Out[6]:
dtype('float64')
In [7]:
(torch.tensor(1., dtype=torch.float32) + 1.).dtype
Out[7]:
torch.float32

Let's look at TensorFlow and Jax next.

In [8]:
try:
    (tf.constant(1., dtype=tf.float32) + tf.constant(1., dtype=tf.float64)).dtype  # TensorFlow
except tf.errors.InvalidArgumentError as err:
    print(err)
cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a double tensor [Op:AddV2] name: 

TensorFlow outright refuses to mix different precisions and requires manual casting. This is not the case when passing a primitive float which is also FP64. Here, TensorFlow keeps the tensor dtype.

In [9]:
(tf.constant(1., dtype=tf.float32) + 1.).dtype  # TensorFlow
Out[9]:
tf.float32

At first glance, Jax seems to upgrade the different precisions like NumPy.

In [10]:
(jnp.array(1., dtype=jnp.float32) + jnp.array(1., dtype=jnp.float64)).dtype  # Jax
Out[10]:
dtype('float64')

Let's modify the expression a bit.

In [11]:
t64 = jnp.array(1.)
print(t64.dtype)
(jnp.array(1., dtype=jnp.float32) + t64).dtype
float64
Out[11]:
dtype('float32')

Here we also add a float64 to a float32 tensor but the result now is float32. Jax remembers that we did not explicitly specify the type of the t64 tensor and treats it differently.

Also, Jax does not upgrade the precision when adding a float.

In [12]:
(jnp.array(1., dtype=jnp.float32) + 1.).dtype  # Jax
Out[12]:
dtype('float32')

Converting integer tensors: Let's look at the behavior when combining a float32 and an int tensor in the different libraries. Can you guess what the result type will be?

In [13]:
(np.array(1., dtype=np.float32) + np.array(1)).dtype  # NumPy
Out[13]:
dtype('float64')
In [14]:
(torch.tensor(1., dtype=torch.float32) + torch.tensor(1)).dtype  # PyTorch
Out[14]:
torch.float32
In [15]:
try:
    (tf.constant(1., dtype=tf.float32) + tf.constant(1)).dtype  # TensorFlow
except tf.errors.InvalidArgumentError as err:
    print(err)
cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a int32 tensor [Op:AddV2] name: 
In [16]:
(jnp.array(1., dtype=jnp.float32) + jnp.array(1)).dtype  # Jax
Out[16]:
dtype('float32')

We see that NumPy upgrades to 64 bit while PyTorch and Jax keep 32. Like before, TensorFlow refuses to combine different types. When adding a primitive int instead, TensorFlow can perform the operation, however.

In [17]:
(tf.constant(1., dtype=tf.float32) + 1).dtype
Out[17]:
tf.float32

Observations¶

We have seen that there is no consistent type handling between the four libraries. In fact no two libraries behave the same.

  • NumPy defaults to float64 and upgrades when combining tensors and primitives, including int.
  • PyTorch defaults to float32 and upgrades only for float tensors, not primitives or integer tensors.
  • Jax defaults to the precision specified by its configuration and uses involved upgrading rules that take into account whether the initial precision was set or inferred.
  • TensorFlow defaults to float32 but requires all tensors to have the same precision, except for Python primitives.
Library f32+f64 f32 + primitive f64 f32+i32
NumPy f64 f64 f64
PyTorch f64 f32 f32
TensorFlow Error f32 Error
Jax Depends f32 f32

These inconsistencies indicate that there is not one obvious correct way to handle precision with the data type system these libraries employ, i.e. where the output dtype is determined solely by the input types.

Precision Management in ΦML¶

In ΦML the operation / output precision is independent of the inputs. Instead, it can be set globally or by context. The default precision is FP32.

Tensor creation: Let's create a tensor like above. Can you guess the resulting dtype?

In [18]:
math.tensor(1.).dtype
Out[18]:
float32

Since we have not changed the precision, ΦML creates an FP32 tensor.

Combining different precisions: Can you guess what will happen if we add a float32 and float64 tensor?

In [19]:
(math.ones(dtype=(float, 32)) + math.ones(dtype=(float, 64))).dtype
Out[19]:
float32

The precision is still set to float32 so that's what we get. Of course this also applies to adding Python primitives or int tensors.

In [20]:
(math.ones(dtype=(float, 32)) + math.ones(dtype=int)).dtype
Out[20]:
float32

If we want to use FP64, we can either set the global precision or execute the code within a precision context. The following line sets the global precision to 64 bit.

In [21]:
math.set_global_precision(64)

Executing the above cells now yields float64 in all cases. Likewise, the precision can be set to 16 bit. In that case we get float16 even when adding a float32 and float64 tensor.

As you can see, this system is much simpler and more predictable than the alternatives. It also makes writing code much easier. Upgrading a script that was written for FP32 to FP64 is as simple as setting the global precision, and executing parts of your code with a different precision is as simple as embedding it into a precision block (see example below).

An Example of Mixing Precisions¶

Let's look at a simple application where we want to run operations with both FP32 and PF64, specifically iterate the map 35 (1-cos(x))^2. The operation 1-cos is much more sensitive to rounding errors than multiplication, so we wish to compute it using FP64. The expected values after 5 iterations are: 0.2659 (FP64), 0.2663 (FP32), 0.2657 (mixed).

Here's the ΦML code. We use a precision context to execute the inner part with FP64.

In [22]:
math.set_global_precision(32)  # reset precision to 32 bit
x = math.tensor(.5)
for i in range(5):
    with math.precision(64):
        x = 1 - math.cos(x)
    x = x ** 2 * 35
x
Out[22]:
float64 0.26586727874017735

Next, let's implement this using PyTorch. Here we need to manually convert x between FP32 and PF64.

In [23]:
x = torch.tensor(.5)
for i in range(5):
    x = x.double()
    x = 1 - torch.cos(x)
    x = x.float()
    x = x ** 2 * 35
x
Out[23]:
tensor(0.2657)

These conversions seem relatively tame here, but imagine we had a bunch of variables to keep track of! Making sure they all have the correct precision can be a time sink, especially when one variable with a too-high precision can upgrade all following intermediate results. The danger of this going unnoticed is why TensorFlow and Jax have taken the extreme measures of banning operations with mixed inputs and disabling FP64 by default, respectively.

Further Reading¶

Data types in ΦML

🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples