Skip to content

4.2. Momentum Operator

The momentum_operator module contains the main operators for the momentum version ConFIG algorithm.

Operator Classes¤

conflictfree.momentum_operator.PseudoMomentumOperator ¤

Bases: MomentumOperator

The major momentum version. In this operator, the second momentum is estimated by a pseudo gradient based on the result of the gradient operator. NOTE: The momentum-based operator, e.g., Adam, is not recommend when using this operator. Please consider using SGD optimizer.

Parameters:

Name Type Description Default
num_vectors int

The number of gradient vectors.

required
beta_1 float

The moving average constant for the first momentum.

0.9
beta_2 float

The moving average constant for the second momentum.

0.999
gradient_operator GradientOperator

The base gradient operator. Defaults to ConFIGOperator().

ConFIGOperator()
loss_recorder LossRecorder

The loss recorder object. If you want to pass loss information to "update_gradient" method or "apply_gradient" method, you need to specify a loss recorder. Defaults to None.

None

Methods:

Name Description
calculate_gradient

Calculates the gradient based on the given indexes, gradients, and losses.

update_gradient

Updates the gradient of the given network with the calculated gradient.

Examples:

from conflictfree.momentum_operator import PseudoMomentumOperator
from conflictfree.utils import get_gradient_vector,apply_gradient_vector
optimizer=torch.Adam(network.parameters(),lr=1e-3)
operator=PseudoMomentumOperator(num_vector=len(loss_fns)) # initialize operator, the only difference here is we need to specify the number of gradient vectors.
for input_i in dataset:
    grads=[]
    for loss_fn in loss_fns:
        optimizer.zero_grad()
        loss_i=loss_fn(input_i)
        loss_i.backward()
        grads.append(get_gradient_vector(network))
    g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
    apply_gradient_vector(network,g_config) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
    optimizer.step()

Source code in conflictfree/momentum_operator.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
class PseudoMomentumOperator(MomentumOperator):
    """
    The major momentum version.
    In this operator, the second momentum is estimated by a pseudo gradient based on the result of the gradient operator.
    NOTE: The momentum-based operator, e.g., Adam, is not recommend when using this operator. Please consider using SGD optimizer.

    Args:
        num_vectors (int): The number of gradient vectors.
        beta_1 (float): The moving average constant for the first momentum.
        beta_2 (float): The moving average constant for the second momentum.
        gradient_operator (GradientOperator, optional): The base gradient operator. Defaults to ConFIGOperator().
        loss_recorder (LossRecorder, optional): The loss recorder object.
            If you want to pass loss information to "update_gradient" method or "apply_gradient" method, you need to specify a loss recorder. Defaults to None.

    Methods:
        calculate_gradient(indexes, grads, losses=None):
            Calculates the gradient based on the given indexes, gradients, and losses.
        update_gradient(network, indexes, grads, losses=None):
            Updates the gradient of the given network with the calculated gradient.

    Examples:
    ```python
    from conflictfree.momentum_operator import PseudoMomentumOperator
    from conflictfree.utils import get_gradient_vector,apply_gradient_vector
    optimizer=torch.Adam(network.parameters(),lr=1e-3)
    operator=PseudoMomentumOperator(num_vector=len(loss_fns)) # initialize operator, the only difference here is we need to specify the number of gradient vectors.
    for input_i in dataset:
        grads=[]
        for loss_fn in loss_fns:
            optimizer.zero_grad()
            loss_i=loss_fn(input_i)
            loss_i.backward()
            grads.append(get_gradient_vector(network))
        g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
        apply_gradient_vector(network,g_config) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
        optimizer.step()
    ```
    """

    def __init__(
        self,
        num_vectors: int,
        beta_1: float = 0.9,
        beta_2: float = 0.999,
        gradient_operator: GradientOperator = ConFIGOperator(),
        loss_recorder: Optional[LossRecorder] = None,
    ) -> None:
        super().__init__(num_vectors, beta_1, beta_2, gradient_operator, loss_recorder)
        self.m = None
        self.s = None
        self.fake_m = None
        self.t = 0
        self.t_grads = [0] * self.num_vectors
        self.all_initialized = False

    def _preprocess_gradients_losses(
        self,
        indexes: Union[int, Sequence[int]],
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Union[float, Sequence]] = None,
    ):
        indexes, grads, losses = super()._preprocess_gradients_losses(
            indexes, grads, losses
        )
        if self.m is None or self.s is None or self.fake_m is None:
            self.m = [
                torch.zeros(self.len_vectors, device=self.device)
                for i in range(self.num_vectors)
            ]
            self.s = torch.zeros(self.len_vectors, device=self.device)
            self.fake_m = torch.zeros(self.len_vectors, device=self.device)
        return indexes, grads, losses

    def calculate_gradient(
        self,
        indexes: Union[int, Sequence[int]],
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Union[float, Sequence]] = None,
    ) -> torch.Tensor:
        """
        Calculates the gradient based on the given indexes, gradients, and losses.

        Args:
            indexes (Union[int,Sequence[int]]): The indexes of the gradient vectors and losses to be updated.
                The momentum with the given indexes will be updated based on the given gradients.
            grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
                It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
            losses (Optional[Sequence], optional): The losses associated with the gradients.
                The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information,
                you can set this value as None. Defaults to None.

        Raises:
            NotImplementedError: This method must be implemented in a subclass.

        Returns:
            torch.Tensor: The calculated gradient.
        """
        with torch.no_grad():
            indexes, grads, losses = self._preprocess_gradients_losses(
                indexes, grads, losses
            )
            for i in range(len(indexes)):
                self.t_grads[indexes[i]] += 1
                self.m[indexes[i]] = (
                    self.beta_1 * self.m[indexes[i]] + (1 - self.beta_1) * grads[i]
                )
            if not self.all_initialized:
                if has_zero(self.t_grads):
                    return torch.zeros_like(self.s)
                else:
                    self.all_initialized = True
            self.t += 1
            m_hats = torch.stack(
                [
                    self.m[i] / (1 - self.beta_1 ** self.t_grads[i])
                    for i in range(self.num_vectors)
                ],
                dim=0,
            )
            final_grad = self.gradient_operator.calculate_gradient(m_hats, losses)
            fake_m = final_grad * (1 - self.beta_1**self.t)
            fake_grad = (fake_m - self.beta_1 * self.fake_m) / (1 - self.beta_1)
            self.fake_m = fake_m
            self.s = self.beta_2 * self.s + (1 - self.beta_2) * fake_grad**2
            s_hat = self.s / (1 - self.beta_2**self.t)
            final_grad = final_grad / (torch.sqrt(s_hat) + 1e-8)
        return final_grad
len_vectors instance-attribute ¤
len_vectors = None
device instance-attribute ¤
device = None
beta_1 instance-attribute ¤
beta_1 = beta_1
beta_2 instance-attribute ¤
beta_2 = beta_2
num_vectors instance-attribute ¤
num_vectors = num_vectors
gradient_operator instance-attribute ¤
gradient_operator = gradient_operator
loss_recorder instance-attribute ¤
loss_recorder = loss_recorder
m instance-attribute ¤
m = None
s instance-attribute ¤
s = None
fake_m instance-attribute ¤
fake_m = None
t instance-attribute ¤
t = 0
t_grads instance-attribute ¤
t_grads = [0] * num_vectors
all_initialized instance-attribute ¤
all_initialized = False
update_gradient ¤
update_gradient(
    network: torch.nn.Module,
    indexes: Union[int, Sequence[int]],
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Union[float, Sequence]] = None,
) -> None

Updates the gradient of the given network with the calculated gradient.

Parameters:

Name Type Description Default
network Module

The network to update the gradient.

required
indexes Union[int, Sequence[int]]

The indexes of the gradient vectors and losses to be updated. The momentum with the given indexes will be updated based on the given gradients.

required
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information, you can set this value as None. Defaults to None.

None

Raises:

Type Description
NotImplementedError

This method must be implemented in a subclass.

Returns:

Type Description
None

None

Source code in conflictfree/momentum_operator.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def update_gradient(
    self,
    network: torch.nn.Module,
    indexes: Union[int, Sequence[int]],
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Union[float, Sequence]] = None,
) -> None:
    """
    Updates the gradient of the given network with the calculated gradient.

    Args:
        network (torch.nn.Module): The network to update the gradient.
        indexes (Union[int,Sequence[int]]): The indexes of the gradient vectors and losses to be updated.
            The momentum with the given indexes will be updated based on the given gradients.
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information,
            you can set this value as None. Defaults to None.

    Raises:
        NotImplementedError: This method must be implemented in a subclass.

    Returns:
        None
    """
    apply_gradient_vector(network, self.calculate_gradient(indexes, grads, losses))
