Skip to content

4.6. Utils

The utils module contains utility functions for the ConFIG algorithm.

Network Utility Functions¤

conflictfree.utils.get_para_vector ¤

get_para_vector(network) -> torch.Tensor

Returns the parameter vector of the given network.

Parameters:

Name Type Description Default
network Module

The network for which to compute the gradient vector.

required

Returns:

Type Description
Tensor

torch.Tensor: The parameter vector of the network.

Source code in conflictfree/utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_para_vector(network) -> torch.Tensor:
    """
    Returns the parameter vector of the given network.

    Args:
        network (torch.nn.Module): The network for which to compute the gradient vector.

    Returns:
        torch.Tensor: The parameter vector of the network.
    """
    with torch.no_grad():
        para_vec = None
        for par in network.parameters():
            viewed=par.data.view(-1)
            if para_vec is None:
                para_vec = viewed
            else:
                para_vec = torch.cat((para_vec, viewed))
        return para_vec

conflictfree.utils.apply_para_vector ¤

apply_para_vector(
    network: torch.nn.Module, para_vec: torch.Tensor
) -> None

Applies a parameter vector to the network's parameters.

Parameters:

Name Type Description Default
network Module

The network to apply the parameter vector to.

required
para_vec Tensor

The parameter vector to apply.

required
Source code in conflictfree/utils.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def apply_para_vector(network:torch.nn.Module,para_vec:torch.Tensor)->None:
    """
    Applies a parameter vector to the network's parameters.

    Args:
        network (torch.nn.Module): The network to apply the parameter vector to.
        para_vec (torch.Tensor): The parameter vector to apply.
    """
    with torch.no_grad():
        start=0
        for par in network.parameters():
            end=start+par.data.view(-1).shape[0]
            par.data=para_vec[start:end].view(par.data.shape)
            start=end

conflictfree.utils.get_gradient_vector ¤

get_gradient_vector(network) -> torch.Tensor

Returns the gradient vector of the given network.

Parameters:

Name Type Description Default
network Module

The network for which to compute the gradient vector.

required

Returns:

Type Description
Tensor

torch.Tensor: The gradient vector of the network.

Source code in conflictfree/utils.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def get_gradient_vector(network)->torch.Tensor:
    """
    Returns the gradient vector of the given network.

    Args:
        network (torch.nn.Module): The network for which to compute the gradient vector.

    Returns:
        torch.Tensor: The gradient vector of the network.
    """
    with torch.no_grad():
        grad_vec = None
        for par in network.parameters():
            viewed=par.grad.data.view(-1)
            if grad_vec is None:
                grad_vec = viewed
            else:
                grad_vec = torch.cat((grad_vec, viewed))
        return grad_vec

conflictfree.utils.apply_gradient_vector ¤

apply_gradient_vector(
    network: torch.nn.Module, grad_vec: torch.Tensor
) -> None

Applies a gradient vector to the network's parameters.

Parameters:

Name Type Description Default
network Module

The network to apply the gradient vector to.

required
grad_vec Tensor

The gradient vector to apply.

required
Source code in conflictfree/utils.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def apply_gradient_vector(network:torch.nn.Module,grad_vec:torch.Tensor)->None:
    """
    Applies a gradient vector to the network's parameters.

    Args:
        network (torch.nn.Module): The network to apply the gradient vector to.
        grad_vec (torch.Tensor): The gradient vector to apply.

    """
    with torch.no_grad():
        start=0
        for par in network.parameters():
            end=start+par.grad.data.view(-1).shape[0]
            par.grad.data=grad_vec[start:end].view(par.grad.data.shape)
            start=end

Math Utility Functions¤

conflictfree.utils.get_cos_similarity ¤

get_cos_similarity(
    vector1: torch.Tensor, vector2: torch.Tensor
) -> torch.Tensor

Calculates the cosine angle between two vectors.

Parameters:

Name Type Description Default
vector1 Tensor

