Module phiml.dataclasses

PhiML makes it easy to work with custom classes. Any class decorated with @dataclass can be used with phiml.math functions, such as shape(), slice(), stack(), concat(), expand() and many more. We recommend always setting frozen=True.

PhiML's dataclass support explicitly handles properties defined decorated with functools.cached_property. Their cached values will be preserved whenever possible, preventing costly re-computation when modifying unrelated properties, slicing, gathering, or stacking objects. This will usually affect all data fields, i.e. fields that hold Tensor or composite properties.

Dataclass fields can additionally be specified as being variable and value. This affects which data_fields are optimized / traced by functions like jit_compile() or minimize().

Template for custom classes:

>>> from typing import Tuple from dataclasses import dataclass
>>> from phiml.dataclasses import sliceable, cached_property
>>> from phiml.math import Tensor, Shape, shape
>>>
>>> @sliceable
>>> @dataclass(frozen=True)
>>> class MyClass:
>>>     # --- Attributes ---
>>>     attribute1: Tensor
>>>     attribute2: 'MyClass' = None
>>>
>>>     # --- Additional fields ---
>>>     field1: str = 'x'
>>>
>>>     # --- Special fields declaring attribute types. Must be of type Tuple[str, ...] ---
>>>     variable_attrs: Tuple[str, ...] = ('attribute1', 'attribute2')
>>>     value_attrs: Tuple[str, ...] = ()
>>>
>>>     def __post_init__(self):
>>>         assert self.field1 in 'xyz'
>>>
>>>     @cached_property
>>>     def shape(self) -> Shape:  # override the default shape which is merged from all attribute shapes
>>>         return self.attribute1.shape & shape(self.attribute2)
>>>
>>>     @cached_property  # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code)
>>>     def derived_property(self) -> Tensor:
>>>         return self.attribute1 + 1

Functions

def config_fields(obj) ‑> Sequence[dataclasses.Field]

List all dataclass Fields of obj that are not considered data_fields or special. These cannot hold any Tensors or shaped objects.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def copy(obj: ~PhiMLDataclass, /, call_metaclass=False) ‑> ~PhiMLDataclass

Create a copy of obj, including cached properties.

Args

obj
Dataclass instance.
call_metaclass
Whether to copy obj by invoking type(obj).__call__. If obj defines a metaclass, this will allow users to define custom constructors for dataclasses.
def data_eq(cls=None, /, *, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True, compare_tensors_by_ref=False)
def data_fields(obj) ‑> Sequence[dataclasses.Field]

List all dataclass Fields of obj that are considered data, i.e. can hold (directly or indirectly) one or multiple Tensor instances. This includes fields referencing other dataclasses.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def equal(obj1, obj2, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True)

Checks if two

Args

obj1: obj2: rel_tolerance: abs_tolerance: equal_nan: Returns:

def get_cache_files(obj) ‑> Set[str]

Searches the data structure for all disk-cached tensors and returns all referenced files.

Args

obj
Tensor or pytree or dataclass (PhiTreeNode).

Returns

Collection of file paths.

def getitem(obj: ~PhiMLDataclass, item, keepdims: Union[str, Sequence[+T_co], set, phiml.math._shape.Shape, Callable, None] = None) ‑> ~PhiMLDataclass

Slice / gather a dataclass by broadcasting the operation to its data_fields.

You may call this from __getitem__ to allow the syntax my_class[component_str], my_class[slicing_dict], my_class[boolean_tensor] and my_class[index_tensor].

def __getitem__(self, item):
    return getitem(self, item)

Args

obj
Dataclass instance to slice / gather.
item
One of the supported tensor slicing / gathering values.
keepdims
Dimensions that will not be removed during slicing. When selecting a single slice, these dims will remain with size 1.

Returns

Slice of obj of same type.

def non_data_fields(obj) ‑> Sequence[dataclasses.Field]

List all dataclass Fields of obj that cannot hold tensors (directly or indirectly).

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def on_load_into_memory(callback: Callable[[phiml.math._tensors.Tensor], None])

Register a function to be called every time a cached tensor is loaded from disk into memory in this process.

Tensors are stored on disk to save memory, e.g. in parallel_compute() if a memory limit is specified.

When accessing the tensor values or using them in any computation, the values are temporarily loaded into memory and callback is called.

Args

callback
Function callback(Tensor).
def parallel_compute(instance, properties: Sequence[+T_co], parallel_dims=<function batch>, max_workers=4, memory_limit: Optional[float] = None, cache_dir: str = None, keep_intermediate=False)

Compute the values of properties decorated with @cached_property or @parallel_property of a dataclass instance in parallel.

Multiple stages via requires=... If @parallel_property are computed whose requires overlaps with parallel_dims, a separate computation stage is set up to compute these properties with fewer parallel workers. In the presence of different requires, the computation is split into different stages in accordance with the property dependency graph. Properties that cannot be parallelized (because it requires all parallel_dims) are computed on the host process.

Caching tensors to disk When memory_limit and cache_dir are set, the evaluation will try to adhere to the given memory limits by moving tensors out of memory onto disk. This is only applied to the outputs of @cached_property and @parallel_property calls, not to intermediate values used in their computation. The per-process memory limit is calculated per stage, dividing the total memory by the active worker count. Cached tensors behave like regular tensors and are temporarily loaded back into memory when accessing their values or using them in a computation. When parallelizing, the full result is assembled by stacking multiple disk-backed tensors from different files created by different processes. These composite tensors will reference multiple binary files and can be pickled/unpickled safely without loading the data into memory. This enables passing large data references to different processes or saving the structure to a file without the data content.

See Also: cached_property, parallel_property(), get_cache_files(), on_load_into_memory().

Warnings

parallel_compute() breaks automatic differentiation.

Args

instance
Dataclass instance for which to compute the values of @cached_property or @parallel_property fields.
properties
References to the unbound properties. These must be cached_property or parallel_property().
parallel_dims
Dimensions to parallelize over.
max_workers
Number of processes to spawn.
memory_limit
Limit to the total memory consumption from Tensor instances on property outputs.
cache_dir
Directory path to store cached tensors in if memory_limit is set.
keep_intermediate
Whether the outputs of cached properties required to compute properties but not contained in properties should be kept in memory. If False, these values will not be cached on instance after this call.
def parallel_property(func: Callable = None, /, requires: Union[str, Sequence[+T_co], set, phiml.math._shape.Shape, Callable, None] = None, on_direct_eval='raise')

Similar to @cached_property but with additional controls over parallelization.

See Also: parallel_compute().

Args

func
Method to wrap.
requires
Dimensions which must be present within one process. These cannot be parallelized when computing this property.
on_direct_eval

What to do when the property is accessed normally (outside parallel_compute()) before it has been computed. Option:

  • 'raise': Raise an error.
  • 'host-compute': Compute the property directly, without using multi-threading.
def replace(obj: ~PhiMLDataclass, /, call_metaclass=False, **changes) ‑> ~PhiMLDataclass

Create a copy of obj with some fields replaced. Unlike dataclasses.replace(), this function also transfers @cached_property members if their dependencies are not affected.

Args

obj
Dataclass instance.
call_metaclass
Whether to copy obj by invoking type(obj).__call__. If obj defines a metaclass, this will allow users to define custom constructors for dataclasses.
**changes
New field values to replace old ones.

Returns

Copy of obj with replaced values.

def set_cache_ttl(ttl_seconds: Optional[float])

Sets the time to live (TTL) for data loaded into memory for disk-backed tensors. This function should be called before the tensor cache is used.

Args

ttl_seconds
Time to live. If None, data will be unallocated immediately after use.
def sliceable(cls=None, /, *, dim_attrs=True, t_props=True, keepdims=None, dim_repr=True, lazy_dims=True)

Decorator for frozen dataclasses, adding slicing functionality by defining __getitem__ and enabling the instance.dim syntax. This enables slicing similar to tensors, gathering and boolean masking.

Args

dim_attrs
Whether to generate __getattr__ that allows slicing via the syntax instance.dim[…] where dim is the name of any dim present on instance.
t_props
Whether to generate the properties Tc, Ts and Ti for transposing channel/spatial/instance dims.
keepdims
Which dimensions should be kept with size 1 taking a single slice along them. This will preserve labels.
dim_repr
Whether to replace the default repr of a dataclass by a simplified one based on the object's shape.
lazy_dims
If False, instantiates all dims of shape(self) as member variables during construction. Dataclass must have slots=False. If True, implements __getattr__ to instantiate accessed dims on demand. This will be skipped if a user-defined __getattr__ is found.
def special_fields(obj) ‑> Sequence[dataclasses.Field]

List all special dataclass Fields of obj, i.e. fields that don't store data related to the object but rather meta-information relevant to PhiML.

These include variable_attrs and value_attrs.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

Classes

class cached_property (func)
Expand source code
def __get__(self, instance, owner=None):
    if instance is None:
        return self
    if self.attrname is None:
        raise TypeError(
            "Cannot use cached_property instance without calling __set_name__ on it.")
    try:
        cache = instance.__dict__
    except AttributeError:  # not all objects have __dict__ (e.g. class defines slots)
        msg = (
            f"No '__dict__' attribute on {type(instance).__name__!r} "
            f"instance to cache {self.attrname!r} property."
        )
        raise TypeError(msg) from None
    val = cache.get(self.attrname, _NOT_FOUND)
    if val is _NOT_FOUND:
        with self.lock:
            # check if another thread filled cache while we awaited lock
            val = cache.get(self.attrname, _NOT_FOUND)
            if val is _NOT_FOUND:
                val = self.func(instance)
                try:
                    cache[self.attrname] = val
                except TypeError:
                    msg = (
                        f"The '__dict__' attribute on {type(instance).__name__!r} instance "
                        f"does not support item assignment for caching {self.attrname!r} property."
                    )
                    raise TypeError(msg) from None
    return val

Subclasses

  • phiml.dataclasses._parallel.ParallelProperty