• 🌐 ΦML • 📖 Documentation • 🔗 API • ▶ Videos • Examples
Linear solves are a vital part in many simulations and machine learning applications. ΦML provides an easy-to-use interface for performing linear solves that supports backpropagation via implicit gradients. Dense, sparse, and matrix-free linear systems can be solved this way.
%%capture
!pip install phiml
from phiml import math
from phiml.math import wrap, channel, dual, spatial, Solve, tensor
Linear solves and sparse matrices are supported on all backends.
Feel free to choose the below line to use jax
, tensorflow
or numpy
instead.
math.use('torch')
torch
We can perform a linear solve by passing a matrix A
, right-hand-side vector b
and initial guess x0
to solve_linear()
.
We recommend passing ΦML tensors. Then, the dual dimensions of the matrix must match the initial guess and the primal dimensions must match the right-hand-side.
Alternatively, solve_linear()
can be used called with native tensors (see below).
A = tensor([[0, 1], [1, 0]], channel('b_vec'), dual('x_vec'))
b = tensor([2, 3], channel('b_vec'))
x0 = tensor([0, 0], channel('x_vec'))
math.solve_linear(A, b, Solve(x0=x0))
(3.000, 2.000) along b_vecᶜ
ΦML implements multiple algorithms to solve linear systems, such as the conjugate gradient method (CG
) and the stabilized bi-conjugate gradient method (biCG
).
All SciPy solvers are also available.
For a full list, see here.
math.solve_linear(A, b, Solve('CG', x0=x0))
(3.000, 2.000) along b_vecᶜ
math.solve_linear(A, b, Solve('biCG-stab', x0=x0))
(3.000, 2.000) along b_vecᶜ
math.solve_linear(A, b, Solve('scipy-GMres', x0=x0))
(3, 2) along b_vecᶜ int64
Instead of passing a matrix, you can also specify a linear Python function that computes the matrix-vector product. This will typically be slower unless the function is compiled to a matrix.
def linear_function(x):
return x * (2, 1)
math.solve_linear(linear_function, b, Solve(x0=x0))
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[8], line 4 1 def linear_function(x): 2 return x * (2, 1) ----> 4 math.solve_linear(linear_function, b, Solve(x0=x0)) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:674, in solve_linear(f, y, solve, grad_for_f, f_kwargs, *f_args, **f_kwargs_) 671 return result # must return exactly `x` so gradient isn't computed w.r.t. other quantities 673 _function_solve = attach_gradient_solve(_function_solve_forward, auxiliary_args='is_backprop,f_kwargs,solve', matrix_adjoint=grad_for_f) --> 674 return _function_solve(y, solve, f_args, f_kwargs=f_kwargs) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_functional.py:961, in CustomGradientFunction.__call__(self, *args, **kwargs) 956 if len(self.traces) >= 8: 957 warnings.warn(f"""{self.__name__} has been traced {len(self.traces)} times. 958 To avoid memory leaks, call {f_name(self.f)}.traces.clear(), {f_name(self.f)}.recorded_mappings.clear(). 959 Traces can be avoided by jit-compiling the code that calls custom gradient functions. 960 """, RuntimeWarning, stacklevel=2) --> 961 native_result = self.traces[key](*natives) # With PyTorch + jit, this does not call forward_native every time 962 output_key = match_output_signature(key, self.recorded_mappings, self) 963 output_tensors = assemble_tensors(native_result, output_key.specs) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/torch/_torch_backend.py:237, in TorchBackend.custom_gradient.<locals>.select_jit(*args) 235 args = [self.as_tensor(arg) for arg in args] 236 if not CURRENT_JIT_CALLS: --> 237 return torch_function.apply(*args) 238 jit = CURRENT_JIT_CALLS[-1] 239 if torch._C._get_tracing_state() is None: # first call: record this function File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, **kwargs) 536 if not torch._C._are_functorch_transforms_active(): 537 # See NOTE: [functorch vjp and autograd interaction] 538 args = _functorch.utils.unwrap_dead_wrappers(args) --> 539 return super().apply(*args, **kwargs) # type: ignore[misc] 541 if cls.setup_context == _SingleLevelFunction.setup_context: 542 raise RuntimeError( 543 "In order to use an autograd.Function with functorch transforms " 544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 545 "staticmethod. For more details, please see " 546 "https://pytorch.org/docs/master/notes/extending.func.html" 547 ) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/torch/_torch_backend.py:1174, in construct_torch_custom_function.<locals>.TorchCustomFunction.forward(ctx, *args, **kwargs) 1172 return f_example_output 1173 ML_LOGGER.debug(f"TorchScript -> run compiled {f.__name__} with args {[(tuple(a.shape), a.requires_grad) for a in args]}") -> 1174 y = (jit_f or f)(*args, **kwargs) 1175 ctx.save_for_backward(*args, *y) 1176 ctx.input_count = len(args) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_functional.py:917, in CustomGradientFunction._trace.<locals>.forward_native(*natives) 915 kwargs = assemble_tree(in_key.tree, in_tensors, attr_type=all_attributes) 916 ML_LOGGER.debug(f"Running forward pass of custom op {forward_native.__name__} given args {tuple(kwargs.keys())} containing {len(natives)} native tensors") --> 917 result = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors 918 nest, out_tensors = disassemble_tree(result, cache=True, attr_type=all_attributes) 919 result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:670, in solve_linear.<locals>._function_solve_forward(y, solve, f_args, f_kwargs, is_backprop) 667 y_native = y_native[batch_index] 668 return y_native --> 670 result = _linear_solve_forward(y, solve, native_lin_f, pattern_dims_in=non_batch(x0_tensor).names, pattern_dims_out=non_batch(y_tensor).names, preconditioner=None, backend=backend, is_backprop=is_backprop) 671 return result File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:723, in _linear_solve_forward(y, solve, native_lin_op, pattern_dims_in, pattern_dims_out, preconditioner, backend, is_backprop) 721 method = 'scipy-' + method 722 t = time.perf_counter() --> 723 ret = backend.linear_solve(method, native_lin_op, y_native, x0_native, rtol, atol, max_iter, preconditioner, matrix_offset) 724 t = time.perf_counter() - t 725 trj_dims = [batch(trajectory=len(max_iter))] if trj else [] File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/_backend.py:1452, in Backend.linear_solve(self, method, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset) 1423 """ 1424 Solve the system of linear equations A · x = y. 1425 This method need not provide a gradient for the operation. (...) 1449 `SolveResult` 1450 """ 1451 if method == 'auto': -> 1452 return self.conjugate_gradient_adaptive(lin, y, x0, rtol, atol, max_iter, pre, matrix_offset) 1453 elif method.startswith('scipy-'): 1454 from ._linalg import scipy_sparse_solve File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/torch/_torch_backend.py:884, in TorchBackend.conjugate_gradient_adaptive(self, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset) 882 if not self.is_available(y): 883 warnings.warn(f"CG with preconditioners is not optimized for PyTorch and will always run the maximum number of iterations when JIT-compiled (max_iter={max_iter}).", RuntimeWarning) --> 884 return Backend.conjugate_gradient_adaptive(self, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset) 885 assert isinstance(lin, torch.Tensor), "Batched matrices are not yet supported" 886 batch_size = self.staticshape(y)[0] File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/_backend.py:1479, in Backend.conjugate_gradient_adaptive(self, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset) 1477 """ Conjugate gradient algorithm with adaptive step size. Signature matches to `Backend.linear_solve()`. """ 1478 from ._linalg import cg_adaptive -> 1479 return cg_adaptive(self, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/_linalg.py:104, in cg_adaptive(b, lin, y, x0, rtol, atol, max_iter, pre, matrix_offset) 102 batch_size = b.staticshape(y)[0] 103 x = x0 --> 104 dx = residual = y - linear(b, lin, x, matrix_offset) 105 dy = linear(b, lin, dx, matrix_offset) 106 iterations = b.zeros([batch_size], DType(int, 32)) File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/_linalg.py:787, in linear(b, lin, vector, matrix_offset, get_without_offset) 785 def linear(b: Backend, lin, vector, matrix_offset, get_without_offset=False): 786 """Apply linear function with matrix offset to vector, i.e. `(lin+matrix_offset) @ vector`""" --> 787 result = result_wo_offset = b.linear(lin, vector) 788 if matrix_offset is not None: 789 result += b.sum(vector, 1, keepdims=True) * matrix_offset[:, None] File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/_backend.py:1488, in Backend.linear(self, lin, vector) 1486 def linear(self, lin, vector): 1487 if callable(lin): -> 1488 return lin(vector) 1489 elif isinstance(lin, (tuple, list)): 1490 for lin_i in lin: File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:664, in solve_linear.<locals>._function_solve_forward.<locals>.native_lin_f(native_x, batch_index) 662 y_ = f(x, *f_args, **f_kwargs) 663 _, (y_tensor_,) = disassemble_tree(y_, cache=False, attr_type=value_attributes) --> 664 assert set(non_batch(y_tensor_)) == set(non_batch(y_tensor)), f"Function returned dimensions {y_tensor_.shape} but right-hand-side has shape {y_tensor.shape}" 665 y_native = reshaped_native(y_tensor_, [batches, non_batch(y_tensor)] if backend.ndims(native_x) >= 2 else [non_batch(y_tensor)]) # order like right-hand-side 666 if batch_index is not None and batches.volume > 1: AssertionError: Function returned dimensions (x_vecᶜ=2) but right-hand-side has shape (b_vecᶜ=2)
ΦML can also build an explicit matrix representation of the provided Python function.
You can do this either by explicitly obtaining the matrix first using matrix_from_function
or by annotating the linear function with
jit_compile_linear
.
If the function adds a constant offset to the output, this will automatically be subtracted from the right-hand-side vector.
from phiml.math import jit_compile_linear
math.solve_linear(jit_compile_linear(linear_function), b, Solve(x0=x0))
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/backend/torch/_torch_backend.py:805: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:53.) return torch.sparse_csr_tensor(row_pointers, column_indices, values, shape, device=values.device)
(1.000, 2.000, 1.500, 3.000) (b_vecᶜ=2, x_vecᶜ=2)
ΦML includes an ILU and experimental cluster preconditioner.
To use a preconditioner, simply specify preconditioner='ilu'
when creating the Solve
object.
math.solve_linear(jit_compile_linear(linear_function), b, Solve('scipy-CG', x0=x0, preconditioner='ilu'))
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/scipy/sparse/linalg/_dsolve/linsolve.py:640: SparseEfficiencyWarning: CSR matrix format is required. Converting to CSR matrix. warn('CSR matrix format is required. Converting to CSR matrix.',
(1, 2, 1, 3) (b_vecᶜ=2, x_vecᶜ=2) int64
The ILU preconditioner always runs on the CPU and should be paired with a SciPy linear solver for optimal efficiency.
Available SciPy solvers include 'scipy-direct'
, 'scipy-CG'
, 'scipy-GMres'
, 'scipy-biCG'
, 'scipy-biCG-stab'
, 'scipy-CGS'
, 'scipy-QMR'
, 'scipy-GCrotMK'
(see the API).
If the matrix or linear function is constant, i.e. only depends on NumPy arrays, the preconditioner computation can be performed during JIT compilation.
@math.jit_compile
def jit_perform_solve(b):
return math.solve_linear(jit_compile_linear(linear_function), b, Solve('scipy-CG', x0=x0, preconditioner='ilu'))
jit_perform_solve(b)
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:717: RuntimeWarning: Preconditioners are not supported for sparse scipy-CG in torch JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode. warnings.warn(f"Preconditioners are not supported for sparse {method} in {y.default_backend} JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode.", RuntimeWarning) /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/phiml/math/_optimize.py:717: RuntimeWarning: Preconditioners are not supported for sparse scipy-CG in torch JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode. warnings.warn(f"Preconditioners are not supported for sparse {method} in {y.default_backend} JIT mode. Disabling preconditioner. Use Jax or TensorFlow to enable preconditioners in JIT mode.", RuntimeWarning)
(1, 2, 1, 3) (b_vecᶜ=2, x_vecᶜ=2) int64
Here, the ILU preconditioner is computed during JIT-compile time since the linear function does not depend on b
.
ΦML enables backpropagation through linear solves. Instead of backpropagating through the unrolled loop (which can lead to inaccurate results and cause high memory consumption), Unify runs an ajoint linear solve for the pullback operation.
def loss_function(b):
x = math.solve_linear(jit_compile_linear(linear_function), b, Solve(x0=x0))
return math.l2_loss(x)
gradient_function = math.gradient(loss_function, 'b', get_output=False)
gradient_function(b)
(2.500, 3.750) along b_vecᶜ
ΦML can also compute gradients for the (sparse) matrix used in a linear solve, which allows differentiating w.r.t. parameters that influenced the matrix values via backpropagation.
To enable this, pass grad_for_f=True
to the solve_linear()
call.
@math.jit_compile_linear
def conditioned_linear_function(x, conditioning):
return x * conditioning
def loss_function(conditioning):
b = math.ones_like(conditioning)
x = math.solve_linear(conditioned_linear_function, b, Solve(x0=x0), conditioning=conditioning, grad_for_f=True)
return math.l2_loss(x)
gradient_function = math.gradient(loss_function, 'conditioning', get_output=False)
gradient_function(tensor([1., 2.], channel('x_vec')))
(-1.000, -0.125) along x_vecᶜ
When a linear solve (or minimize
call) does not find a solution, a subclass of ConvergenceException
is thrown, depending on the reason.
NotConverged
is thrown.Diverged
is thrown.These exceptions can also be thrown during backpropagation if the adjoint solve fails (except for TensorFlow).
You can deal with failed solves using Python's try
-except
clause.
try:
solution = math.solve_linear(lambda x: 0 * x, wrap([1, 2, 3], spatial('x')), solve=math.Solve(x0=math.zeros(spatial(x=3))))
print("converged", solution)
except math.ConvergenceException as exc:
print(exc)
print(f"Last estimate: {exc.result.x}")
Φ-ML CG-adaptive (torch) did not converge to rel_tol=1e-05, abs_tol=1e-05 within 1000 iterations. Max residual: 3, ., 0 Last estimate: (0.000, 0.000, 0.000) along xˢ
If you want the regular execution flow to continue despite non-convergence, you can pass suppress=[math.Diverged, math.NotConverged]
to the Solve
constructor.
All solves are logged internally and can be shown by setting phiml.set_logging_level('debug')
.
Additional solve properties can be recorded using a SolveTape
.
Recording the full optimization trajectory requires setting record_trajectories=True
.
import phiml
phiml.set_logging_level('debug')
with math.SolveTape() as solves:
math.solve_linear(jit_compile_linear(linear_function), b, Solve('scipy-CG', x0=0 * b, preconditioner='ilu'))
factor_ilu: auto-selecting iterations=1 (eager mode) for matrix (2.000, 0.000); (0.000, 1.000) (b_vecᶜ=2, ~b_vecᵈ=2) (DEBUG), 2024-12-19 17:10:01,898n TorchScript -> run compiled forward '_matrix_solve_forward' with args [((2,), False)] (DEBUG), 2024-12-19 17:10:01,913n Running forward pass of custom op forward '_matrix_solve_forward' given args ('y',) containing 1 native tensors (DEBUG), 2024-12-19 17:10:01,914n Performing linear solve scipy-CG with tolerance float64 1e-05 (rel), float64 1e-05 (abs), max_iterations=1000 with backend torch (DEBUG), 2024-12-19 17:10:01,918n
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/scipy/sparse/linalg/_dsolve/linsolve.py:640: SparseEfficiencyWarning: CSR matrix format is required. Converting to CSR matrix. warn('CSR matrix format is required. Converting to CSR matrix.',
The solve information about a performed solve(s) can then be obtained by indexing specific solves by index or Solve
object.
print(solves[0].solve)
print("Solution", solves[0].x)
print("Residual", solves[0].residual)
print("Fun.evals", solves[0].function_evaluations)
print("Iterations", solves[0].iterations)
print("Diverged", solves[0].diverged)
print("Converged", solves[0].converged)
scipy-CG with tolerance float64 1e-05 (rel), float64 1e-05 (abs), max_iterations=1000 Solution (1, 3) along b_vecᶜ int64 Residual (0, 0) along b_vecᶜ int64 Fun.evals 2 Iterations 1 Diverged False Converged True
When performing a linear solve without ΦML tensors, the matrix must have shape (..., N, N) and x0
and b
must have shape (..., N)
where ...
denotes the batch dimensions.
This matches the signatures of the native solve functions like torch.linalg.solve
or jax.numpy.linalg.solve
.
import torch
A = torch.tensor([[0., 1], [1, 0]])
b = torch.tensor([2., 3])
x0 = torch.tensor([0., 0])
math.solve_linear(A, b, Solve(x0=x0))
TorchScript -> run compiled forward '_matrix_solve_forward' with args [((2,), False), ((2, 2), False)] (DEBUG), 2024-12-19 17:10:01,940n Running forward pass of custom op forward '_matrix_solve_forward' given args ('y', 'matrix') containing 2 native tensors (DEBUG), 2024-12-19 17:10:01,941n Performing linear solve auto with tolerance float64 1e-05 (rel), float64 1e-05 (abs), max_iterations=1000 with backend torch (DEBUG), 2024-12-19 17:10:01,944n
tensor([3., 2.])