Module phiml.dataclasses

PhiML makes it easy to work with custom classes. Any class decorated with @dataclass can be used with phiml.math functions, such as shape(), slice(), stack(), concat(), expand() and many more. We recommend always setting frozen=True.

PhiML's dataclass support explicitly handles properties defined decorated with functools.cached_property. Their cached values will be preserved whenever possible, preventing costly re-computation when modifying unrelated properties, slicing, gathering, or stacking objects. This will usually affect all data fields, i.e. fields that hold Tensor or composite properties.

Dataclass fields can additionally be specified as being variable and value. This affects which data_fields are optimized / traced by functions like jit_compile() or minimize().

Template for custom classes:

>>> from typing import Tuple from dataclasses import dataclass
>>> from phiml.dataclasses import sliceable, cached_property
>>> from phiml.math import Tensor, Shape, shape
>>>
>>> @sliceable
>>> @dataclass(frozen=True)
>>> class MyClass:
>>>     # --- Attributes ---
>>>     attribute1: Tensor
>>>     attribute2: 'MyClass' = None
>>>
>>>     # --- Additional fields ---
>>>     field1: str = 'x'
>>>
>>>     # --- Special fields declaring attribute types. Must be of type Tuple[str, ...] ---
>>>     variable_attrs: Tuple[str, ...] = ('attribute1', 'attribute2')
>>>     value_attrs: Tuple[str, ...] = ()
>>>
>>>     def __post_init__(self):
>>>         assert self.field1 in 'xyz'
>>>
>>>     @cached_property
>>>     def shape(self) -> Shape:  # override the default shape which is merged from all attribute shapes
>>>         return self.attribute1.shape & shape(self.attribute2)
>>>
>>>     @cached_property  # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code)
>>>     def derived_property(self) -> Tensor:
>>>         return self.attribute1 + 1

Functions

def config_fields(obj) ‑> Sequence[dataclasses.Field]
Expand source code
def config_fields(obj) -> Sequence[dataclasses.Field]:
    """
    List all dataclass Fields of `obj` that are not considered data_fields or special.
    These cannot hold any Tensors or shaped objects.

    Args:
        obj: Dataclass type or instance.

    Returns:
        Sequence of `dataclasses.Field`.
    """
    return [f for f in dataclasses.fields(obj) if not is_data_field(f) and f.name not in ('variable_attrs', 'value_attrs')]

List all dataclass Fields of obj that are not considered data_fields or special. These cannot hold any Tensors or shaped objects.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def copy(obj: ~PhiMLDataclass, /, call_metaclass=False) ‑> ~PhiMLDataclass
Expand source code
def copy(obj: PhiMLDataclass, /, call_metaclass=False) -> PhiMLDataclass:
    """
    Create a copy of `obj`, including cached properties.

    Args:
        obj: Dataclass instance.
        call_metaclass: Whether to copy `obj` by invoking `type(obj).__call__`.
            If `obj` defines a metaclass, this will allow users to define custom constructors for dataclasses.
    """
    return replace(obj, call_metaclass=call_metaclass)

Create a copy of obj, including cached properties.

Args

obj
Dataclass instance.
call_metaclass
Whether to copy obj by invoking type(obj).__call__. If obj defines a metaclass, this will allow users to define custom constructors for dataclasses.
def data_eq(cls=None,
/,
*,
rel_tolerance=0.0,
abs_tolerance=0.0,
equal_nan=True,
compare_tensors_by_ref=False)
Expand source code
def data_eq(cls=None, /, *, rel_tolerance=0., abs_tolerance=0., equal_nan=True, compare_tensors_by_ref=False):
    """
    Decorator for dataclasses that overrides the default `__eq__` method to compare data fields by shape and value instead of equality.
    Non-data fields are compared by equality (`==`).

    See Also:
        `equal()`, `data_fields()`, `non_data_fields()`.

    Args:
        rel_tolerance: Relative tolerance for comparing floating point tensors.
        abs_tolerance: Absolute tolerance for comparing floating point tensors.
        equal_nan: Whether to consider NaN values as equal.
        compare_tensors_by_ref: If True, compares all tensors by reference (location in memory) instead of value.
            This avoids costly element-wise comparisons, but may count objects as unequal even if they hold the same data.
    """
    def wrap(cls):
        assert cls.__dataclass_params__.eq, f"@data_eq can only be used with dataclasses with eq=True."
        cls.__default_dataclass_eq__ = cls.__eq__
        def __tensor_eq__(obj, other):
            if compare_tensors_by_ref:
                with equality_by_ref():
                    return cls.__default_dataclass_eq__(obj, other)
            with equality_by_shape_and_value(rel_tolerance, abs_tolerance, equal_nan):
                return cls.__default_dataclass_eq__(obj, other)
        cls.__eq__ = __tensor_eq__
        # __ne__ calls `not __eq__()` by default
        return cls
    return wrap(cls) if cls is not None else wrap  # See if we're being called as @dataclass or @dataclass().

