Skip to content

4.6. Utils

The utils module contains utility functions for the ConFIG algorithm.

Network Utility Functions¤

conflictfree.utils.get_para_vector ¤

get_para_vector(network: torch.nn.Module) -> torch.Tensor

Returns the parameter vector of the given network.

Parameters:

Name Type Description Default
network Module

The network for which to compute the gradient vector.

required

Returns:

Type Description
Tensor

torch.Tensor: The parameter vector of the network.

Source code in conflictfree/utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def get_para_vector(network: torch.nn.Module) -> torch.Tensor:
    """
    Returns the parameter vector of the given network.

    Args:
        network (torch.nn.Module): The network for which to compute the gradient vector.

    Returns:
        torch.Tensor: The parameter vector of the network.
    """
    with torch.no_grad():
        para_vec = None
        for par in network.parameters():
            viewed = par.data.view(-1)
            if para_vec is None:
                para_vec = viewed
            else:
                para_vec = torch.cat((para_vec, viewed))
        return para_vec

conflictfree.utils.apply_para_vector ¤

apply_para_vector(
    network: torch.nn.Module, para_vec: torch.Tensor
) -> None

Applies a parameter vector to the network's parameters.

Parameters:

Name Type Description Default
network Module

The network to apply the parameter vector to.

required
para_vec Tensor

The parameter vector to apply.

required
Source code in conflictfree/utils.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def apply_para_vector(network: torch.nn.Module, para_vec: torch.Tensor) -> None:
    """
    Applies a parameter vector to the network's parameters.

    Args:
        network (torch.nn.Module): The network to apply the parameter vector to.
        para_vec (torch.Tensor): The parameter vector to apply.
    """
    with torch.no_grad():
        start = 0
        for par in network.parameters():
            end = start + par.data.view(-1).shape[0]
            par.data = para_vec[start:end].view(par.data.shape)
            start = end

conflictfree.utils.get_gradient_vector ¤

get_gradient_vector(
    network: torch.nn.Module,
    none_grad_mode: Literal[
        "raise", "zero", "skip"
    ] = "skip",
) -> torch.Tensor

Returns the gradient vector of the given network.

Parameters:

Name Type Description Default
network Module

The network for which to compute the gradient vector.

required
none_grad_mode Literal['raise', 'zero', 'skip']

The mode to handle None gradients. default: 'skip' - 'raise': Raise an error when the gradient of a parameter is None. - 'zero': Replace the None gradient with a zero tensor. - 'skip': Skip the None gradient. The None gradient usually occurs when part of the network is not trainable (e.g., fine-tuning) or the weight is not used to calculate the current loss (e.g., different parts of the network calculate different losses). If all of your losses are calculated using the same part of the network, you should set none_grad_mode to 'skip'. If your losses are calculated using different parts of the network, you should set none_grad_mode to 'zero' to ensure the gradients have the same shape.

'skip'

Returns:

Type Description
Tensor

torch.Tensor: The gradient vector of the network.

Source code in conflictfree/utils.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def get_gradient_vector(
    network: torch.nn.Module, none_grad_mode: Literal["raise", "zero", "skip"] = "skip"
) -> torch.Tensor:
    """
    Returns the gradient vector of the given network.

    Args:
        network (torch.nn.Module): The network for which to compute the gradient vector.
        none_grad_mode (Literal['raise', 'zero', 'skip']): The mode to handle None gradients. default: 'skip'
            - 'raise': Raise an error when the gradient of a parameter is None.
            - 'zero': Replace the None gradient with a zero tensor.
            - 'skip': Skip the None gradient.
                        The None gradient usually occurs when part of the network is not trainable (e.g., fine-tuning)
            or the weight is not used to calculate the current loss (e.g., different parts of the network calculate different losses).
            If all of your losses are calculated using the same part of the network, you should set none_grad_mode to 'skip'.
            If your losses are calculated using different parts of the network, you should set none_grad_mode to 'zero' to ensure the gradients have the same shape.

    Returns:
        torch.Tensor: The gradient vector of the network.
    """
    with torch.no_grad():
        grad_vec = None
        for par in network.parameters():
            if par.grad is None:
                if none_grad_mode == "raise":
                    raise RuntimeError("None gradient detected.")
                elif none_grad_mode == "zero":
                    viewed = torch.zeros_like(par.data.view(-1))
                elif none_grad_mode == "skip":
                    continue
                else:
                    raise ValueError(f"Invalid none_grad_mode '{none_grad_mode}'.")
            else:
                viewed = par.grad.data.view(-1)
            if grad_vec is None:
                grad_vec = viewed
            else:
                grad_vec = torch.cat((grad_vec, viewed))
        return grad_vec

