Linear Solves in ΦML¶

Colab   •   🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples

Linear solves are a vital part in many simulations and machine learning applications. ΦML provides an easy-to-use interface for performing linear solves that supports backpropagation via implicit gradients. Dense, sparse, and matrix-free linear systems can be solved this way.

In [1]:
%%capture
!pip install phiml
In [2]:
from phiml import math
from phiml.math import wrap, channel, dual, spatial, Solve, tensor

Linear solves and sparse matrices are supported on all backends. Feel free to choose the below line to use jax, tensorflow or numpy instead.

In [3]:
math.use('torch')
Out[3]:
torch

We can perform a linear solve by passing a matrix A, right-hand-side vector b and initial guess x0 to solve_linear().

We recommend passing ΦML tensors. Then, the dual dimensions of the matrix must match the initial guess and the primal dimensions must match the right-hand-side.

Alternatively, solve_linear() can be used called with native tensors (see below).

In [4]:
A = tensor([[0, 1], [1, 0]], channel('b_vec'), dual('x_vec'))
b = tensor([2, 3], channel('b_vec'))
x0 = tensor([0, 0], channel('x_vec'))
math.solve_linear(A, b, Solve(x0=x0))
Out[4]:
(3.000, 2.000) along x_vecᶜ

ΦML implements multiple algorithms to solve linear systems, such as the conjugate gradient method (CG) and the stabilized bi-conjugate gradient method (biCG). All SciPy solvers are also available. For a full list, see here.

In [5]:
math.solve_linear(A, b, Solve('CG', x0=x0))
Out[5]:
(3.000, 2.000) along x_vecᶜ
In [6]:
math.solve_linear(A, b, Solve('biCG-stab', x0=x0))
Out[6]:
(3.000, 2.000) along x_vecᶜ
In [7]:
math.solve_linear(A, b, Solve('scipy-GMres', x0=x0))
Out[7]:
(3.000, 2.000) along x_vecᶜ

Matrix-free Solves¶

Instead of passing a matrix, you can also specify a linear Python function that computes the matrix-vector product. This will typically be slower unless the function is compiled to a matrix.

In [8]:
def linear_function(x):
    return x * (2, 1)

math.solve_linear(linear_function, b, Solve(x0=x0))
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[8], line 4
      1 def linear_function(x):
      2     return x * (2, 1)
      3 
----> 4 math.solve_linear(linear_function, b, Solve(x0=x0))

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/math/_optimize.py:732, in solve_linear(f, y, solve, grad_for_f, f_kwargs, *f_args, **f_kwargs_)
    729     return result  # must return exactly `x` so gradient isn't computed w.r.t. other quantities
    731 _function_solve = attach_gradient_solve(_function_solve_forward, auxiliary_args='is_backprop,f_kwargs,solve', matrix_adjoint=grad_for_f)
--> 732 return _function_solve(y, solve, f_args, f_kwargs=f_kwargs)

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/math/_functional.py:972, in CustomGradientFunction.__call__(self, *args, **kwargs)
    967             if len(self.traces) >= 8:
    968                 warnings.warn(f"""{self.__name__} has been traced {len(self.traces)} times.
    969 To avoid memory leaks, call {f_name(self.f)}.traces.clear(), {f_name(self.f)}.recorded_mappings.clear().
    970 Traces can be avoided by jit-compiling the code that calls custom gradient functions.
    971 """, RuntimeWarning, stacklevel=2)
--> 972         native_result = self.traces[key](*natives)  # With PyTorch + jit, this does not call forward_native every time
    973         output_key = match_output_signature(key, self.recorded_mappings, self)
    974         output_tensors = assemble_tensors(native_result, output_key.specs)

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/torch/_torch_backend.py:262, in TorchBackend.custom_gradient.<locals>.select_jit(*args)
    260 args = [self.as_tensor(arg) for arg in args]
    261 if not CURRENT_JIT_CALLS:
--> 262     return torch_function.apply(*args)
    263 jit = CURRENT_JIT_CALLS[-1]
    264 if torch._C._get_tracing_state() is None:  # first call: record this function

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/torch/autograd/function.py:596, in Function.apply(cls, *args, **kwargs)
    593 if not torch._C._are_functorch_transforms_active():
    594     # See NOTE: [functorch vjp and autograd interaction]
    595     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 596     return super().apply(*args, **kwargs)  # type: ignore[misc]
    598 if not is_setup_ctx_defined:
    599     raise RuntimeError(
    600         "In order to use an autograd.Function with functorch transforms "
    601         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    602         "staticmethod. For more details, please see "
    603         "https://pytorch.org/docs/main/notes/extending.func.html"
    604     )

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/torch/_torch_backend.py:1245, in construct_torch_custom_function.<locals>.TorchCustomFunction.forward(ctx, *args, **kwargs)
   1243     return f_example_output
   1244 ML_LOGGER.debug(f"TorchScript -> run compiled {f.__name__} with args {[(tuple(a.shape), a.requires_grad) for a in args]}")
-> 1245 y = (jit_f or f)(*args, **kwargs)
   1246 ctx.save_for_backward(*args, *y)
   1247 ctx.input_count = len(args)

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/math/_functional.py:928, in CustomGradientFunction._trace.<locals>.forward_native(*natives)
    926 kwargs = assemble_tree(in_key.tree, in_tensors, attr_type=variable_attributes)
    927 ML_LOGGER.debug(f"Running forward pass of custom op {forward_native.__name__} given args {tuple(kwargs.keys())} containing {len(natives)} native tensors")
--> 928 result = self.f(**kwargs, **in_key.auxiliary_kwargs)  # Tensor or tuple/list of Tensors
    929 nest, out_tensors = disassemble_tree(result, cache=True, attr_type=variable_attributes)
    930 result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/math/_optimize.py:728, in solve_linear.<locals>._function_solve_forward(y, solve, f_args, f_kwargs, is_backprop)
    725             y_native = y_native[batch_index]
    726     return y_native
--> 728 result = _linear_solve_forward(y, solve, native_lin_f, pattern_dims_in=non_batch(x0_tensor).names, pattern_dims_out=non_batch(y_tensor).names, preconditioner=None, backend=backend, is_backprop=is_backprop)
    729 return result

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/math/_optimize.py:791, in _linear_solve_forward(y, solve, native_lin_op, pattern_dims_in, pattern_dims_out, preconditioner, backend, is_backprop, reduce_x, reduce_y, expand_x, expand_y)
    789     method = 'scipy-' + method
    790 t = time.perf_counter()
--> 791 ret = backend.linear_solve(method, native_lin_op, y_native, x0_native, rtol, atol, max_iter, preconditioner, matrix_offset)
    792 t = time.perf_counter() - t
    793 trj_dims = [batch(trajectory=len(max_iter))] if trj else []

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_backend.py:1500, in Backend.linear_solve(self, method, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset)
   1498 assert callable(lin) or isinstance(lin, tuple) or self.is_tensor(lin, only_native=True)
   1499 if method == 'auto':
-> 1500     return self.conjugate_gradient_adaptive(lin, y, x0, rtol, atol, max_iter, pre, matrix_offset)
   1501 elif method.startswith('scipy-'):
   1502     from ._linalg import scipy_sparse_solve

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/torch/_torch_backend.py:954, in TorchBackend.conjugate_gradient_adaptive(self, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset)
    952     if not self.is_available(y):
    953         warnings.warn(f"CG with preconditioners is not optimized for PyTorch and will always run the maximum number of iterations when JIT-compiled (max_iter={max_iter}).", RuntimeWarning)
--> 954     return Backend.conjugate_gradient_adaptive(self, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset)
    955 assert isinstance(lin, torch.Tensor), "Batched matrices are not yet supported"
    956 batch_size = self.staticshape(y)[0]

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_backend.py:1527, in Backend.conjugate_gradient_adaptive(self, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset)
   1525 """ Conjugate gradient algorithm with adaptive step size. Signature matches to `Backend.linear_solve()`. """
   1526 from ._linalg import cg_adaptive
-> 1527 return cg_adaptive(self, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset)

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_linalg.py:104, in cg_adaptive(b, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset)
    102 batch_size = b.staticshape(y)[0]
    103 x = x0