Decorator for dataclasses that overrides the default __eq__ method to compare data fields by shape and value instead of equality. Non-data fields are compared by equality (==).

See Also: equal(), data_fields(), non_data_fields().

Args

rel_tolerance
Relative tolerance for comparing floating point tensors.
abs_tolerance
Absolute tolerance for comparing floating point tensors.
equal_nan
Whether to consider NaN values as equal.
compare_tensors_by_ref
If True, compares all tensors by reference (location in memory) instead of value. This avoids costly element-wise comparisons, but may count objects as unequal even if they hold the same data.
def data_fields(obj) ‑> Sequence[dataclasses.Field]
Expand source code
def data_fields(obj) -> Sequence[dataclasses.Field]:
    """
    List all dataclass Fields of `obj` that are considered data, i.e. can hold (directly or indirectly) one or multiple `Tensor` instances.
    This includes fields referencing other dataclasses.

    Args:
        obj: Dataclass type or instance.

    Returns:
        Sequence of `dataclasses.Field`.
    """
    return [f for f in dataclasses.fields(obj) if is_data_field(f)]

List all dataclass Fields of obj that are considered data, i.e. can hold (directly or indirectly) one or multiple Tensor instances. This includes fields referencing other dataclasses.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def equal(obj1, obj2, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True)
Expand source code
def equal(obj1, obj2, rel_tolerance=0., abs_tolerance=0., equal_nan=True):
    """
    Checks if two dataclass instances are equal by comparing their data fields by value and shape.
    Non-data fields are compared by equality (`==`).

    Args:
        obj1: First dataclass instance.
        obj2: Second dataclass instance.
        rel_tolerance: Relative tolerance for comparing floating point tensors.
        abs_tolerance: Absolute tolerance for comparing floating point tensors.
        equal_nan: Whether to consider NaN values as equal.

    Returns:
        `bool`
    """
    cls = type(obj1)
    eq_fn = cls.__default_dataclass_eq__ if hasattr(cls, '__default_dataclass_eq__') else cls.__eq__
    with equality_by_shape_and_value(rel_tolerance, abs_tolerance, equal_nan):
        return eq_fn(obj1, obj2)

Checks if two dataclass instances are equal by comparing their data fields by value and shape. Non-data fields are compared by equality (==).

Args

obj1
First dataclass instance.
obj2
Second dataclass instance.
rel_tolerance
Relative tolerance for comparing floating point tensors.
abs_tolerance
Absolute tolerance for comparing floating point tensors.
equal_nan
Whether to consider NaN values as equal.

Returns

bool

def getitem(obj: ~PhiMLDataclass,
item,
keepdims: str | Sequence | set | phiml.math._shape.Shape | Callable | None = None) ‑> ~PhiMLDataclass
Expand source code
def getitem(obj: PhiMLDataclass, item, keepdims: DimFilter = None) -> PhiMLDataclass:
    """
    Slice / gather a dataclass by broadcasting the operation to its data_fields.

    You may call this from `__getitem__` to allow the syntax `my_class[component_str]`, `my_class[slicing_dict]`, `my_class[boolean_tensor]` and `my_class[index_tensor]`.

    ```python
    def __getitem__(self, item):
        return getitem(self, item)
    ```

    Args:
        obj: Dataclass instance to slice / gather.
        item: One of the supported tensor slicing / gathering values.
        keepdims: Dimensions that will not be removed during slicing.
            When selecting a single slice, these dims will remain with size 1.

    Returns:
        Slice of `obj` of same type.
    """
    assert dataclasses.is_dataclass(obj), f"obj must be a dataclass but got {type(obj)}"
    item = slicing_dict(obj, item)
    if keepdims:
        keep = shape(obj).only(keepdims)
        for dim, sel in item.items():
            if dim in keep:
                if isinstance(sel, int):
                    item[dim] = slice(sel, sel+1)
                elif isinstance(sel, str) and ',' not in sel:
                    item[dim] = [sel]
    if not item:
        return obj
    attrs = data_fields(obj)
    kwargs = {f.name: slice_(getattr(obj, f.name), item) if f in attrs else getattr(obj, f.name) for f in dataclasses.fields(obj)}
    cls = type(obj)
    new_obj = cls.__new__(cls, **kwargs)
    new_obj.__init__(**kwargs)
    cache = {k: slice_(v, item) for k, v in obj.__dict__.items() if isinstance(getattr(type(obj), k, None), cached_property) and not isinstance(v, SHAPE_TYPES)}
    new_obj.__dict__.update(cache)
    return new_obj

Slice / gather a dataclass by broadcasting the operation to its data_fields.

