Skip to content

4.4. Length Model

The length_model module contains classes for rescaling the magnitude of the final gradient vector. The ProjectionLength class is the default length model for the ConFIG algorithm. You can create a custom length model by inheriting from the LengthModel class.

Length Model¤

conflictfree.length_model.ProjectionLength ¤

Bases: LengthModel

Rescale the length of the target vector based on the projection of the gradients on the target vector:

\[ |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_i|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]
Source code in conflictfree/length_model.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class ProjectionLength(LengthModel):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector:

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_i|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$
    """

    def __init__(self):
        super().__init__()

    def get_length(
        self,
        target_vector: Optional[torch.Tensor] = None,
        unit_target_vector: Optional[torch.Tensor] = None,
        gradients: Optional[torch.Tensor] = None,
        losses: Optional[Sequence] = None,
    ) -> torch.Tensor:
        """
        Calculates the length based on the given parameters. Not all parameters are required.

        Args:
            target_vector (Optional[torch.Tensor]): The final update gradient vector.
                One of the `target_vector` or `unit_target_vector` parameter need to be provided.
            unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
                One of the `target_vector` or `unit_target_vector` parameter need to be provided.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
            losses (Optional[Sequence]): The losses. Not used in this model.

        Returns:
            Union[torch.Tensor, float]: The calculated length.
        """
        assert gradients is not None, "The ProjectionLength model requires gradients information."
        if unit_target_vector is None:
            unit_target_vector = unit_vector(target_vector)
        return torch.sum(
            torch.stack([torch.dot(grad_i, unit_target_vector)
                        for grad_i in gradients])
        )
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
79
80
def __init__(self):
    super().__init__()
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectionLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    return torch.sum(
        torch.stack([torch.dot(grad_i, unit_target_vector)
                    for grad_i in gradients])
    )

conflictfree.length_model.TrackMinimum ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the same length as the minimum gradient before projection, i.e., the minimum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_{min}|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]
Source code in conflictfree/length_model.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class TrackMinimum(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the same length as the minimum gradient before projection, i.e., the minimum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_{min}|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms.min()
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
170
171
def __init__(self):
    super().__init__()

conflictfree.length_model.TrackMaximum ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the same length as the maximum gradient before projection, i.e., the maximum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_{max}|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]
Source code in conflictfree/length_model.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class TrackMaximum(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the same length as the maximum gradient before projection, i.e., the maximum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m|\mathbf{g}_{max}|\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms.max()
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
187
188
def __init__(self):
    super().__init__()

conflictfree.length_model.TrackHarmonicAverage ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the harmonic average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{harm}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]

where

\[ \overline{|\mathbf{g}|}_{harm}= rac{m}{\sum_{i=1}^m rac{1}{|\mathbf{g}_i|}} \]

The harmonic average can be used to avoid the influence of the large gradients.

Source code in conflictfree/length_model.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class TrackHarmonicAverage(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the harmonic average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{harm}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$

    where

    $$
    \overline{|\mathbf{g}|}_{harm}=\frac{m}{\sum_{i=1}^m \frac{1}{|\mathbf{g}_i|}}
    $$

    The harmonic average can be used to avoid the influence of the large gradients.
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms.shape[0] / torch.sum(1 / grad_norms)
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
212
213
def __init__(self):
    super().__init__()

conflictfree.length_model.TrackArithmeticAverage ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the arithmetic average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{arith}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]

where

\[ \overline{|\mathbf{g}|}_{arith}= rac{1}{m}\sum_{i=1}^m |\mathbf{g}_i| \]
Source code in conflictfree/length_model.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
class TrackArithmeticAverage(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the arithmetic average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{arith}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$

    where

    $$
    \overline{|\mathbf{g}|}_{arith}=\frac{1}{m}\sum_{i=1}^m |\mathbf{g}_i|
    $$
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms.mean()
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
235
236
def __init__(self):
    super().__init__()

conflictfree.length_model.TrackGeometricAverage ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the geometric average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

\[ |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{geom}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]

where

\[ \overline{|\mathbf{g}|}_{geom}=\left(\prod_{i=1}^m |\mathbf{g}_i| ight)^{ rac{1}{m}} \]

The geometric average can be used to avoid the influence of the large gradients.

Source code in conflictfree/length_model.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
class TrackGeometricAverage(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the geometric average of the lengths of all gradients before projection, i.e., the minimum gradient will be the same length as the target vector.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{geom}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$

    where

    $$
    \overline{|\mathbf{g}|}_{geom}=\left(\prod_{i=1}^m |\mathbf{g}_i|\right)^{\frac{1}{m}}
    $$

    The geometric average can be used to avoid the influence of the large gradients.
    """

    def __init__(self):
        super().__init__()

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return torch.prod(grad_norms) ** (1 / grad_norms.shape[0])
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
260
261
def __init__(self):
    super().__init__()

