Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion climada/hazard/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,10 @@ def write_hdf5(self, file_name, todense=False):
# Centroids have their own write_hdf5 method,
# which is invoked at the end of this method (s.b.)
continue
elif var_name == "lead_time":
hf_data.create_dataset(
var_name, data=var_val.astype("timedelta64[ns]").astype("int64")
)
elif isinstance(var_val, sparse.csr_matrix):
if todense:
hf_data.create_dataset(var_name, data=var_val.toarray())
Expand Down Expand Up @@ -987,7 +991,11 @@ def from_hdf5(cls, file_name):
continue
if var_name == "centroids":
continue
if isinstance(var_val, np.ndarray) and var_val.ndim == 1:
if var_name == "lead_time":
hazard_kwargs[var_name] = np.array(hf_data.get(var_name)).astype(
"timedelta64[ns]"
)
elif isinstance(var_val, np.ndarray) and var_val.ndim == 1:
hazard_kwargs[var_name] = np.array(hf_data.get(var_name))
elif isinstance(var_val, sparse.csr_matrix):
hf_csr = hf_data.get(var_name)
Expand Down
17 changes: 17 additions & 0 deletions climada/hazard/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,20 @@ def test_hazard_forecast_select(haz_fc, lead_time, member):
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([3, 0])])
npt.assert_array_equal(haz_fc_select.member, member[np.array([3, 0])])
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([3, 0])])


def test_write_read_hazard_forecast(haz_fc, tmp_path):

file_name = tmp_path / "test_hazard_forecast.h5"

haz_fc.write_hdf5(file_name)
haz_fc_read = HazardForecast.from_hdf5(file_name)

assert haz_fc_read.lead_time.dtype.kind == np.dtype("timedelta64").kind

for key in haz_fc.__dict__.keys():
if key in ["intensity", "fraction"]:
(haz_fc.__dict__[key] != haz_fc_read.__dict__[key]).nnz == 0
else:
# npt.assert_array_equal also works for comparing int, float or list
npt.assert_array_equal(haz_fc.__dict__[key], haz_fc_read.__dict__[key])
Loading