The first vector.

required
vector2 Tensor

The second vector.

required

Returns:

Type Description
Tensor

torch.Tensor: The cosine angle between the two vectors.

Source code in conflictfree/utils.py
78
79
80
81
82
83
84
85
86
87
88
89
90
def get_cos_similarity(vector1:torch.Tensor,vector2:torch.Tensor)->torch.Tensor:
    """
    Calculates the cosine angle between two vectors.

    Args:
        vector1 (torch.Tensor): The first vector.
        vector2 (torch.Tensor): The second vector.

    Returns:
        torch.Tensor: The cosine angle between the two vectors.
    """
    with torch.no_grad():
        return torch.dot(vector1,vector2)/vector1.norm()/vector2.norm()

conflictfree.utils.unit_vector ¤

unit_vector(
    vector: torch.Tensor, warn_zero=False
) -> torch.Tensor

Compute the unit vector of a given tensor.

Parameters:

Name Type Description Default
vector Tensor

The input tensor.

required
warn_zero bool

Whether to print a warning when the input tensor is zero. default: False

False

Returns:

Type Description
Tensor

torch.Tensor: The unit vector of the input tensor.

Source code in conflictfree/utils.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def unit_vector(vector: torch.Tensor, warn_zero=False)->torch.Tensor:
    """
    Compute the unit vector of a given tensor.

    Parameters:
        vector (torch.Tensor): The input tensor.
        warn_zero (bool): Whether to print a warning when the input tensor is zero. default: False

    Returns:
        torch.Tensor: The unit vector of the input tensor.
    """
    with torch.no_grad():
        if vector.norm()==0:
            if warn_zero:
                print("Detected zero vector when doing normalization.")
            return torch.zeros_like(vector)
        else:
            return vector / vector.norm()

Conflict Utility Functions¤

conflictfree.utils.estimate_conflict ¤

estimate_conflict(gradients: torch.Tensor) -> torch.Tensor

Estimates the degree of conflict of gradients.

Parameters:

Name Type Description Default
gradients Tensor

A tensor containing gradients.

required

Returns:

Type Description
Tensor

torch.Tensor: A tensor consistent of the dot products between the sum of gradients and each sub-gradient.

Source code in conflictfree/utils.py
139
140
141
142
143
144
145
146
147
148
149
150
151
def estimate_conflict(gradients: torch.Tensor)->torch.Tensor:
    """
    Estimates the degree of conflict of gradients.

    Args:
        gradients (torch.Tensor): A tensor containing gradients.

    Returns:
        torch.Tensor: A tensor consistent of the dot products between the sum of gradients and each sub-gradient.
    """
    direct_sum = unit_vector(gradients.sum(dim=0))
    unit_grads = gradients / torch.norm(gradients, dim=1).view(-1, 1)
    return unit_grads @ direct_sum

Slice Selector Classes¤

conflictfree.utils.OrderedSliceSelector ¤

Selects a slice of the source sequence in order. Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

Source code in conflictfree/utils.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class OrderedSliceSelector:
    """
    Selects a slice of the source sequence in order.
    Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

    """
    def __init__(self):
        self.start_index=0

    def select(self, n:int, source_sequence:Sequence) -> Tuple[Sequence,Union[float,Sequence]]:
        """
        Selects a slice of the source sequence in order.

        Args:
            n (int): The length of the target slice.
            source_sequence (Sequence): The source sequence to select from.

        Returns:
            Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
        """
        if n > len(source_sequence):
            raise ValueError("n must be less than or equal to the length of the source sequence")
        end_index = self.start_index + n
        if end_index > len(source_sequence)-1:
            new_start=end_index-len(source_sequence)
            indexes = list(range(self.start_index,len(source_sequence)))+list(range(0,new_start))
            self.start_index=new_start
        else:
            indexes = list(range(self.start_index,end_index))
            self.start_index=end_index
        if len(indexes)==1:
            return indexes,source_sequence[indexes[0]]
        else:
            return indexes,[source_sequence[i] for i in indexes]
