Module phiml.dataclasses

PhiML makes it easy to work with custom classes. Any class decorated with @dataclass can be used with phiml.math functions, such as shape(), slice(), stack(), concat(), expand() and many more. We recommend always setting frozen=True.

PhiML's dataclass support explicitly handles properties defined decorated with functools.cached_property. Their cached values will be preserved whenever possible, preventing costly re-computation when modifying unrelated properties, slicing, gathering, or stacking objects. This will usually affect all data fields, i.e. fields that hold Tensor or composite properties.

Dataclass fields can additionally be specified as being variable and value. This affects which data_fields are optimized / traced by functions like jit_compile() or minimize().

Template for custom classes:

>>> from typing import Tuple from dataclasses import dataclass
>>> from phiml.dataclasses import sliceable, cached_property
>>> from phiml.math import Tensor, Shape, shape
>>>
>>> @sliceable
>>> @dataclass(frozen=True)
>>> class MyClass:
>>>     # --- Attributes ---
>>>     attribute1: Tensor
>>>     attribute2: 'MyClass' = None
>>>
>>>     # --- Additional fields ---
>>>     field1: str = 'x'
>>>
>>>     # --- Special fields declaring attribute types. Must be of type Tuple[str, ...] ---
>>>     variable_attrs: Tuple[str, ...] = ('attribute1', 'attribute2')
>>>     value_attrs: Tuple[str, ...] = ()
>>>
>>>     def __post_init__(self):
>>>         assert self.field1 in 'xyz'
>>>
>>>     @cached_property
>>>     def shape(self) -> Shape:  # override the default shape which is merged from all attribute shapes
>>>         return self.attribute1.shape & shape(self.attribute2)
>>>
>>>     @cached_property  # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code)
>>>     def derived_property(self) -> Tensor:
>>>         return self.attribute1 + 1

Functions

def config_fields(obj) ‑> Sequence[dataclasses.Field]

List all dataclass Fields of obj that are not considered data_fields or special. These cannot hold any Tensors or shaped objects.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def data_eq(cls=None, /, *, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True, compare_tensors_by_ref=False)
def data_fields(obj) ‑> Sequence[dataclasses.Field]

List all dataclass Fields of obj that are considered data, i.e. can hold (directly or indirectly) one or multiple Tensor instances. This includes fields referencing other dataclasses.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def equal(obj1, obj2, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True)

Checks if two

Args

obj1: obj2: rel_tolerance: abs_tolerance: equal_nan: Returns:

def getitem(obj: ~PhiMLDataclass, item, keepdims: Union[str, Sequence[+T_co], set, ForwardRef('Shape'), Callable, None] = None) ‑> ~PhiMLDataclass

Slice / gather a dataclass by broadcasting the operation to its data_fields.

You may call this from __getitem__ to allow the syntax my_class[component_str], my_class[slicing_dict], my_class[boolean_tensor] and my_class[index_tensor].

def __getitem__(self, item):
    return getitem(self, item)

Args

obj
Dataclass instance to slice / gather.
item
One of the supported tensor slicing / gathering values.
keepdims
Dimensions that will not be removed during slicing. When selecting a single slice, these dims will remain with size 1.

Returns

Slice of obj of same type.

def non_data_fields(obj) ‑> Sequence[dataclasses.Field]

List all dataclass Fields of obj that cannot hold tensors (directly or indirectly).

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def replace(obj: ~PhiMLDataclass, /, call_metaclass=False, **changes) ‑> ~PhiMLDataclass

Create a copy of obj with some fields replaced. Unlike dataclasses.replace(), this function also transfers @cached_property members if their dependencies are not affected.

Args

obj
Dataclass instance.
call_metaclass
Whether to copy obj by invoking type(obj).__call__. If obj defines a metaclass, this will allow users to define custom constructors for dataclasses.
**changes
New field values to replace old ones.

Returns

Copy of obj with replaced values.

def sliceable(cls=None, /, *, dim_attrs=True, keepdims=None, dim_repr=True)

Decorator for frozen dataclasses, adding slicing functionality by defining __getitem__. This enables slicing similar to tensors, gathering and boolean masking.

Args

dim_attrs
Whether to generate __getattr__ that allows slicing via the syntax instance.dim[…] where dim is the name of any dim present on instance.
keepdims
Which dimensions should be kept with size 1 taking a single slice along them. This will preserve item names.
dim_repr
Whether to replace the default repr of a dataclass by a simplified one based on the object's shape.
def special_fields(obj) ‑> Sequence[dataclasses.Field]

List all special dataclass Fields of obj, i.e. fields that don't store data related to the object but rather meta-information relevant to PhiML.

These include variable_attrs and value_attrs.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

Classes

class cached_property (func)
Expand source code
def __get__(self, instance, owner=None):
    if instance is None:
        return self
    if self.attrname is None:
        raise TypeError(
            "Cannot use cached_property instance without calling __set_name__ on it.")
    try:
        cache = instance.__dict__
    except AttributeError:  # not all objects have __dict__ (e.g. class defines slots)
        msg = (
            f"No '__dict__' attribute on {type(instance).__name__!r} "
            f"instance to cache {self.attrname!r} property."
        )
        raise TypeError(msg) from None
    val = cache.get(self.attrname, _NOT_FOUND)
    if val is _NOT_FOUND:
        with self.lock:
            # check if another thread filled cache while we awaited lock
            val = cache.get(self.attrname, _NOT_FOUND)
            if val is _NOT_FOUND:
                val = self.func(instance)
                try:
                    cache[self.attrname] = val
                except TypeError:
                    msg = (
                        f"The '__dict__' attribute on {type(instance).__name__!r} instance "
                        f"does not support item assignment for caching {self.attrname!r} property."
                    )
                    raise TypeError(msg) from None
    return val