Simple Advection 1D Emulation
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import apebench
advection_1d_scenario = apebench.scenarios.difficulty.Advection()
advection_1d_scenario
ref_trjs = advection_1d_scenario.get_ref_sample_data()
ref_trjs.shape
plt.figure(figsize=(10, 3))
plt.imshow(
ref_trjs[0, :, 0, :].T,
vmin=-1,
vmax=1,
cmap="RdBu_r",
aspect="auto",
origin="lower",
)
plt.colorbar()
plt.xlabel("Time")
plt.ylabel("Space")
data, trained_model = advection_1d_scenario(
task_config="predict",
network_config="Conv;34;10;relu",
start_seed=0,
num_seeds=10,
)
data
trained_model
loss_data = apebench.melt_loss(data)
loss_data
sns.lineplot(data=loss_data, x="update_step", y="train_loss")
plt.yscale("log")
plt.grid()
metric_data = apebench.melt_metrics(data)
sns.lineplot(data=metric_data, x="time_step", y="mean_nRMSE")
plt.ylim(-0.05, 1.05)
plt.grid()
sample_rollout_data = apebench.melt_sample_rollouts(data)
sample_rollout_data
import numpy as np
plt.imshow(
np.array(sample_rollout_data["sample_rollout"][0])[:, 0, :].T,
vmin=-1,
vmax=1,
cmap="RdBu_r",
aspect="auto",
origin="lower",
)
Running a study of experiments¤
Comparing the performance of a Conv Net and a FNO
CONFIGS = [
{
"scenario": "norm_adv",
"task": "predict",
"net": net,
"train": "one",
"start_seed": 0,
"num_seeds": 10,
}
for net in ["Conv;26;10;relu", "FNO;12;8;4;gelu"]
]
(
df_metric,
df_loss,
df_sample_rollout,
network_list,
) = apebench.run_study_convenience(
CONFIGS,
do_loss=True,
)
df_loss
facet = sns.relplot(
data=df_loss,
x="update_step",
y="train_loss",
hue="net",
kind="line",
)
facet.set(yscale="log")
plt.grid()
facet = sns.relplot(
data=df_metric,
x="time_step",
y="mean_nRMSE",
hue="net",
kind="line",
aspect=2,
facet_kws=dict(ylim=(-0.05, 1.05)),
)
plt.grid()