conflictfree.utils.apply_gradient_vector ¤

apply_gradient_vector(
    network: torch.nn.Module,
    grad_vec: torch.Tensor,
    none_grad_mode: Literal["zero", "skip"] = "skip",
    zero_grad_mode: Literal[
        "skip", "pad_zero", "pad_value"
    ] = "pad_value",
) -> None

Applies a gradient vector to the network's parameters. This function requires the network contains the some gradient information in order to apply the gradient vector. If your network does not contain the gradient information, you should consider using apply_gradient_vector_para_based function.

Parameters:

Name Type Description Default
network Module

The network to apply the gradient vector to.

required
grad_vec Tensor

The gradient vector to apply.

required
none_grad_mode Literal['zero', 'skip']

The mode to handle None gradients. You should set this parameter to the same value as the one used in get_gradient_vector method.

'skip'
zero_grad_mode Literal['padding', 'skip']

How to set the value of the gradient if your none_grad_mode is "zero". default: 'skip' - 'skip': Skip the None gradient. - 'padding': Replace the None gradient with a zero tensor. - 'pad_value': Replace the None gradient using the value in the gradient. If you set none_grad_mode to 'zero', that means you padded zero to your grad_vec if the gradient of the parameter is None when getting the gradient vector. When you apply the gradient vector back to the network, the value in the grad_vec corresponding to the previous None gradient may not be zero due to the applied gradient operation. Thus, you need to determine whether to recover the original None value, set it to zero, or set the value according to the value in grad_vec. If you are not sure what you are doing, it is safer to set it to 'pad_value'.

'pad_value'
Source code in conflictfree/utils.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def apply_gradient_vector(
    network: torch.nn.Module,
    grad_vec: torch.Tensor,
    none_grad_mode: Literal["zero", "skip"] = "skip",
    zero_grad_mode: Literal["skip", "pad_zero", "pad_value"] = "pad_value",
) -> None:
    """
    Applies a gradient vector to the network's parameters.
    This function requires the network contains the some gradient information in order to apply the gradient vector.
    If your network does not contain the gradient information, you should consider using `apply_gradient_vector_para_based` function.

    Args:
        network (torch.nn.Module): The network to apply the gradient vector to.
        grad_vec (torch.Tensor): The gradient vector to apply.
        none_grad_mode (Literal['zero', 'skip']): The mode to handle None gradients.
            You should set this parameter to the same value as the one used in `get_gradient_vector` method.
        zero_grad_mode (Literal['padding', 'skip']): How to set the value of the gradient if your `none_grad_mode` is "zero". default: 'skip'
            - 'skip': Skip the None gradient.
            - 'padding': Replace the None gradient with a zero tensor.
            - 'pad_value': Replace the None gradient using the value in the gradient.
            If you set `none_grad_mode` to 'zero', that means you padded zero to your `grad_vec` if the gradient of the parameter is None when getting the gradient vector.
            When you apply the gradient vector back to the network, the value in the `grad_vec` corresponding to the previous None gradient may not be zero due to the applied gradient operation.
                        Thus, you need to determine whether to recover the original None value, set it to zero, or set the value according to the value in `grad_vec`.
            If you are not sure what you are doing, it is safer to set it to 'pad_value'.

    """
    if none_grad_mode == "zero" and zero_grad_mode == "pad_value":
        apply_gradient_vector_para_based(network, grad_vec)
    with torch.no_grad():
        start = 0
        for par in network.parameters():
            if par.grad is None:
                if none_grad_mode == "skip":
                    continue
                elif none_grad_mode == "zero":
                    start = start + par.data.view(-1).shape[0]
                    if zero_grad_mode == "pad_zero":
                        par.grad = torch.zeros_like(par.data)
                    elif zero_grad_mode == "skip":
                        continue
                    else:
                        raise ValueError(f"Invalid zero_grad_mode '{zero_grad_mode}'.")
                else:
                    raise ValueError(f"Invalid none_grad_mode '{none_grad_mode}'.")
            else:
                end = start + par.data.view(-1).shape[0]
                par.grad.data = grad_vec[start:end].view(par.data.shape)
                start = end

conflictfree.utils.apply_gradient_vector_para_based ¤

apply_gradient_vector_para_based(
    network: torch.nn.Module, grad_vec: torch.Tensor
) -> None

Applies a gradient vector to the network's parameters. Please only use this function when you are sure that the length of grad_vec is the same of your network's parameters. This happens when you use get_gradient_vector with none_grad_mode set to 'zero'. Or, the 'none_grad_mode' is 'skip' but all of the parameters in your network is involved in the loss calculation.

