Module phiml.dataclasses
PhiML makes it easy to work with custom classes.
Any class decorated with @dataclass can be used with phiml.math functions, such as shape(), slice(), stack(), concat(), expand() and many more.
We recommend always setting frozen=True.
PhiML's dataclass support explicitly handles properties defined decorated with functools.cached_property.
Their cached values will be preserved whenever possible, preventing costly re-computation when modifying unrelated properties, slicing, gathering, or stacking objects.
This will usually affect all data fields, i.e. fields that hold Tensor or composite properties.
Dataclass fields can additionally be specified as being variable and value.
This affects which data_fields are optimized / traced by functions like jit_compile() or minimize().
Template for custom classes:
>>> from typing import Tuple from dataclasses import dataclass
>>> from phiml.dataclasses import sliceable, cached_property
>>> from phiml.math import Tensor, Shape, shape
>>>
>>> @sliceable
>>> @dataclass(frozen=True)
>>> class MyClass:
>>> # --- Attributes ---
>>> attribute1: Tensor
>>> attribute2: 'MyClass' = None
>>>
>>> # --- Additional fields ---
>>> field1: str = 'x'
>>>
>>> # --- Special fields declaring attribute types. Must be of type Tuple[str, ...] ---
>>> variable_attrs: Tuple[str, ...] = ('attribute1', 'attribute2')
>>> value_attrs: Tuple[str, ...] = ()
>>>
>>> def __post_init__(self):
>>> assert self.field1 in 'xyz'
>>>
>>> @cached_property
>>> def shape(self) -> Shape: # override the default shape which is merged from all attribute shapes
>>> return self.attribute1.shape & shape(self.attribute2)
>>>
>>> @cached_property # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code)
>>> def derived_property(self) -> Tensor:
>>> return self.attribute1 + 1
Functions
def config_fields(obj) ‑> Sequence[dataclasses.Field]-
Expand source code
def config_fields(obj) -> Sequence[dataclasses.Field]: """ List all dataclass Fields of `obj` that are not considered data_fields or special. These cannot hold any Tensors or shaped objects. Args: obj: Dataclass type or instance. Returns: Sequence of `dataclasses.Field`. """ return [f for f in dataclasses.fields(obj) if not is_data_field(f) and f.name not in ('variable_attrs', 'value_attrs')]List all dataclass Fields of
objthat are not considered data_fields or special. These cannot hold any Tensors or shaped objects.Args
obj- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field. def copy(obj: ~PhiMLDataclass, /, call_metaclass=False) ‑> ~PhiMLDataclass-
Expand source code
def copy(obj: PhiMLDataclass, /, call_metaclass=False) -> PhiMLDataclass: """ Create a copy of `obj`, including cached properties. Args: obj: Dataclass instance. call_metaclass: Whether to copy `obj` by invoking `type(obj).__call__`. If `obj` defines a metaclass, this will allow users to define custom constructors for dataclasses. """ return replace(obj, call_metaclass=call_metaclass)Create a copy of
obj, including cached properties.Args
obj- Dataclass instance.
call_metaclass- Whether to copy
objby invokingtype(obj).__call__. Ifobjdefines a metaclass, this will allow users to define custom constructors for dataclasses.
def data_eq(cls=None,
/,
*,
rel_tolerance=0.0,
abs_tolerance=0.0,
equal_nan=True,
compare_tensors_by_ref=False)-
Expand source code
def data_eq(cls=None, /, *, rel_tolerance=0., abs_tolerance=0., equal_nan=True, compare_tensors_by_ref=False): """ Decorator for dataclasses that overrides the default `__eq__` method to compare data fields by shape and value instead of equality. Non-data fields are compared by equality (`==`). See Also: `equal()`, `data_fields()`, `non_data_fields()`. Args: rel_tolerance: Relative tolerance for comparing floating point tensors. abs_tolerance: Absolute tolerance for comparing floating point tensors. equal_nan: Whether to consider NaN values as equal. compare_tensors_by_ref: If True, compares all tensors by reference (location in memory) instead of value. This avoids costly element-wise comparisons, but may count objects as unequal even if they hold the same data. """ def wrap(cls): assert cls.__dataclass_params__.eq, f"@data_eq can only be used with dataclasses with eq=True." cls.__default_dataclass_eq__ = cls.__eq__ def __tensor_eq__(obj, other): if compare_tensors_by_ref: with equality_by_ref(): return cls.__default_dataclass_eq__(obj, other) with equality_by_shape_and_value(rel_tolerance, abs_tolerance, equal_nan): return cls.__default_dataclass_eq__(obj, other) cls.__eq__ = __tensor_eq__ # __ne__ calls `not __eq__()` by default return cls return wrap(cls) if cls is not None else wrap # See if we're being called as @dataclass or @dataclass().Decorator for dataclasses that overrides the default
__eq__method to compare data fields by shape and value instead of equality. Non-data fields are compared by equality (==).See Also:
equal(),data_fields(),non_data_fields().Args
rel_tolerance- Relative tolerance for comparing floating point tensors.
abs_tolerance- Absolute tolerance for comparing floating point tensors.
equal_nan- Whether to consider NaN values as equal.
compare_tensors_by_ref- If True, compares all tensors by reference (location in memory) instead of value. This avoids costly element-wise comparisons, but may count objects as unequal even if they hold the same data.
def data_fields(obj) ‑> Sequence[dataclasses.Field]-
Expand source code
def data_fields(obj) -> Sequence[dataclasses.Field]: """ List all dataclass Fields of `obj` that are considered data, i.e. can hold (directly or indirectly) one or multiple `Tensor` instances. This includes fields referencing other dataclasses. Args: obj: Dataclass type or instance. Returns: Sequence of `dataclasses.Field`. """ return [f for f in dataclasses.fields(obj) if is_data_field(f)]List all dataclass Fields of
objthat are considered data, i.e. can hold (directly or indirectly) one or multipleTensorinstances. This includes fields referencing other dataclasses.Args
obj- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field. def equal(obj1, obj2, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True)-
Expand source code
def equal(obj1, obj2, rel_tolerance=0., abs_tolerance=0., equal_nan=True): """ Checks if two dataclass instances are equal by comparing their data fields by value and shape. Non-data fields are compared by equality (`==`). Args: obj1: First dataclass instance. obj2: Second dataclass instance. rel_tolerance: Relative tolerance for comparing floating point tensors. abs_tolerance: Absolute tolerance for comparing floating point tensors. equal_nan: Whether to consider NaN values as equal. Returns: `bool` """ cls = type(obj1) eq_fn = cls.__default_dataclass_eq__ if hasattr(cls, '__default_dataclass_eq__') else cls.__eq__ with equality_by_shape_and_value(rel_tolerance, abs_tolerance, equal_nan): return eq_fn(obj1, obj2)Checks if two dataclass instances are equal by comparing their data fields by value and shape. Non-data fields are compared by equality (
==).Args
obj1- First dataclass instance.
obj2- Second dataclass instance.
rel_tolerance- Relative tolerance for comparing floating point tensors.
abs_tolerance- Absolute tolerance for comparing floating point tensors.
equal_nan- Whether to consider NaN values as equal.
Returns
bool def getitem(obj: ~PhiMLDataclass,
item,
keepdims: str | Sequence | set | phiml.math._shape.Shape | Callable | None = None) ‑> ~PhiMLDataclass-
Expand source code
def getitem(obj: PhiMLDataclass, item, keepdims: DimFilter = None) -> PhiMLDataclass: """ Slice / gather a dataclass by broadcasting the operation to its data_fields. You may call this from `__getitem__` to allow the syntax `my_class[component_str]`, `my_class[slicing_dict]`, `my_class[boolean_tensor]` and `my_class[index_tensor]`. ```python def __getitem__(self, item): return getitem(self, item) ``` Args: obj: Dataclass instance to slice / gather. item: One of the supported tensor slicing / gathering values. keepdims: Dimensions that will not be removed during slicing. When selecting a single slice, these dims will remain with size 1. Returns: Slice of `obj` of same type. """ assert dataclasses.is_dataclass(obj), f"obj must be a dataclass but got {type(obj)}" item = slicing_dict(obj, item) if keepdims: keep = shape(obj).only(keepdims) for dim, sel in item.items(): if dim in keep: if isinstance(sel, int): item[dim] = slice(sel, sel+1) elif isinstance(sel, str) and ',' not in sel: item[dim] = [sel] if not item: return obj attrs = data_fields(obj) kwargs = {f.name: slice_(getattr(obj, f.name), item) if f in attrs else getattr(obj, f.name) for f in dataclasses.fields(obj)} cls = type(obj) new_obj = cls.__new__(cls, **kwargs) new_obj.__init__(**kwargs) cache = {k: slice_(v, item) for k, v in obj.__dict__.items() if isinstance(getattr(type(obj), k, None), cached_property) and not isinstance(v, SHAPE_TYPES)} new_obj.__dict__.update(cache) return new_objSlice / gather a dataclass by broadcasting the operation to its data_fields.
You may call this from
__getitem__to allow the syntaxmy_class[component_str],my_class[slicing_dict],my_class[boolean_tensor]andmy_class[index_tensor].def __getitem__(self, item): return getitem(self, item)Args
obj- Dataclass instance to slice / gather.
item- One of the supported tensor slicing / gathering values.
keepdims- Dimensions that will not be removed during slicing. When selecting a single slice, these dims will remain with size 1.
Returns
Slice of
objof same type. def non_data_fields(obj) ‑> Sequence[dataclasses.Field]-
Expand source code
def non_data_fields(obj) -> Sequence[dataclasses.Field]: """ List all dataclass Fields of `obj` that cannot hold tensors (directly or indirectly). Args: obj: Dataclass type or instance. Returns: Sequence of `dataclasses.Field`. """ return [f for f in dataclasses.fields(obj) if not is_data_field(f)]List all dataclass Fields of
objthat cannot hold tensors (directly or indirectly).Args
obj- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field. def replace(obj: ~PhiMLDataclass, /, call_metaclass=False, **changes) ‑> ~PhiMLDataclass-
Expand source code
def replace(obj: PhiMLDataclass, /, call_metaclass=False, **changes) -> PhiMLDataclass: """ Create a copy of `obj` with some fields replaced. Unlike `dataclasses.replace()`, this function also transfers `@cached_property` members if their dependencies are not affected. Args: obj: Dataclass instance. call_metaclass: Whether to copy `obj` by invoking `type(obj).__call__`. If `obj` defines a metaclass, this will allow users to define custom constructors for dataclasses. **changes: New field values to replace old ones. Returns: Copy of `obj` with replaced values. """ cls = obj.__class__ kwargs = {f.name: getattr(obj, f.name) for f in dataclasses.fields(obj)} kwargs.update(**changes) if call_metaclass: new_obj = cls(**kwargs) else: # This allows us override the dataclass constructor with a metaclass for user convenience, but not call it internally. new_obj = cls.__new__(cls) new_obj.__init__(**kwargs) cache = get_unchanged_cache(obj, set(changes.keys())) new_obj.__dict__.update(cache) return new_objCreate a copy of
objwith some fields replaced. Unlikedataclasses.replace(), this function also transfers@cached_propertymembers if their dependencies are not affected.Args
obj- Dataclass instance.
call_metaclass- Whether to copy
objby invokingtype(obj).__call__. Ifobjdefines a metaclass, this will allow users to define custom constructors for dataclasses. **changes- New field values to replace old ones.
Returns
Copy of
objwith replaced values. def sliceable(cls=None,
/,
*,
dim_attrs=True,
t_props=True,
keepdims=None,
dim_repr=True,
lazy_dims=True)-
Expand source code
def sliceable(cls=None, /, *, dim_attrs=True, t_props=True, keepdims=None, dim_repr=True, lazy_dims=True): """ Decorator for frozen dataclasses, adding slicing functionality by defining `__getitem__` and enabling the `instance.dim` syntax. This enables slicing similar to tensors, gathering and boolean masking. Args: dim_attrs: Whether to generate `__getattr__` that allows slicing via the syntax `instance.dim[...]` where `dim` is the name of any dim present on `instance`. t_props: Whether to generate the properties `Tc`, `Ts` and `Ti` for transposing channel/spatial/instance dims. keepdims: Which dimensions should be kept with size 1 taking a single slice along them. This will preserve labels. dim_repr: Whether to replace the default `repr` of a dataclass by a simplified one based on the object's shape. lazy_dims: If `False`, instantiates all dims of `shape(self)` as member variables during construction. Dataclass must have `slots=False`. If `True`, implements `__getattr__` to instantiate accessed dims on demand. This will be skipped if a user-defined `__getattr__` is found. """ def wrap(cls): assert dataclasses.is_dataclass(cls), f"@sliceable must be used on a @dataclass, i.e. declared above it." assert cls.__dataclass_params__.frozen, f"@sliceable dataclasses must be frozen. Declare as @dataclass(frozen=True)" assert data_fields(cls), f"PhiML dataclasses must have at least one field storing a Shaped object, such as a Tensor, tree of Tensors or compatible dataclass." if not implements(cls, '__getitem__', exclude_metaclass=True): def __dataclass_getitem__(obj, item): return getitem(obj, item, keepdims=keepdims) cls.__getitem__ = __dataclass_getitem__ if t_props: def transpose(obj, dim_type): old_shape = shape(obj) new_shape = old_shape.transpose(dim_type) return rename_dims(obj, old_shape, new_shape) cls.Tc = property(partial(transpose, dim_type=CHANNEL_DIM)) cls.Ts = property(partial(transpose, dim_type=SPATIAL_DIM)) cls.Ti = property(partial(transpose, dim_type=INSTANCE_DIM)) if not lazy_dims: # instantiate BoundDims in constructor assert not hasattr(cls, '__slots__'), f"front-loading dims is not supported for dataclasses using slots." dc_init = cls.__init__ def __dataclass_init__(self, *args, **kwargs): dc_init(self, *args, **kwargs) for dim in shape(self): object.__setattr__(self, dim.name, BoundDim(self, dim.name)) # object.__setattr__ also works for frozen dataclasses cls.__init__ = __dataclass_init__ else: # instantiate BoundDims lazily via __getattr__ if dim_attrs and not implements(cls, '__getattr__'): def __dataclass_getattr__(obj, name: str): if name == 'shape' or (name.startswith('__') and name.endswith('__')): # __setstate__, __deepcopy__ can cause infinite recursion raise AttributeError(f"'{type(obj)}' instance has no attribute '{name}'") if name in shape(obj): return BoundDim(obj, name) elif hasattr(type(obj), name): raise RuntimeError(f"Evaluation of property '{type(obj).__name__}.{name}' failed.") else: raise AttributeError(f"'{type(obj)}' instance has no attribute '{name}'") cls.__getattr__ = __dataclass_getattr__ if dim_repr: def __dataclass_repr__(obj): try: content = shape(obj) if not content: content = f"{', '.join([f'{f.name}={getattr(obj, f.name)}' for f in dataclasses.fields(cls)])}" except BaseException as err: content = f"Unknown shape: {type(err).__name__}" return f"{type(obj).__name__}[{content}]" cls.__repr__ = __dataclass_repr__ return cls return wrap(cls) if cls is not None else wrap # See if we're being called as @dataclass or @dataclass().Decorator for frozen dataclasses, adding slicing functionality by defining
__getitem__and enabling theinstance.dimsyntax. This enables slicing similar to tensors, gathering and boolean masking.Args
dim_attrs- Whether to generate
__getattr__that allows slicing via the syntaxinstance.dim[…]wheredimis the name of any dim present oninstance. t_props- Whether to generate the properties
Tc,TsandTifor transposing channel/spatial/instance dims. keepdims- Which dimensions should be kept with size 1 taking a single slice along them. This will preserve labels.
dim_repr- Whether to replace the default
reprof a dataclass by a simplified one based on the object's shape. lazy_dims- If
False, instantiates all dims ofshape(self)as member variables during construction. Dataclass must haveslots=False. IfTrue, implements__getattr__to instantiate accessed dims on demand. This will be skipped if a user-defined__getattr__is found.
def special_fields(obj) ‑> Sequence[dataclasses.Field]-
Expand source code
def special_fields(obj) -> Sequence[dataclasses.Field]: """ List all special dataclass Fields of `obj`, i.e. fields that don't store data related to the object but rather meta-information relevant to PhiML. These include `variable_attrs` and `value_attrs`. Args: obj: Dataclass type or instance. Returns: Sequence of `dataclasses.Field`. """ return [f for f in dataclasses.fields(obj) if f.name in ('variable_attrs', 'value_attrs')]List all special dataclass Fields of
obj, i.e. fields that don't store data related to the object but rather meta-information relevant to PhiML.These include
variable_attrsandvalue_attrs.Args
obj- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field.
Classes
class cached_property (func)-
Expand source code
def __get__(self, instance, owner=None): if instance is None: return self if self.attrname is None: raise TypeError( "Cannot use cached_property instance without calling __set_name__ on it.") try: cache = instance.__dict__ except AttributeError: # not all objects have __dict__ (e.g. class defines slots) msg = ( f"No '__dict__' attribute on {type(instance).__name__!r} " f"instance to cache {self.attrname!r} property." ) raise TypeError(msg) from None val = cache.get(self.attrname, _NOT_FOUND) if val is _NOT_FOUND: with self.lock: # check if another thread filled cache while we awaited lock val = cache.get(self.attrname, _NOT_FOUND) if val is _NOT_FOUND: val = self.func(instance) try: cache[self.attrname] = val except TypeError: msg = ( f"The '__dict__' attribute on {type(instance).__name__!r} instance " f"does not support item assignment for caching {self.attrname!r} property." ) raise TypeError(msg) from None return valSubclasses
- phiml.parallel._parallel.ParallelProperty