Check for NaNs in the train and test data of a scenario. Also checks the
train and test data set produced by the coarse stepper if the scenario
supports a correction mode. Raises an AssertionError if NaNs are found.
Source code in apebench/_utils.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252 | def check_for_nan(scene: BaseScenario):
"""
Check for NaNs in the train and test data of a scenario. Also checks the
train and test data set produced by the coarse stepper if the scenario
supports a correction mode. Raises an AssertionError if NaNs are found.
"""
train_data = scene.get_train_data()
train_num_nans = count_nan_trjs(train_data)
assert (
train_num_nans == 0
), f"Train data has {train_num_nans} trajectories with NaNs"
del train_data
test_data = scene.get_test_data()
test_num_nans = count_nan_trjs(test_data)
assert test_num_nans == 0, f"Test data has {test_num_nans} trajectories with NaNs"
del test_data
try:
# Some scenarios might not support a correction mode
train_data_coarse = scene.get_train_data_coarse()
train_num_nans_coarse = count_nan_trjs(train_data_coarse)
assert (
train_num_nans_coarse == 0
), f"Train data coarse has {train_num_nans_coarse} trajectories with NaNs"
del train_data_coarse
test_data_coarse = scene.get_test_data_coarse()
test_num_nans_coarse = count_nan_trjs(test_data_coarse)
assert (
test_num_nans_coarse == 0
), f"Test data coarse has {test_num_nans_coarse} trajectories with NaNs"
del test_data_coarse
except NotImplementedError:
return
|