Parameters:

Name Type Description Default
network Module

The network to apply the gradient vector to.

required
grad_vec Tensor

The gradient vector to apply.

required
Source code in conflictfree/utils.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def apply_gradient_vector_para_based(
    network: torch.nn.Module,
    grad_vec: torch.Tensor,
) -> None:
    """
    Applies a gradient vector to the network's parameters.
    Please only use this function when you are sure that the length of `grad_vec` is the same of your network's parameters.
    This happens when you use `get_gradient_vector` with `none_grad_mode` set to 'zero'.
    Or, the 'none_grad_mode' is 'skip' but all of the parameters in your network is involved in the loss calculation.

    Args:
        network (torch.nn.Module): The network to apply the gradient vector to.
        grad_vec (torch.Tensor): The gradient vector to apply.
    """
    with torch.no_grad():
        start = 0
        for par in network.parameters():
            end = start + par.data.view(-1).shape[0]
            par.grad = grad_vec[start:end].view(par.data.shape)
            start = end

Math Utility Functions¤

conflictfree.utils.get_cos_similarity ¤

get_cos_similarity(
    vector1: torch.Tensor, vector2: torch.Tensor
) -> torch.Tensor

Calculates the cosine angle between two vectors.

Parameters:

Name Type Description Default
vector1 Tensor

The first vector.

required
vector2 Tensor

The second vector.

required

Returns:

Type Description
Tensor

torch.Tensor: The cosine angle between the two vectors.

Source code in conflictfree/utils.py
158
159
160
161
162
163
164
165
166
167
168
169
170
def get_cos_similarity(vector1: torch.Tensor, vector2: torch.Tensor) -> torch.Tensor:
    """
    Calculates the cosine angle between two vectors.

    Args:
        vector1 (torch.Tensor): The first vector.
        vector2 (torch.Tensor): The second vector.

    Returns:
        torch.Tensor: The cosine angle between the two vectors.
    """
    with torch.no_grad():
        return torch.dot(vector1, vector2) / vector1.norm() / vector2.norm()

conflictfree.utils.unit_vector ¤

unit_vector(
    vector: torch.Tensor, warn_zero: bool = False
) -> torch.Tensor

Compute the unit vector of a given tensor.

Parameters:

Name Type Description Default
vector Tensor

The input tensor.

required
warn_zero bool

Whether to print a warning when the input tensor is zero. default: False

False

Returns:

Type Description
Tensor

torch.Tensor: The unit vector of the input tensor.

Source code in conflictfree/utils.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def unit_vector(vector: torch.Tensor, warn_zero: bool = False) -> torch.Tensor:
    """
    Compute the unit vector of a given tensor.

    Parameters:
        vector (torch.Tensor): The input tensor.
        warn_zero (bool): Whether to print a warning when the input tensor is zero. default: False

    Returns:
        torch.Tensor: The unit vector of the input tensor.
    """
    with torch.no_grad():
        if vector.norm() == 0:
            if warn_zero:
                print("Detected zero vector when doing normalization.")
            return torch.zeros_like(vector)
        else:
            return vector / vector.norm()

Conflict Utility Functions¤

conflictfree.utils.estimate_conflict ¤

estimate_conflict(gradients: torch.Tensor) -> torch.Tensor

Estimates the degree of conflict of gradients.

Parameters:

Name Type Description Default
gradients Tensor

A tensor containing gradients.

required

Returns:

Type Description
Tensor

torch.Tensor: A tensor consistent of the dot products between the sum of gradients and each sub-gradient.

Source code in conflictfree/utils.py
230
231
232
233
234
235
236
237
238
239
240
241
242
def estimate_conflict(gradients: torch.Tensor) -> torch.Tensor:
    """
    Estimates the degree of conflict of gradients.

    Args:
        gradients (torch.Tensor): A tensor containing gradients.

    Returns:
        torch.Tensor: A tensor consistent of the dot products between the sum of gradients and each sub-gradient.
    """
    direct_sum = unit_vector(gradients.sum(dim=0))
    unit_grads = gradients / torch.norm(gradients, dim=1).view(-1, 1)
    return unit_grads @ direct_sum

Slice Selector Classes¤

conflictfree.utils.OrderedSliceSelector ¤

Selects a slice of the source sequence in order. Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

