Skip to content

Low-Level Postprocessing Utilities¤

apebench.melt_data ¤

melt_data(
    wide_data: pd.DataFrame,
    quantity_name: Union[str, list[str]],
    uniquifier_name: str,
    *,
    base_columns: list[str] = BASE_NAMES
) -> pd.DataFrame

Melt a wide APEBench result DataFrame into a long format suitable for visualization (e.g. with seaborn or plotly).

Arguments:

  • wide_data: The wide DataFrame to melt, must contain quantity_name and base_columns as columns.
  • quantity_name: The name of the column(s) to melt.
  • uniquifier_name: The name of the column that will be used to uniquely identify the melted rows.
  • base_columns: The columns that should be kept as is in the melted DataFrame.

Returns:

  • A long DataFrame with the same columns as base_columns and the melted quantity_name.
Source code in apebench/_utils.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def melt_data(
    wide_data: pd.DataFrame,
    quantity_name: Union[str, list[str]],
    uniquifier_name: str,
    *,
    base_columns: list[str] = BASE_NAMES,
) -> pd.DataFrame:
    """
    Melt a wide APEBench result DataFrame into a long format suitable for
    visualization (e.g. with seaborn or plotly).

    **Arguments:**

    * `wide_data`: The wide DataFrame to melt, must contain `quantity_name` and
        `base_columns` as columns.
    * `quantity_name`: The name of the column(s) to melt.
    * `uniquifier_name`: The name of the column that will be used to uniquely
        identify the melted rows.
    * `base_columns`: The columns that should be kept as is in the melted
        DataFrame.

    **Returns:**

    * A long DataFrame with the same columns as `base_columns` and the melted
        `quantity_name`.
    """
    if isinstance(quantity_name, str):
        quantity_name = [
            quantity_name,
        ]
    data_melted = pd.wide_to_long(
        wide_data,
        stubnames=quantity_name,
        i=base_columns,
        j=uniquifier_name,
        sep="_",
    )
    data_melted = data_melted[quantity_name]
    data_melted = data_melted.reset_index()

    return data_melted

apebench.melt_loss ¤

melt_loss(
    wide_data: pd.DataFrame, loss_name: str = "train_loss"
) -> pd.DataFrame

Melt the loss from a wide DataFrame.

Source code in apebench/_utils.py
84
85
86
87
88
89
90
91
92
def melt_loss(wide_data: pd.DataFrame, loss_name: str = "train_loss") -> pd.DataFrame:
    """
    Melt the loss from a wide DataFrame.
    """
    return melt_data(
        wide_data,
        quantity_name=loss_name,
        uniquifier_name="update_step",
    )

apebench.melt_metrics ¤

melt_metrics(
    wide_data: pd.DataFrame,
    metric_name: Union[str, list[str]] = "mean_nRMSE",
) -> pd.DataFrame

Melt the metrics from a wide DataFrame.

Source code in apebench/_utils.py
70
71
72
73
74
75
76
77
78
79
80
81
def melt_metrics(
    wide_data: pd.DataFrame,
    metric_name: Union[str, list[str]] = "mean_nRMSE",
) -> pd.DataFrame:
    """
    Melt the metrics from a wide DataFrame.
    """
    return melt_data(
        wide_data,
        quantity_name=metric_name,
        uniquifier_name="time_step",
    )

apebench.melt_sample_rollouts ¤

melt_sample_rollouts(
    wide_data: pd.DataFrame,
    sample_rollout_name: str = "sample_rollout",
) -> pd.DataFrame

Melt the sample rollouts from a wide DataFrame.

Source code in apebench/_utils.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def melt_sample_rollouts(
    wide_data: pd.DataFrame,
    sample_rollout_name: str = "sample_rollout",
) -> pd.DataFrame:
    """
    Melt the sample rollouts from a wide DataFrame.
    """
    return melt_data(
        wide_data,
        quantity_name=sample_rollout_name,
        uniquifier_name="sample_index",
    )