__init__ ¤
__init__(
    num_vectors: int,
    beta_1: float = 0.9,
    beta_2: float = 0.999,
    gradient_operator: GradientOperator = ConFIGOperator(),
    loss_recorder: Optional[LossRecorder] = None,
) -> None
Source code in conflictfree/momentum_operator.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def __init__(
    self,
    num_vectors: int,
    beta_1: float = 0.9,
    beta_2: float = 0.999,
    gradient_operator: GradientOperator = ConFIGOperator(),
    loss_recorder: Optional[LossRecorder] = None,
) -> None:
    super().__init__(num_vectors, beta_1, beta_2, gradient_operator, loss_recorder)
    self.m = None
    self.s = None
    self.fake_m = None
    self.t = 0
    self.t_grads = [0] * self.num_vectors
    self.all_initialized = False
calculate_gradient ¤
calculate_gradient(
    indexes: Union[int, Sequence[int]],
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Union[float, Sequence]] = None,
) -> torch.Tensor

Calculates the gradient based on the given indexes, gradients, and losses.

Parameters:

Name Type Description Default
indexes Union[int, Sequence[int]]

The indexes of the gradient vectors and losses to be updated. The momentum with the given indexes will be updated based on the given gradients.

required
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information, you can set this value as None. Defaults to None.

None

Raises:

Type Description
NotImplementedError

This method must be implemented in a subclass.

Returns:

Type Description
Tensor

torch.Tensor: The calculated gradient.

Source code in conflictfree/momentum_operator.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def calculate_gradient(
    self,
    indexes: Union[int, Sequence[int]],
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Union[float, Sequence]] = None,
) -> torch.Tensor:
    """
    Calculates the gradient based on the given indexes, gradients, and losses.

    Args:
        indexes (Union[int,Sequence[int]]): The indexes of the gradient vectors and losses to be updated.
            The momentum with the given indexes will be updated based on the given gradients.
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information,
            you can set this value as None. Defaults to None.

    Raises:
        NotImplementedError: This method must be implemented in a subclass.

    Returns:
        torch.Tensor: The calculated gradient.
    """
    with torch.no_grad():
        indexes, grads, losses = self._preprocess_gradients_losses(
            indexes, grads, losses
        )
        for i in range(len(indexes)):
            self.t_grads[indexes[i]] += 1
            self.m[indexes[i]] = (
                self.beta_1 * self.m[indexes[i]] + (1 - self.beta_1) * grads[i]
            )
        if not self.all_initialized:
            if has_zero(self.t_grads):
                return torch.zeros_like(self.s)
            else:
                self.all_initialized = True
        self.t += 1
        m_hats = torch.stack(
            [
                self.m[i] / (1 - self.beta_1 ** self.t_grads[i])
                for i in range(self.num_vectors)
            ],
            dim=0,
        )
        final_grad = self.gradient_operator.calculate_gradient(m_hats, losses)
        fake_m = final_grad * (1 - self.beta_1**self.t)
        fake_grad = (fake_m - self.beta_1 * self.fake_m) / (1 - self.beta_1)
        self.fake_m = fake_m
        self.s = self.beta_2 * self.s + (1 - self.beta_2) * fake_grad**2
        s_hat = self.s / (1 - self.beta_2**self.t)
        final_grad = final_grad / (torch.sqrt(s_hat) + 1e-8)
    return final_grad

conflictfree.momentum_operator.SeparateMomentumOperator ¤

Bases: MomentumOperator

In this operator, each gradient has its own second gradient. The gradient operator is applied on the rescaled momentum. NOTE: Please consider using the PseudoMomentumOperator since this operator does not give good performance according to our research. The momentum-based operator, e.g., Adam, is not recommend when using this operator. Please consider using SGD optimizer.

Parameters:

Name Type Description Default
num_vectors int

The number of gradient vectors.

required
beta_1 float

The moving average constant for the first momentum.

0.9
beta_2 float

The moving average constant for the second momentum.

0.999
gradient_operator GradientOperator

