Skip to content

4.1. Gradient Operator

The grad_operator module contains the main operators of ConFIG algorithm. You can use these operators to perform the ConFIG update step for your optimization problem.

Operation Functions¤

conflictfree.grad_operator.ConFIG_update ¤

ConFIG_update(
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    weight_model: WeightModel = EqualWeight(),
    length_model: LengthModel = ProjectionLength(),
    use_least_square: bool = True,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Performs the standard ConFIG update step.

Parameters:

Name Type Description Default
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
weight_model WeightModel

The weight model for calculating the direction weights. Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.

EqualWeight()
length_model LengthModel

The length model for rescaling the length of the final gradient. Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.

ProjectionLength()
use_least_square bool

Whether to use the least square method for calculating the best direction. If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. See torch.linalg.pinv and torch.linalg.lstsq for more details. Recommended to set to True. Defaults to True.

True
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information, you can set this value as None. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: The final update gradient.

Examples:

from conflictfree.grad_operator import ConFIG_update
from conflictfree.utils import get_gradient_vector,apply_gradient_vector
optimizer=torch.Adam(network.parameters(),lr=1e-3)
for input_i in dataset:
    grads=[] # we record gradients rather than losses
    for loss_fn in loss_fns:
        optimizer.zero_grad()
        loss_i=loss_fn(input_i)
        loss_i.backward()
        grads.append(get_gradient_vector(network)) #get loss-specfic gradient
    g_config=ConFIG_update(grads) # calculate the conflict-free direction
    apply_gradient_vector(network,g_config) # set the condlict-free direction to the network
    optimizer.step()
Source code in conflictfree/grad_operator.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def ConFIG_update(
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    weight_model: WeightModel = EqualWeight(),
    length_model: LengthModel = ProjectionLength(),
    use_least_square: bool = True,
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Performs the standard ConFIG update step.

    Args:
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        weight_model (WeightModel, optional): The weight model for calculating the direction weights.
            Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
        length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
            Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
        use_least_square (bool, optional): Whether to use the least square method for calculating the best direction.
            If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. See `torch.linalg.pinv` and `torch.linalg.lstsq` for more details.
            Recommended to set to True. Defaults to True.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
            you can set this value as None. Defaults to None.

    Returns:
        torch.Tensor: The final update gradient.

    Examples:
        ```python
        from conflictfree.grad_operator import ConFIG_update
        from conflictfree.utils import get_gradient_vector,apply_gradient_vector
        optimizer=torch.Adam(network.parameters(),lr=1e-3)
        for input_i in dataset:
            grads=[] # we record gradients rather than losses
            for loss_fn in loss_fns:
                optimizer.zero_grad()
                loss_i=loss_fn(input_i)
                loss_i.backward()
                grads.append(get_gradient_vector(network)) #get loss-specfic gradient
            g_config=ConFIG_update(grads) # calculate the conflict-free direction
            apply_gradient_vector(network,g_config) # set the condlict-free direction to the network
            optimizer.step()
        ```
    """
    if not isinstance(grads, torch.Tensor):
        grads = torch.stack(grads)
    with torch.no_grad():
        weights = weight_model.get_weights(
            gradients=grads, losses=losses, device=grads.device
        )
        units = torch.nan_to_num((grads / (grads.norm(dim=1)).unsqueeze(1)), 0)
        if use_least_square:
            best_direction = torch.linalg.lstsq(units, weights).solution
        else:
            best_direction = torch.linalg.pinv(units) @ weights
        return length_model.rescale_length(
            target_vector=best_direction,
            gradients=grads,
            losses=losses,
        )

conflictfree.grad_operator.ConFIG_update_double ¤

ConFIG_update_double(
    grad_1: torch.Tensor,
    grad_2: torch.Tensor,
    weight_model: WeightModel = EqualWeight(),
    length_model: LengthModel = ProjectionLength(),
    losses: Optional[Sequence] = None,
) -> torch.Tensor

ConFIG update for two gradients where no inverse calculation is needed.

Parameters:

Name Type Description Default
grad_1 Tensor

The first gradient.

required
grad_2 Tensor

The second gradient.

required
weight_model WeightModel

The weight model for calculating the direction weights. Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.

EqualWeight()
length_model LengthModel

The length model for rescaling the length of the final gradient. Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.

ProjectionLength()
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information, you can set this value as None. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: The final update gradient.

Examples:

from conflictfree.grad_operator import ConFIG_update_double
from conflictfree.utils import get_gradient_vector,apply_gradient_vector
optimizer=torch.Adam(network.parameters(),lr=1e-3)
for input_i in dataset:
    grads=[] # we record gradients rather than losses
    for loss_fn in [loss_fn1, loss_fn2]:
        optimizer.zero_grad()
        loss_i=loss_fn(input_i)
        loss_i.backward()
        grads.append(get_gradient_vector(network)) #get loss-specfic gradient
    g_config=ConFIG_update_double(grads) # calculate the conflict-free direction
    apply_gradient_vector(network,g_config) # set the condlict-free direction to the network
    optimizer.step()
Source code in conflictfree/grad_operator.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def ConFIG_update_double(
    grad_1: torch.Tensor,
    grad_2: torch.Tensor,
    weight_model: WeightModel = EqualWeight(),
    length_model: LengthModel = ProjectionLength(),
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    ConFIG update for two gradients where no inverse calculation is needed.

    Args:
        grad_1 (torch.Tensor): The first gradient.
        grad_2 (torch.Tensor): The second gradient.
        weight_model (WeightModel, optional): The weight model for calculating the direction weights.
            Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
        length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
            Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
            you can set this value as None. Defaults to None.

    Returns:
        torch.Tensor: The final update gradient.

    Examples:
        ```python
        from conflictfree.grad_operator import ConFIG_update_double
        from conflictfree.utils import get_gradient_vector,apply_gradient_vector
        optimizer=torch.Adam(network.parameters(),lr=1e-3)
        for input_i in dataset:
            grads=[] # we record gradients rather than losses
            for loss_fn in [loss_fn1, loss_fn2]:
                optimizer.zero_grad()
                loss_i=loss_fn(input_i)
                loss_i.backward()
                grads.append(get_gradient_vector(network)) #get loss-specfic gradient
            g_config=ConFIG_update_double(grads) # calculate the conflict-free direction
            apply_gradient_vector(network,g_config) # set the condlict-free direction to the network
            optimizer.step()
        ```

    """
    with torch.no_grad():
        norm_1 = grad_1.norm()
        norm_2 = grad_2.norm()
        unit_1 = grad_1 / norm_1
        unit_2 = grad_2 / norm_2
        cos_angle = get_cos_similarity(grad_1, grad_2)
        or_2 = grad_1 - norm_1 * cos_angle * unit_2
        or_1 = grad_2 - norm_2 * cos_angle * unit_1
        unit_or1 = unit_vector(or_1)
        unit_or2 = unit_vector(or_2)
        coef_1, coef_2 = transfer_coef_double(
            weight_model.get_weights(
                gradients=torch.stack([grad_1, grad_2]),
                losses=losses,
                device=grad_1.device,
            ),
            unit_1,
            unit_2,
            unit_or1,
            unit_or2,
        )
        best_direction = coef_1 * unit_or1 + coef_2 * unit_or2
        return length_model.rescale_length(
            target_vector=best_direction,
            gradients=torch.stack([grad_1, grad_2]),
            losses=losses,
        )

Operator Classes¤

conflictfree.grad_operator.ConFIGOperator ¤

Bases: GradientOperator

Operator for the ConFIG algorithm.

Parameters:

Name Type Description Default
weight_model WeightModel

The weight model for calculating the direction weights. Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.

EqualWeight()
length_model LengthModel

The length model for rescaling the length of the final gradient. Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.

ProjectionLength()
allow_simplified_model bool

Whether to allow simplified model for calculating the gradient. If set to True, will use simplified form of ConFIG method when there are only two losses (ConFIG_update_double). Defaults to True.

True
use_least_square bool

Whether to use the least square method for calculating the best direction. If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. See torch.linalg.pinv and torch.linalg.lstsq for more details. Recommended to set to True. Defaults to True.

True

Examples:

from conflictfree.grad_operator import ConFIGOperator
from conflictfree.utils import get_gradient_vector,apply_gradient_vector
optimizer=torch.Adam(network.parameters(),lr=1e-3)
operator=ConFIGOperator() # initialize operator
for input_i in dataset:
    grads=[]
    for loss_fn in loss_fns:
        optimizer.zero_grad()
        loss_i=loss_fn(input_i)
        loss_i.backward()
        grads.append(get_gradient_vector(network))
    g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
    apply_gradient_vector(network,g_config) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
    optimizer.step()
Source code in conflictfree/grad_operator.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
class ConFIGOperator(GradientOperator):
    """
    Operator for the ConFIG algorithm.

    Args:
        weight_model (WeightModel, optional): The weight model for calculating the direction weights.
            Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
        length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
            Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
        allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient.
            If set to True, will use simplified form of ConFIG method when there are only two losses (ConFIG_update_double). Defaults to True.
        use_least_square (bool, optional): Whether to use the least square method for calculating the best direction.
            If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. See `torch.linalg.pinv` and `torch.linalg.lstsq` for more details.
            Recommended to set to True. Defaults to True.

    Examples:
        ```python
        from conflictfree.grad_operator import ConFIGOperator
        from conflictfree.utils import get_gradient_vector,apply_gradient_vector
        optimizer=torch.Adam(network.parameters(),lr=1e-3)
        operator=ConFIGOperator() # initialize operator
        for input_i in dataset:
            grads=[]
            for loss_fn in loss_fns:
                optimizer.zero_grad()
                loss_i=loss_fn(input_i)
                loss_i.backward()
                grads.append(get_gradient_vector(network))
            g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
            apply_gradient_vector(network,g_config) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
            optimizer.step()
        ```

    """

    def __init__(
        self,
        weight_model: WeightModel = EqualWeight(),
        length_model: LengthModel = ProjectionLength(),
        allow_simplified_model: bool = True,
        use_least_square: bool = True,
    ):
        super().__init__()
        self.weight_model = weight_model
        self.length_model = length_model
        self.allow_simplified_model = allow_simplified_model
        self.use_least_square = use_least_square

    def calculate_gradient(
        self,
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Sequence] = None,
    ) -> torch.Tensor:
        """
        Calculates the gradient using the ConFIG algorithm.

        Args:
            grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
                It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
            losses (Optional[Sequence], optional): The losses associated with the gradients.
                The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
                you can set this value as None. Defaults to None.

        Returns:
            torch.Tensor: The calculated gradient.
        """
        if not isinstance(grads, torch.Tensor):
            grads = torch.stack(grads)
        if grads.shape[0] == 2 and self.allow_simplified_model:
            return ConFIG_update_double(
                grads[0],
                grads[1],
                weight_model=self.weight_model,
                length_model=self.length_model,
                losses=losses,
            )
        else:
            return ConFIG_update(
                grads,
                weight_model=self.weight_model,
                length_model=self.length_model,
                use_least_square=self.use_least_square,
                losses=losses,
            )
weight_model instance-attribute ¤
weight_model = weight_model
length_model instance-attribute ¤
length_model = length_model
allow_simplified_model instance-attribute ¤
allow_simplified_model = allow_simplified_model
use_least_square instance-attribute ¤
use_least_square = use_least_square
update_gradient ¤
update_gradient(
    network: torch.nn.Module,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> None

Calculate the gradient and apply the gradient to the network.

Parameters:

Name Type Description Default
network Module

The target network.

required
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information, you can set this value as None. Defaults to None.

None

Returns:

Type Description
None

None

Source code in conflictfree/grad_operator.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def update_gradient(
    self,
    network: torch.nn.Module,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> None:
    """
    Calculate the gradient and apply the gradient to the network.

    Args:
        network (torch.nn.Module): The target network.
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
            you can set this value as None. Defaults to None.

    Returns:
        None

    """
    apply_gradient_vector(network, self.calculate_gradient(grads, losses))
__init__ ¤
__init__(
    weight_model: WeightModel = EqualWeight(),
    length_model: LengthModel = ProjectionLength(),
    allow_simplified_model: bool = True,
    use_least_square: bool = True,
)
Source code in conflictfree/grad_operator.py
240
241
242
243
244
245
246
247
248
249
250
251
def __init__(
    self,
    weight_model: WeightModel = EqualWeight(),
    length_model: LengthModel = ProjectionLength(),
    allow_simplified_model: bool = True,
    use_least_square: bool = True,
):
    super().__init__()
    self.weight_model = weight_model
    self.length_model = length_model
    self.allow_simplified_model = allow_simplified_model
    self.use_least_square = use_least_square
calculate_gradient ¤
calculate_gradient(
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the gradient using the ConFIG algorithm.

Parameters:

Name Type Description Default
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information, you can set this value as None. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: The calculated gradient.

Source code in conflictfree/grad_operator.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def calculate_gradient(
    self,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the gradient using the ConFIG algorithm.

    Args:
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
            you can set this value as None. Defaults to None.

    Returns:
        torch.Tensor: The calculated gradient.
    """
    if not isinstance(grads, torch.Tensor):
        grads = torch.stack(grads)
    if grads.shape[0] == 2 and self.allow_simplified_model:
        return ConFIG_update_double(
            grads[0],
            grads[1],
            weight_model=self.weight_model,
            length_model=self.length_model,
            losses=losses,
        )
    else:
        return ConFIG_update(
            grads,
            weight_model=self.weight_model,
            length_model=self.length_model,
            use_least_square=self.use_least_square,
            losses=losses,
        )

conflictfree.grad_operator.PCGradOperator ¤

Bases: GradientOperator

PCGradOperator class represents a gradient operator for PCGrad algorithm.

@inproceedings{yu2020gradient, title={Gradient surgery for multi-task learning}, author={Yu, Tianhe and Kumar, Saurabh and Gupta, Abhishek and Levine, Sergey and Hausman, Karol and Finn, Chelsea}, booktitle={34th International Conference on Neural Information Processing Systems}, year={2020}, url={https://dl.acm.org/doi/abs/10.5555/3495724.3496213} }

Source code in conflictfree/grad_operator.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
class PCGradOperator(GradientOperator):
    """
    PCGradOperator class represents a gradient operator for PCGrad algorithm.

    @inproceedings{yu2020gradient,
    title={Gradient surgery for multi-task learning},
    author={Yu, Tianhe and Kumar, Saurabh and Gupta, Abhishek and Levine, Sergey and Hausman, Karol and Finn, Chelsea},
    booktitle={34th International Conference on Neural Information Processing Systems},
    year={2020},
    url={https://dl.acm.org/doi/abs/10.5555/3495724.3496213}
    }

    """

    def calculate_gradient(
        self,
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Sequence] = None,
    ) -> torch.Tensor:
        """
        Calculates the gradient using the PCGrad algorithm.

        Args:
            grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
                It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
            losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.

        Returns:
            torch.Tensor: The calculated gradient using PCGrad method.
        """
        if not isinstance(grads, torch.Tensor):
            grads = torch.stack(grads)
        with torch.no_grad():
            grads_pc = torch.clone(grads)
            length = grads.shape[0]
            for i in range(length):
                for j in range(length):
                    if j != i:
                        dot = grads_pc[i].dot(grads[j])
                        if dot < 0:
                            grads_pc[i] -= dot * grads[j] / ((grads[j].norm()) ** 2)
            return torch.sum(grads_pc, dim=0)
__init__ ¤
__init__()
Source code in conflictfree/grad_operator.py
154
155
def __init__(self):
    pass
update_gradient ¤
update_gradient(
    network: torch.nn.Module,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> None

Calculate the gradient and apply the gradient to the network.

Parameters:

Name Type Description Default
network Module

The target network.

required
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information, you can set this value as None. Defaults to None.

None

Returns:

Type Description
None

None

Source code in conflictfree/grad_operator.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def update_gradient(
    self,
    network: torch.nn.Module,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> None:
    """
    Calculate the gradient and apply the gradient to the network.

    Args:
        network (torch.nn.Module): The target network.
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
            you can set this value as None. Defaults to None.

    Returns:
        None

    """
    apply_gradient_vector(network, self.calculate_gradient(grads, losses))
calculate_gradient ¤
calculate_gradient(
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the gradient using the PCGrad algorithm.

Parameters:

Name Type Description Default
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

This parameter should not be set for current operator. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: The calculated gradient using PCGrad method.

Source code in conflictfree/grad_operator.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def calculate_gradient(
    self,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the gradient using the PCGrad algorithm.

    Args:
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.

    Returns:
        torch.Tensor: The calculated gradient using PCGrad method.
    """
    if not isinstance(grads, torch.Tensor):
        grads = torch.stack(grads)
    with torch.no_grad():
        grads_pc = torch.clone(grads)
        length = grads.shape[0]
        for i in range(length):
            for j in range(length):
                if j != i:
                    dot = grads_pc[i].dot(grads[j])
                    if dot < 0:
                        grads_pc[i] -= dot * grads[j] / ((grads[j].norm()) ** 2)
        return torch.sum(grads_pc, dim=0)

conflictfree.grad_operator.IMTLGOperator ¤

Bases: GradientOperator

PCGradOperator class represents a gradient operator for IMTL-G algorithm.

@inproceedings{ liu2021towards, title={Towards Impartial Multi-task Learning}, author={Liyang Liu and Yi Li and Zhanghui Kuang and Jing-Hao Xue and Yimin Chen and Wenming Yang and Qingmin Liao and Wayne Zhang}, booktitle={International Conference on Learning Representations}, year={2021}, url={https://openreview.net/forum?id=IMPnRXEWpvr} }

Source code in conflictfree/grad_operator.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
class IMTLGOperator(GradientOperator):
    """
    PCGradOperator class represents a gradient operator for IMTL-G algorithm.

    @inproceedings{
    liu2021towards,
    title={Towards Impartial Multi-task Learning},
    author={Liyang Liu and Yi Li and Zhanghui Kuang and Jing-Hao Xue and Yimin Chen and Wenming Yang and Qingmin Liao and Wayne Zhang},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=IMPnRXEWpvr}
    }

    """

    def calculate_gradient(
        self,
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Sequence] = None,
    ) -> torch.Tensor:
        """
        Calculates the gradient using the IMTL-G algorithm.

        Args:
            grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
                It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
            losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.

        Returns:
            torch.Tensor: The calculated gradient using IMTL-G method.
        """
        if not isinstance(grads, torch.Tensor):
            grads = torch.stack(grads)
        with torch.no_grad():
            ut_norm = grads / grads.norm(dim=1).unsqueeze(1)
            ut_norm = torch.nan_to_num(ut_norm, 0)
            ut = torch.stack(
                [ut_norm[0] - ut_norm[i + 1] for i in range(grads.shape[0] - 1)], dim=0
            ).T
            d = torch.stack(
                [grads[0] - grads[i + 1] for i in range(grads.shape[0] - 1)], dim=0
            )
            at = grads[0] @ ut @ torch.linalg.pinv(d @ ut)
            return (1 - torch.sum(at)) * grads[0] + torch.sum(
                at.unsqueeze(1) * grads[1:], dim=0
            )
__init__ ¤
__init__()
Source code in conflictfree/grad_operator.py
154
155
def __init__(self):
    pass
update_gradient ¤
update_gradient(
    network: torch.nn.Module,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> None

Calculate the gradient and apply the gradient to the network.

Parameters:

Name Type Description Default
network Module

The target network.

required
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information, you can set this value as None. Defaults to None.

None

Returns:

Type Description
None

None

Source code in conflictfree/grad_operator.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def update_gradient(
    self,
    network: torch.nn.Module,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> None:
    """
    Calculate the gradient and apply the gradient to the network.

    Args:
        network (torch.nn.Module): The target network.
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
            you can set this value as None. Defaults to None.

    Returns:
        None

    """
    apply_gradient_vector(network, self.calculate_gradient(grads, losses))
calculate_gradient ¤
calculate_gradient(
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the gradient using the IMTL-G algorithm.

Parameters:

Name Type Description Default
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

This parameter should not be set for current operator. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: The calculated gradient using IMTL-G method.

Source code in conflictfree/grad_operator.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def calculate_gradient(
    self,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the gradient using the IMTL-G algorithm.

    Args:
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.

    Returns:
        torch.Tensor: The calculated gradient using IMTL-G method.
    """
    if not isinstance(grads, torch.Tensor):
        grads = torch.stack(grads)
    with torch.no_grad():
        ut_norm = grads / grads.norm(dim=1).unsqueeze(1)
        ut_norm = torch.nan_to_num(ut_norm, 0)
        ut = torch.stack(
            [ut_norm[0] - ut_norm[i + 1] for i in range(grads.shape[0] - 1)], dim=0
        ).T
        d = torch.stack(
            [grads[0] - grads[i + 1] for i in range(grads.shape[0] - 1)], dim=0
        )
        at = grads[0] @ ut @ torch.linalg.pinv(d @ ut)
        return (1 - torch.sum(at)) * grads[0] + torch.sum(
            at.unsqueeze(1) * grads[1:], dim=0
        )

Base Class of Operators¤

conflictfree.grad_operator.GradientOperator ¤

A base class that represents a gradient operator.

Methods:

Name Description
calculate_gradient

Calculates the gradient based on the given gradients and losses.

update_gradient

Updates the gradient of the network based on the calculated gradient.

Source code in conflictfree/grad_operator.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class GradientOperator:
    """
    A base class that represents a gradient operator.

    Methods:
        calculate_gradient: Calculates the gradient based on the given gradients and losses.
        update_gradient: Updates the gradient of the network based on the calculated gradient.

    """

    def __init__(self):
        pass

    def calculate_gradient(
        self,
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Sequence] = None,
    ) -> torch.Tensor:
        """
        Calculates the gradient based on the given gradients and losses.

        Args:
            grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
                It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
            losses (Optional[Sequence], optional): The losses associated with the gradients.
                The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
                you can set this value as None. Defaults to None.

        Returns:
            torch.Tensor: The calculated gradient.

        Raises:
            NotImplementedError: If the method is not implemented.

        """
        raise NotImplementedError("calculate_gradient method must be implemented")

    def update_gradient(
        self,
        network: torch.nn.Module,
        grads: Union[torch.Tensor, Sequence[torch.Tensor]],
        losses: Optional[Sequence] = None,
    ) -> None:
        """
        Calculate the gradient and apply the gradient to the network.

        Args:
            network (torch.nn.Module): The target network.
            grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
                It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
            losses (Optional[Sequence], optional): The losses associated with the gradients.
                The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
                you can set this value as None. Defaults to None.

        Returns:
            None

        """
        apply_gradient_vector(network, self.calculate_gradient(grads, losses))
__init__ ¤
__init__()
Source code in conflictfree/grad_operator.py
154
155
def __init__(self):
    pass
calculate_gradient ¤
calculate_gradient(
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the gradient based on the given gradients and losses.

Parameters:

Name Type Description Default
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information, you can set this value as None. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: The calculated gradient.

Raises:

Type Description
NotImplementedError

If the method is not implemented.

Source code in conflictfree/grad_operator.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def calculate_gradient(
    self,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> torch.Tensor:
    """
    Calculates the gradient based on the given gradients and losses.

    Args:
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
            you can set this value as None. Defaults to None.

    Returns:
        torch.Tensor: The calculated gradient.

    Raises:
        NotImplementedError: If the method is not implemented.

    """
    raise NotImplementedError("calculate_gradient method must be implemented")
update_gradient ¤
update_gradient(
    network: torch.nn.Module,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> None

Calculate the gradient and apply the gradient to the network.

Parameters:

Name Type Description Default
network Module

The target network.

required
grads Union[Tensor, Sequence[Tensor]]

The gradients to update. It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.

required
losses Optional[Sequence]

The losses associated with the gradients. The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information, you can set this value as None. Defaults to None.

None

Returns:

Type Description
None

None

Source code in conflictfree/grad_operator.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def update_gradient(
    self,
    network: torch.nn.Module,
    grads: Union[torch.Tensor, Sequence[torch.Tensor]],
    losses: Optional[Sequence] = None,
) -> None:
    """
    Calculate the gradient and apply the gradient to the network.

    Args:
        network (torch.nn.Module): The target network.
        grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
            It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
        losses (Optional[Sequence], optional): The losses associated with the gradients.
            The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
            you can set this value as None. Defaults to None.

    Returns:
        None

    """
    apply_gradient_vector(network, self.calculate_gradient(grads, losses))