diff --git a/climada/engine/impact_forecast.py b/climada/engine/impact_forecast.py index 1406f4ae5f..2160b18da7 100644 --- a/climada/engine/impact_forecast.py +++ b/climada/engine/impact_forecast.py @@ -184,3 +184,59 @@ def _check_sizes(self): num_entries = len(self.event_id) size(exp_len=num_entries, var=self.member, var_name="Forecast.member") size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time") + + def select( + self, + event_ids=None, + event_names=None, + dates=None, + coord_exp=None, + reset_frequency=False, + member=None, + lead_time=None, + ): + """Select entries based on the parameters and return a new instance. + The selection will contain the intersection of all given parameters. + + Parameters + ---------- + member : Sequence of ints + Ensemble members to select + lead_time : Sequence of numpy.timedelta64 + Lead times to select + + Returns + ------- + ImpactForecast + + See Also + -------- + :py:meth:`~climada.engine.impact.Impact.select` + """ + if member is not None or lead_time is not None: + mask_member = ( + self.idx_member(member) + if member is not None + else np.full_like(self.member, True, dtype=bool) + ) + mask_lead_time = ( + self.idx_lead_time(lead_time) + if lead_time is not None + else np.full_like(self.lead_time, True, dtype=bool) + ) + event_id_from_forecast_mask = np.asarray(self.event_id)[ + (mask_member & mask_lead_time) + ] + event_ids = ( + np.intersect1d(event_ids, event_id_from_forecast_mask) + if event_ids is not None + else event_id_from_forecast_mask + ) + + return super().select( + event_ids=event_ids, + event_names=event_names, + dates=dates, + coord_exp=coord_exp, + reset_frequency=reset_frequency, + ) diff --git a/climada/engine/test/test_impact_forecast.py b/climada/engine/test/test_impact_forecast.py index 33566acd5a..94655c8b17 100644 --- a/climada/engine/test/test_impact_forecast.py +++ b/climada/engine/test/test_impact_forecast.py @@ -92,58 +92,114 @@ def test_impact_forecast_from_impact( self.assert_impact_kwargs(impact_forecast, **impact_kwargs) -@pytest.mark.parametrize( - "var, var_select", - [("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")], -) -def test_impact_forecast_select_events( - impact_forecast, lead_time, member, impact_kwargs, var, var_select -): - """Check if Impact.select works on the derived class""" - select_mask = np.array([2, 1]) - ordered_select_mask = np.array([1, 2]) - if var == "date": - # Date needs to be a valid delta - select_mask = np.array([1, 2]) - ordered_select_mask = np.array([1, 2]) +class TestSelect: - var_value = np.array(impact_kwargs[var])[select_mask] - # event_name is a list, convert to numpy array for indexing - impact_fc = impact_forecast.select(**{var_select: var_value}) - # NOTE: Events keep their original order - npt.assert_array_equal( - impact_fc.event_id, - impact_forecast.event_id[ordered_select_mask], - ) - npt.assert_array_equal( - impact_fc.event_name, - np.array(impact_forecast.event_name)[ordered_select_mask], - ) - npt.assert_array_equal(impact_fc.date, impact_forecast.date[ordered_select_mask]) - npt.assert_array_equal( - impact_fc.frequency, impact_forecast.frequency[ordered_select_mask] - ) - npt.assert_array_equal(impact_fc.member, member[ordered_select_mask]) - npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask]) - npt.assert_array_equal( - impact_fc.imp_mat.todense(), - impact_forecast.imp_mat.todense()[ordered_select_mask], + @pytest.mark.parametrize( + "var, var_select", + [("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")], ) + def test_base_class_select( + self, impact_forecast, lead_time, member, impact_kwargs, var, var_select + ): + """Check if Impact.select works on the derived class""" + select_mask = np.array([2, 1]) + ordered_select_mask = np.array([1, 2]) + if var == "date": + # Date needs to be a valid delta + select_mask = np.array([1, 2]) + ordered_select_mask = np.array([1, 2]) + + var_value = np.array(impact_kwargs[var])[select_mask] + # event_name is a list, convert to numpy array for indexing + impact_fc = impact_forecast.select(**{var_select: var_value}) + # NOTE: Events keep their original order + npt.assert_array_equal( + impact_fc.event_id, + impact_forecast.event_id[ordered_select_mask], + ) + npt.assert_array_equal( + impact_fc.event_name, + np.array(impact_forecast.event_name)[ordered_select_mask], + ) + npt.assert_array_equal( + impact_fc.date, impact_forecast.date[ordered_select_mask] + ) + npt.assert_array_equal( + impact_fc.frequency, impact_forecast.frequency[ordered_select_mask] + ) + npt.assert_array_equal(impact_fc.member, member[ordered_select_mask]) + npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask]) + npt.assert_array_equal( + impact_fc.imp_mat.todense(), + impact_forecast.imp_mat.todense()[ordered_select_mask], + ) + def test_impact_forecast_select_exposure( + self, impact_forecast, lead_time, member, impact_kwargs + ): + """Check if Impact.select works on the derived class""" + exp_col = 0 + select_mask = np.array([exp_col]) + coord_exp = impact_kwargs["coord_exp"][select_mask] + impact_fc = impact_forecast.select(coord_exp=coord_exp) + npt.assert_array_equal(impact_fc.member, member) + npt.assert_array_equal(impact_fc.lead_time, lead_time) + npt.assert_array_equal( + impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col] + ) -def test_impact_forecast_select_exposure( - impact_forecast, lead_time, member, impact_kwargs -): - """Check if Impact.select works on the derived class""" - exp_col = 0 - select_mask = np.array([exp_col]) - coord_exp = impact_kwargs["coord_exp"][select_mask] - impact_fc = impact_forecast.select(coord_exp=coord_exp) - npt.assert_array_equal(impact_fc.member, member) - npt.assert_array_equal(impact_fc.lead_time, lead_time) - npt.assert_array_equal( - impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col] - ) + def test_derived_select_single(self, impact_forecast, lead_time, member): + imp_fc_select = impact_forecast.select(member=[2, 0]) + idx = np.array([0, 2]) + npt.assert_array_equal(imp_fc_select.event_id, impact_forecast.event_id[idx]) + npt.assert_array_equal(imp_fc_select.member, member[idx]) + npt.assert_array_equal(imp_fc_select.lead_time, lead_time[idx]) + + imp_fc_select = impact_forecast.select(lead_time=lead_time[np.array([2, 0])]) + npt.assert_array_equal(imp_fc_select.event_id, impact_forecast.event_id[idx]) + npt.assert_array_equal(imp_fc_select.member, member[idx]) + npt.assert_array_equal(imp_fc_select.lead_time, lead_time[idx]) + + def test_derived_select_intersections( + self, impact_forecast, lead_time, member, impact_kwargs + ): + imp_fc_select = impact_forecast.select(event_ids=[10, 14], member=[0, 1, 2]) + npt.assert_array_equal( + imp_fc_select.event_id, impact_forecast.event_id[np.array([0])] + ) + + imp_fc_select = impact_forecast.select( + event_ids=[10, 11, 13], member=[0, 1, 2], lead_time=lead_time[1:3] + ) + npt.assert_array_equal( + imp_fc_select.event_id, impact_forecast.event_id[np.array([1])] + ) + + # Test "outer" + impact_forecast2 = ImpactForecast( + lead_time=lead_time, + member=np.zeros_like(member, dtype="int"), + **impact_kwargs, + ) + imp_fc_select = impact_forecast2.select(event_ids=[10, 11, 13], member=[0]) + npt.assert_array_equal(imp_fc_select.event_id, [10, 11, 13]) + npt.assert_array_equal(imp_fc_select.member, [0, 0, 0]) + + def test_no_select(self, impact_forecast, impact_kwargs): + imp_fc_select = impact_forecast.select() + npt.assert_array_equal( + imp_fc_select.imp_mat.todense(), impact_forecast.imp_mat.todense() + ) + + num_centroids = len(impact_kwargs["coord_exp"]) + imp_fc_select = impact_forecast.select(event_names=["aaaaa", "foo"]) + assert imp_fc_select.imp_mat.shape == (0, num_centroids) + imp_fc_select = impact_forecast.select(event_ids=[-1, 1002]) + assert imp_fc_select.imp_mat.shape == (0, num_centroids) + imp_fc_select = impact_forecast.select(member=[-1]) + assert imp_fc_select.imp_mat.shape == (0, num_centroids) + imp_fc_select = impact_forecast.select(np.timedelta64("3", "Y")) + assert imp_fc_select.imp_mat.shape == (0, num_centroids) @pytest.mark.skip("Concat from base class does not work")