diff --git a/climada/util/forecast.py b/climada/util/forecast.py index eb7cc7fc14..6884c90ef3 100644 --- a/climada/util/forecast.py +++ b/climada/util/forecast.py @@ -56,3 +56,35 @@ def __init__( ) self.member = np.asarray(member) if member is not None else np.array([]) super().__init__(**kwargs) + + def idx_member(self, member: np.ndarray) -> np.ndarray: + """Return boolean array where self.member == member using numpy.isin() + + Parameters + ---------- + member : np.ndarray + Array of ensemble members (ints) for which to return an indexer + + Returns + ------- + np.ndarray + Boolean array where self.member is in member. + """ + + return np.isin(self.member, member) + + def idx_lead_time(self, lead_time: np.ndarray) -> np.ndarray: + """Return boolean array where self.lead_time == lead_time using numpy.isin() + + Parameters + ---------- + lead_time : np.ndarray + Array of lead times (numpy.timedelta64) for which to return an indexer + + Returns + ------- + np.ndarray + Boolean array where self.lead_time is in lead_time. + """ + + return np.isin(self.lead_time, lead_time) diff --git a/climada/util/test/test_forecast.py b/climada/util/test/test_forecast.py index 54d11e6622..8f1fcab0e5 100644 --- a/climada/util/test/test_forecast.py +++ b/climada/util/test/test_forecast.py @@ -50,3 +50,46 @@ def test_forecast_init(): forecast = Forecast(lead_time=lead_times_seconds, member=[1, 2, 3]) npt.assert_array_equal(forecast.lead_time, lead_times_seconds, strict=True) assert forecast.lead_time.dtype == np.dtype("timedelta64[ns]") + + +def test_idx_member(): + """Test idx_member method of Forecast class.""" + forecast = Forecast(member=np.array([1, 2, 3, 4])) + + idx = forecast.idx_member(1) + npt.assert_array_equal(idx, np.array([True, False, False, False]), strict=True) + + idx = forecast.idx_member(np.array([2, 4])) + npt.assert_array_equal(idx, np.array([False, True, False, True]), strict=True) + + idx = forecast.idx_member([2, 4]) + npt.assert_array_equal(idx, np.array([False, True, False, True]), strict=True) + + idx = forecast.idx_member(None) + npt.assert_array_equal(idx, np.array([False, False, False, False]), strict=True) + + # Try once with inconsitent types + forecast = Forecast(member=np.array(["1", -2, np.nan])) + npt.assert_array_equal( + forecast.idx_member([np.nan, "1"]), np.array([True, False, True]), strict=True + ) + + +def test_idx_lead_time(): + """Test idx_lead_time method of Forecast class.""" + forecast = Forecast( + lead_time=pd.timedelta_range(start="1 day", periods=4).to_numpy() + ) + + idx = forecast.idx_lead_time( + pd.timedelta_range(start="1 day", periods=4).to_numpy()[::2] + ) + npt.assert_array_equal(idx, np.array([True, False, True, False]), strict=True) + + idx = forecast.idx_lead_time( + pd.timedelta_range(start="1 day", periods=4).to_numpy()[0] + ) + npt.assert_array_equal(idx, np.array([True, False, False, False]), strict=True) + + idx = forecast.idx_lead_time(None) + npt.assert_array_equal(idx, np.array([False, False, False, False]), strict=True)