You may call this from __getitem__ to allow the syntax my_class[component_str], my_class[slicing_dict], my_class[boolean_tensor] and my_class[index_tensor].

def __getitem__(self, item):
    return getitem(self, item)

Args

obj
Dataclass instance to slice / gather.
item
One of the supported tensor slicing / gathering values.
keepdims
Dimensions that will not be removed during slicing. When selecting a single slice, these dims will remain with size 1.

Returns

Slice of obj of same type.

def non_data_fields(obj) ‑> Sequence[dataclasses.Field]
Expand source code
def non_data_fields(obj) -> Sequence[dataclasses.Field]:
    """
    List all dataclass Fields of `obj` that cannot hold tensors (directly or indirectly).

    Args:
        obj: Dataclass type or instance.

    Returns:
        Sequence of `dataclasses.Field`.
    """
    return [f for f in dataclasses.fields(obj) if not is_data_field(f)]

List all dataclass Fields of obj that cannot hold tensors (directly or indirectly).

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

def replace(obj: ~PhiMLDataclass, /, call_metaclass=False, **changes) ‑> ~PhiMLDataclass
Expand source code
def replace(obj: PhiMLDataclass, /, call_metaclass=False, **changes) -> PhiMLDataclass:
    """
    Create a copy of `obj` with some fields replaced.
    Unlike `dataclasses.replace()`, this function also transfers `@cached_property` members if their dependencies are not affected.

    Args:
        obj: Dataclass instance.
        call_metaclass: Whether to copy `obj` by invoking `type(obj).__call__`.
            If `obj` defines a metaclass, this will allow users to define custom constructors for dataclasses.
        **changes: New field values to replace old ones.

    Returns:
        Copy of `obj` with replaced values.
    """
    cls = obj.__class__
    kwargs = {f.name: getattr(obj, f.name) for f in dataclasses.fields(obj)}
    kwargs.update(**changes)
    if call_metaclass:
        new_obj = cls(**kwargs)
    else:  # This allows us override the dataclass constructor with a metaclass for user convenience, but not call it internally.
        new_obj = cls.__new__(cls)
        new_obj.__init__(**kwargs)
    cache = get_unchanged_cache(obj, set(changes.keys()))
    new_obj.__dict__.update(cache)
    return new_obj

Create a copy of obj with some fields replaced. Unlike dataclasses.replace(), this function also transfers @cached_property members if their dependencies are not affected.

Args

obj
Dataclass instance.
call_metaclass
Whether to copy obj by invoking type(obj).__call__. If obj defines a metaclass, this will allow users to define custom constructors for dataclasses.
**changes
New field values to replace old ones.

Returns

Copy of obj with replaced values.

def sliceable(cls=None,
/,
*,
dim_attrs=True,
t_props=True,
keepdims=None,
dim_repr=True,
lazy_dims=True)
Expand source code
def sliceable(cls=None, /, *, dim_attrs=True, t_props=True, keepdims=None, dim_repr=True, lazy_dims=True):
    """
    Decorator for frozen dataclasses, adding slicing functionality by defining `__getitem__` and enabling the `instance.dim` syntax.
    This enables slicing similar to tensors, gathering and boolean masking.

    Args:
        dim_attrs: Whether to generate `__getattr__` that allows slicing via the syntax `instance.dim[...]` where `dim` is the name of any dim present on `instance`.
        t_props: Whether to generate the properties `Tc`, `Ts` and `Ti` for transposing channel/spatial/instance dims.
        keepdims: Which dimensions should be kept with size 1 taking a single slice along them. This will preserve labels.
        dim_repr: Whether to replace the default `repr` of a dataclass by a simplified one based on the object's shape.
        lazy_dims: If `False`, instantiates all dims of `shape(self)` as member variables during construction. Dataclass must have `slots=False`.
            If `True`, implements `__getattr__` to instantiate accessed dims on demand. This will be skipped if a user-defined `__getattr__` is found.
    """
    def wrap(cls):
        assert dataclasses.is_dataclass(cls), f"@sliceable must be used on a @dataclass, i.e. declared above it."
        assert cls.__dataclass_params__.frozen, f"@sliceable dataclasses must be frozen. Declare as @dataclass(frozen=True)"
        assert data_fields(cls), f"PhiML dataclasses must have at least one field storing a Shaped object, such as a Tensor, tree of Tensors or compatible dataclass."
        if not implements(cls, '__getitem__', exclude_metaclass=True):
            def __dataclass_getitem__(obj, item):
                return getitem(obj, item, keepdims=keepdims)
            cls.__getitem__ = __dataclass_getitem__
        if t_props:
            def transpose(obj, dim_type):
                old_shape = shape(obj)
                new_shape = old_shape.transpose(dim_type)
                return rename_dims(obj, old_shape, new_shape)
            cls.Tc = property(partial(transpose, dim_type=CHANNEL_DIM))
            cls.Ts = property(partial(transpose, dim_type=SPATIAL_DIM))
            cls.Ti = property(partial(transpose, dim_type=INSTANCE_DIM))
        if not lazy_dims:  # instantiate BoundDims in constructor
            assert not hasattr(cls, '__slots__'), f"front-loading dims is not supported for dataclasses using slots."
            dc_init = cls.__init__
            def __dataclass_init__(self, *args, **kwargs):
                dc_init(self, *args, **kwargs)
                for dim in shape(self):
                    object.__setattr__(self, dim.name, BoundDim(self, dim.name))  # object.__setattr__ also works for frozen dataclasses
            cls.__init__ = __dataclass_init__
        else:  # instantiate BoundDims lazily via __getattr__
            if dim_attrs and not implements(cls, '__getattr__'):
                def __dataclass_getattr__(obj, name: str):
                    if name == 'shape' or (name.startswith('__') and name.endswith('__')):  # __setstate__, __deepcopy__ can cause infinite recursion
                        raise AttributeError(f"'{type(obj)}' instance has no attribute '{name}'")
                    if name in shape(obj):
                        return BoundDim(obj, name)
                    elif hasattr(type(obj), name):
                        raise RuntimeError(f"Evaluation of property '{type(obj).__name__}.{name}' failed.")
                    else:
                        raise AttributeError(f"'{type(obj)}' instance has no attribute '{name}'")
                cls.__getattr__ = __dataclass_getattr__
        if dim_repr:
            def __dataclass_repr__(obj):
                try:
                    content = shape(obj)
                    if not content:
                        content = f"{', '.join([f'{f.name}={getattr(obj, f.name)}' for f in dataclasses.fields(cls)])}"
                except BaseException as err:
                    content = f"Unknown shape: {type(err).__name__}"
                return f"{type(obj).__name__}[{content}]"
            cls.__repr__ = __dataclass_repr__
        return cls
    return wrap(cls) if cls is not None else wrap  # See if we're being called as @dataclass or @dataclass().

