diff --git a/climada/engine/impact_forecast.py b/climada/engine/impact_forecast.py index d4afc551d4..6b18c61659 100644 --- a/climada/engine/impact_forecast.py +++ b/climada/engine/impact_forecast.py @@ -24,6 +24,7 @@ import numpy as np from ..util import log_level +from ..util.checker import size from ..util.forecast import Forecast from .impact import Impact @@ -51,8 +52,8 @@ def __init__( impact_kwargs Keyword-arguments passed to ~:py:class`climada.engine.impact.Impact`. """ - # TODO: Maybe assert array lengths? super().__init__(lead_time=lead_time, member=member, **impact_kwargs) + self._check_sizes() @classmethod def from_impact( @@ -88,3 +89,16 @@ def from_impact( imp_mat=impact.imp_mat, haz_type=impact.haz_type, ) + + def _check_sizes(self): + """Check sizes of forecast data vs. impact data. + + Raises + ------ + ValueError + If the sizes of the forecast data do not match the + :py:attr:`~climada.engine.impact.Impact.event_id` + """ + num_entries = len(self.event_id) + size(exp_len=num_entries, var=self.member, var_name="Forecast.member") + size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time") diff --git a/climada/engine/test/test_impact_forecast.py b/climada/engine/test/test_impact_forecast.py index 0d421152c2..6ada17777c 100644 --- a/climada/engine/test/test_impact_forecast.py +++ b/climada/engine/test/test_impact_forecast.py @@ -41,13 +41,15 @@ def impact(impact_kwargs): @pytest.fixture -def lead_time(): - return pd.timedelta_range(start="1 day", periods=6).to_numpy() +def lead_time(impact_kwargs): + return pd.timedelta_range( + start="1 day", periods=len(impact_kwargs["event_id"]) + ).to_numpy() @pytest.fixture -def member(): - return np.arange(6) +def member(impact_kwargs): + return np.arange(len(impact_kwargs["event_id"])) @pytest.fixture @@ -76,6 +78,12 @@ def test_impact_forecast_init(self, impact_kwargs, lead_time, member): npt.assert_array_equal(forecast1.member, member) self.assert_impact_kwargs(forecast1, **impact_kwargs) + def test_impact_forecast_init_error(self, impact, impact_kwargs, lead_time, member): + with pytest.raises(ValueError, match="Forecast.lead_time"): + ImpactForecast(lead_time=lead_time[:-2], member=member, **impact_kwargs) + with pytest.raises(ValueError, match="Forecast.member"): + ImpactForecast.from_impact(impact, lead_time=lead_time, member=member[1:]) + def test_impact_forecast_from_impact( self, impact_forecast, impact_kwargs, lead_time, member ): diff --git a/climada/hazard/forecast.py b/climada/hazard/forecast.py index c2705134a0..5130e66af1 100644 --- a/climada/hazard/forecast.py +++ b/climada/hazard/forecast.py @@ -23,8 +23,9 @@ import numpy as np -from climada.hazard.base import Hazard -from climada.util.forecast import Forecast +from ..util.checker import size +from ..util.forecast import Forecast +from .base import Hazard LOGGER = logging.getLogger(__name__) @@ -52,6 +53,7 @@ def __init__( py:meth`~climada.hazard.base.Hazard.__init__` for details. """ super().__init__(lead_time=lead_time, member=member, **hazard_kwargs) + self._check_sizes() @classmethod def from_hazard(cls, hazard: Hazard, lead_time: np.ndarray, member: np.ndarray): @@ -89,3 +91,16 @@ def from_hazard(cls, hazard: Hazard, lead_time: np.ndarray, member: np.ndarray): intensity=hazard.intensity, fraction=hazard.fraction, ) + + def _check_sizes(self): + """Check sizes of forecast data vs. hazard data. + + Raises + ------ + ValueError + If the sizes of the forecast data do not match the + :py:attr:`~climada.hazard.base.Hazard.event_id` + """ + num_entries = len(self.event_id) + size(exp_len=num_entries, var=self.member, var_name="Forecast.member") + size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time") diff --git a/climada/hazard/test/test_forecast.py b/climada/hazard/test/test_forecast.py index 646ccaa0cf..54cc37a4e1 100644 --- a/climada/hazard/test/test_forecast.py +++ b/climada/hazard/test/test_forecast.py @@ -41,13 +41,13 @@ def hazard(haz_kwargs): @pytest.fixture -def lead_time(): - return pd.timedelta_range("1h", periods=6).to_numpy() +def lead_time(haz_kwargs): + return pd.timedelta_range("1h", periods=len(haz_kwargs["event_id"])).to_numpy() @pytest.fixture -def member(): - return np.arange(6) +def member(haz_kwargs): + return np.arange(len(haz_kwargs["event_id"])) @pytest.fixture @@ -78,6 +78,13 @@ def test_init_hazard_forecast(haz_fc, member, lead_time, haz_kwargs): assert_hazard_kwargs(haz_fc, **haz_kwargs) +def test_init_hazard_forecast_error(hazard, member, lead_time, haz_kwargs): + with pytest.raises(ValueError, match="Forecast.lead_time"): + HazardForecast(lead_time=lead_time[:-2], member=member, **haz_kwargs) + with pytest.raises(ValueError, match="Forecast.member"): + HazardForecast.from_hazard(hazard, lead_time=lead_time, member=member[1:]) + + def test_from_hazard(lead_time, member, hazard, haz_kwargs): haz_fc_from_haz = HazardForecast.from_hazard( hazard, lead_time=lead_time, member=member