The base gradient operator. Defaults to ConFIGOperator().

ConFIGOperator()
loss_recorder LossRecorder

The loss recorder object. If you want to pass loss information to "update_gradient" method or "apply_gradient" method, you need to specify a loss recorder. Defaults to None.

None

Methods:

Name Description
calculate_gradient

Calculates the gradient based on the given indexes, gradients, and losses.

update_gradient

Updates the gradient of the given network with the calculated gradient.

Source code in conflictfree/momentum_operator.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
class SeparateMomentumOperator(MomentumOperator):
    """
    In this operator, each gradient has its own second gradient. The gradient operator is applied on the rescaled momentum.
    NOTE: Please consider using the PseudoMomentumOperator since this operator does not give good performance according to our research.
    The momentum-based operator, e.g., Adam, is not recommend when using this operator. Please consider using SGD optimizer.

    Args:
        num_vectors (int): The number of gradient vectors.
        beta_1 (float): The moving average constant for the first momentum.
        beta_2 (float): The moving average constant for the second momentum.
        gradient_operator (GradientOperator, optional): The base gradient operator. Defaults to ConFIGOperator().
        loss_recorder (LossRecorder, optional): The loss recorder object.
            If you want to pass loss information to "update_gradient" method or "apply_gradient" method, you need to specify a loss recorder. Defaults to None.

    Methods:
        calculate_gradient(indexes, grads, losses=None):
            Calculates the gradient based on the given indexes, gradients, and losses.
        update_gradient(network, indexes, grads, losses=None):
            Updates the gradient of the given network with the calculated gradient.

    """

    def __init__(
        self,
        num_vectors: int,
        beta_1: float = 0.9,
        beta_2: float = 0.999,
        gradient_operator: GradientOperator = ConFIGOperator(),
        loss_recorder: Optional[LossRecorder] = None,
    ) -> None:
        super().__init__(num_vectors, beta_1, beta_2, gradient_operator, loss_recorder)
        self.m = None
        self.s = None
        self.t_grads = [0] * len(self.num_vectors)
        self.all_initialized = False

    def _preprocess_gradients_losses(
        self,
        indexes: Union[int, Sequence[int]],
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Union[float, Sequence]] = None,
    ):
        indexes, grads, losses = super()._preprocess_gradients_losses(
            indexes, grads, losses
        )
        if self.m is None or self.s is None:
            self.m = [
                torch.zeros(self.len_vectors, device=self.device)
                for i in range(self.num_vectors)
            ]
            self.s = [
                torch.zeros(self.len_vectors, device=self.device)
                for i in range(self.num_vectors)
            ]
        return indexes, grads, losses

    def calculate_gradient(
        self,
        indexes: Union[int, Sequence[int]],
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Union[float, Sequence]] = None,
    ) -> torch.Tensor:
        """
        Calculates the gradient based on the given indexes, gradients, and losses.

        Args:
            indexes (Union[int,Sequence[int]]): The indexes of the gradient vectors and losses to be updated.
                The momentum with the given indexes will be updated based on the given gradients.
            grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
                It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
            losses (Optional[Sequence], optional): The losses associated with the gradients.
                The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information,
                you can set this value as None. Defaults to None.

        Raises:
            NotImplementedError: This method must be implemented in a subclass.

        Returns:
            torch.Tensor: The calculated gradient.
        """
        with torch.no_grad():
            indexes, grads, losses = self._preprocess_gradients_losses(
                indexes, grads, losses
            )
            for i in range(len(indexes)):
                self.t_grads[indexes[i]] += 1
                self.m[indexes[i]] = (
                    self.beta_1 * self.m[indexes[i]] + (1 - self.beta_1) * grads[i]
                )
                self.s[indexes[i]] = (
                    self.beta_2 * self.s[indexes[i]] + (1 - self.beta_2) * grads[i] ** 2
                )
            if not self.all_initialized:
                if has_zero(self.t_grads):
                    return torch.zeros_like(self.s)
                else:
                    self.all_initialized = True
            m_hats = torch.stack(
                [
                    self.m[i] / (1 - self.betas_1 ** self.t_grads[i])
                    for i in range(self.num_vectors)
                ],
                dim=0,
            )
            s_hats = torch.stack(
                [
                    self.s[i] / (1 - self.betas_2 ** self.t_grads[i])
                    for i in range(self.num_vectors)
                ],
                dim=0,
            )
        return self.gradient_operator.calculate_gradient(
            m_hats / (torch.sqrt(s_hats) + 1e-8),
            losses,
        )