start_index instance-attribute ¤
start_index = 0
__init__ ¤
__init__()
Source code in conflictfree/utils.py
174
175
def __init__(self):
    self.start_index=0
select ¤
select(
    n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]

Selects a slice of the source sequence in order.

Parameters:

Name Type Description Default
n int

The length of the target slice.

required
source_sequence Sequence

The source sequence to select from.

required

Returns:

Type Description
Tuple[Sequence, Union[float, Sequence]]

Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.

Source code in conflictfree/utils.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def select(self, n:int, source_sequence:Sequence) -> Tuple[Sequence,Union[float,Sequence]]:
    """
    Selects a slice of the source sequence in order.

    Args:
        n (int): The length of the target slice.
        source_sequence (Sequence): The source sequence to select from.

    Returns:
        Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
    """
    if n > len(source_sequence):
        raise ValueError("n must be less than or equal to the length of the source sequence")
    end_index = self.start_index + n
    if end_index > len(source_sequence)-1:
        new_start=end_index-len(source_sequence)
        indexes = list(range(self.start_index,len(source_sequence)))+list(range(0,new_start))
        self.start_index=new_start
    else:
        indexes = list(range(self.start_index,end_index))
        self.start_index=end_index
    if len(indexes)==1:
        return indexes,source_sequence[indexes[0]]
    else:
        return indexes,[source_sequence[i] for i in indexes]

conflictfree.utils.RandomSliceSelector ¤

Selects a slice of the source sequence randomly. Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.

Source code in conflictfree/utils.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class RandomSliceSelector:
    """
    Selects a slice of the source sequence randomly.
    Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.
    """

    def select(self, n:int, source_sequence:Sequence)-> Tuple[Sequence,Union[float,Sequence]]:
        """
        Selects a slice of the source sequence randomly.

        Args:
            n (int): The length of the target slice.
            source_sequence (Sequence): The source sequence to select from.

        Returns:
            Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
        """
        if n > len(source_sequence):
            raise ValueError("n must be less than or equal to the length of the source sequence")
        indexes = np.random.choice(len(source_sequence),n,replace=False)
        if len(indexes)==1:
            return indexes,source_sequence[indexes[0]]
        else:
            return indexes,[source_sequence[i] for i in indexes]
select ¤
select(
    n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]

Selects a slice of the source sequence randomly.

Parameters:

Name Type Description Default
n int

The length of the target slice.

required
source_sequence Sequence

The source sequence to select from.

required

Returns:

Type Description
Tuple[Sequence, Union[float, Sequence]]

Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.

Source code in conflictfree/utils.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def select(self, n:int, source_sequence:Sequence)-> Tuple[Sequence,Union[float,Sequence]]:
    """
    Selects a slice of the source sequence randomly.

    Args:
        n (int): The length of the target slice.
        source_sequence (Sequence): The source sequence to select from.

    Returns:
        Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
    """
    if n > len(source_sequence):
        raise ValueError("n must be less than or equal to the length of the source sequence")
    indexes = np.random.choice(len(source_sequence),n,replace=False)
    if len(indexes)==1:
        return indexes,source_sequence[indexes[0]]
    else:
        return indexes,[source_sequence[i] for i in indexes]

Others¤

conflictfree.utils.has_zero ¤

has_zero(lists: Sequence) -> bool

Check if any element in the list is zero.

Parameters:

Name Type Description Default
lists Sequence

A list of elements.

required

Returns:

Name Type Description
bool bool

True if any element is zero, False otherwise.

Source code in conflictfree/utils.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def has_zero(lists:Sequence)->bool:
    """
    Check if any element in the list is zero.

    Args:
        lists (Sequence): A list of elements.

    Returns:
        bool: True if any element is zero, False otherwise.
    """
    for i in lists:
        if i==0:
            return True
    return False