conflictfree.length_model.TrackSpecific ¤

Bases: _FlexibleTrackProjectionLength

Rescale the length of the target vector based on the projection of the gradients on the target vector. All the gradients will be rescaled to the same length as the specific gradient before projection. E.g., if the track_id is 2, then all the gradients will be rescaled to the same length as the third gradient before projection.

\[ |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{track_id}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c) \]
Source code in conflictfree/length_model.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
class TrackSpecific(_FlexibleTrackProjectionLength):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    All the gradients will be rescaled to the same length as the specific gradient before projection.
    E.g., if the track_id is 2, then all the gradients will be rescaled to the same length as the third gradient before projection.

    $$
    |\mathbf{g}_c|=\sum_{i=1}^m\overline{|\mathbf{g}|}_{track_id}\mathcal{S}_c(\mathbf{g}_i,\mathbf{g}_c)
    $$

    """

    def __init__(self, track_id: int):
        super().__init__()
        self.track_id = track_id

    def _tracked_value(self, grad_norms: Tensor) -> Tensor:
        return grad_norms[self.track_id]
track_id instance-attribute ¤
track_id = track_id
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    assert gradients is not None, "The ProjectLength model requires gradients information."
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    norms = torch.norm(gradients, dim=1)
    tracked_value = self._tracked_value(norms)
    return sum(
        [
            torch.dot(grad_i / norm_i, unit_target_vector) * tracked_value
            for grad_i, norm_i in zip(gradients, norms)
        ]
    )
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )
__init__ ¤
__init__(track_id: int)
Source code in conflictfree/length_model.py
279
280
281
def __init__(self, track_id: int):
    super().__init__()
    self.track_id = track_id

Base Class of Length Model¤

conflictfree.length_model.LengthModel ¤

The base class for length model.

Methods:

Name Description
get_length

Calculates the length based on the given parameters.

rescale_length

Rescales the length of the target vector based on the given parameters.

Source code in conflictfree/length_model.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class LengthModel:
    """
    The base class for length model.

    Methods:
        get_length: Calculates the length based on the given parameters.
        rescale_length: Rescales the length of the target vector based on the given parameters.
    """

    def __init__(self):
        pass

    def get_length(
        self,
        target_vector: Optional[torch.Tensor] = None,
        unit_target_vector: Optional[torch.Tensor] = None,
        gradients: Optional[torch.Tensor] = None,
        losses: Optional[Sequence] = None,
    ) -> Union[torch.Tensor, float]:
        """
        Calculates the length based on the given parameters. Not all parameters are required.

        Args:
            target_vector (Optional[torch.Tensor]): The final update gradient vector.
            unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
            losses (Optional[Sequence]): The losses.

        Returns:
            Union[torch.Tensor, float]: The calculated length.
        """
        raise NotImplementedError(
            "This method must be implemented by the subclass.")

    def rescale_length(
        self,
        target_vector: torch.Tensor,
        gradients: Optional[torch.Tensor] = None,
        losses: Optional[Sequence] = None,
    ) -> torch.Tensor:
        """
        Rescales the length of the target vector based on the given parameters.
        It calls the get_length method to calculate the length and then rescales the target vector.

        Args:
            target_vector (torch.Tensor): The final update gradient vector.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
            losses (Optional[Sequence]): The losses.

        Returns:
            torch.Tensor: The rescaled target vector.
        """
        unit_target_vector = unit_vector(target_vector)
        return (
            self.get_length(
                target_vector=target_vector,
                unit_target_vector=unit_target_vector,
                gradients=gradients,
                losses=losses,
            )
            * unit_target_vector
        )
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
15
16
def __init__(self):
    pass
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> Union[torch.Tensor, float]

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector.

None
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Union[Tensor, float]

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def get_length(
    self,
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> Union[torch.Tensor, float]:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    raise NotImplementedError(
        "This method must be implemented by the subclass.")
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def rescale_length(
    self,
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix. The shape of this tensor should be (m,N) where m is the number of gradients and N is the number of elements of each gradients.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return (
        self.get_length(
            target_vector=target_vector,
            unit_target_vector=unit_target_vector,
            gradients=gradients,
            losses=losses,
        )
        * unit_target_vector
    )