Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions climada/hazard/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def _check_sizes(self):
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")

@classmethod
def concat(cls, haz_list: list):
"""Concatenate multiple HazardForecast instances and return a new object"""
if len(haz_list) == 0:
return cls()
hazard = Hazard.concat(haz_list)
lead_time = np.concatenate(tuple(haz.lead_time for haz in haz_list))
member = np.concatenate(tuple(haz.member for haz in haz_list))
return cls.from_hazard(hazard, lead_time=lead_time, member=member)

def select(
self,
member=None,
Expand Down
41 changes: 31 additions & 10 deletions climada/hazard/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,37 @@ def test_from_hazard(lead_time, member, hazard, haz_kwargs):
assert_hazard_kwargs(haz_fc_from_haz, **haz_kwargs)


@pytest.mark.skip("Concat from base class does not work")
def test_hazard_forecast_concat(haz_fc, lead_time, member):
haz_fc1 = haz_fc.select(event_id=[1, 2])
haz_fc2 = haz_fc.select(event_id=[3, 4])
haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2])
assert isinstance(haz_fc_concat, HazardForecast)
npt.assert_array_equal(
haz_fc_concat.lead_time, np.concatenate([lead_time, lead_time])
)
npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member]))
class TestHazardForecastConcat:

def test_concat(self, haz_fc, lead_time, member, haz_kwargs):
haz_fc1 = haz_fc.select(event_id=[3])
haz_fc2 = HazardForecast(
haz_type=haz_kwargs["haz_type"], frequency_unit=haz_kwargs["frequency_unit"]
) # Empty hazard
haz_fc3 = haz_fc.select(event_id=[1, 2])
haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2, haz_fc3])
assert isinstance(haz_fc_concat, HazardForecast)
assert haz_fc_concat.size == 3
npt.assert_array_equal(
haz_fc_concat.lead_time, np.concatenate((lead_time[2:3], lead_time[0:2]))
)
npt.assert_array_equal(
haz_fc_concat.member, np.concatenate((member[2:3], member[0:2]))
)
npt.assert_array_equal(haz_fc_concat.event_id, [3, 1, 2])

def test_empty_list(self):
haz_concat = HazardForecast.concat([])
assert isinstance(haz_concat, HazardForecast)
assert haz_concat.size == 0
npt.assert_array_equal(haz_concat.lead_time, [])
npt.assert_array_equal(haz_concat.event_id, [])

def test_type_fail(self, haz_fc, hazard):
with pytest.raises(TypeError, match="different classes"):
HazardForecast.concat([haz_fc, hazard])
with pytest.raises(TypeError, match="different classes"):
Hazard.concat([haz_fc, hazard])


class TestSelect:
Expand Down
8 changes: 6 additions & 2 deletions climada/util/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ def __init__(
"""

self.lead_time = (
np.asarray(lead_time) if lead_time is not None else np.array([])
np.asarray(lead_time)
if lead_time is not None
else np.array([], dtype="timedelta64[ns]")
)
self.member = (
np.asarray(member) if member is not None else np.array([], dtype="int")
)
self.member = np.asarray(member) if member is not None else np.array([])
super().__init__(**kwargs)

def idx_member(self, member: np.ndarray) -> np.ndarray:
Expand Down
Loading