diff --git a/climada/hazard/forecast.py b/climada/hazard/forecast.py index b09e1a44e3..a8c5cdc543 100644 --- a/climada/hazard/forecast.py +++ b/climada/hazard/forecast.py @@ -105,6 +105,16 @@ def _check_sizes(self): size(exp_len=num_entries, var=self.member, var_name="Forecast.member") size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time") + @classmethod + def concat(cls, haz_list: list): + """Concatenate multiple HazardForecast instances and return a new object""" + if len(haz_list) == 0: + return cls() + hazard = Hazard.concat(haz_list) + lead_time = np.concatenate(tuple(haz.lead_time for haz in haz_list)) + member = np.concatenate(tuple(haz.member for haz in haz_list)) + return cls.from_hazard(hazard, lead_time=lead_time, member=member) + def select( self, member=None, diff --git a/climada/hazard/test/test_forecast.py b/climada/hazard/test/test_forecast.py index ac1a726965..cb94767241 100644 --- a/climada/hazard/test/test_forecast.py +++ b/climada/hazard/test/test_forecast.py @@ -95,16 +95,37 @@ def test_from_hazard(lead_time, member, hazard, haz_kwargs): assert_hazard_kwargs(haz_fc_from_haz, **haz_kwargs) -@pytest.mark.skip("Concat from base class does not work") -def test_hazard_forecast_concat(haz_fc, lead_time, member): - haz_fc1 = haz_fc.select(event_id=[1, 2]) - haz_fc2 = haz_fc.select(event_id=[3, 4]) - haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2]) - assert isinstance(haz_fc_concat, HazardForecast) - npt.assert_array_equal( - haz_fc_concat.lead_time, np.concatenate([lead_time, lead_time]) - ) - npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member])) +class TestHazardForecastConcat: + + def test_concat(self, haz_fc, lead_time, member, haz_kwargs): + haz_fc1 = haz_fc.select(event_id=[3]) + haz_fc2 = HazardForecast( + haz_type=haz_kwargs["haz_type"], frequency_unit=haz_kwargs["frequency_unit"] + ) # Empty hazard + haz_fc3 = haz_fc.select(event_id=[1, 2]) + haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2, haz_fc3]) + assert isinstance(haz_fc_concat, HazardForecast) + assert haz_fc_concat.size == 3 + npt.assert_array_equal( + haz_fc_concat.lead_time, np.concatenate((lead_time[2:3], lead_time[0:2])) + ) + npt.assert_array_equal( + haz_fc_concat.member, np.concatenate((member[2:3], member[0:2])) + ) + npt.assert_array_equal(haz_fc_concat.event_id, [3, 1, 2]) + + def test_empty_list(self): + haz_concat = HazardForecast.concat([]) + assert isinstance(haz_concat, HazardForecast) + assert haz_concat.size == 0 + npt.assert_array_equal(haz_concat.lead_time, []) + npt.assert_array_equal(haz_concat.event_id, []) + + def test_type_fail(self, haz_fc, hazard): + with pytest.raises(TypeError, match="different classes"): + HazardForecast.concat([haz_fc, hazard]) + with pytest.raises(TypeError, match="different classes"): + Hazard.concat([haz_fc, hazard]) class TestSelect: diff --git a/climada/util/forecast.py b/climada/util/forecast.py index 6884c90ef3..94b9751a8b 100644 --- a/climada/util/forecast.py +++ b/climada/util/forecast.py @@ -52,9 +52,13 @@ def __init__( """ self.lead_time = ( - np.asarray(lead_time) if lead_time is not None else np.array([]) + np.asarray(lead_time) + if lead_time is not None + else np.array([], dtype="timedelta64[ns]") + ) + self.member = ( + np.asarray(member) if member is not None else np.array([], dtype="int") ) - self.member = np.asarray(member) if member is not None else np.array([]) super().__init__(**kwargs) def idx_member(self, member: np.ndarray) -> np.ndarray: