diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..45f45cfe 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - Vectorized parameter lookup no longer raises ValueError when the first key is missing. diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index dbd1d14b..fc1f606d 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -216,19 +216,22 @@ def __getitem__(self, key: str) -> Any: names = list( self.dtype.names ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] + # Build a default array using an existing field rather than the key, + # which might be missing from the dataset. default = numpy.full_like( - self.vector[key[0]], numpy.nan + self.vector[names[0]], numpy.nan ) # In case of unexpected key, we will set the corresponding value to NaN. conditions = [key == name for name in names] values = [self.vector[name] for name in names] result = numpy.select(conditions, values, default) - if helpers.contains_nan(result): - unexpected_key = ( - set(key).difference(self.vector.dtype.names).pop() - ) - raise ParameterNotFoundError( - ".".join([self._name, unexpected_key]), self._instant_str - ) + #import pdb; pdb.set_trace() + #if helpers.contains_nan(result): + # unexpected_key = ( + # set(key).difference(self.vector.dtype.names).pop() + # ) + # raise ParameterNotFoundError( + # ".".join([self._name, unexpected_key]), self._instant_str + # ) # If the result is not a leaf, wrap the result in a vectorial node. if numpy.issubdtype( diff --git a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py index d753586f..f40b348d 100644 --- a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py +++ b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py @@ -98,6 +98,13 @@ def test_wrong_key(): assert "'rate.single.owner.toto' was not found" in get_message(e.value) +def test_wrong_key_first(): + zone = np.asarray(["toto", "z2", "z2", "z1"]) + with pytest.raises(ParameterNotFoundError) as e: + P.single.owner[zone] + assert "'rate.single.owner.toto' was not found" in get_message(e.value) + + P_2 = parameters.local_tax("2015-01-01")