Decorator for frozen dataclasses, adding slicing functionality by defining __getitem__ and enabling the instance.dim syntax. This enables slicing similar to tensors, gathering and boolean masking.

Args

dim_attrs
Whether to generate __getattr__ that allows slicing via the syntax instance.dim[…] where dim is the name of any dim present on instance.
t_props
Whether to generate the properties Tc, Ts and Ti for transposing channel/spatial/instance dims.
keepdims
Which dimensions should be kept with size 1 taking a single slice along them. This will preserve labels.
dim_repr
Whether to replace the default repr of a dataclass by a simplified one based on the object's shape.
lazy_dims
If False, instantiates all dims of shape(self) as member variables during construction. Dataclass must have slots=False. If True, implements __getattr__ to instantiate accessed dims on demand. This will be skipped if a user-defined __getattr__ is found.
def special_fields(obj) ‑> Sequence[dataclasses.Field]
Expand source code
def special_fields(obj) -> Sequence[dataclasses.Field]:
    """
    List all special dataclass Fields of `obj`, i.e. fields that don't store data related to the object but rather meta-information relevant to PhiML.

    These include `variable_attrs` and `value_attrs`.

    Args:
        obj: Dataclass type or instance.

    Returns:
        Sequence of `dataclasses.Field`.
    """
    return [f for f in dataclasses.fields(obj) if f.name in ('variable_attrs', 'value_attrs')]

List all special dataclass Fields of obj, i.e. fields that don't store data related to the object but rather meta-information relevant to PhiML.

These include variable_attrs and value_attrs.

Args

obj
Dataclass type or instance.

Returns

Sequence of dataclasses.Field.

Classes

class cached_property (func)
Expand source code
def __get__(self, instance, owner=None):
    if instance is None:
        return self
    if self.attrname is None:
        raise TypeError(
            "Cannot use cached_property instance without calling __set_name__ on it.")
    try:
        cache = instance.__dict__
    except AttributeError:  # not all objects have __dict__ (e.g. class defines slots)
        msg = (
            f"No '__dict__' attribute on {type(instance).__name__!r} "
            f"instance to cache {self.attrname!r} property."
        )
        raise TypeError(msg) from None
    val = cache.get(self.attrname, _NOT_FOUND)
    if val is _NOT_FOUND:
        with self.lock:
            # check if another thread filled cache while we awaited lock
            val = cache.get(self.attrname, _NOT_FOUND)
            if val is _NOT_FOUND:
                val = self.func(instance)
                try:
                    cache[self.attrname] = val
                except TypeError:
                    msg = (
                        f"The '__dict__' attribute on {type(instance).__name__!r} instance "
                        f"does not support item assignment for caching {self.attrname!r} property."
                    )
                    raise TypeError(msg) from None
    return val

Subclasses

  • phiml.parallel._parallel.ParallelProperty