diff --git a/climada/util/forecast.py b/climada/util/forecast.py index 84225c47fe..1fbd090db9 100644 --- a/climada/util/forecast.py +++ b/climada/util/forecast.py @@ -19,6 +19,40 @@ Define Forecast base class. """ +import numpy as np + class Forecast: - pass + """Mixin class for forecast data. + + Attributes + ---------- + lead_time : np.ndarray + Array of forecast lead times, given as datetime64 objects. + Represents the time points for which forecasts are made. + member : np.ndarray + Array of ensemble member identifiers, given as integers. + Represents different forecast ensemble members. + """ + + def __init__( + self, + lead_time: np.ndarray | None = None, + member: np.ndarray | None = None, + **kwargs, + ): + """Initialize Forecast. + + Parameters + ---------- + lead_time : np.ndarray or None, optional + Forecast lead times. Default is empty array. + member : np.ndarray or None, optional + Ensemble member identifiers. Default is empty array. + """ + + self.lead_time = ( + np.asarray(lead_time) if lead_time is not None else np.array([]) + ) + self.member = np.asarray(member) if member is not None else np.array([]) + super().__init__(**kwargs) diff --git a/climada/util/test/test_forecast.py b/climada/util/test/test_forecast.py index 196573e583..f500c4ba88 100644 --- a/climada/util/test/test_forecast.py +++ b/climada/util/test/test_forecast.py @@ -18,3 +18,35 @@ Tests for Forecast base class. """ + +import numpy as np +import numpy.testing as npt + +from climada.util.forecast import Forecast + + +def test_forecast_init(): + """Test initialization of Forecast class.""" + forecast = Forecast() + npt.assert_array_equal(forecast.lead_time, np.array([])) + npt.assert_array_equal(forecast.member, np.array([])) + + forecast = Forecast(member=np.array([1, 2])) + npt.assert_array_equal(forecast.member, np.array([1, 2]), strict=True) + + forecast = Forecast(lead_time=np.array([1, 2])) + npt.assert_array_equal(forecast.lead_time, np.array([1, 2]), strict=True) + + forecast = Forecast(lead_time=np.array([1, 2]), member=[3, 4]) + npt.assert_array_equal(forecast.lead_time, np.array([1, 2]), strict=True) + npt.assert_array_equal(forecast.member, np.array([3, 4]), strict=True) + assert isinstance(forecast.member, np.ndarray) + + # Test with datetime64 including seconds + lead_times_seconds = np.array( + ["2024-01-01T00:00:00", "2024-01-01T00:01:00", "2024-01-01"], + dtype="datetime64[s]", + ) + forecast = Forecast(lead_time=lead_times_seconds, member=[1, 2, 3]) + npt.assert_array_equal(forecast.lead_time, lead_times_seconds, strict=True) + assert forecast.lead_time.dtype == np.dtype("datetime64[s]")