len_vectors instance-attribute ¤
len_vectors = None
device instance-attribute ¤
device = None
beta_1 instance-attribute ¤
beta_1 = beta_1
beta_2 instance-attribute ¤
beta_2 = beta_2
num_vectors instance-attribute ¤
num_vectors = num_vectors
gradient_operator instance-attribute ¤
gradient_operator = gradient_operator
loss_recorder instance-attribute ¤
loss_recorder = loss_recorder
m instance-attribute ¤
m = None
s instance-attribute ¤
s = None
t_grads instance-attribute ¤
t_grads = [0] * len(num_vectors)
all_initialized instance-attribute ¤
all_initialized = False
update_gradient ¤
update_gradient(
    network: torch.nn.Module,
    indexes: Union[int, Sequence[int]],
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Union[float, Sequence]] = None,
) -> None

Updates the gradient of the given network with the calculated gradient.

Parameters:

Name Type Description Default
network Module

The network to update the gradient.

required
indexes Union[int, Sequence[int]]

The indexes of the gradient vectors and losses to be updated. The momentum with the given indexes will be updated based on the given gradients.

required
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information, you can set this value as None. Defaults to None.

None

Raises:

Type Description
NotImplementedError

This method must be implemented in a subclass.

Returns:

Type Description
None

None

Source code in conflictfree/momentum_operator.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def update_gradient(
    self,
    network: torch.nn.Module,
    indexes: Union[int, Sequence[int]],
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Union[float, Sequence]] = None,
) -> None:
    """
    Updates the gradient of the given network with the calculated gradient.

    Args:
        network (torch.nn.Module): The network to update the gradient.
        indexes (Union[int,Sequence[int]]): The indexes of the gradient vectors and losses to be updated.
            The momentum with the given indexes will be updated based on the given gradients.
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information,
            you can set this value as None. Defaults to None.

    Raises:
        NotImplementedError: This method must be implemented in a subclass.

    Returns:
        None
    """
    apply_gradient_vector(network, self.calculate_gradient(indexes, grads, losses))
__init__ ¤
__init__(
    num_vectors: int,
    beta_1: float = 0.9,
    beta_2: float = 0.999,
    gradient_operator: GradientOperator = ConFIGOperator(),
    loss_recorder: Optional[LossRecorder] = None,
) -> None
Source code in conflictfree/momentum_operator.py
291
292
293
294
295
296
297
298
299
300
301
302
303
def __init__(
    self,
    num_vectors: int,
    beta_1: float = 0.9,
    beta_2: float = 0.999,
    gradient_operator: GradientOperator = ConFIGOperator(),
    loss_recorder: Optional[LossRecorder] = None,
) -> None:
    super().__init__(num_vectors, beta_1, beta_2, gradient_operator, loss_recorder)
    self.m = None
    self.s = None
    self.t_grads = [0] * len(self.num_vectors)
    self.all_initialized = False
calculate_gradient ¤
calculate_gradient(
    indexes: Union[int, Sequence[int]],
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Union[float, Sequence]] = None,
) -> torch.Tensor

Calculates the gradient based on the given indexes, gradients, and losses.

Parameters:

Name Type Description Default
indexes Union[int, Sequence[int]]

The indexes of the gradient vectors and losses to be updated. The momentum with the given indexes will be updated based on the given gradients.

required
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information, you can set this value as None. Defaults to None.

None

Raises:

Type Description
NotImplementedError

This method must be implemented in a subclass.

Returns:

Type Description
Tensor

torch.Tensor: The calculated gradient.

