Skip to content

4.5. Loss Recorder

The loss_recorder module contains classes for recording the loss values during the optimization process. It is used in the momentum version of the ConFIG algorithm to record the loss values. Not every loss is calculated in a single iteration with the momentum version of the ConFIG algorithm. However, sometimes we need to know the information of all the loss values, e.g., logging and calculating length/weight model. You can create a custom loss recorder by inheriting from the LossRecorder class.

Loss Recorder¤

conflictfree.loss_recorder.LatestLossRecorder ¤

Bases: LossRecorder

A loss recorder return the latest losses.

Parameters:

Name Type Description Default
num_losses int

The number of losses to record

required
Source code in conflictfree/loss_recorder.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class LatestLossRecorder(LossRecorder):
    """
    A loss recorder return the latest losses.

    Args:
        num_losses (int): The number of losses to record
    """

    def __init__(self, num_losses: int) -> None:
        super().__init__(num_losses)

    def record_loss(
        self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
    ) -> list:
        """
        Records the given loss and returns the recorded loss.

        Args:
            losses_indexes: The index of the loss.
            losses (torch.Tensor): The loss to record.

        Returns:
            list: The recorded loss.

        """
        losses_indexes, losses = self._preprocess_losses(losses_indexes, losses)
        for i in losses_indexes:
            self.current_losses[i] = losses[losses_indexes.index(i)]
        return self.current_losses
num_losses instance-attribute ¤
num_losses = num_losses
current_losses instance-attribute ¤
current_losses = [0.0 for i in range(num_losses)]
record_all_losses ¤
record_all_losses(losses: Sequence) -> list

Records all the losses and returns the recorded losses.

Parameters:

Name Type Description Default
losses Tensor

The losses to record.

required

Returns:

Name Type Description
list list

The recorded losses.

Source code in conflictfree/loss_recorder.py
35
36
37
38
39
40
41
42
43
44
45
46
47
def record_all_losses(self, losses: Sequence) -> list:
    """
    Records all the losses and returns the recorded losses.

    Args:
        losses (torch.Tensor): The losses to record.

    Returns:
        list: The recorded losses.

    """
    assert len(losses) == self.num_losses, "The number of losses does not match the number of losses to be recorded."
    return self.record_loss([i for i in range(self.num_losses)], losses)
__init__ ¤
__init__(num_losses: int) -> None
Source code in conflictfree/loss_recorder.py
77
78
def __init__(self, num_losses: int) -> None:
    super().__init__(num_losses)
record_loss ¤
record_loss(
    losses_indexes: Union[int, Sequence[int]],
    losses: Union[float, Sequence],
) -> list

Records the given loss and returns the recorded loss.

Parameters:

Name Type Description Default
losses_indexes Union[int, Sequence[int]]

The index of the loss.

required
losses Tensor

The loss to record.

required

Returns:

Name Type Description
list list

The recorded loss.

Source code in conflictfree/loss_recorder.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def record_loss(
    self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
) -> list:
    """
    Records the given loss and returns the recorded loss.

    Args:
        losses_indexes: The index of the loss.
        losses (torch.Tensor): The loss to record.

    Returns:
        list: The recorded loss.

    """
    losses_indexes, losses = self._preprocess_losses(losses_indexes, losses)
    for i in losses_indexes:
        self.current_losses[i] = losses[losses_indexes.index(i)]
    return self.current_losses

conflictfree.loss_recorder.MomentumLossRecorder ¤

Bases: LossRecorder

A loss recorder that records the momentum of the loss.

Parameters:

Name Type Description Default
num_losses int

The number of losses to record

required
betas Union[float, Sequence[float]]

The moving average constant.

