Module phiml.parallel

This package contains utilities for running PhiML code in parallel. To use it, you need to declare the code to be parallelized as properties of a dataclass. A tutorial can be found here. The following is a simple example that runs a batch of FFTs in parallel:

>>> from dataclasses import dataclass
>>> from functools import cached_property
>>> from phiml.math import Tensor, spatial, fft, batch
>>> from phiml.parallel import parallel_property, parallel_compute, INFER, MIXED
>>>
>>> @dataclass(frozen=True)
>>> class ParallelComputation:
>>>     data: Tensor
>>>
>>>     @parallel_property(requires=spatial, out=INFER, on_direct_eval='raise')
>>>     def result(self) -> Tensor:
>>>         return fft(self.data)
>>>
>>> data = ...
>>> computation = ParallelComputation(data)
>>> parallel_compute(computation, [ParallelComputation.result], parallel_dims=batch, max_workers=4)
>>> computation.result

Properties declared as @cached_property behave as @parallel_property(requires=INFER, out=INFER, on_direct_eval='host-compute').

Computational Graph

PhiML performs static code analysis (source code in 1.14, bytecode from 1.15 onwards) to determine dependencies between all properties declared as either @parallel_property or @cached_property. Static analysis traces into methods and properties defined in the same class, but does not trace functions outside. Do not pass a reference self to external calls, as this could lead to dependencies not being captured properly.

The resulting computational graph is split into computation stages depending on the requires values of the involved properties. The requires property declares dims that must be present on the data in order to compute the result. Properties that have no required dims can be parallelized across all dims specified in parallel_compute(). Properties that cannot be parallelized at all (because all parallel dims are marked as requires) are computed on the host process.

INFER via Dynamic Traces

When either requires or out is set as INFER, PhiML performs a dynamic trace to infer their values. These properties may only use PhiML functions and not all functions are supported as of yet. Use phiml.set_loggin_level to catch failed traces.

MIXED Parallelization

For properties that only use supported PhiML calls, you can set requires=MIXED, which allows PhiML to split the computation of the proeprty into multiple stages. For example, the expression math.sum(batched_data * 2, 'example,x,y') would be split into three parts if example is parallelized over:

  • Multiplication (parallel)
  • Sum over x,y (parallel)
  • Sum over example (on host)

Data transfers between workers and host are performed as needed.

Currently, tensor operators (a+b), simple one-tensor functions (abs,exp,sin,round,is_nan,…) and reduction functions (sum,prod,max,finite_mean,…) are supported. Do not enable MIXED on functions that use unsupported functions. While direct calls on tracers will fail, there can still be undesirable effects.

Disk Caching

By configuring parallel_compute(), you can have workers write results onto disk instead of serializing the full result in transfers. This should be used if system memory is limited or data needs to be passed between processes many times. See this example for a demonstration.

Functions

def get_cache_files(obj) ‑> Set[str]
Expand source code
def get_cache_files(obj) -> Set[str]:
    """
    Searches the data structure for all disk-cached tensors and returns all referenced files.

    Args:
        obj: `Tensor` or pytree or dataclass (`phiml.math.magic.PhiTreeNode`).

    Returns:
        Collection of file paths.
    """
    result = set()
    _recursive_add_cache_files(obj, result)
    return result

Searches the data structure for all disk-cached tensors and returns all referenced files.

Args

obj
Tensor or pytree or dataclass (PhiTreeNode).

Returns

Collection of file paths.

def load_cache_as(host_backend: str | phiml.backend._backend.Backend,
worker_backend: str | phiml.backend._backend.Backend = None)
Expand source code
@contextlib.contextmanager
def load_cache_as(host_backend: Union[str, Backend], worker_backend: Union[str, Backend] = None):
    """
    Context manager to temporarily set the backend for loading disk-cached tensors into memory.

    Args:
        host_backend: Backend name or instance to use in the current process.
            This also applies to `parallel_compute()` stages that cannot be parallelized and thus run in the host process.
        worker_backend: Backend name or instance to use in worker processes, see `parallel_compute()`. If `None`, the workers will use the same backend per tensor as the host.

    Usage:

        >>> with load_cache_as('torch', worker_backend='numpy'):
    """
    _LOAD_AS.append(get_backend(host_backend))
    if worker_backend is not None:
        _WORKER_LOAD_AS.append(get_backend(worker_backend))
    try:
        yield
    finally:
        _LOAD_AS.pop()
        if worker_backend is not None:
            _WORKER_LOAD_AS.pop()