Source code in conflictfree/utils.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
class OrderedSliceSelector:
    """
    Selects a slice of the source sequence in order.
    Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

    """

    def __init__(self):
        self.start_index = 0

    def select(
        self, n: int, source_sequence: Sequence
    ) -> Tuple[Sequence, Union[float, Sequence]]:
        """
        Selects a slice of the source sequence in order.

        Args:
            n (int): The length of the target slice.
            source_sequence (Sequence): The source sequence to select from.

        Returns:
            Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
        """
        if n > len(source_sequence):
            raise ValueError(
                "n must be less than or equal to the length of the source sequence"
            )
        end_index = self.start_index + n
        if end_index > len(source_sequence) - 1:
            new_start = end_index - len(source_sequence)
            indexes = list(range(self.start_index, len(source_sequence))) + list(
                range(0, new_start)
            )
            self.start_index = new_start
        else:
            indexes = list(range(self.start_index, end_index))
            self.start_index = end_index
        if len(indexes) == 1:
            return indexes, source_sequence[indexes[0]]
        else:
            return indexes, [source_sequence[i] for i in indexes]
start_index instance-attribute ¤
start_index = 0
__init__ ¤
__init__()
Source code in conflictfree/utils.py
268
269
def __init__(self):
    self.start_index = 0
select ¤
select(
    n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]

Selects a slice of the source sequence in order.

Parameters:

Name Type Description Default
n int

The length of the target slice.

required
source_sequence Sequence

The source sequence to select from.

required

Returns:

Type Description
Tuple[Sequence, Union[float, Sequence]]

Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.

Source code in conflictfree/utils.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def select(
    self, n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]:
    """
    Selects a slice of the source sequence in order.

    Args:
        n (int): The length of the target slice.
        source_sequence (Sequence): The source sequence to select from.

    Returns:
        Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
    """
    if n > len(source_sequence):
        raise ValueError(
            "n must be less than or equal to the length of the source sequence"
        )
    end_index = self.start_index + n
    if end_index > len(source_sequence) - 1:
        new_start = end_index - len(source_sequence)
        indexes = list(range(self.start_index, len(source_sequence))) + list(
            range(0, new_start)
        )
        self.start_index = new_start
    else:
        indexes = list(range(self.start_index, end_index))
        self.start_index = end_index
    if len(indexes) == 1:
        return indexes, source_sequence[indexes[0]]
    else:
        return indexes, [source_sequence[i] for i in indexes]

conflictfree.utils.RandomSliceSelector ¤

Selects a slice of the source sequence randomly. Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

Source code in conflictfree/utils.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
class RandomSliceSelector:
    """
    Selects a slice of the source sequence randomly.
    Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.
    """

    def select(
        self, n: int, source_sequence: Sequence
    ) -> Tuple[Sequence, Union[float, Sequence]]:
        """
        Selects a slice of the source sequence randomly.

        Args:
            n (int): The length of the target slice.
            source_sequence (Sequence): The source sequence to select from.

        Returns:
            Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
        """
        assert n <= len(
            source_sequence
        ), "n can not be larger than or equal to the length of the source sequence"
        indexes = np.random.choice(len(source_sequence), n, replace=False)
        if len(indexes) == 1:
            return indexes, source_sequence[indexes[0]]
        else:
            return indexes, [source_sequence[i] for i in indexes]
select ¤
select(
    n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]

Selects a slice of the source sequence randomly.

Parameters:

Name Type Description Default
n int

The length of the target slice.

required
source_sequence Sequence

The source sequence to select from.

required

Returns:

Type Description
Tuple[Sequence, Union[float, Sequence]]

Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.

Source code in conflictfree/utils.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def select(
    self, n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]:
    """
    Selects a slice of the source sequence randomly.

    Args:
        n (int): The length of the target slice.
        source_sequence (Sequence): The source sequence to select from.

    Returns:
        Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
    """
    assert n <= len(
        source_sequence
    ), "n can not be larger than or equal to the length of the source sequence"
    indexes = np.random.choice(len(source_sequence), n, replace=False)
    if len(indexes) == 1:
        return indexes, source_sequence[indexes[0]]
    else:
        return indexes, [source_sequence[i] for i in indexes]

Others¤

conflictfree.utils.has_zero ¤

has_zero(lists: Sequence) -> bool

Check if any element in the list is zero.

Parameters:

Name Type Description Default
lists Sequence

A list of elements.

required

Returns:

Name Type Description
bool bool

True if any element is zero, False otherwise.

Source code in conflictfree/utils.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def has_zero(lists: Sequence) -> bool:
    """
    Check if any element in the list is zero.

    Args:
        lists (Sequence): A list of elements.

    Returns:
        bool: True if any element is zero, False otherwise.
    """
    for i in lists:
        if i == 0:
            return True
    return False