3.3. Weight Model
The weight_model
module contains classes for calculating the direction weight of the final gradient vector.
The EqualWeight
class is the default weight model for the ConFIG algorithm. You can create a custom weight model by inheriting from the WeightModel
class.
Weight Model¤
conflictfree.weight_model.EqualWeight
¤
Bases: WeightModel
A weight model that assigns equal weights to all gradients.
Source code in conflictfree/weight_model.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
__init__
¤
__init__()
Source code in conflictfree/weight_model.py
27 28 |
|
get_weights
¤
get_weights(
gradients: torch.Tensor,
losses: Optional[Sequence] = None,
device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor
Calculate the weights for the given gradients.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gradients |
Optional[Tensor]
|
The loss-specific gradients matrix. |
required |
losses |
Optional[Sequence]
|
The losses. Not used in this model. |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: A tensor of equal weights for all gradients. |
Raises:
Type | Description |
---|---|
ValueError
|
If gradients is None. |
Source code in conflictfree/weight_model.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
Base Class of Weight Model¤
conflictfree.weight_model.WeightModel
¤
Base class for weight models.
Source code in conflictfree/weight_model.py
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
|
get_weights
¤
get_weights(
gradients: Optional[torch.Tensor] = None,
losses: Optional[Sequence] = None,
)
summary
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gradients |
Optional[Tensor]
|
The loss-specific gradients matrix. |
None
|
losses |
Optional[Sequence]
|
The losses. |
None
|
Raises:
Type | Description |
---|---|
NotImplementedError
|
description |
Source code in conflictfree/weight_model.py
10 11 12 13 14 15 16 17 18 19 20 |
|