4.6. Utils
The utils
module contains utility functions for the ConFIG algorithm.
Network Utility Functions¤
conflictfree.utils.get_para_vector
¤
get_para_vector(network: torch.nn.Module) -> torch.Tensor
Returns the parameter vector of the given network.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network
|
Module
|
The network for which to compute the gradient vector. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The parameter vector of the network. |
Source code in conflictfree/utils.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
|
conflictfree.utils.apply_para_vector
¤
apply_para_vector(
network: torch.nn.Module, para_vec: torch.Tensor
) -> None
Applies a parameter vector to the network's parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network
|
Module
|
The network to apply the parameter vector to. |
required |
para_vec
|
Tensor
|
The parameter vector to apply. |
required |
Source code in conflictfree/utils.py
142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
|
conflictfree.utils.get_gradient_vector
¤
get_gradient_vector(
network: torch.nn.Module,
none_grad_mode: Literal[
"raise", "zero", "skip"
] = "skip",
) -> torch.Tensor
Returns the gradient vector of the given network.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network
|
Module
|
The network for which to compute the gradient vector. |
required |
none_grad_mode
|
Literal['raise', 'zero', 'skip']
|
The mode to handle None gradients. default: 'skip' - 'raise': Raise an error when the gradient of a parameter is None. - 'zero': Replace the None gradient with a zero tensor. - 'skip': Skip the None gradient. The None gradient usually occurs when part of the network is not trainable (e.g., fine-tuning) or the weight is not used to calculate the current loss (e.g., different parts of the network calculate different losses). If all of your losses are calculated using the same part of the network, you should set none_grad_mode to 'skip'. If your losses are calculated using different parts of the network, you should set none_grad_mode to 'zero' to ensure the gradients have the same shape. |
'skip'
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The gradient vector of the network. |
Source code in conflictfree/utils.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
conflictfree.utils.apply_gradient_vector
¤
apply_gradient_vector(
network: torch.nn.Module,
grad_vec: torch.Tensor,
none_grad_mode: Literal["zero", "skip"] = "skip",
zero_grad_mode: Literal[
"skip", "pad_zero", "pad_value"
] = "pad_value",
) -> None
Applies a gradient vector to the network's parameters.
This function requires the network contains the some gradient information in order to apply the gradient vector.
If your network does not contain the gradient information, you should consider using apply_gradient_vector_para_based
function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network
|
Module
|
The network to apply the gradient vector to. |
required |
grad_vec
|
Tensor
|
The gradient vector to apply. |
required |
none_grad_mode
|
Literal['zero', 'skip']
|
The mode to handle None gradients.
You should set this parameter to the same value as the one used in |
'skip'
|
zero_grad_mode
|
Literal['padding', 'skip']
|
How to set the value of the gradient if your |
'pad_value'
|
Source code in conflictfree/utils.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
|
conflictfree.utils.apply_gradient_vector_para_based
¤
apply_gradient_vector_para_based(
network: torch.nn.Module, grad_vec: torch.Tensor
) -> None
Applies a gradient vector to the network's parameters.
Please only use this function when you are sure that the length of grad_vec
is the same of your network's parameters.
This happens when you use get_gradient_vector
with none_grad_mode
set to 'zero'.
Or, the 'none_grad_mode' is 'skip' but all of the parameters in your network is involved in the loss calculation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network
|
Module
|
The network to apply the gradient vector to. |
required |
grad_vec
|
Tensor
|
The gradient vector to apply. |
required |
Source code in conflictfree/utils.py
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
|
Math Utility Functions¤
conflictfree.utils.get_cos_similarity
¤
get_cos_similarity(
vector1: torch.Tensor, vector2: torch.Tensor
) -> torch.Tensor
Calculates the cosine angle between two vectors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
vector1
|
Tensor
|
The first vector. |
required |
vector2
|
Tensor
|
The second vector. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The cosine angle between the two vectors. |
Source code in conflictfree/utils.py
158 159 160 161 162 163 164 165 166 167 168 169 170 |
|
conflictfree.utils.unit_vector
¤
unit_vector(
vector: torch.Tensor, warn_zero: bool = False
) -> torch.Tensor
Compute the unit vector of a given tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
vector
|
Tensor
|
The input tensor. |
required |
warn_zero
|
bool
|
Whether to print a warning when the input tensor is zero. default: False |
False
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The unit vector of the input tensor. |
Source code in conflictfree/utils.py
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
|
Conflict Utility Functions¤
conflictfree.utils.estimate_conflict
¤
estimate_conflict(gradients: torch.Tensor) -> torch.Tensor
Estimates the degree of conflict of gradients.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gradients
|
Tensor
|
A tensor containing gradients. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: A tensor consistent of the dot products between the sum of gradients and each sub-gradient. |
Source code in conflictfree/utils.py
230 231 232 233 234 235 236 237 238 239 240 241 242 |
|
Slice Selector Classes¤
conflictfree.utils.OrderedSliceSelector
¤
Selects a slice of the source sequence in order. Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.
Source code in conflictfree/utils.py
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
|
__init__
¤
__init__()
Source code in conflictfree/utils.py
268 269 |
|
select
¤
select(
n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]
Selects a slice of the source sequence in order.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n
|
int
|
The length of the target slice. |
required |
source_sequence
|
Sequence
|
The source sequence to select from. |
required |
Returns:
Type | Description |
---|---|
Tuple[Sequence, Union[float, Sequence]]
|
Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice. |
Source code in conflictfree/utils.py
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
|
conflictfree.utils.RandomSliceSelector
¤
Selects a slice of the source sequence randomly. Usually used for selecting loss functions/gradients/losses in momentum-based method if you want to update more tha one gradient in a single iteration.
Source code in conflictfree/utils.py
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
|
select
¤
select(
n: int, source_sequence: Sequence
) -> Tuple[Sequence, Union[float, Sequence]]
Selects a slice of the source sequence randomly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n
|
int
|
The length of the target slice. |
required |
source_sequence
|
Sequence
|
The source sequence to select from. |
required |
Returns:
Type | Description |
---|---|
Tuple[Sequence, Union[float, Sequence]]
|
Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice. |
Source code in conflictfree/utils.py
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
|
Others¤
conflictfree.utils.has_zero
¤
has_zero(lists: Sequence) -> bool
Check if any element in the list is zero.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
lists
|
Sequence
|
A list of elements. |
required |
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
True if any element is zero, False otherwise. |
Source code in conflictfree/utils.py
245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
|