Skip to content

4.6. Utils

The utils module contains utility functions for the ConFIG algorithm.

Network Utility Functions¤

conflictfree.utils.get_para_vector ¤

get_para_vector(network) -> torch.Tensor

Returns the parameter vector of the given network.

Parameters:

Name Type Description Default
network Module

The network for which to compute the gradient vector.

required

Returns:

Type Description
Tensor

torch.Tensor: The parameter vector of the network.

Source code in conflictfree/utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_para_vector(network) -> torch.Tensor:
    """
    Returns the parameter vector of the given network.

    Args:
        network (torch.nn.Module): The network for which to compute the gradient vector.

    Returns:
        torch.Tensor: The parameter vector of the network.
    """
    with torch.no_grad():
        para_vec = None
        for par in network.parameters():
            viewed=par.data.view(-1)
            if para_vec is None:
                para_vec = viewed
            else:
                para_vec = torch.cat((para_vec, viewed))
        return para_vec

conflictfree.utils.apply_para_vector ¤

apply_para_vector(
    network: torch.nn.Module, para_vec: torch.Tensor
) -> None

Applies a parameter vector to the network's parameters.

Parameters:

Name Type Description Default
network Module

The network to apply the parameter vector to.

required
para_vec Tensor

The parameter vector to apply.

required
Source code in conflictfree/utils.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def apply_para_vector(network:torch.nn.Module,para_vec:torch.Tensor)->None:
    """
    Applies a parameter vector to the network's parameters.

    Args:
        network (torch.nn.Module): The network to apply the parameter vector to.
        para_vec (torch.Tensor): The parameter vector to apply.
    """
    with torch.no_grad():
        start=0
        for par in network.parameters():
            end=start+par.data.view(-1).shape[0]
            par.data=para_vec[start:end].view(par.data.shape)
            start=end

conflictfree.utils.get_gradient_vector ¤

get_gradient_vector(
    network, jump_none=True
) -> torch.Tensor

Returns the gradient vector of the given network.

Parameters:

Name Type Description Default
network Module

The network for which to compute the gradient vector.

required
jump_none bool

Whether to skip the None gradients. default: True This is useful when part of your neural network is frozen or not trainable. You should set the same value to apply_gradient_vector when applying the gradient vector.

True

Returns:

Type Description
Tensor

torch.Tensor: The gradient vector of the network.

Source code in conflictfree/utils.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def get_gradient_vector(network,jump_none=True)->torch.Tensor:
    """
    Returns the gradient vector of the given network.

    Args:
        network (torch.nn.Module): The network for which to compute the gradient vector.
        jump_none (bool): Whether to skip the None gradients. default: True
            This is useful when part of your neural network is frozen or not trainable.
            You should set the same value to `apply_gradient_vector` when applying the gradient vector.

    Returns:
        torch.Tensor: The gradient vector of the network.
    """
    with torch.no_grad():
        grad_vec = None
        for par in network.parameters():
            if par.grad is None:
                if jump_none:
                    continue
            viewed=par.grad.data.view(-1)
            if grad_vec is None:
                grad_vec = viewed
            else:
                grad_vec = torch.cat((grad_vec, viewed))
        return grad_vec

conflictfree.utils.apply_gradient_vector ¤

apply_gradient_vector(
    network: torch.nn.Module,
    grad_vec: torch.Tensor,
    jump_none=True,
) -> None

Applies a gradient vector to the network's parameters.

Parameters:

Name Type Description Default
network Module

The network to apply the gradient vector to.

required
grad_vec Tensor

The gradient vector to apply.

required
jump_none bool

Whether to skip the None gradients. default: True This is useful when part of your neural network is frozen or not trainable. You should set the same value to get_gradient_vector when applying the gradient vector.

True
Source code in conflictfree/utils.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def apply_gradient_vector(network:torch.nn.Module,grad_vec:torch.Tensor,jump_none=True)->None:
    """
    Applies a gradient vector to the network's parameters.

    Args:
        network (torch.nn.Module): The network to apply the gradient vector to.
        grad_vec (torch.Tensor): The gradient vector to apply.
        jump_none (bool): Whether to skip the None gradients. default: True
            This is useful when part of your neural network is frozen or not trainable.
            You should set the same value to `get_gradient_vector` when applying the gradient vector.

    """
    with torch.no_grad():
        start=0
        for par in network.parameters():
            if par.grad is None:
                if jump_none:
                    continue
            end=start+par.grad.data.view(-1).shape[0]
            par.grad.data=grad_vec[start:end].view(par.grad.data.shape)
            start=end

