Skip to content

Utilities¤

apebench.check_for_nan ¤

check_for_nan(scene: BaseScenario)

Check for NaNs in the train and test data of a scenario. Also checks the train and test data set produced by the coarse stepper if the scenario supports a correction mode. Raises an AssertionError if NaNs are found.

Source code in apebench/_utils.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def check_for_nan(scene: BaseScenario):
    """
    Check for NaNs in the train and test data of a scenario. Also checks the
    train and test data set produced by the coarse stepper if the scenario
    supports a correction mode. Raises an AssertionError if NaNs are found.
    """
    train_data = scene.get_train_data()

    train_num_nans = count_nan_trjs(train_data)
    assert (
        train_num_nans == 0
    ), f"Train data has {train_num_nans} trajectories with NaNs"

    del train_data

    test_data = scene.get_test_data()

    test_num_nans = count_nan_trjs(test_data)
    assert test_num_nans == 0, f"Test data has {test_num_nans} trajectories with NaNs"

    del test_data

    try:
        # Some scenarios might not support a correction mode
        train_data_coarse = scene.get_train_data_coarse()

        train_num_nans_coarse = count_nan_trjs(train_data_coarse)
        assert (
            train_num_nans_coarse == 0
        ), f"Train data coarse has {train_num_nans_coarse} trajectories with NaNs"

        del train_data_coarse

        test_data_coarse = scene.get_test_data_coarse()

        test_num_nans_coarse = count_nan_trjs(test_data_coarse)
        assert (
            test_num_nans_coarse == 0
        ), f"Test data coarse has {test_num_nans_coarse} trajectories with NaNs"

        del test_data_coarse
    except NotImplementedError:
        return