Module phiml.dataclasses
PhiML makes it easy to work with custom classes.
Any class decorated with @dataclass
can be used with phiml.math
functions, such as shape()
, slice()
, stack()
, concat()
, expand()
and many more.
We recommend always setting frozen=True
.
PhiML's dataclass support explicitly handles properties defined decorated with functools.cached_property
.
Their cached values will be preserved whenever possible, preventing costly re-computation when modifying unrelated properties, slicing, gathering, or stacking objects.
This will usually affect all data fields, i.e. fields that hold Tensor
or composite properties.
Dataclass fields can additionally be specified as being variable and value.
This affects which data_fields are optimized / traced by functions like jit_compile()
or minimize()
.
Template for custom classes:
>>> from typing import Tuple from dataclasses import dataclass
>>> from phiml.dataclasses import sliceable, cached_property
>>> from phiml.math import Tensor, Shape, shape
>>>
>>> @sliceable
>>> @dataclass(frozen=True)
>>> class MyClass:
>>> # --- Attributes ---
>>> attribute1: Tensor
>>> attribute2: 'MyClass' = None
>>>
>>> # --- Additional fields ---
>>> field1: str = 'x'
>>>
>>> # --- Special fields declaring attribute types. Must be of type Tuple[str, ...] ---
>>> variable_attrs: Tuple[str, ...] = ('attribute1', 'attribute2')
>>> value_attrs: Tuple[str, ...] = ()
>>>
>>> def __post_init__(self):
>>> assert self.field1 in 'xyz'
>>>
>>> @cached_property
>>> def shape(self) -> Shape: # override the default shape which is merged from all attribute shapes
>>> return self.attribute1.shape & shape(self.attribute2)
>>>
>>> @cached_property # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code)
>>> def derived_property(self) -> Tensor:
>>> return self.attribute1 + 1
Functions
def config_fields(obj) ‑> Sequence[dataclasses.Field]
-
List all dataclass Fields of
obj
that are not considered data_fields or special. These cannot hold any Tensors or shaped objects.Args
obj
- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field
. def data_eq(cls=None, /, *, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True, compare_tensors_by_ref=False)
def data_fields(obj) ‑> Sequence[dataclasses.Field]
-
List all dataclass Fields of
obj
that are considered data, i.e. can hold (directly or indirectly) one or multipleTensor
instances. This includes fields referencing other dataclasses.Args
obj
- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field
. def equal(obj1, obj2, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True)
-
Checks if two
Args
obj1: obj2: rel_tolerance: abs_tolerance: equal_nan: Returns:
def getitem(obj: ~PhiMLDataclass, item, keepdims: Union[str, Sequence[+T_co], set, ForwardRef('Shape'), Callable, None] = None) ‑> ~PhiMLDataclass
-
Slice / gather a dataclass by broadcasting the operation to its data_fields.
You may call this from
__getitem__
to allow the syntaxmy_class[component_str]
,my_class[slicing_dict]
,my_class[boolean_tensor]
andmy_class[index_tensor]
.def __getitem__(self, item): return getitem(self, item)
Args
obj
- Dataclass instance to slice / gather.
item
- One of the supported tensor slicing / gathering values.
keepdims
- Dimensions that will not be removed during slicing. When selecting a single slice, these dims will remain with size 1.
Returns
Slice of
obj
of same type. def non_data_fields(obj) ‑> Sequence[dataclasses.Field]
-
List all dataclass Fields of
obj
that cannot hold tensors (directly or indirectly).Args
obj
- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field
. def replace(obj: ~PhiMLDataclass, /, call_metaclass=False, **changes) ‑> ~PhiMLDataclass
-
Create a copy of
obj
with some fields replaced. Unlikedataclasses.replace()
, this function also transfers@cached_property
members if their dependencies are not affected.Args
obj
- Dataclass instance.
call_metaclass
- Whether to copy
obj
by invokingtype(obj).__call__
. Ifobj
defines a metaclass, this will allow users to define custom constructors for dataclasses. **changes
- New field values to replace old ones.
Returns
Copy of
obj
with replaced values. def sliceable(cls=None, /, *, dim_attrs=True, keepdims=None, dim_repr=True)
-
Decorator for frozen dataclasses, adding slicing functionality by defining
__getitem__
. This enables slicing similar to tensors, gathering and boolean masking.Args
dim_attrs
- Whether to generate
__getattr__
that allows slicing via the syntaxinstance.dim[…]
wheredim
is the name of any dim present oninstance
. keepdims
- Which dimensions should be kept with size 1 taking a single slice along them. This will preserve item names.
dim_repr
- Whether to replace the default
repr
of a dataclass by a simplified one based on the object's shape.
def special_fields(obj) ‑> Sequence[dataclasses.Field]
-
List all special dataclass Fields of
obj
, i.e. fields that don't store data related to the object but rather meta-information relevant to PhiML.These include
variable_attrs
andvalue_attrs
.Args
obj
- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field
.
Classes
class cached_property (func)
-
Expand source code
def __get__(self, instance, owner=None): if instance is None: return self if self.attrname is None: raise TypeError( "Cannot use cached_property instance without calling __set_name__ on it.") try: cache = instance.__dict__ except AttributeError: # not all objects have __dict__ (e.g. class defines slots) msg = ( f"No '__dict__' attribute on {type(instance).__name__!r} " f"instance to cache {self.attrname!r} property." ) raise TypeError(msg) from None val = cache.get(self.attrname, _NOT_FOUND) if val is _NOT_FOUND: with self.lock: # check if another thread filled cache while we awaited lock val = cache.get(self.attrname, _NOT_FOUND) if val is _NOT_FOUND: val = self.func(instance) try: cache[self.attrname] = val except TypeError: msg = ( f"The '__dict__' attribute on {type(instance).__name__!r} instance " f"does not support item assignment for caching {self.attrname!r} property." ) raise TypeError(msg) from None return val