Skip to content

4.4. Length Model

The length_model module contains classes for rescaling the magnitude of the final gradient vector. The ProjectionLength class is the default length model for the ConFIG algorithm. You can create a custom length model by inheriting from the LengthModel class.

Length Model¤

conflictfree.length_model.ProjectionLength ¤

Bases: LengthModel

Rescale the length of the target vector based on the projection of the gradients on the target vector:

\[ |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_i|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]
Source code in conflictfree/length_model.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class ProjectionLength(LengthModel):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector:

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_i|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$
    """

    def __init__(self):
        super().__init__()

    def get_length(
        self,
        target_vector: Optional[torch.Tensor] = None,
        unit_target_vector: Optional[torch.Tensor] = None,
        gradients: Optional[torch.Tensor] = None,
        losses: Optional[Sequence] = None,
    ) -> torch.Tensor:
        """
        Calculates the length based on the given parameters. Not all parameters are required.

        Args:
            target_vector (Optional[torch.Tensor]): The final update gradient vector.
                One of the `target_vector` or `unit_target_vector` parameter need to be provided.
            unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
                One of the `target_vector` or `unit_target_vector` parameter need to be provided.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
            losses (Optional[Sequence]): The losses. Not used in this model.

        Returns:
            Union[torch.Tensor, float]: The calculated length.
        """
        assert gradients is not None, "The ProjectionLength model requires gradients information."
        if unit_target_vector is None:
            unit_target_vector = unit_vector(target_vector)
        return torch.sum(
            torch.stack([torch.dot(grad_i, unit_target_vector)
                        for grad_i in gradients])
        )
rescale_length ¤
rescale_length(
    target_vector: Tensor,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
80
81
def __init__(self):
    super().__init__()
get_length ¤
get_length(
    target_vector: Optional[Tensor] = None,
    unit_target_vector: Optional[Tensor] = None,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectionLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    return torch.sum(
        torch.stack([torch.dot(grad_i, unit_target_vector)
                    for grad_i in gradients])
    )

conflictfree.length_model.TrackMinimum ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the same length as the minimum gradient before projection, i.e., the minimum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_{min}|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]
Source code in conflictfree/length_model.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class TrackMinimum(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the same length as the minimum gradient before projection, i.e., the minimum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_{min}|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms.min()
get_length ¤
get_length(
    target_vector: Optional[Tensor] = None,
    unit_target_vector: Optional[Tensor] = None,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: Tensor,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
171
172
def __init__(self):
    super().__init__()
_tracked_value ¤
_tracked_value(grad_norms: Tensor) -> Tensor
Source code in conflictfree/length_model.py
174
175
def _tracked_value(self, grad_norms: Tensor) -> Tensor:
    return grad_norms.min()

conflictfree.length_model.TrackMaximum ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the same length as the maximum gradient before projection, i.e., the maximum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_{max}|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]
Source code in conflictfree/length_model.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class TrackMaximum(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the same length as the maximum gradient before projection, i.e., the maximum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_{max}|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms.max()
get_length ¤
get_length(
    target_vector: Optional[Tensor] = None,
    unit_target_vector: Optional[Tensor] = None,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: Tensor,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
188
189
def __init__(self):
    super().__init__()
_tracked_value ¤
_tracked_value(grad_norms: Tensor) -> Tensor
Source code in conflictfree/length_model.py
191
192
def _tracked_value(self, grad_norms: Tensor) -> Tensor:
    return grad_norms.max()

conflictfree.length_model.TrackHarmonicAverage ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the harmonic average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{harm}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]

where

\[ \overline{|\mathbf{g}|}_{harm}= rac{m}{\sum_{i=1}^m rac{1}{|\mathbf{g}_i|}} \]

The harmonic average can be used to avoid the influence of the large gradients.

Source code in conflictfree/length_model.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
class TrackHarmonicAverage(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the harmonic average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{harm}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$

    where

    $$
    \overline{|\mathbf{g}|}_{harm}=\frac{m}{\sum_{i=1}^m \frac{1}{|\mathbf{g}_i|}}
    $$

    The harmonic average can be used to avoid the influence of the large gradients.
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms.shape[0] / torch.sum(1 / grad_norms)
get_length ¤
get_length(
    target_vector: Optional[Tensor] = None,
    unit_target_vector: Optional[Tensor] = None,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: Tensor,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
213
214
def __init__(self):
    super().__init__()
_tracked_value ¤
_tracked_value(grad_norms: Tensor) -> Tensor
Source code in conflictfree/length_model.py
216
217
def _tracked_value(self, grad_norms: Tensor) -> Tensor:
    return grad_norms.shape[0] / torch.sum(1 / grad_norms)

conflictfree.length_model.TrackArithmeticAverage ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the arithmetic average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{arith}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]

where

\[ \overline{|\mathbf{g}|}_{arith}= rac{1}{m}\sum_{i=1}^m |\mathbf{g}_i| \]
Source code in conflictfree/length_model.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class TrackArithmeticAverage(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the arithmetic average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{arith}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$

    where

    $$
    \overline{|\mathbf{g}|}_{arith}=\frac{1}{m}\sum_{i=1}^m |\mathbf{g}_i|
    $$
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms.mean()
get_length ¤
get_length(
    target_vector: Optional[Tensor] = None,
    unit_target_vector: Optional[Tensor] = None,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: Tensor,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
236
237
def __init__(self):
    super().__init__()
_tracked_value ¤
_tracked_value(grad_norms: Tensor) -> Tensor
Source code in conflictfree/length_model.py
239
240
def _tracked_value(self, grad_norms: Tensor) -> Tensor:
    return grad_norms.mean()

conflictfree.length_model.TrackGeometricAverage ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the geometric average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{geom}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]

where

\[ \overline{|\mathbf{g}|}_{geom}=\left(\prod_{i=1}^m |\mathbf{g}_i| ight)^{ rac{1}{m}} \]

The geometric average can be used to avoid the influence of the large gradients.

Source code in conflictfree/length_model.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
class TrackGeometricAverage(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the geometric average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{geom}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$

    where

    $$
    \overline{|\mathbf{g}|}_{geom}=\left(\prod_{i=1}^m |\mathbf{g}_i|\right)^{\frac{1}{m}}
    $$

    The geometric average can be used to avoid the influence of the large gradients.
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return torch.prod(grad_norms) ** (1 / grad_norms.shape[0])
get_length ¤
get_length(
    target_vector: Optional[Tensor] = None,
    unit_target_vector: Optional[Tensor] = None,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: Tensor,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
261
262
def __init__(self):
    super().__init__()
_tracked_value ¤
_tracked_value(grad_norms: Tensor) -> Tensor
Source code in conflictfree/length_model.py
264
265
def _tracked_value(self, grad_norms: Tensor) -> Tensor:
    return torch.prod(grad_norms) ** (1 / grad_norms.shape[0])

conflictfree.length_model.TrackSpecific ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the same length as the specific gradient before projection. E.g., if the track_id is 2, then all the gradients will be rescaled to the same length as the third gradient before projection.

\[ |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{track_id}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]
Source code in conflictfree/length_model.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
class TrackSpecific(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the same length as the specific gradient before projection.
    E.g., if the track_id is 2, then all the gradients will be rescaled to the same length as the third gradient before projection.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{track_id}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$

    """

    def __init__(self, track_id: int):
        super().__init__()
        self.track_id = track_id

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms[self.track_id]
track_id instance-attribute ¤
track_id = track_id
get_length ¤
get_length(
    target_vector: Optional[Tensor] = None,
    unit_target_vector: Optional[Tensor] = None,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: Tensor,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__(track_id: int)
Source code in conflictfree/length_model.py
280
281
282
def __init__(self, track_id: int):
    super().__init__()
    self.track_id = track_id
_tracked_value ¤
_tracked_value(grad_norms: Tensor) -> Tensor
Source code in conflictfree/length_model.py
284
285
def _tracked_value(self, grad_norms: Tensor) -> Tensor:
    return grad_norms[self.track_id]

Base Class of Length Model¤

conflictfree.length_model.LengthModel ¤

The base class for length model.

Methods:

Name Description
get_length

Calculates the length based on the given parameters.

rescale_length

Rescales the length of the target vector based on the given parameters.

Source code in conflictfree/length_model.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class LengthModel:
    """
    The base class for length model.

    Methods:
        get_length: Calculates the length based on the given parameters.
        rescale_length: Rescales the length of the target vector based on the given parameters.
    """

    def __init__(self):
        pass

    def get_length(
        self,
        target_vector: Optional[torch.Tensor] = None,
        unit_target_vector: Optional[torch.Tensor] = None,
        gradients: Optional[torch.Tensor] = None,
        losses: Optional[Sequence] = None,
    ) -> Union[torch.Tensor, float]:
        """
        Calculates the length based on the given parameters. Not all parameters are required.

        Args:
            target_vector (Optional[torch.Tensor]): The final update gradient vector.
            unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
            losses (Optional[Sequence]): The losses.

        Returns:
            Union[torch.Tensor, float]: The calculated length.
        """
        raise NotImplementedError(
            "This method must be implemented by the subclass.")

    def rescale_length(
        self,
        target_vector: torch.Tensor,
        gradients: Optional[torch.Tensor] = None,
        losses: Optional[Sequence] = None,
    ) -> torch.Tensor:
        """
        Rescales the length of the target vector based on the given parameters.
        It calls the get_length method to calculate the length and then rescales the target vector.

        Args:
            target_vector (torch.Tensor): The final update gradient vector.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
            losses (Optional[Sequence]): The losses.

        Returns:
            torch.Tensor: The rescaled target vector.
        """
        unit_target_vector = unit_vector(target_vector)
        return (
            self.get_length(
                target_vector=target_vector,
                unit_target_vector=unit_target_vector,
                gradients=gradients,
                losses=losses,
            )
            * unit_target_vector
        )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
16
17
def __init__(self):
    pass
get_length ¤
get_length(
    target_vector: Optional[Tensor] = None,
    unit_target_vector: Optional[Tensor] = None,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> Union[torch.Tensor, float]

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Union[Tensor, float]

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> Union[torch.Tensor, float]:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    raise NotImplementedError(
        "This method must be implemented by the subclass.")
rescale_length ¤
rescale_length(
    target_vector: Tensor,
    gradients: Optional[Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )