Skip to content

4.4. Length Model

The length_model module contains classes for rescaling the magnitude of the final gradient vector. The ProjectionLength class is the default length model for the ConFIG algorithm. You can create a custom length model by inheriting from the LengthModel class.

Length Model¤

conflictfree.length_model.ProjectionLength ¤

Bases: LengthModel

Rescale the length of the target vector based on the projection of the gradients on the target vector.

Source code in conflictfree/length_model.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class ProjectionLength(LengthModel):
    """
    Rescale the length of the target vector based on the projection of the gradients on the target vector.
    """

    def __init__(self):
        super().__init__()

    def get_length(self, target_vector:Optional[torch.Tensor]=None,
                         unit_target_vector:Optional[torch.Tensor]=None,
                         gradients:Optional[torch.Tensor]=None,
                         losses:Optional[Sequence]=None)->torch.Tensor:
        """
        Calculates the length based on the given parameters. Not all parameters are required.

        Args:
            target_vector (Optional[torch.Tensor]): The final update gradient vector. 
                One of the `target_vector` or `unit_target_vector` parameter need to be provided.
            unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
                One of the `target_vector` or `unit_target_vector` parameter need to be provided.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
            losses (Optional[Sequence]): The losses. Not used in this model.

        Returns:
            Union[torch.Tensor, float]: The calculated length.
        """
        if gradients is None:
            raise ValueError("The ProjectLength model requires gradients information.")
        if unit_target_vector is None:
            unit_target_vector = unit_vector(target_vector)
        return torch.sum(torch.stack([torch.dot(grad_i,unit_target_vector) for grad_i in gradients]))
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def rescale_length(self, 
                   target_vector:torch.Tensor,
                   gradients:Optional[torch.Tensor]=None,
                   losses:Optional[Sequence]=None)->torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return self.get_length(target_vector=target_vector,
                           unit_target_vector=unit_target_vector,
                           gradients=gradients,
                           losses=losses) * unit_target_vector
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
62
63
def __init__(self):
    super().__init__()
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector. One of the target_vector or unit_target_vector parameter need to be provided.

None
gradients Optional[Tensor]

The loss-specific gradients matrix.

None
losses Optional[Sequence]

The losses. Not used in this model.

None

Returns:

Type Description
Tensor

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def get_length(self, target_vector:Optional[torch.Tensor]=None,
                     unit_target_vector:Optional[torch.Tensor]=None,
                     gradients:Optional[torch.Tensor]=None,
                     losses:Optional[Sequence]=None)->torch.Tensor:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector. 
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            One of the `target_vector` or `unit_target_vector` parameter need to be provided.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
        losses (Optional[Sequence]): The losses. Not used in this model.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    if gradients is None:
        raise ValueError("The ProjectLength model requires gradients information.")
    if unit_target_vector is None:
        unit_target_vector = unit_vector(target_vector)
    return torch.sum(torch.stack([torch.dot(grad_i,unit_target_vector) for grad_i in gradients]))

Base Class of Length Model¤

conflictfree.length_model.LengthModel ¤

The base class for length model.

Methods:

Name Description
get_length

Calculates the length based on the given parameters.

rescale_length

Rescales the length of the target vector based on the given parameters.

Source code in conflictfree/length_model.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class LengthModel:
    """
    The base class for length model.

    Methods:
        get_length: Calculates the length based on the given parameters.
        rescale_length: Rescales the length of the target vector based on the given parameters.
    """
    def __init__(self):
        pass

    def get_length(self, 
                   target_vector:Optional[torch.Tensor]=None,
                   unit_target_vector:Optional[torch.Tensor]=None,
                   gradients:Optional[torch.Tensor]=None,
                   losses:Optional[Sequence]=None)-> Union[torch.Tensor, float]:
        """
        Calculates the length based on the given parameters. Not all parameters are required.

        Args:
            target_vector (Optional[torch.Tensor]): The final update gradient vector.
            unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
            losses (Optional[Sequence]): The losses.

        Returns:
            Union[torch.Tensor, float]: The calculated length.
        """
        raise NotImplementedError("This method must be implemented by the subclass.")

    def rescale_length(self, 
                       target_vector:torch.Tensor,
                       gradients:Optional[torch.Tensor]=None,
                       losses:Optional[Sequence]=None)->torch.Tensor:
        """
        Rescales the length of the target vector based on the given parameters.
        It calls the get_length method to calculate the length and then rescales the target vector.

        Args:
            target_vector (torch.Tensor): The final update gradient vector.
            gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
            losses (Optional[Sequence]): The losses.

        Returns:
            torch.Tensor: The rescaled target vector.
        """
        unit_target_vector = unit_vector(target_vector)
        return self.get_length(target_vector=target_vector,
                               unit_target_vector=unit_target_vector,
                               gradients=gradients,
                               losses=losses) * unit_target_vector
__init__ ¤
__init__()
Source code in conflictfree/length_model.py
13
14
def __init__(self):
    pass
get_length ¤
get_length(
    target_vector: Optional[torch.Tensor] = None,
    unit_target_vector: Optional[torch.Tensor] = None,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> Union[torch.Tensor, float]

Calculates the length based on the given parameters. Not all parameters are required.

Parameters:

Name Type Description Default
target_vector Optional[Tensor]

The final update gradient vector.

None
unit_target_vector Optional[Tensor]

The unit vector of the target vector.

None
gradients Optional[Tensor]

The loss-specific gradients matrix.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Union[Tensor, float]

Union[torch.Tensor, float]: The calculated length.

Source code in conflictfree/length_model.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_length(self, 
               target_vector:Optional[torch.Tensor]=None,
               unit_target_vector:Optional[torch.Tensor]=None,
               gradients:Optional[torch.Tensor]=None,
               losses:Optional[Sequence]=None)-> Union[torch.Tensor, float]:
    """
    Calculates the length based on the given parameters. Not all parameters are required.

    Args:
        target_vector (Optional[torch.Tensor]): The final update gradient vector.
        unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
        losses (Optional[Sequence]): The losses.

    Returns:
        Union[torch.Tensor, float]: The calculated length.
    """
    raise NotImplementedError("This method must be implemented by the subclass.")
rescale_length ¤
rescale_length(
    target_vector: torch.Tensor,
    gradients: Optional[torch.Tensor] = None,
    losses: Optional[Sequence] = None,
) -> torch.Tensor

Rescales the length of the target vector based on the given parameters. It calls the get_length method to calculate the length and then rescales the target vector.

Parameters:

Name Type Description Default
target_vector Tensor

The final update gradient vector.

required
gradients Optional[Tensor]

The loss-specific gradients matrix.

None
losses Optional[Sequence]

The losses.

None

Returns:

Type Description
Tensor

torch.Tensor: The rescaled target vector.

Source code in conflictfree/length_model.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def rescale_length(self, 
                   target_vector:torch.Tensor,
                   gradients:Optional[torch.Tensor]=None,
                   losses:Optional[Sequence]=None)->torch.Tensor:
    """
    Rescales the length of the target vector based on the given parameters.
    It calls the get_length method to calculate the length and then rescales the target vector.

    Args:
        target_vector (torch.Tensor): The final update gradient vector.
        gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
        losses (Optional[Sequence]): The losses.

    Returns:
        torch.Tensor: The rescaled target vector.
    """
    unit_target_vector = unit_vector(target_vector)
    return self.get_length(target_vector=target_vector,
                           unit_target_vector=unit_target_vector,
                           gradients=gradients,
                           losses=losses) * unit_target_vector