Module phiml.parallel
This package contains utilities for running PhiML code in parallel. To use it, you need to declare the code to be parallelized as properties of a dataclass. A tutorial can be found here. The following is a simple example that runs a batch of FFTs in parallel:
>>> from dataclasses import dataclass
>>> from functools import cached_property
>>> from phiml.math import Tensor, spatial, fft, batch
>>> from phiml.parallel import parallel_property, parallel_compute, INFER, MIXED
>>>
>>> @dataclass(frozen=True)
>>> class ParallelComputation:
>>> data: Tensor
>>>
>>> @parallel_property(requires=spatial, out=INFER, on_direct_eval='raise')
>>> def result(self) -> Tensor:
>>> return fft(self.data)
>>>
>>> data = ...
>>> computation = ParallelComputation(data)
>>> parallel_compute(computation, [ParallelComputation.result], parallel_dims=batch, max_workers=4)
>>> computation.result
Properties declared as @cached_property behave as @parallel_property(requires=INFER, out=INFER, on_direct_eval='host-compute').
Computational Graph
PhiML performs static code analysis (source code in 1.14, bytecode from 1.15 onwards) to determine dependencies between
all properties declared as either @parallel_property or @cached_property.
Static analysis traces into methods and properties defined in the same class, but does not trace functions outside.
Do not pass a reference self to external calls, as this could lead to dependencies not being captured properly.
The resulting computational graph is split into computation stages depending on the requires values of the involved properties.
The requires property declares dims that must be present on the data in order to compute the result.
Properties that have no required dims can be parallelized across all dims specified in parallel_compute().
Properties that cannot be parallelized at all (because all parallel dims are marked as requires) are computed on the host process.
INFER via Dynamic Traces
When either requires or out is set as INFER, PhiML performs a dynamic trace to infer their values.
These properties may only use PhiML functions and not all functions are supported as of yet.
Use phiml.set_loggin_level to catch failed traces.
MIXED Parallelization
For properties that only use supported PhiML calls, you can set requires=MIXED, which allows PhiML to split the computation of the proeprty into multiple stages.
For example, the expression math.sum(batched_data * 2, 'example,x,y') would be split into three parts if example is parallelized over:
- Multiplication (parallel)
- Sum over
x,y(parallel) - Sum over
example(on host)
Data transfers between workers and host are performed as needed.
Currently, tensor operators (a+b), simple one-tensor functions (abs,exp,sin,round,is_nan,…) and reduction functions (sum,prod,max,finite_mean,…) are supported.
Do not enable MIXED on functions that use unsupported functions. While direct calls on tracers will fail, there can still be undesirable effects.
Disk Caching
By configuring parallel_compute(), you can have workers write results onto disk instead of serializing the full result in transfers.
This should be used if system memory is limited or data needs to be passed between processes many times.
See this example for a demonstration.
Functions
def get_cache_files(obj) ‑> Set[str]-
Expand source code
def get_cache_files(obj) -> Set[str]: """ Searches the data structure for all disk-cached tensors and returns all referenced files. Args: obj: `Tensor` or pytree or dataclass (`phiml.math.magic.PhiTreeNode`). Returns: Collection of file paths. """ result = set() _recursive_add_cache_files(obj, result) return resultSearches the data structure for all disk-cached tensors and returns all referenced files.
Args
objTensoror pytree or dataclass (PhiTreeNode).
Returns
Collection of file paths.
def load_cache_as(host_backend: str | phiml.backend._backend.Backend,
worker_backend: str | phiml.backend._backend.Backend = None)-
Expand source code
@contextlib.contextmanager def load_cache_as(host_backend: Union[str, Backend], worker_backend: Union[str, Backend] = None): """ Context manager to temporarily set the backend for loading disk-cached tensors into memory. Args: host_backend: Backend name or instance to use in the current process. This also applies to `parallel_compute()` stages that cannot be parallelized and thus run in the host process. worker_backend: Backend name or instance to use in worker processes, see `parallel_compute()`. If `None`, the workers will use the same backend per tensor as the host. Usage: >>> with load_cache_as('torch', worker_backend='numpy'): """ _LOAD_AS.append(get_backend(host_backend)) if worker_backend is not None: _WORKER_LOAD_AS.append(get_backend(worker_backend)) try: yield finally: _LOAD_AS.pop() if worker_backend is not None: _WORKER_LOAD_AS.pop()Context manager to temporarily set the backend for loading disk-cached tensors into memory.
Args
host_backend- Backend name or instance to use in the current process.
This also applies to
parallel_compute()stages that cannot be parallelized and thus run in the host process. worker_backend- Backend name or instance to use in worker processes, see
parallel_compute(). IfNone, the workers will use the same backend per tensor as the host.
Usage
>>> with load_cache_as('torch', worker_backend='numpy'): def on_load_into_memory(callback: Callable[[phiml.math._tensors.Tensor], None])-
Expand source code
def on_load_into_memory(callback: Callable[[Tensor], None]): """ Register a function to be called every time a cached tensor is loaded from disk into memory in this process. Tensors are stored on disk to save memory, e.g. in `parallel_compute()` if a memory limit is specified. When accessing the tensor values or using them in any computation, the values are temporarily loaded into memory and `callback` is called. Args: callback: Function `callback(Tensor)`. """ _ON_LOAD_FROM_DISK.append(callback)Register a function to be called every time a cached tensor is loaded from disk into memory in this process.
Tensors are stored on disk to save memory, e.g. in
parallel_compute()if a memory limit is specified.When accessing the tensor values or using them in any computation, the values are temporarily loaded into memory and
callbackis called.Args
callback- Function
callback(Tensor).
def parallel_compute(instance,
properties: Sequence,
parallel_dims=<function batch>,
max_workers=4,
memory_limit: float | None = None,
cache_dir: str = None,
keep_intermediate=False)-
Expand source code
def parallel_compute(instance, properties: Sequence, parallel_dims=batch, max_workers=multiprocessing.cpu_count(), memory_limit: Optional[float] = None, cache_dir: str = None, keep_intermediate=False): """ Compute the values of properties decorated with `@cached_property` or `@parallel_property` of a dataclass instance in parallel. **Multiple stages via `requires=...`** If `@parallel_property` are computed whose `requires` overlaps with `parallel_dims`, a separate computation stage is set up to compute these properties with fewer parallel workers. In the presence of different `requires`, the computation is split into different stages in accordance with the property dependency graph. Properties that cannot be parallelized (because it requires all `parallel_dims`) are computed on the host process. **Caching tensors to disk** When `memory_limit` and `cache_dir` are set, the evaluation will try to adhere to the given memory limits by moving tensors out of memory onto disk. This is only applied to the outputs of `@cached_property` and `@parallel_property` calls, not to intermediate values used in their computation. The per-process memory limit is calculated per stage, dividing the total memory by the active worker count. Cached tensors behave like regular tensors and are temporarily loaded back into memory when accessing their values or using them in a computation. When parallelizing, the full result is assembled by stacking multiple disk-backed tensors from different files created by different processes. These composite tensors will reference multiple binary files and can be pickled/unpickled safely without loading the data into memory. This enables passing large data references to different processes or saving the structure to a file without the data content. See Also: `cached_property`, `parallel_property`, `get_cache_files()`, `on_load_into_memory()`. Warnings: `parallel_compute` breaks automatic differentiation. Args: instance: Dataclass instance for which to compute the values of `@cached_property` or `@parallel_property` fields. properties: References to the unbound properties. These must be `cached_property` or `parallel_property`. parallel_dims: Dimensions to parallelize over. max_workers: Number of processes to spawn. memory_limit: Limit to the total memory consumption from `Tensor` instances on property outputs. cache_dir: Directory path to store cached tensors in if `memory_limit` is set. keep_intermediate: Whether the outputs of cached properties required to compute `properties` but not contained in `properties` should be kept in memory. If `False`, these values will not be cached on `instance` after this call. """ assert hasattr(instance, '__dict__'), f"parallel_compute requires instance to have __dict__. Slots are not supported." if memory_limit is not None: assert cache_dir is not None, "cache_dir must be specified if memory_limit is set" dims = shape(instance).only(parallel_dims) # --- Build graph of relevant properties --- cls = type(instance) if is_class_from_notebook(cls): class_example = (cls.__name__, class_to_string(cls)) else: class_example = cls.__new__(cls) nodes: Dict[str, PGraphNode] = {} output_user = PGraphNode('<output>', EMPTY_SHAPE, parallel_dims, None, False, set(), None, False, [], 999999) for p in properties: recursive_add_node(instance, cls, property_name(p), p, dims, nodes).users.append(output_user) # nodes = merge_duplicate_nodes(nodes.values()) for node in nodes.values(): if node.name in instance.__dict__: node.done = True node.dependencies = [] stages = build_stages(nodes) ML_LOGGER.debug(f"Assembled {len(stages)} stages containing {sum(len(ns) for ns in stages)} properties for parallel computation.") # --- Execute stages --- any_parallel = any(stage_nodes[0].distributed for stage_nodes in stages) and max_workers > 0 max_workers = min(max_workers, max(stage_nodes[0].distributed.volume for stage_nodes in stages)) init_args = (_WORKER_LOAD_AS[-1].name if _WORKER_LOAD_AS else None,) with ProcessPoolExecutor(initializer=init_worker, initargs=init_args, max_workers=max_workers) if any_parallel else nullcontext() as pool: for stage_idx, stage_nodes in enumerate(stages): parallel_dims = stage_nodes[0].distributed if parallel_dims and any_parallel: ML_LOGGER.debug(f"Parallel | {parallel_dims} | {[n.name for n in stage_nodes]}") property_names = [n.name for n in stage_nodes if n.is_used_later] programs = {n.name: n.program for n in stage_nodes if n.program is not None} required_caches = set.union(*[n.prior_dep_names for n in stage_nodes]) required_fields = set.union(*[n.field_dep_names for n in stage_nodes]) # --- Split data --- instances = unstack(instance, parallel_dims) n = len(instances) caches = [unstack(instance.__dict__[c], parallel_dims, expand=True) for c in required_caches] data = [] for i, inst_i in enumerate(instances): data_i = {} data_i.update(**{f_name: inst_i.__dict__[f_name] for f_name in required_fields}) if caches: data_i.update(**{c: caches[j][i] for j, c in enumerate(required_caches)}) data.append(data_i) keep_intermediate or delete_intermediate_caches(instance, stages, stage_idx) # --- Submit to pool --- mem_per_item = memory_limit / min(max_workers, len(instances)) if memory_limit is not None else None cache_dir is not None and os.makedirs(cache_dir, exist_ok=True) cache_file_suggestions = [os.path.join(cache_dir, f"s{stage_idx}_i{i}") for i in range(len(instances))] if cache_dir is not None else [None] * len(instances) results = list(pool.map(_evaluate_properties, [class_example]*n, [property_names]*n, [programs]*n, data, [mem_per_item]*n, cache_file_suggestions)) for name, *outputs in zip(property_names, *results): output = stack(outputs, parallel_dims) instance.__dict__[name] = output else: # No parallelization in this stage ML_LOGGER.debug(f"Host | {[n.name for n in stage_nodes]}") _EXECUTION_STATUS.append('host') for n in stage_nodes: get_property_value(instance, n.name, n.program) assert _EXECUTION_STATUS.pop() == 'host', "Host execution status mismatch. (internal error)" keep_intermediate or delete_intermediate_caches(instance, stages, stage_idx)Compute the values of properties decorated with
@cached_propertyor@parallel_propertyof a dataclass instance in parallel.Multiple stages via
requires=...If@parallel_propertyare computed whoserequiresoverlaps withparallel_dims, a separate computation stage is set up to compute these properties with fewer parallel workers. In the presence of differentrequires, the computation is split into different stages in accordance with the property dependency graph. Properties that cannot be parallelized (because it requires allparallel_dims) are computed on the host process.Caching tensors to disk When
memory_limitandcache_dirare set, the evaluation will try to adhere to the given memory limits by moving tensors out of memory onto disk. This is only applied to the outputs of@cached_propertyand@parallel_propertycalls, not to intermediate values used in their computation. The per-process memory limit is calculated per stage, dividing the total memory by the active worker count. Cached tensors behave like regular tensors and are temporarily loaded back into memory when accessing their values or using them in a computation. When parallelizing, the full result is assembled by stacking multiple disk-backed tensors from different files created by different processes. These composite tensors will reference multiple binary files and can be pickled/unpickled safely without loading the data into memory. This enables passing large data references to different processes or saving the structure to a file without the data content.See Also:
cached_property,parallel_property(),get_cache_files(),on_load_into_memory().Warnings
parallel_compute()breaks automatic differentiation.Args
instance- Dataclass instance for which to compute the values of
@cached_propertyor@parallel_propertyfields. properties- References to the unbound properties. These must be
cached_propertyorparallel_property(). parallel_dims- Dimensions to parallelize over.
max_workers- Number of processes to spawn.
memory_limit- Limit to the total memory consumption from
Tensorinstances on property outputs. cache_dir- Directory path to store cached tensors in if
memory_limitis set. keep_intermediate- Whether the outputs of cached properties required to compute
propertiesbut not contained inpropertiesshould be kept in memory. IfFalse, these values will not be cached oninstanceafter this call.
def parallel_property(func: Callable = None,
/,
requires: str | Sequence | set | phiml.math._shape.Shape | Callable | object | None = None,
out: Any = <object object>,
persistent: bool = False,
on_direct_eval='raise')-
Expand source code
def parallel_property(func: Callable = None, /, requires: Union[DimFilter, object] = None, out: Any = INFER, persistent: bool = False, on_direct_eval='raise'): """ Similar to `@cached_property` but with additional controls over parallelization. See Also: `parallel_compute()`. Args: func: Method to wrap. requires: Dimensions which must be present within one process. These cannot be parallelized when computing this property. out: Declare output shapes and dtypes in the same tree structure as the output of `func`. Placeholders for `shape` and `dtype` can be created using `shape * dtype`. `Shape` instances will be assumed to be of floating-point type. persistent: If `True` the output of this property will be available after `parallel_compute` even if it was not specified as a property to be computed, as long as its computation is necessary to compute any of the requested properties. on_direct_eval: What to do when the property is accessed normally (outside `parallel_compute`) before it has been computed. Option: * `'raise'`: Raise an error. * `'host-compute'`: Compute the property directly, without using multi-threading. """ assert on_direct_eval in {'host-compute', 'raise'} if out is not INFER: out = to_tracers(out) if func is None: return partial(parallel_property, requires=requires, out=out, persistent=persistent, on_direct_eval=on_direct_eval) return ParallelProperty(func, requires, out, persistent, on_direct_eval)Similar to
@cached_propertybut with additional controls over parallelization.See Also:
parallel_compute().Args
func- Method to wrap.
requires- Dimensions which must be present within one process. These cannot be parallelized when computing this property.
out- Declare output shapes and dtypes in the same tree structure as the output of
func. Placeholders forshapeanddtypecan be created usingshape * dtype.Shapeinstances will be assumed to be of floating-point type. persistent- If
Truethe output of this property will be available afterparallel_compute()even if it was not specified as a property to be computed, as long as its computation is necessary to compute any of the requested properties. on_direct_eval-
What to do when the property is accessed normally (outside
parallel_compute()) before it has been computed. Option:'raise': Raise an error.'host-compute': Compute the property directly, without using multi-threading.
def set_cache_ttl(ttl_seconds: float | None)-
Expand source code
def set_cache_ttl(ttl_seconds: Optional[float]): """ Sets the time to live (TTL) for data loaded into memory for disk-backed tensors. This function should be called before the tensor cache is used. Args: ttl_seconds: Time to live. If `None`, data will be unallocated immediately after use. """ _CACHE.ttl = ttl_secondsSets the time to live (TTL) for data loaded into memory for disk-backed tensors. This function should be called before the tensor cache is used.
Args
ttl_seconds- Time to live. If
None, data will be unallocated immediately after use.