Module phiml.dataclasses
PhiML makes it easy to work with custom classes.
Any class decorated with @dataclass
can be used with phiml.math
functions, such as shape()
, slice()
, stack()
, concat()
, expand()
and many more.
We recommend always setting frozen=True
.
PhiML's dataclass support explicitly handles properties defined decorated with functools.cached_property
.
Their cached values will be preserved whenever possible, preventing costly re-computation when modifying unrelated properties, slicing, gathering, or stacking objects.
This will usually affect all data fields, i.e. fields that hold Tensor
or composite properties.
Dataclass fields can additionally be specified as being variable and value.
This affects which data_fields are optimized / traced by functions like jit_compile()
or minimize()
.
Template for custom classes:
>>> from typing import Tuple from dataclasses import dataclass
>>> from phiml.dataclasses import sliceable, cached_property
>>> from phiml.math import Tensor, Shape, shape
>>>
>>> @sliceable
>>> @dataclass(frozen=True)
>>> class MyClass:
>>> # --- Attributes ---
>>> attribute1: Tensor
>>> attribute2: 'MyClass' = None
>>>
>>> # --- Additional fields ---
>>> field1: str = 'x'
>>>
>>> # --- Special fields declaring attribute types. Must be of type Tuple[str, ...] ---
>>> variable_attrs: Tuple[str, ...] = ('attribute1', 'attribute2')
>>> value_attrs: Tuple[str, ...] = ()
>>>
>>> def __post_init__(self):
>>> assert self.field1 in 'xyz'
>>>
>>> @cached_property
>>> def shape(self) -> Shape: # override the default shape which is merged from all attribute shapes
>>> return self.attribute1.shape & shape(self.attribute2)
>>>
>>> @cached_property # the cache will be copied to derived instances unless attribute1 changes (this is analyzed from the code)
>>> def derived_property(self) -> Tensor:
>>> return self.attribute1 + 1
Functions
def config_fields(obj) ‑> Sequence[dataclasses.Field]
-
List all dataclass Fields of
obj
that are not considered data_fields or special. These cannot hold any Tensors or shaped objects.Args
obj
- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field
. def copy(obj: ~PhiMLDataclass, /, call_metaclass=False) ‑> ~PhiMLDataclass
-
Create a copy of
obj
, including cached properties.Args
obj
- Dataclass instance.
call_metaclass
- Whether to copy
obj
by invokingtype(obj).__call__
. Ifobj
defines a metaclass, this will allow users to define custom constructors for dataclasses.
def data_eq(cls=None, /, *, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True, compare_tensors_by_ref=False)
def data_fields(obj) ‑> Sequence[dataclasses.Field]
-
List all dataclass Fields of
obj
that are considered data, i.e. can hold (directly or indirectly) one or multipleTensor
instances. This includes fields referencing other dataclasses.Args
obj
- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field
. def equal(obj1, obj2, rel_tolerance=0.0, abs_tolerance=0.0, equal_nan=True)
-
Checks if two
Args
obj1: obj2: rel_tolerance: abs_tolerance: equal_nan: Returns:
def get_cache_files(obj) ‑> Set[str]
-
Searches the data structure for all disk-cached tensors and returns all referenced files.
Args
obj
Tensor
or pytree or dataclass (PhiTreeNode
).
Returns
Collection of file paths.
def getitem(obj: ~PhiMLDataclass, item, keepdims: Union[str, Sequence[+T_co], set, phiml.math._shape.Shape, Callable, None] = None) ‑> ~PhiMLDataclass
-
Slice / gather a dataclass by broadcasting the operation to its data_fields.
You may call this from
__getitem__
to allow the syntaxmy_class[component_str]
,my_class[slicing_dict]
,my_class[boolean_tensor]
andmy_class[index_tensor]
.def __getitem__(self, item): return getitem(self, item)
Args
obj
- Dataclass instance to slice / gather.
item
- One of the supported tensor slicing / gathering values.
keepdims
- Dimensions that will not be removed during slicing. When selecting a single slice, these dims will remain with size 1.
Returns
Slice of
obj
of same type. def non_data_fields(obj) ‑> Sequence[dataclasses.Field]
-
List all dataclass Fields of
obj
that cannot hold tensors (directly or indirectly).Args
obj
- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field
. def on_load_into_memory(callback: Callable[[phiml.math._tensors.Tensor], None])
-
Register a function to be called every time a cached tensor is loaded from disk into memory in this process.
Tensors are stored on disk to save memory, e.g. in
parallel_compute()
if a memory limit is specified.When accessing the tensor values or using them in any computation, the values are temporarily loaded into memory and
callback
is called.Args
callback
- Function
callback(Tensor)
.
def parallel_compute(instance, properties: Sequence[+T_co], parallel_dims=<function batch>, max_workers=4, memory_limit: Optional[float] = None, cache_dir: str = None, keep_intermediate=False)
-
Compute the values of properties decorated with
@cached_property
or@parallel_property
of a dataclass instance in parallel.Multiple stages via
requires=...
If@parallel_property
are computed whoserequires
overlaps withparallel_dims
, a separate computation stage is set up to compute these properties with fewer parallel workers. In the presence of differentrequires
, the computation is split into different stages in accordance with the property dependency graph. Properties that cannot be parallelized (because it requires allparallel_dims
) are computed on the host process.Caching tensors to disk When
memory_limit
andcache_dir
are set, the evaluation will try to adhere to the given memory limits by moving tensors out of memory onto disk. This is only applied to the outputs of@cached_property
and@parallel_property
calls, not to intermediate values used in their computation. The per-process memory limit is calculated per stage, dividing the total memory by the active worker count. Cached tensors behave like regular tensors and are temporarily loaded back into memory when accessing their values or using them in a computation. When parallelizing, the full result is assembled by stacking multiple disk-backed tensors from different files created by different processes. These composite tensors will reference multiple binary files and can be pickled/unpickled safely without loading the data into memory. This enables passing large data references to different processes or saving the structure to a file without the data content.See Also:
cached_property
,parallel_property()
,get_cache_files()
,on_load_into_memory()
.Warnings
parallel_compute()
breaks automatic differentiation.Args
instance
- Dataclass instance for which to compute the values of
@cached_property
or@parallel_property
fields. properties
- References to the unbound properties. These must be
cached_property
orparallel_property()
. parallel_dims
- Dimensions to parallelize over.
max_workers
- Number of processes to spawn.
memory_limit
- Limit to the total memory consumption from
Tensor
instances on property outputs. cache_dir
- Directory path to store cached tensors in if
memory_limit
is set. keep_intermediate
- Whether the outputs of cached properties required to compute
properties
but not contained inproperties
should be kept in memory. IfFalse
, these values will not be cached oninstance
after this call.
def parallel_property(func: Callable = None, /, requires: Union[str, Sequence[+T_co], set, phiml.math._shape.Shape, Callable, None] = None, on_direct_eval='raise')
-
Similar to
@cached_property
but with additional controls over parallelization.See Also:
parallel_compute()
.Args
func
- Method to wrap.
requires
- Dimensions which must be present within one process. These cannot be parallelized when computing this property.
on_direct_eval
-
What to do when the property is accessed normally (outside
parallel_compute()
) before it has been computed. Option:'raise'
: Raise an error.'host-compute'
: Compute the property directly, without using multi-threading.
def replace(obj: ~PhiMLDataclass, /, call_metaclass=False, **changes) ‑> ~PhiMLDataclass
-
Create a copy of
obj
with some fields replaced. Unlikedataclasses.replace()
, this function also transfers@cached_property
members if their dependencies are not affected.Args
obj
- Dataclass instance.
call_metaclass
- Whether to copy
obj
by invokingtype(obj).__call__
. Ifobj
defines a metaclass, this will allow users to define custom constructors for dataclasses. **changes
- New field values to replace old ones.
Returns
Copy of
obj
with replaced values. def set_cache_ttl(ttl_seconds: Optional[float])
-
Sets the time to live (TTL) for data loaded into memory for disk-backed tensors. This function should be called before the tensor cache is used.
Args
ttl_seconds
- Time to live. If
None
, data will be unallocated immediately after use.
def sliceable(cls=None, /, *, dim_attrs=True, t_props=True, keepdims=None, dim_repr=True, lazy_dims=True)
-
Decorator for frozen dataclasses, adding slicing functionality by defining
__getitem__
and enabling theinstance.dim
syntax. This enables slicing similar to tensors, gathering and boolean masking.Args
dim_attrs
- Whether to generate
__getattr__
that allows slicing via the syntaxinstance.dim[…]
wheredim
is the name of any dim present oninstance
. t_props
- Whether to generate the properties
Tc
,Ts
andTi
for transposing channel/spatial/instance dims. keepdims
- Which dimensions should be kept with size 1 taking a single slice along them. This will preserve labels.
dim_repr
- Whether to replace the default
repr
of a dataclass by a simplified one based on the object's shape. lazy_dims
- If
False
, instantiates all dims ofshape(self)
as member variables during construction. Dataclass must haveslots=False
. IfTrue
, implements__getattr__
to instantiate accessed dims on demand. This will be skipped if a user-defined__getattr__
is found.
def special_fields(obj) ‑> Sequence[dataclasses.Field]
-
List all special dataclass Fields of
obj
, i.e. fields that don't store data related to the object but rather meta-information relevant to PhiML.These include
variable_attrs
andvalue_attrs
.Args
obj
- Dataclass type or instance.
Returns
Sequence of
dataclasses.Field
.
Classes
class cached_property (func)
-
Expand source code
def __get__(self, instance, owner=None): if instance is None: return self if self.attrname is None: raise TypeError( "Cannot use cached_property instance without calling __set_name__ on it.") try: cache = instance.__dict__ except AttributeError: # not all objects have __dict__ (e.g. class defines slots) msg = ( f"No '__dict__' attribute on {type(instance).__name__!r} " f"instance to cache {self.attrname!r} property." ) raise TypeError(msg) from None val = cache.get(self.attrname, _NOT_FOUND) if val is _NOT_FOUND: with self.lock: # check if another thread filled cache while we awaited lock val = cache.get(self.attrname, _NOT_FOUND) if val is _NOT_FOUND: val = self.func(instance) try: cache[self.attrname] = val except TypeError: msg = ( f"The '__dict__' attribute on {type(instance).__name__!r} instance " f"does not support item assignment for caching {self.attrname!r} property." ) raise TypeError(msg) from None return val
Subclasses
- phiml.dataclasses._parallel.ParallelProperty