--> 104 dx = residual = y - linear(b, lin, x, matrix_offset)
    105 dy = linear(b, lin, dx, matrix_offset)
    106 iterations = b.zeros([batch_size], INT32)

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_linalg.py:841, in linear(b, lin, vector, matrix_offset, get_without_offset)
    839 def linear(b: Backend, lin, vector, matrix_offset, get_without_offset=False):
    840     """Apply linear function with matrix offset to vector, i.e. `(lin+matrix_offset) @ vector`"""
--> 841     result = result_wo_offset = b.linear(lin, vector)
    842     if matrix_offset is not None:
    843         result += b.sum(vector, 1, keepdims=True) * matrix_offset[:, None]

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_backend.py:1536, in Backend.linear(self, lin, vector)
   1534 def linear(self, lin, vector):
   1535     if callable(lin):
-> 1536         return lin(vector)
   1537     elif isinstance(lin, (tuple, list)):
   1538         for lin_i in lin:

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/math/_optimize.py:719, in solve_linear.<locals>._function_solve_forward.<locals>.native_lin_f(native_x, batch_index, is_trajectory)
    717 y_ = f(x, *f_args, **f_kwargs)
    718 _, (y_tensor_,) = disassemble_tree(y_, cache=False, attr_type=value_attributes)
--> 719 assert set(non_batch(y_tensor_)) == set(non_batch(y_tensor)), f"Function returned dimensions {y_tensor_.shape} but right-hand-side has shape {y_tensor.shape}"
    720 if is_trajectory:
    721     y_native = y_tensor_.native([non_batch(y_tensor), batch(x_tensor)])

AssertionError: Function returned dimensions (x_vecᶜ=2) but right-hand-side has shape (b_vecᶜ=2)

Explicit Matrices from Python Functions¶

ΦML can also build an explicit matrix representation of the provided Python function. You can do this either by explicitly obtaining the matrix first using matrix_from_function or by annotating the linear function with jit_compile_linear. If the function adds a constant offset to the output, this will automatically be subtracted from the right-hand-side vector.

In [9]:
from phiml.math import jit_compile_linear

math.solve_linear(jit_compile_linear(linear_function), b, Solve(x0=x0))
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/torch/_torch_backend.py:875: UserWarning: Sparse invariant checks are implicitly disabled. Memory errors (e.g. SEGFAULT) will occur when operating on a sparse tensor which violates the invariants, but checks incur performance overhead. To silence this warning, explicitly opt in or out. See `torch.sparse.check_sparse_tensor_invariants.__doc__` for guidance.  (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:760.)
  return torch.sparse_csr_tensor(row_pointers, column_indices, values, shape, device=values.device)
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/torch/_torch_backend.py:875: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at /pytorch/aten/src/ATen/SparseCsrTensorImpl.cpp:49.)
  return torch.sparse_csr_tensor(row_pointers, column_indices, values, shape, device=values.device)
Out[9]:
(1.000, 2.000, 1.500, 3.000) (b_vecᶜ=2, x_vecᶜ=2)

Preconditioned Linear Solves¶

ΦML includes an ILU and experimental cluster preconditioner. To use a preconditioner, simply specify preconditioner='ilu' when creating the Solve object.

In [10]:
math.solve_linear(jit_compile_linear(linear_function), b, Solve('scipy-CG', x0=x0, preconditioner='ilu'))
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_numpy_backend.py:580: SparseEfficiencyWarning: CSC or CSR matrix format is required. Converting to CSC matrix.
  return spsolve_triangular(matrix, rhs.T, lower=lower, unit_diagonal=unit_diagonal).T
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_numpy_backend.py:580: SparseEfficiencyWarning: CSC or CSR matrix format is required. Converting to CSC matrix.
  return spsolve_triangular(matrix, rhs.T, lower=lower, unit_diagonal=unit_diagonal).T
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_numpy_backend.py:580: SparseEfficiencyWarning: CSC or CSR matrix format is required. Converting to CSC matrix.
  return spsolve_triangular(matrix, rhs.T, lower=lower, unit_diagonal=unit_diagonal).T
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_numpy_backend.py:580: SparseEfficiencyWarning: CSC or CSR matrix format is required. Converting to CSC matrix.
  return spsolve_triangular(matrix, rhs.T, lower=lower, unit_diagonal=unit_diagonal).T
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_numpy_backend.py:580: SparseEfficiencyWarning: CSC or CSR matrix format is required. Converting to CSC matrix.
  return spsolve_triangular(matrix, rhs.T, lower=lower, unit_diagonal=unit_diagonal).T