Context manager to temporarily set the backend for loading disk-cached tensors into memory.

Args

host_backend
Backend name or instance to use in the current process. This also applies to parallel_compute() stages that cannot be parallelized and thus run in the host process.
worker_backend
Backend name or instance to use in worker processes, see parallel_compute(). If None, the workers will use the same backend per tensor as the host.

Usage

>>> with load_cache_as('torch', worker_backend='numpy'):
def on_load_into_memory(callback: Callable[[phiml.math._tensors.Tensor], None])
Expand source code
def on_load_into_memory(callback: Callable[[Tensor], None]):
    """
    Register a function to be called every time a cached tensor is loaded from disk into memory in this process.

    Tensors are stored on disk to save memory, e.g. in `parallel_compute()` if a memory limit is specified.

    When accessing the tensor values or using them in any computation, the values are temporarily loaded into memory and `callback` is called.

    Args:
        callback: Function `callback(Tensor)`.
    """
    _ON_LOAD_FROM_DISK.append(callback)

Register a function to be called every time a cached tensor is loaded from disk into memory in this process.

Tensors are stored on disk to save memory, e.g. in parallel_compute() if a memory limit is specified.

When accessing the tensor values or using them in any computation, the values are temporarily loaded into memory and callback is called.

Args

callback
Function callback(Tensor).
def parallel_compute(instance,
properties: Sequence,
parallel_dims=<function batch>,
max_workers=4,
memory_limit: float | None = None,
cache_dir: str = None,
keep_intermediate=False)
Expand source code
def parallel_compute(instance, properties: Sequence, parallel_dims=batch,
                     max_workers=multiprocessing.cpu_count(), memory_limit: Optional[float] = None, cache_dir: str = None, keep_intermediate=False):
    """
    Compute the values of properties decorated with `@cached_property` or `@parallel_property` of a dataclass instance in parallel.

    **Multiple stages via `requires=...`**
    If `@parallel_property` are computed whose `requires` overlaps with `parallel_dims`, a separate computation stage is set up to compute these properties with fewer parallel workers.
    In the presence of different `requires`, the computation is split into different stages in accordance with the property dependency graph.
    Properties that cannot be parallelized (because it requires all `parallel_dims`) are computed on the host process.

    **Caching tensors to disk**
    When `memory_limit` and `cache_dir` are set, the evaluation will try to adhere to the given memory limits by moving tensors out of memory onto disk.
    This is only applied to the outputs of `@cached_property` and `@parallel_property` calls, not to intermediate values used in their computation.
    The per-process memory limit is calculated per stage, dividing the total memory by the active worker count.
    Cached tensors behave like regular tensors and are temporarily loaded back into memory when accessing their values or using them in a computation.
    When parallelizing, the full result is assembled by stacking multiple disk-backed tensors from different files created by different processes.
    These composite tensors will reference multiple binary files and can be pickled/unpickled safely without loading the data into memory.
    This enables passing large data references to different processes or saving the structure to a file without the data content.

    See Also:
        `cached_property`, `parallel_property`, `get_cache_files()`, `on_load_into_memory()`.

    Warnings:
        `parallel_compute` breaks automatic differentiation.

    Args:
        instance: Dataclass instance for which to compute the values of `@cached_property` or `@parallel_property` fields.
        properties: References to the unbound properties. These must be `cached_property` or `parallel_property`.
        parallel_dims: Dimensions to parallelize over.
        max_workers: Number of processes to spawn.
        memory_limit: Limit to the total memory consumption from `Tensor` instances on property outputs.
        cache_dir: Directory path to store cached tensors in if `memory_limit` is set.
        keep_intermediate: Whether the outputs of cached properties required to compute `properties` but not contained in `properties` should be kept in memory.
            If `False`, these values will not be cached on `instance` after this call.
    """
    assert hasattr(instance, '__dict__'), f"parallel_compute requires instance to have __dict__. Slots are not supported."
    if memory_limit is not None:
        assert cache_dir is not None, "cache_dir must be specified if memory_limit is set"
    dims = shape(instance).only(parallel_dims)
    # --- Build graph of relevant properties ---
    cls = type(instance)
    if is_class_from_notebook(cls):
        class_example = (cls.__name__, class_to_string(cls))
    else:
        class_example = cls.__new__(cls)
    nodes: Dict[str, PGraphNode] = {}
    output_user = PGraphNode('<output>', EMPTY_SHAPE, parallel_dims, None, False, set(), None, False, [], 999999)
    for p in properties:
        recursive_add_node(instance, cls, property_name(p), p, dims, nodes).users.append(output_user)
    # nodes = merge_duplicate_nodes(nodes.values())
    for node in nodes.values():
        if node.name in instance.__dict__:
            node.done = True
            node.dependencies = []
    stages = build_stages(nodes)
    ML_LOGGER.debug(f"Assembled {len(stages)} stages containing {sum(len(ns) for ns in stages)} properties for parallel computation.")
    # --- Execute stages ---
    any_parallel = any(stage_nodes[0].distributed for stage_nodes in stages) and max_workers > 0
    max_workers = min(max_workers, max(stage_nodes[0].distributed.volume for stage_nodes in stages))
    init_args = (_WORKER_LOAD_AS[-1].name if _WORKER_LOAD_AS else None,)
    with ProcessPoolExecutor(initializer=init_worker, initargs=init_args, max_workers=max_workers) if any_parallel else nullcontext() as pool:
        for stage_idx, stage_nodes in enumerate(stages):
            parallel_dims = stage_nodes[0].distributed
            if parallel_dims and any_parallel:
                ML_LOGGER.debug(f"Parallel | {parallel_dims} | {[n.name for n in stage_nodes]}")
                property_names = [n.name for n in stage_nodes if n.is_used_later]
                programs = {n.name: n.program for n in stage_nodes if n.program is not None}
                required_caches = set.union(*[n.prior_dep_names for n in stage_nodes])
                required_fields = set.union(*[n.field_dep_names for n in stage_nodes])
                # --- Split data ---
                instances = unstack(instance, parallel_dims)
                n = len(instances)
                caches = [unstack(instance.__dict__[c], parallel_dims, expand=True) for c in required_caches]
                data = []
                for i, inst_i in enumerate(instances):
                    data_i = {}
                    data_i.update(**{f_name: inst_i.__dict__[f_name] for f_name in required_fields})
                    if caches:
                        data_i.update(**{c: caches[j][i] for j, c in enumerate(required_caches)})
                    data.append(data_i)
                keep_intermediate or delete_intermediate_caches(instance, stages, stage_idx)
                # --- Submit to pool ---
                mem_per_item = memory_limit / min(max_workers, len(instances)) if memory_limit is not None else None
                cache_dir is not None and os.makedirs(cache_dir, exist_ok=True)
                cache_file_suggestions = [os.path.join(cache_dir, f"s{stage_idx}_i{i}") for i in range(len(instances))] if cache_dir is not None else [None] * len(instances)
                results = list(pool.map(_evaluate_properties, [class_example]*n, [property_names]*n, [programs]*n, data, [mem_per_item]*n, cache_file_suggestions))
                for name, *outputs in zip(property_names, *results):
                    output = stack(outputs, parallel_dims)
                    instance.__dict__[name] = output
            else:  # No parallelization in this stage
                ML_LOGGER.debug(f"Host | {[n.name for n in stage_nodes]}")
                _EXECUTION_STATUS.append('host')
                for n in stage_nodes:
                    get_property_value(instance, n.name, n.program)
                assert _EXECUTION_STATUS.pop() == 'host', "Host execution status mismatch. (internal error)"
                keep_intermediate or delete_intermediate_caches(instance, stages, stage_idx)

