Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
4.2.3
- fix: cache ancillary params per Indentation curve (#27)
4.2.2
- ref: rewrite asserts to ValueErrors
4.2.1
Expand Down
19 changes: 17 additions & 2 deletions src/nanite/indent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(self, data, metadata, diskcache=None):
# Curve rating (see `self.rate_quality`)
self._rating = None

# ancillary param caching
self._anc_cache = None

@property
def data(self):
warnings.warn("Please use __getitem__ instead!", DeprecationWarning)
Expand Down Expand Up @@ -248,6 +251,9 @@ def fit_model(self, **kwargs):
# properties are the same.
pass
else:
# invalidate the cache
self._anc_cache = None

fitter = IndentationFitter(self)
# Perform fitting
# Note: if `fitter.fp["success"]` is `False`, then
Expand All @@ -260,13 +266,22 @@ def fit_model(self, **kwargs):

def get_ancillary_parameters(self, model_key=None):
"""Compute ancillary parameters for the current model"""
if self._anc_cache:
return self._anc_cache

if model_key is None:
if "model_key" in self.fit_properties:
model_key = self.fit_properties["model_key"]
else:
model_key = FP_DEFAULT["model_key"]
return model.compute_anc_parms(idnt=self,
model_key=model_key)

anc = model.compute_anc_parms(idnt=self,
model_key=model_key)
# handle ancill cache
if self.fit_properties.get("success", False):
self._anc_cache = anc

return anc

def get_initial_fit_parameters(self, model_key=None,
common_ancillaries=True,
Expand Down
6 changes: 1 addition & 5 deletions src/nanite/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,8 @@ def compute_ancillaries(self, fd):
-------
ancillaries: collections.OrderedDict
key-value dictionary of ancillary parameters

"""
# TODO:
# - ancillaries are not cached yet (some ancillaries might depend on
# fitting interval or other initial parameters - take that into
# account)
# - "max_indent" actually belongs to "common_ancillaries" (see fit.py)
anc_ord = OrderedDict()
# general
for key in ANCILLARY_COMMON:
Expand Down
70 changes: 57 additions & 13 deletions tests/test_fit_ancillary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from common import MockModelModule


data_path = pathlib.Path(__file__).parent / "data"
jpkfile = data_path / "fmt-jpk-fd_spot3-0192.jpk-force"

Expand All @@ -22,10 +21,10 @@ def compute_ancillaries(*args, **kwargs):
raise ValueError("Not computed")

with MockModelModule(
compute_ancillaries=compute_ancillaries,
parameter_anc_keys=["J"],
parameter_anc_names=["ancillary J guess"],
parameter_anc_units=["Pa"],
compute_ancillaries=compute_ancillaries,
parameter_anc_keys=["J"],
parameter_anc_names=["ancillary J guess"],
parameter_anc_units=["Pa"],
model_key="test2"):
# We need to perform preprocessing first, if we want to get the
# correct initial contact point.
Expand All @@ -45,10 +44,10 @@ def test_simple_ancillary_override():
idnt = ds1[0]

with MockModelModule(
compute_ancillaries=lambda x: {"E": 1580},
parameter_anc_keys=["E"],
parameter_anc_names=["ancillary E guess"],
parameter_anc_units=["Pa"],
compute_ancillaries=lambda x: {"E": 1580},
parameter_anc_keys=["E"],
parameter_anc_names=["ancillary E guess"],
parameter_anc_units=["Pa"],
model_key="test1"):
# We need to perform preprocessing first, if we want to get the
# correct initial contact point.
Expand All @@ -71,10 +70,10 @@ def test_simple_ancillary_override_nan():
idnt = ds1[0]

with MockModelModule(
compute_ancillaries=lambda x: {"E": np.nan},
parameter_anc_keys=["E"],
parameter_anc_names=["ancillary E guess"],
parameter_anc_units=["Pa"],
compute_ancillaries=lambda x: {"E": np.nan},
parameter_anc_keys=["E"],
parameter_anc_names=["ancillary E guess"],
parameter_anc_units=["Pa"],
model_key="test2"):
# We need to perform preprocessing first, if we want to get the
# correct initial contact point.
Expand All @@ -89,3 +88,48 @@ def test_simple_ancillary_override_nan():
1584.8876592662375,
atol=1,
rtol=0)


def test_request_ancillary_parameters():
"""request the ancillary parameters after fitting"""
ds1 = nanite.IndentationGroup(jpkfile)
idnt = ds1[0]

model_key = "test4"
with MockModelModule(
compute_ancillaries=lambda x: {"amazing_ancillary": 42.314},
parameter_anc_keys=["amazing_ancillary"],
parameter_anc_names=["Amazing Ancillary"],
parameter_anc_units=["m/s"],
model_key=model_key):
# We need to perform preprocessing first, if we want to get the
# correct initial contact point.
idnt.apply_preprocessing(["compute_tip_position"])
# We set the baseline fixed, because this test was written so)
params_initial = idnt.get_initial_fit_parameters(model_key=model_key)
params_initial["baseline"].set(vary=False)
# the old ancillary cache is invalidated just before fitting
idnt.fit_model(model_key=model_key,
params_initial=params_initial)

# ancillary parameters are not yet requested
assert idnt._anc_cache is None

# actually request the ancillary parameters, as done in pyjibe
anc = idnt.get_ancillary_parameters(
model_key=model_key)

# check ancillaries
assert idnt._anc_cache is not None
assert idnt._anc_cache == anc
# max_indent is a common ancillary
assert np.allclose(anc["max_indent"], 3.669487775650337e-07,
atol=1e-10, rtol=0)
# new ancillary for this model
assert np.allclose(anc["amazing_ancillary"], 42.314,
atol=1, rtol=0)

# check params_initial and params_fitted
assert idnt.fit_properties["params_initial"]["E"].value == 3000
assert np.allclose(idnt.fit_properties["params_fitted"]["E"].value,
1584.8876592662375, atol=1, rtol=0)