Out[10]:
(1.000, 2.000, 1.500, 3.000) (b_vecᶜ=2, x_vecᶜ=2)

The ILU preconditioner always runs on the CPU and should be paired with a SciPy linear solver for optimal efficiency. Available SciPy solvers include 'scipy-direct', 'scipy-CG', 'scipy-GMres', 'scipy-biCG', 'scipy-biCG-stab', 'scipy-CGS', 'scipy-QMR', 'scipy-GCrotMK' (see the API).

If the matrix or linear function is constant, i.e. only depends on NumPy arrays, the preconditioner computation can be performed during JIT compilation.

Matrices.html#Building-Matrices-from-Linear-Functions)

In [11]:
@math.jit_compile
def jit_perform_solve(b):
    return math.solve_linear(jit_compile_linear(linear_function), b, Solve('scipy-CG', x0=x0, preconditioner='ilu'))

jit_perform_solve(b)
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/math/_optimize.py:785: RuntimeWarning: Preconditioners are not supported for sparse scipy-CG in torch JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode.
  warnings.warn(f"Preconditioners are not supported for sparse {method} in {y.default_backend} JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode.", RuntimeWarning)
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/math/_optimize.py:785: RuntimeWarning: Preconditioners are not supported for sparse scipy-CG in torch JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode.
  warnings.warn(f"Preconditioners are not supported for sparse {method} in {y.default_backend} JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode.", RuntimeWarning)
Out[11]:
(1.000, 2.000, 1.500, 3.000) (b_vecᶜ=2, x_vecᶜ=2)

Here, the ILU preconditioner is computed during JIT-compile time since the linear function does not depend on b.

Implicit Differentiation¶

ΦML enables backpropagation through linear solves. Instead of backpropagating through the unrolled loop (which can lead to inaccurate results and cause high memory consumption), Unify runs an ajoint linear solve for the pullback operation.

In [12]:
def loss_function(b):
    x = math.solve_linear(jit_compile_linear(linear_function), b, Solve(x0=x0))
    return math.l2_loss(x)

gradient_function = math.gradient(loss_function, 'b', get_output=False)
gradient_function(b)
Out[12]:
(2.500, 3.750) along b_vecᶜ

Matrix Gradients¶

ΦML can also compute gradients for the (sparse) matrix used in a linear solve, which allows differentiating w.r.t. parameters that influenced the matrix values via backpropagation. To enable this, pass grad_for_f=True to the solve_linear() call.

In [13]:
@math.jit_compile_linear
def conditioned_linear_function(x, conditioning):
    return x * conditioning

def loss_function(conditioning):
    b = math.ones_like(conditioning)
    x = math.solve_linear(conditioned_linear_function, b, Solve(x0=x0), conditioning=conditioning, grad_for_f=True)
    return math.l2_loss(x)

gradient_function = math.gradient(loss_function, 'conditioning', get_output=False)
gradient_function(tensor([1., 2.], channel('x_vec')))
Out[13]:
(-1.000, -0.125) along x_vecᶜ

Handling Failed Optimizations¶

When a linear solve (or minimize call) does not find a solution, a subclass of ConvergenceException is thrown, depending on the reason.

  • If the maximum number of iterations was reached, NotConverged is thrown.
  • If the solve diverged or failed prematurely, Diverged is thrown.

These exceptions can also be thrown during backpropagation if the adjoint solve fails (except for TensorFlow).

You can deal with failed solves using Python's try-except clause.

In [14]:
try:
    solution = math.solve_linear(lambda x: 0 * x, wrap([1, 2, 3], spatial('x')), solve=math.Solve(x0=math.zeros(spatial(x=3))))
    print("converged", solution)
except math.ConvergenceException as exc:
    print(exc)
    print(f"Last estimate: {exc.result.x}")
Φ-ML CG-adaptive (torch)  did not converge to rel_tol=1e-05, abs_tol=1e-05 within 1000 iterations. Max residual: 3.0
Last estimate: (0.000, 0.000, 0.000) along xˢ