Compute the values of properties decorated with @cached_property or @parallel_property of a dataclass instance in parallel.

Multiple stages via requires=... If @parallel_property are computed whose requires overlaps with parallel_dims, a separate computation stage is set up to compute these properties with fewer parallel workers. In the presence of different requires, the computation is split into different stages in accordance with the property dependency graph. Properties that cannot be parallelized (because it requires all parallel_dims) are computed on the host process.

Caching tensors to disk When memory_limit and cache_dir are set, the evaluation will try to adhere to the given memory limits by moving tensors out of memory onto disk. This is only applied to the outputs of @cached_property and @parallel_property calls, not to intermediate values used in their computation. The per-process memory limit is calculated per stage, dividing the total memory by the active worker count. Cached tensors behave like regular tensors and are temporarily loaded back into memory when accessing their values or using them in a computation. When parallelizing, the full result is assembled by stacking multiple disk-backed tensors from different files created by different processes. These composite tensors will reference multiple binary files and can be pickled/unpickled safely without loading the data into memory. This enables passing large data references to different processes or saving the structure to a file without the data content.

See Also: cached_property, parallel_property(), get_cache_files(), on_load_into_memory().

Warnings

parallel_compute() breaks automatic differentiation.

Args

instance
Dataclass instance for which to compute the values of @cached_property or @parallel_property fields.
properties
References to the unbound properties. These must be cached_property or parallel_property().
parallel_dims
Dimensions to parallelize over.
max_workers
Number of processes to spawn.
memory_limit
Limit to the total memory consumption from Tensor instances on property outputs.
cache_dir
Directory path to store cached tensors in if memory_limit is set.
keep_intermediate
Whether the outputs of cached properties required to compute properties but not contained in properties should be kept in memory. If False, these values will not be cached on instance after this call.
def parallel_property(func: Callable = None,
/,
requires: str | Sequence | set | phiml.math._shape.Shape | Callable | object | None = None,
out: Any = <object object>,
persistent: bool = False,
on_direct_eval='raise')
Expand source code
def parallel_property(func: Callable = None, /,
                      requires: Union[DimFilter, object] = None,
                      out: Any = INFER,
                      persistent: bool = False,
                      on_direct_eval='raise'):
    """
    Similar to `@cached_property` but with additional controls over parallelization.

    See Also:
        `parallel_compute()`.

    Args:
        func: Method to wrap.
        requires: Dimensions which must be present within one process. These cannot be parallelized when computing this property.
        out: Declare output shapes and dtypes in the same tree structure as the output of `func`.
            Placeholders for `shape` and `dtype` can be created using `shape * dtype`.
            `Shape` instances will be assumed to be of floating-point type.
        persistent: If `True` the output of this property will be available after `parallel_compute` even if it was not specified as a property to be computed,
            as long as its computation is necessary to compute any of the requested properties.
        on_direct_eval: What to do when the property is accessed normally (outside `parallel_compute`) before it has been computed.
            Option:

            * `'raise'`: Raise an error.
            * `'host-compute'`: Compute the property directly, without using multi-threading.
    """
    assert on_direct_eval in {'host-compute', 'raise'}
    if out is not INFER:
        out = to_tracers(out)
    if func is None:
        return partial(parallel_property, requires=requires, out=out, persistent=persistent, on_direct_eval=on_direct_eval)
    return ParallelProperty(func, requires, out, persistent, on_direct_eval)

