Skip to content

4.3. Weight Model

The weight_model module contains classes for calculating the direction weight of the final gradient vector. The EqualWeight class is the default weight model for the ConFIG algorithm. You can create a custom weight model by inheriting from the WeightModel class.

Weight Model¤

conflictfree.weight_model.EqualWeight ¤

Bases: WeightModel

A weight model that assigns equal weights to all gradients.

Source code in conflictfree/weight_model.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class EqualWeight(WeightModel):
    """
    A weight model that assigns equal weights to all gradients.
    """

    def __init__(self):
        super().__init__()

    def get_weights(
        self,
        gradients: torch.Tensor,
        losses: Optional[Sequence] = None,
        device: Optional[Union[torch.device, str]] = None,
    ) -> torch.Tensor:
        """
        Calculate the weights for the given gradients.

        Args:
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
            losses (Optional[Sequence]): The losses. Not used in this model.

        Returns:
            torch.Tensor: A tensor of equal weights for all gradients.

        Raises:
            ValueError: If gradients is None.
        """
        assert gradients is not None, "The EqualWeight model requires gradients"
        return torch.ones(gradients.shape[0], device=device)
__init__ ¤
__init__()
Source code in conflictfree/weight_model.py
34
35
def __init__(self):
    super().__init__()
get_weights ¤
get_weights(
    gradients: torch.Tensor,
    losses: Optional[Sequence] = None,
    device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor

Calculate the weights for the given gradients.

Parameters:

Name Type Description Default
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

required
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

torch.Tensor: A tensor of equal weights for all gradients.

Raises:

Type Description
ValueError

If gradients is None.

Source code in conflictfree/weight_model.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def get_weights(
    self,
    gradients: torch.Tensor,
    losses: Optional[Sequence] = None,
    device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor:
    """
    Calculate the weights for the given gradients.

    Args:
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        torch.Tensor: A tensor of equal weights for all gradients.

    Raises:
        ValueError: If gradients is None.
    """
    assert gradients is not None, "The EqualWeight model requires gradients"
    return torch.ones(gradients.shape[0], device=device)

Base Class of Weight Model¤

conflictfree.weight_model.WeightModel ¤

Base class for weight models.

Source code in conflictfree/weight_model.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class WeightModel:
    """
    Base class for weight models.
    """

    def __init__(self):
        pass

    def get_weights(
        self,
        gradients: Optional[torch.Tensor] = None,
        losses: Optional[Sequence] = None,
    ):
        """_summary_

        Args:
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
            losses (Optional[Sequence]): The losses.

        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError("This method must be implemented by the subclass.")
__init__ ¤
__init__()
Source code in conflictfree/weight_model.py
 9
10
def __init__(self):
    pass
get_weights ¤
get_weights(
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
)

summary

Parameters:

Name Type Description Default
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Raises:

Type Description
NotImplementedError

description

Source code in conflictfree/weight_model.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def get_weights(
    self,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
):
    """_summary_

    Args:
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Raises:
        NotImplementedError: _description_
    """
    raise NotImplementedError("This method must be implemented by the subclass.")