Source code in conflictfree/momentum_operator.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def calculate_gradient(
    self,
    indexes: Union[int, Sequence[int]],
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Union[float, Sequence]] = None,
) -> torch.Tensor:
    """
    Calculates the gradient based on the given indexes, gradients, and losses.

    Args:
        indexes (Union[int,Sequence[int]]): The indexes of the gradient vectors and losses to be updated.
            The momentum with the given indexes will be updated based on the given gradients.
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to base gradient operator. If the base gradient operator doesn't require loss information,
            you can set this value as None. Defaults to None.

    Raises:
        NotImplementedError: This method must be implemented in a subclass.

    Returns:
        torch.Tensor: The calculated gradient.
    """
    with torch.no_grad():
        indexes, grads, losses = self._preprocess_gradients_losses(
            indexes, grads, losses
        )
        for i in range(len(indexes)):
            self.t_grads[indexes[i]] += 1
            self.m[indexes[i]] = (
                self.beta_1 * self.m[indexes[i]] + (1 - self.beta_1) * grads[i]
            )
            self.s[indexes[i]] = (
                self.beta_2 * self.s[indexes[i]] + (1 - self.beta_2) * grads[i] ** 2
            )
        if not self.all_initialized:
            if has_zero(self.t_grads):
                return torch.zeros_like(self.s)
            else:
                self.all_initialized = True
        m_hats = torch.stack(
            [
                self.m[i] / (1 - self.betas_1 ** self.t_grads[i])
                for i in range(self.num_vectors)
            ],
            dim=0,
        )
        s_hats = torch.stack(
            [
                self.s[i] / (1 - self.betas_2 ** self.t_grads[i])
                for i in range(self.num_vectors)
            ],
            dim=0,
        )
    return self.gradient_operator.calculate_gradient(
        m_hats / (torch.sqrt(s_hats) + 1e-8),
        losses,
    )

Base Class of Operators¤

conflictfree.momentum_operator.LatestLossRecorder ¤

Bases: LossRecorder

A loss recorder return the latest losses.

Parameters:

Name Type Description Default
num_losses int

The number of losses to record

required
Source code in conflictfree/loss_recorder.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class LatestLossRecorder(LossRecorder):
    """
    A loss recorder return the latest losses.

    Args:
        num_losses (int): The number of losses to record
    """

    def __init__(self, num_losses: int) -> None:
        super().__init__(num_losses)

    def record_loss(
        self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
    ) -> list:
        """
        Records the given loss and returns the recorded loss.

        Args:
            losses_indexes: The index of the loss.
            losses (torch.Tensor): The loss to record.

        Returns:
            list: The recorded loss.

        """
        losses_indexes, losses = self._preprocess_losses(losses_indexes, losses)
        for i in losses_indexes:
            self.current_losses[i] = losses[losses_indexes.index(i)]
        return self.current_losses
num_losses instance-attribute ¤
num_losses = num_losses
current_losses instance-attribute ¤
current_losses = [0.0 for i in range(num_losses)]
record_all_losses ¤
record_all_losses(losses: Sequence) -> list

Records all the losses and returns the recorded losses.

Parameters:

Name Type Description Default
losses Tensor

The losses to record.

required

Returns:

Name Type Description
list list

The recorded losses.

Source code in conflictfree/loss_recorder.py
35
36
37
38
39
40
41
42
43
44
45
46
47
def record_all_losses(self, losses: Sequence) -> list:
    """
    Records all the losses and returns the recorded losses.

    Args:
        losses (torch.Tensor): The losses to record.

    Returns:
        list: The recorded losses.

    """
    assert len(losses) == self.num_losses, "The number of losses does not match the number of losses to be recorded."
    return self.record_loss([i for i in range(self.num_losses)], losses)
__init__ ¤
__init__(num_losses: int) -> None
Source code in conflictfree/loss_recorder.py
77
78
def __init__(self, num_losses: int) -> None:
    super().__init__(num_losses)
record_loss ¤
record_loss(
    losses_indexes: Union[int, Sequence[int]],
    losses: Union[float, Sequence],
) -> list

Records the given loss and returns the recorded loss.

Parameters:

Name Type Description Default
losses_indexes Union[int, Sequence[int]]

The index of the loss.

required
losses Tensor

The loss to record.

required

Returns:

Name Type Description
list list

The recorded loss.

Source code in conflictfree/loss_recorder.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def record_loss(
    self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
) -> list:
    """
    Records the given loss and returns the recorded loss.

    Args:
        losses_indexes: The index of the loss.
        losses (torch.Tensor): The loss to record.

    Returns:
        list: The recorded loss.

    """
    losses_indexes, losses = self._preprocess_losses(losses_indexes, losses)
    for i in losses_indexes:
        self.current_losses[i] = losses[losses_indexes.index(i)]
    return self.current_losses