diff --git a/CHANGELOG b/CHANGELOG index af31b26..4d058c2 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,5 @@ +4.2.3 + - fix: cache ancillary params per Indentation curve (#27) 4.2.2 - ref: rewrite asserts to ValueErrors 4.2.1 diff --git a/src/nanite/indent.py b/src/nanite/indent.py index f8cceb1..1892fc2 100644 --- a/src/nanite/indent.py +++ b/src/nanite/indent.py @@ -32,6 +32,9 @@ def __init__(self, data, metadata, diskcache=None): # Curve rating (see `self.rate_quality`) self._rating = None + # ancillary param caching + self._anc_cache = None + @property def data(self): warnings.warn("Please use __getitem__ instead!", DeprecationWarning) @@ -248,6 +251,9 @@ def fit_model(self, **kwargs): # properties are the same. pass else: + # invalidate the cache + self._anc_cache = None + fitter = IndentationFitter(self) # Perform fitting # Note: if `fitter.fp["success"]` is `False`, then @@ -260,13 +266,22 @@ def fit_model(self, **kwargs): def get_ancillary_parameters(self, model_key=None): """Compute ancillary parameters for the current model""" + if self._anc_cache: + return self._anc_cache + if model_key is None: if "model_key" in self.fit_properties: model_key = self.fit_properties["model_key"] else: model_key = FP_DEFAULT["model_key"] - return model.compute_anc_parms(idnt=self, - model_key=model_key) + + anc = model.compute_anc_parms(idnt=self, + model_key=model_key) + # handle ancill cache + if self.fit_properties.get("success", False): + self._anc_cache = anc + + return anc def get_initial_fit_parameters(self, model_key=None, common_ancillaries=True, diff --git a/src/nanite/model/core.py b/src/nanite/model/core.py index fdb3d58..8c786fc 100644 --- a/src/nanite/model/core.py +++ b/src/nanite/model/core.py @@ -205,12 +205,8 @@ def compute_ancillaries(self, fd): ------- ancillaries: collections.OrderedDict key-value dictionary of ancillary parameters + """ - # TODO: - # - ancillaries are not cached yet (some ancillaries might depend on - # fitting interval or other initial parameters - take that into - # account) - # - "max_indent" actually belongs to "common_ancillaries" (see fit.py) anc_ord = OrderedDict() # general for key in ANCILLARY_COMMON: diff --git a/tests/test_fit_ancillary.py b/tests/test_fit_ancillary.py index a612478..197c4d5 100644 --- a/tests/test_fit_ancillary.py +++ b/tests/test_fit_ancillary.py @@ -8,7 +8,6 @@ from common import MockModelModule - data_path = pathlib.Path(__file__).parent / "data" jpkfile = data_path / "fmt-jpk-fd_spot3-0192.jpk-force" @@ -22,10 +21,10 @@ def compute_ancillaries(*args, **kwargs): raise ValueError("Not computed") with MockModelModule( - compute_ancillaries=compute_ancillaries, - parameter_anc_keys=["J"], - parameter_anc_names=["ancillary J guess"], - parameter_anc_units=["Pa"], + compute_ancillaries=compute_ancillaries, + parameter_anc_keys=["J"], + parameter_anc_names=["ancillary J guess"], + parameter_anc_units=["Pa"], model_key="test2"): # We need to perform preprocessing first, if we want to get the # correct initial contact point. @@ -45,10 +44,10 @@ def test_simple_ancillary_override(): idnt = ds1[0] with MockModelModule( - compute_ancillaries=lambda x: {"E": 1580}, - parameter_anc_keys=["E"], - parameter_anc_names=["ancillary E guess"], - parameter_anc_units=["Pa"], + compute_ancillaries=lambda x: {"E": 1580}, + parameter_anc_keys=["E"], + parameter_anc_names=["ancillary E guess"], + parameter_anc_units=["Pa"], model_key="test1"): # We need to perform preprocessing first, if we want to get the # correct initial contact point. @@ -71,10 +70,10 @@ def test_simple_ancillary_override_nan(): idnt = ds1[0] with MockModelModule( - compute_ancillaries=lambda x: {"E": np.nan}, - parameter_anc_keys=["E"], - parameter_anc_names=["ancillary E guess"], - parameter_anc_units=["Pa"], + compute_ancillaries=lambda x: {"E": np.nan}, + parameter_anc_keys=["E"], + parameter_anc_names=["ancillary E guess"], + parameter_anc_units=["Pa"], model_key="test2"): # We need to perform preprocessing first, if we want to get the # correct initial contact point. @@ -89,3 +88,48 @@ def test_simple_ancillary_override_nan(): 1584.8876592662375, atol=1, rtol=0) + + +def test_request_ancillary_parameters(): + """request the ancillary parameters after fitting""" + ds1 = nanite.IndentationGroup(jpkfile) + idnt = ds1[0] + + model_key = "test4" + with MockModelModule( + compute_ancillaries=lambda x: {"amazing_ancillary": 42.314}, + parameter_anc_keys=["amazing_ancillary"], + parameter_anc_names=["Amazing Ancillary"], + parameter_anc_units=["m/s"], + model_key=model_key): + # We need to perform preprocessing first, if we want to get the + # correct initial contact point. + idnt.apply_preprocessing(["compute_tip_position"]) + # We set the baseline fixed, because this test was written so) + params_initial = idnt.get_initial_fit_parameters(model_key=model_key) + params_initial["baseline"].set(vary=False) + # the old ancillary cache is invalidated just before fitting + idnt.fit_model(model_key=model_key, + params_initial=params_initial) + + # ancillary parameters are not yet requested + assert idnt._anc_cache is None + + # actually request the ancillary parameters, as done in pyjibe + anc = idnt.get_ancillary_parameters( + model_key=model_key) + + # check ancillaries + assert idnt._anc_cache is not None + assert idnt._anc_cache == anc + # max_indent is a common ancillary + assert np.allclose(anc["max_indent"], 3.669487775650337e-07, + atol=1e-10, rtol=0) + # new ancillary for this model + assert np.allclose(anc["amazing_ancillary"], 42.314, + atol=1, rtol=0) + + # check params_initial and params_fitted + assert idnt.fit_properties["params_initial"]["E"].value == 3000 + assert np.allclose(idnt.fit_properties["params_fitted"]["E"].value, + 1584.8876592662375, atol=1, rtol=0)