If you want the regular execution flow to continue despite non-convergence, you can pass suppress=[math.Diverged, math.NotConverged] to the Solve constructor.

Obtaining Additional Information about a Solve¶

All solves are logged internally and can be shown by setting phiml.set_logging_level('debug').

Additional solve properties can be recorded using a SolveTape. Recording the full optimization trajectory requires setting record_trajectories=True.

In [15]:
import phiml
phiml.set_logging_level('debug')

with math.SolveTape() as solves:
    math.solve_linear(jit_compile_linear(linear_function), b, Solve('scipy-CG', x0=0 * b, preconditioner='ilu'))
factor_ilu: auto-selecting iterations=1 (eager mode) for matrix (2, 0); (0, 1) (b_vecᶜ=2, ~b_vecᵈ=2) int64 (DEBUG), 2026-05-10 11:15:19,393n

TorchScript -> run compiled forward '_matrix_solve_forward' with args [((2,), False)] (DEBUG), 2026-05-10 11:15:19,396n

Running forward pass of custom op forward '_matrix_solve_forward' given args ('y',) containing 1 native tensors (DEBUG), 2026-05-10 11:15:19,396n

Performing linear solve scipy-CG with tolerance float64 1e-05 (rel), float64 1e-05 (abs), max_iterations=1000 with backend torch (DEBUG), 2026-05-10 11:15:19,398n

/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_numpy_backend.py:580: SparseEfficiencyWarning: CSC or CSR matrix format is required. Converting to CSC matrix.
  return spsolve_triangular(matrix, rhs.T, lower=lower, unit_diagonal=unit_diagonal).T
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_numpy_backend.py:580: SparseEfficiencyWarning: CSC or CSR matrix format is required. Converting to CSC matrix.
  return spsolve_triangular(matrix, rhs.T, lower=lower, unit_diagonal=unit_diagonal).T
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/phiml/backend/_numpy_backend.py:580: SparseEfficiencyWarning: CSC or CSR matrix format is required. Converting to CSC matrix.
  return spsolve_triangular(matrix, rhs.T, lower=lower, unit_diagonal=unit_diagonal).T

The solve information about a performed solve(s) can then be obtained by indexing specific solves by index or Solve object.

In [16]:
print(solves[0].solve)
print("Solution", solves[0].x)
print("Residual", solves[0].residual)
print("Fun.evals", solves[0].function_evaluations)
print("Iterations", solves[0].iterations)
print("Diverged", solves[0].diverged)
print("Converged", solves[0].converged)
scipy-CG with tolerance float64 1e-05 (rel), float64 1e-05 (abs), max_iterations=1000
Solution (1.000, 3.000) along b_vecᶜ
Residual (0.000, 0.000) along b_vecᶜ
Fun.evals 2
Iterations 1
Diverged False
Converged True

Linear Solves with Native Tensors¶

When performing a linear solve without ΦML tensors, the matrix must have shape (..., N, N) and x0 and b must have shape (..., N) where ... denotes the batch dimensions. This matches the signatures of the native solve functions like torch.linalg.solve or jax.numpy.linalg.solve.

In [17]:
import torch

A = torch.tensor([[0., 1], [1, 0]])
b = torch.tensor([2., 3])
x0 = torch.tensor([0., 0])
math.solve_linear(A, b, Solve(x0=x0))
TorchScript -> run compiled forward '_matrix_solve_forward' with args [((2,), False), ((2, 2), False)] (DEBUG), 2026-05-10 11:15:19,415n

Running forward pass of custom op forward '_matrix_solve_forward' given args ('y', 'matrix') containing 2 native tensors (DEBUG), 2026-05-10 11:15:19,415n

Performing linear solve auto with tolerance float64 1e-05 (rel), float64 1e-05 (abs), max_iterations=1000 with backend torch (DEBUG), 2026-05-10 11:15:19,417n

Out[17]:
tensor([3., 2.])

Further Reading¶

We will upload a whitepaper to the ArXiv shortly, describing the implemented algorithms.

🌐 ΦML   •   📖 Documentation   •   🔗 API   •   ▶ Videos   •   Examples