0.9
Source code in conflictfree/loss_recorder.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class MomentumLossRecorder(LossRecorder):
    """
    A loss recorder that records the momentum of the loss.

    Args:
        num_losses (int): The number of losses to record
        betas (Union[float, Sequence[float]]): The moving average constant.
    """

    def __init__(self, num_losses: int, betas: Union[float, Sequence[float]] = 0.9):
        super().__init__(num_losses)
        if isinstance(betas, float):
            self.betas = [betas] * num_losses
        self.m = [0.0 for i in range(num_losses)]
        self.t = [0 for i in range(num_losses)]

    def record_loss(
        self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
    ) -> list:
        """
        Records the given loss and returns the recorded loss.

        Args:
            losses_indexes: The index of the loss.
            losses (torch.Tensor): The loss to record.

        Returns:
            list: The recorded loss.

        """
        losses_indexes, losses = self._preprocess_losses(losses_indexes, losses)
        for index in losses_indexes:
            self.t[index] += 1
            self.m[index] = (
                self.betas * self.m[index]
                + (1 - self.betas[index]) * losses[losses_indexes.index(index)]
            )
        self.current_losses = [
            self.m[index] / (1 - self.betas[index] ** self.t[index])
            for index in len(self.m)
        ]
        return self.current_losses
num_losses instance-attribute ¤
num_losses = num_losses
current_losses instance-attribute ¤
current_losses = [0.0 for i in range(num_losses)]
betas instance-attribute ¤
betas = [betas] * num_losses
m instance-attribute ¤
m = [0.0 for i in range(num_losses)]
t instance-attribute ¤
t = [0 for i in range(num_losses)]
record_all_losses ¤
record_all_losses(losses: Sequence) -> list

Records all the losses and returns the recorded losses.

Parameters:

Name Type Description Default
losses Tensor

The losses to record.

required

Returns:

Name Type Description
list list

The recorded losses.

Source code in conflictfree/loss_recorder.py
35
36
37
38
39
40
41
42
43
44
45
46
47
def record_all_losses(self, losses: Sequence) -> list:
    """
    Records all the losses and returns the recorded losses.

    Args:
        losses (torch.Tensor): The losses to record.

    Returns:
        list: The recorded losses.

    """
    assert len(losses) == self.num_losses, "The number of losses does not match the number of losses to be recorded."
    return self.record_loss([i for i in range(self.num_losses)], losses)
__init__ ¤
__init__(
    num_losses: int,
    betas: Union[float, Sequence[float]] = 0.9,
)
Source code in conflictfree/loss_recorder.py
109
110
111
112
113
114
def __init__(self, num_losses: int, betas: Union[float, Sequence[float]] = 0.9):
    super().__init__(num_losses)
    if isinstance(betas, float):
        self.betas = [betas] * num_losses
    self.m = [0.0 for i in range(num_losses)]
    self.t = [0 for i in range(num_losses)]
record_loss ¤
record_loss(
    losses_indexes: Union[int, Sequence[int]],
    losses: Union[float, Sequence],
) -> list

Records the given loss and returns the recorded loss.

Parameters:

Name Type Description Default
losses_indexes Union[int, Sequence[int]]

The index of the loss.

required
losses Tensor

The loss to record.

required

Returns:

Name Type Description
list list

The recorded loss.

Source code in conflictfree/loss_recorder.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def record_loss(
    self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
) -> list:
    """
    Records the given loss and returns the recorded loss.

    Args:
        losses_indexes: The index of the loss.
        losses (torch.Tensor): The loss to record.

    Returns:
        list: The recorded loss.

    """
    losses_indexes, losses = self._preprocess_losses(losses_indexes, losses)
    for index in losses_indexes:
        self.t[index] += 1
        self.m[index] = (
            self.betas * self.m[index]
            + (1 - self.betas[index]) * losses[losses_indexes.index(index)]
        )
    self.current_losses = [
        self.m[index] / (1 - self.betas[index] ** self.t[index])
        for index in len(self.m)
    ]
    return self.current_losses

Base Class of Loss Recorder¤

conflictfree.loss_recorder.LossRecorder ¤

Base class for loss recorders.