Math Utility Functions¤

conflictfree.utils.get_cos_similarity ¤

get_cos_similarity(
    vector1: torch.Tensor, vector2: torch.Tensor
) -> torch.Tensor

Calculates the cosine angle between two vectors.

Parameters:

Name Type Description Default
vector1 Tensor

The first vector.

required
vector2 Tensor

The second vector.

required

Returns:

Type Description
Tensor

torch.Tensor: The cosine angle between the two vectors.

Source code in conflictfree/utils.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def get_cos_similarity(vector1:torch.Tensor,vector2:torch.Tensor)->torch.Tensor:
    """
    Calculates the cosine angle between two vectors.

    Args:
        vector1 (torch.Tensor): The first vector.
        vector2 (torch.Tensor): The second vector.

    Returns:
        torch.Tensor: The cosine angle between the two vectors.
    """
    with torch.no_grad():
        return torch.dot(vector1,vector2)/vector1.norm()/vector2.norm()

conflictfree.utils.unit_vector ¤

unit_vector(
    vector: torch.Tensor, warn_zero=False
) -> torch.Tensor

Compute the unit vector of a given tensor.

Parameters:

Name Type Description Default
vector Tensor

The input tensor.

required
warn_zero bool

Whether to print a warning when the input tensor is zero. default: False

False

Returns:

Type Description
Tensor

torch.Tensor: The unit vector of the input tensor.

Source code in conflictfree/utils.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def unit_vector(vector: torch.Tensor, warn_zero=False)->torch.Tensor:
    """
    Compute the unit vector of a given tensor.

    Parameters:
        vector (torch.Tensor): The input tensor.
        warn_zero (bool): Whether to print a warning when the input tensor is zero. default: False

    Returns:
        torch.Tensor: The unit vector of the input tensor.
    """
    with torch.no_grad():
        if vector.norm()==0:
            if warn_zero:
                print("Detected zero vector when doing normalization.")
            return torch.zeros_like(vector)
        else:
            return vector / vector.norm()

Conflict Utility Functions¤

conflictfree.utils.estimate_conflict ¤

estimate_conflict(gradients: torch.Tensor) -> torch.Tensor

Estimates the degree of conflict of gradients.

Parameters:

Name Type Description Default
gradients Tensor

A tensor containing gradients.

required

Returns:

Type Description
Tensor

torch.Tensor: A tensor consistent of the dot products between the sum of gradients and each sub-gradient.

Source code in conflictfree/utils.py
151
152
153
154
155
156
157
158
159
160
161
162
163
def estimate_conflict(gradients: torch.Tensor)->torch.Tensor:
    """
    Estimates the degree of conflict of gradients.

    Args:
        gradients (torch.Tensor): A tensor containing gradients.

    Returns:
        torch.Tensor: A tensor consistent of the dot products between the sum of gradients and each sub-gradient.
    """
    direct_sum = unit_vector(gradients.sum(dim=0))
    unit_grads = gradients / torch.norm(gradients, dim=1).view(-1, 1)
    return unit_grads @ direct_sum

Slice Selector Classes¤

conflictfree.utils.OrderedSliceSelector ¤

Selects a slice of the source sequence in order. Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

Source code in conflictfree/utils.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class OrderedSliceSelector:
    """
    Selects a slice of the source sequence in order.
    Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

    """
    def __init__(self):
        self.start_index=0

    def select(self, n:int, source_sequence:Sequence) -> Tuple[Sequence,Union[float,Sequence]]:
        """
        Selects a slice of the source sequence in order.

        Args:
            n (int): The length of the target slice.
            source_sequence (Sequence): The source sequence to select from.

        Returns:
            Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
        """
        if n > len(source_sequence):
            raise ValueError("n must be less than or equal to the length of the source sequence")
        end_index = self.start_index + n
        if end_index > len(source_sequence)-1:
            new_start=end_index-len(source_sequence)
            indexes = list(range(self.start_index,len(source_sequence)))+list(range(0,new_start))
            self.start_index=new_start
        else:
            indexes = list(range(self.start_index,end_index))
            self.start_index=end_index
        if len(indexes)==1:
            return indexes,source_sequence[indexes[0]]
        else:
            return indexes,[source_sequence[i] for i in indexes]
