4.3. Weight Model
The weight_model
module contains classes for calculating the direction weight of the final gradient vector.
The EqualWeight
class is the default weight model for the ConFIG algorithm. You can create a custom weight model by inheriting from the WeightModel
class.
Weight Model¤
conflictfree.weight_model.EqualWeight
¤
Bases: WeightModel
A weight model that assigns equal weights to all gradients.
Source code in conflictfree/weight_model.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
|
__init__
¤
__init__()
Source code in conflictfree/weight_model.py
34 35 |
|
get_weights
¤
get_weights(
gradients: torch.Tensor,
losses: Optional[Sequence] = None,
device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor
Calculate the weights for the given gradients.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gradients
|
Optional[Tensor]
|
The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients. |
required |
losses
|
Optional[Sequence]
|
The losses. Not used in this model. |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: A tensor of equal weights for all gradients. |
Raises:
Type | Description |
---|---|
ValueError
|
If gradients is None. |
Source code in conflictfree/weight_model.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
|
Base Class of Weight Model¤
conflictfree.weight_model.WeightModel
¤
Base class for weight models.
Source code in conflictfree/weight_model.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
|
get_weights
¤
get_weights(
gradients: Optional[torch.Tensor] = None,
losses: Optional[Sequence] = None,
)
summary
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gradients
|
Optional[Tensor]
|
The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients. |
None
|
losses
|
Optional[Sequence]
|
The losses. |
None
|
Raises:
Type | Description |
---|---|
NotImplementedError
|
description |
Source code in conflictfree/weight_model.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
|