Parameters:

Name Type Description Default
num_losses int

The number of losses to record

required
Source code in conflictfree/loss_recorder.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class LossRecorder:
    """
    Base class for loss recorders.

    Args:
        num_losses (int): The number of losses to record
    """

    def __init__(self, num_losses: int) -> None:
        self.num_losses = num_losses
        self.current_losses = [0.0 for i in range(num_losses)]

    def record_loss(
        self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
    ) -> list:
        """
        Records the given loss and returns the recorded losses.

        Args:
            losses_indexes: The index of the loss.
            losses (torch.Tensor): The loss to record.

        Returns:
            list: The recorded losses.

        Raises:
            NotImplementedError: If the method is not implemented.

        """
        raise NotImplementedError("record_loss method must be implemented")

    def record_all_losses(self, losses: Sequence) -> list:
        """
        Records all the losses and returns the recorded losses.

        Args:
            losses (torch.Tensor): The losses to record.

        Returns:
            list: The recorded losses.

        """
        assert len(losses) == self.num_losses, "The number of losses does not match the number of losses to be recorded."
        return self.record_loss([i for i in range(self.num_losses)], losses)

    def _preprocess_losses(
        self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
    ) -> Tuple[Sequence[int], Sequence]:
        """
        Preprocesses the losses and their indexes. Recommended to be used in the `record_loss` method.

        Args:
            losses_indexes (Union[int, Sequence[int]]): The indexes of the losses.
            losses (Union[float, Sequence]): The losses.

        Returns:
            Tuple[Sequence[int], Sequence]: A tuple containing the preprocessed losses indexes and losses.
        """
        if isinstance(losses_indexes, int):
            losses_indexes = [losses_indexes]
        if isinstance(losses, float):
            losses = [losses]
        return losses_indexes, losses
num_losses instance-attribute ¤
num_losses = num_losses
current_losses instance-attribute ¤
current_losses = [0.0 for i in range(num_losses)]
__init__ ¤
__init__(num_losses: int) -> None
Source code in conflictfree/loss_recorder.py
12
13
14
def __init__(self, num_losses: int) -> None:
    self.num_losses = num_losses
    self.current_losses = [0.0 for i in range(num_losses)]
record_loss ¤
record_loss(
    losses_indexes: Union[int, Sequence[int]],
    losses: Union[float, Sequence],
) -> list

Records the given loss and returns the recorded losses.

Parameters:

Name Type Description Default
losses_indexes Union[int, Sequence[int]]

The index of the loss.

required
losses Tensor

The loss to record.

required

Returns:

Name Type Description
list list

The recorded losses.

Raises:

Type Description
NotImplementedError

If the method is not implemented.

Source code in conflictfree/loss_recorder.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def record_loss(
    self, losses_indexes: Union[int, Sequence[int]], losses: Union[float, Sequence]
) -> list:
    """
    Records the given loss and returns the recorded losses.

    Args:
        losses_indexes: The index of the loss.
        losses (torch.Tensor): The loss to record.

    Returns:
        list: The recorded losses.

    Raises:
        NotImplementedError: If the method is not implemented.

    """
    raise NotImplementedError("record_loss method must be implemented")
record_all_losses ¤
record_all_losses(losses: Sequence) -> list

Records all the losses and returns the recorded losses.

Parameters:

Name Type Description Default
losses Tensor

The losses to record.

required

Returns:

Name Type Description
list list

The recorded losses.

Source code in conflictfree/loss_recorder.py
35
36
37
38
39
40
41
42
43
44
45
46
47
def record_all_losses(self, losses: Sequence) -> list:
    """
    Records all the losses and returns the recorded losses.

    Args:
        losses (torch.Tensor): The losses to record.

    Returns:
        list: The recorded losses.

    """
    assert len(losses) == self.num_losses, "The number of losses does not match the number of losses to be recorded."
    return self.record_loss([i for i in range(self.num_losses)], losses)