Skip to content

3.3. Weight Model

The weight_model module contains classes for calculating the direction weight of the final gradient vector. The EqualWeight class is the default weight model for the ConFIG algorithm. You can create a custom weight model by inheriting from the WeightModel class.

Weight Model¤

conflictfree.weight_model.EqualWeight ¤

Bases: WeightModel

A weight model that assigns equal weights to all gradients.

Source code in conflictfree/weight_model.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class EqualWeight(WeightModel):
    """
    A weight model that assigns equal weights to all gradients.
    """

    def __init__(self):
        super().__init__()

    def get_weights(self, gradients: torch.Tensor, losses: Optional[Sequence] = None,device:Optional[Union[torch.device,str]]=None) -> torch.Tensor:
        """
        Calculate the weights for the given gradients.

        Args:
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
            losses (Optional[Sequence]): The losses. Not used in this model.

        Returns:
            torch.Tensor: A tensor of equal weights for all gradients.

        Raises:
            ValueError: If gradients is None.
        """
        if gradients is None:
            raise ValueError("The EqualWeight model requires gradients.")
        return torch.ones(gradients.shape[0],device=device)
__init__ ¤
__init__()
Source code in conflictfree/weight_model.py
27
28
def __init__(self):
    super().__init__()
get_weights ¤
get_weights(
    gradients: torch.Tensor,
    losses: Optional[Sequence] = None,
    device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor

Calculate the weights for the given gradients.

Parameters:

Name Type Description Default
gradients Optional[Tensor]

The loss-specific gradients matrix.

required
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

torch.Tensor: A tensor of equal weights for all gradients.

Raises:

Type Description
ValueError

If gradients is None.

Source code in conflictfree/weight_model.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def get_weights(self, gradients: torch.Tensor, losses: Optional[Sequence] = None,device:Optional[Union[torch.device,str]]=None) -> torch.Tensor:
    """
    Calculate the weights for the given gradients.

    Args:
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        torch.Tensor: A tensor of equal weights for all gradients.

    Raises:
        ValueError: If gradients is None.
    """
    if gradients is None:
        raise ValueError("The EqualWeight model requires gradients.")
    return torch.ones(gradients.shape[0],device=device)

Base Class of Weight Model¤

conflictfree.weight_model.WeightModel ¤

Base class for weight models.

Source code in conflictfree/weight_model.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class WeightModel:
    """
    Base class for weight models. 
    """
    def __init__(self):
        pass

    def get_weights(self, gradients:Optional[torch.Tensor]=None,losses:Optional[Sequence]=None):
        """_summary_

        Args:
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
            losses (Optional[Sequence]): The losses.

        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError("This method must be implemented by the subclass.")
__init__ ¤
__init__()
Source code in conflictfree/weight_model.py
7
8
def __init__(self):
    pass
get_weights ¤
get_weights(
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
)

summary

Parameters:

Name Type Description Default
gradients Optional[Tensor]

The loss-specific gradients matrix.

None
losses Optional[Sequence]

The losses.

None

Raises:

Type Description
NotImplementedError

description

Source code in conflictfree/weight_model.py
10
11
12
13
14
15
16
17
18
19
20
def get_weights(self, gradients:Optional[torch.Tensor]=None,losses:Optional[Sequence]=None):
    """_summary_

    Args:
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
        losses (Optional[Sequence]): The losses.

    Raises:
        NotImplementedError: _description_
    """
    raise NotImplementedError("This method must be implemented by the subclass.")