Similar to @cached_property but with additional controls over parallelization.

See Also: parallel_compute().

Args

func
Method to wrap.
requires
Dimensions which must be present within one process. These cannot be parallelized when computing this property.
out
Declare output shapes and dtypes in the same tree structure as the output of func. Placeholders for shape and dtype can be created using shape * dtype. Shape instances will be assumed to be of floating-point type.
persistent
If True the output of this property will be available after parallel_compute() even if it was not specified as a property to be computed, as long as its computation is necessary to compute any of the requested properties.
on_direct_eval

What to do when the property is accessed normally (outside parallel_compute()) before it has been computed. Option:

  • 'raise': Raise an error.
  • 'host-compute': Compute the property directly, without using multi-threading.
def set_cache_ttl(ttl_seconds: float | None)
Expand source code
def set_cache_ttl(ttl_seconds: Optional[float]):
    """
    Sets the time to live (TTL) for data loaded into memory for disk-backed tensors.
    This function should be called before the tensor cache is used.

    Args:
        ttl_seconds: Time to live. If `None`, data will be unallocated immediately after use.
    """
    _CACHE.ttl = ttl_seconds

Sets the time to live (TTL) for data loaded into memory for disk-backed tensors. This function should be called before the tensor cache is used.

Args

ttl_seconds
Time to live. If None, data will be unallocated immediately after use.