Custom Extensions in APEBench - Example of a new metric¤
This is an example of how to extend the components system in APEBench
.
import jax
import jax.numpy as jnp
import apebench
import seaborn as sns
Register a new entry in the corresponding dictionary of the component. Here,
this is apebench.components.metric_dict
. Ensure that your constructor complies
with the required interface. Here, this is supposed to the a function which
takes the "metric_config"
(for example to supply further arguments to your
metric). It must then return another function which processes a pred
array and
ref
array to a scalar value. These arrays always have a leading batch axis
(can be singleton), a subsequent channel axis (can be singleton), and then one,
two, or three spatial axes for 1d, 2d, or 3d data, respectively.
Let's implement a very unnecessary metric which computes the difference between the two arrays and extracts the zeroth-channel and third spatial point (assuming we are in 1d) and then takes the mean over the batch axis.
apebench.components.metric_dict[
"my_crazy_metric"
] = lambda metric_config: lambda pred, ref: jnp.mean((pred - ref)[..., 0, 3])
Let's instantiate the difficulty advection scenario, reduce the number of training steps to 100 (to have the notebook run quicker) and have this new metric as the only metric that is supposed to be reported.
adv_scene = apebench.scenarios.difficulty.Advection(
optim_config="adam;100;constant;1e-3",
report_metrics="my_crazy_metric",
)
Execute the scenario
data, nets = adv_scene()
Then we can melt the new metric out of the results
metrics_data = apebench.melt_metrics(data, metric_name="my_crazy_metric")
metrics_data
And visualize its rollout
sns.lineplot(
data=metrics_data,
x="time_step",
y="my_crazy_metric",
)