diff --git a/climada/engine/test/test_impact_forecast.py b/climada/engine/test/test_impact_forecast.py index 6ada17777c..cc461b6101 100644 --- a/climada/engine/test/test_impact_forecast.py +++ b/climada/engine/test/test_impact_forecast.py @@ -92,16 +92,58 @@ def test_impact_forecast_from_impact( self.assert_impact_kwargs(impact_forecast, **impact_kwargs) -def test_impact_forecast_select(impact_forecast, lead_time, member, impact_kwargs): +@pytest.mark.parametrize( + "var, var_select", + [("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")], +) +def test_impact_forecast_select_events( + impact_forecast, lead_time, member, impact_kwargs, var, var_select +): """Check if Impact.select works on the derived class""" - event_ids = impact_kwargs["event_id"][np.array([2, 0])] - impact_fc = impact_forecast.select(event_ids=event_ids) + select_mask = np.array([2, 1]) + ordered_select_mask = np.array([1, 2]) + if var == "date": + # Date needs to be a valid delta + select_mask = np.array([1, 2]) + ordered_select_mask = np.array([1, 2]) + + var_value = np.array(impact_kwargs[var])[select_mask] + # event_name is a list, convert to numpy array for indexing + impact_fc = impact_forecast.select(**{var_select: var_value}) # NOTE: Events keep their original order npt.assert_array_equal( - impact_fc.event_id, impact_forecast.event_id[np.array([0, 2])] + impact_fc.event_id, + impact_forecast.event_id[ordered_select_mask], + ) + npt.assert_array_equal( + impact_fc.event_name, + np.array(impact_forecast.event_name)[ordered_select_mask], + ) + npt.assert_array_equal(impact_fc.date, impact_forecast.date[ordered_select_mask]) + npt.assert_array_equal( + impact_fc.frequency, impact_forecast.frequency[ordered_select_mask] + ) + npt.assert_array_equal(impact_fc.member, member[ordered_select_mask]) + npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask]) + npt.assert_array_equal( + impact_fc.imp_mat.todense(), + impact_forecast.imp_mat.todense()[ordered_select_mask], + ) + + +def test_impact_forecast_select_exposure( + impact_forecast, lead_time, member, impact_kwargs +): + """Check if Impact.select works on the derived class""" + exp_col = 0 + select_mask = np.array([exp_col]) + coord_exp = impact_kwargs["coord_exp"][select_mask] + impact_fc = impact_forecast.select(coord_exp=coord_exp) + npt.assert_array_equal(impact_fc.member, member) + npt.assert_array_equal(impact_fc.lead_time, lead_time) + npt.assert_array_equal( + impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col] ) - npt.assert_array_equal(impact_fc.member, member[np.array([0, 2])]) - npt.assert_array_equal(impact_fc.lead_time, lead_time[np.array([0, 2])]) @pytest.mark.skip("Concat from base class does not work") diff --git a/climada/hazard/test/test_forecast.py b/climada/hazard/test/test_forecast.py index 5e975c2885..b102ee2d17 100644 --- a/climada/hazard/test/test_forecast.py +++ b/climada/hazard/test/test_forecast.py @@ -107,13 +107,46 @@ def test_hazard_forecast_concat(haz_fc, lead_time, member): npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member])) -def test_hazard_forecast_select(haz_fc, lead_time, member): +@pytest.mark.parametrize( + "var, var_select", + [("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")], +) +def test_hazard_forecast_select(haz_fc, lead_time, member, haz_kwargs, var, var_select): """Check if Hazard.select works on the derived class""" - haz_fc_select = haz_fc.select(event_id=[4, 1]) - # NOTE: Events keep their original order - npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([3, 0])]) - npt.assert_array_equal(haz_fc_select.member, member[np.array([3, 0])]) - npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([3, 0])]) + + select_mask = np.array([3, 2]) + ordered_select_mask = np.array([3, 2]) + if var == "date": + # Date needs to be a valid delta + select_mask = np.array([2, 3]) + ordered_select_mask = np.array([2, 3]) + + var_value = np.array(haz_kwargs[var])[select_mask] + # event_name is a list, convert to numpy array for indexing + haz_fc_sel = haz_fc.select(**{var_select: var_value}) + # Note: order is preserved + npt.assert_array_equal( + haz_fc_sel.event_id, + haz_fc.event_id[ordered_select_mask], + ) + npt.assert_array_equal( + haz_fc_sel.event_name, + np.array(haz_fc.event_name)[ordered_select_mask], + ) + npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask]) + npt.assert_array_equal(haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask]) + npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask]) + npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask]) + npt.assert_array_equal( + haz_fc_sel.intensity.todense(), + haz_fc.intensity.todense()[ordered_select_mask], + ) + npt.assert_array_equal( + haz_fc_sel.fraction.todense(), + haz_fc.fraction.todense()[ordered_select_mask], + ) + + assert haz_fc_sel.centroids == haz_fc.centroids def test_write_read_hazard_forecast(haz_fc, tmp_path):