start_index instance-attribute ¤
start_index = 0
__init__ ¤
__init__()
Source code in conflictfree/utils.py
186
187
def __init__(self):
    self.start_index=0
select ¤
select(
    n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]

Selects a slice of the source sequence in order.

Parameters:

Name Type Description Default
n int

The length of the target slice.

required
source_sequence Sequence

The source sequence to select from.

required

Returns:

Type Description
Tuple[Sequence, Union[float, Sequence]]

Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.

Source code in conflictfree/utils.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def select(self, n:int, source_sequence:Sequence) -> Tuple[Sequence,Union[float,Sequence]]:
    """
    Selects a slice of the source sequence in order.

    Args:
        n (int): The length of the target slice.
        source_sequence (Sequence): The source sequence to select from.

    Returns:
        Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
    """
    if n > len(source_sequence):
        raise ValueError("n must be less than or equal to the length of the source sequence")
    end_index = self.start_index + n
    if end_index > len(source_sequence)-1:
        new_start=end_index-len(source_sequence)
        indexes = list(range(self.start_index,len(source_sequence)))+list(range(0,new_start))
        self.start_index=new_start
    else:
        indexes = list(range(self.start_index,end_index))
        self.start_index=end_index
    if len(indexes)==1:
        return indexes,source_sequence[indexes[0]]
    else:
        return indexes,[source_sequence[i] for i in indexes]

conflictfree.utils.RandomSliceSelector ¤

Selects a slice of the source sequence randomly. Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

Source code in conflictfree/utils.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
class RandomSliceSelector:
    """
    Selects a slice of the source sequence randomly.
    Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.
    """

    def select(self, n:int, source_sequence:Sequence)-> Tuple[Sequence,Union[float,Sequence]]:
        """
        Selects a slice of the source sequence randomly.

        Args:
            n (int): The length of the target slice.
            source_sequence (Sequence): The source sequence to select from.

        Returns:
            Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
        """
        if n > len(source_sequence):
            raise ValueError("n must be less than or equal to the length of the source sequence")
        indexes = np.random.choice(len(source_sequence),n,replace=False)
        if len(indexes)==1:
            return indexes,source_sequence[indexes[0]]
        else:
            return indexes,[source_sequence[i] for i in indexes]
select ¤
select(
    n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]

Selects a slice of the source sequence randomly.

Parameters:

Name Type Description Default
n int

The length of the target slice.

required
source_sequence Sequence

The source sequence to select from.

required

Returns:

Type Description
Tuple[Sequence, Union[float, Sequence]]

Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.

Source code in conflictfree/utils.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def select(self, n:int, source_sequence:Sequence)-> Tuple[Sequence,Union[float,Sequence]]:
    """
    Selects a slice of the source sequence randomly.

    Args:
        n (int): The length of the target slice.
        source_sequence (Sequence): The source sequence to select from.

    Returns:
        Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
    """
    if n > len(source_sequence):
        raise ValueError("n must be less than or equal to the length of the source sequence")
    indexes = np.random.choice(len(source_sequence),n,replace=False)
    if len(indexes)==1:
        return indexes,source_sequence[indexes[0]]
    else:
        return indexes,[source_sequence[i] for i in indexes]

Others¤

conflictfree.utils.has_zero ¤

has_zero(lists: Sequence) -> bool

Check if any element in the list is zero.

Parameters:

Name Type Description Default
lists Sequence

A list of elements.

required

Returns:

Name Type Description
bool bool

True if any element is zero, False otherwise.

Source code in conflictfree/utils.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def has_zero(lists:Sequence)->bool:
    """
    Check if any element in the list is zero.

    Args:
        lists (Sequence): A list of elements.

    Returns:
        bool: True if any element is zero, False otherwise.
    """
    for i in lists:
        if i==0:
            return True
    return False