-
Notifications
You must be signed in to change notification settings - Fork 7
Description
Hi @alecheckert, I ran across some quirks when trying to subsample within an SAD. Here's a code snippet:
import saspt
from saspt.dataset import StateArrayDataset
import os
import pandas as pd
module_path = os.path.dirname(os.path.dirname(saspt.__file__))
sample_csv = os.path.join(module_path, "examples",
"u2os_ht_nls_7.48ms", "region_8_7ms_trajs.csv")
settings = dict(pixel_size_um=0.16,
frame_interval=0.00748,
focal_depth=0.7,
sample_size=10,
progress_bar=True,
likelihood_type='rbme',
splitsize=10,
start_frame=0)
paths = dict(filepath=[sample_csv for _ in range(3)],
condition=['test' for _ in range(3)])
SAD = StateArrayDataset.from_kwargs(pd.DataFrame(paths),
path_col='filepath',
condition_col='condition',
**settings)
print(f"Sum of unnormalized posterior probabilities per file:",
f"{SAD.posterior_occs.sum(axis=(1,2))}",
sep="\n")
print(f"SAD.jumps_per_file attr:", SAD.jumps_per_file, sep="\n")
The problem is that subsampling is happening every time StateArrayDataset._load_tracks is called. This can happen twice while using the object (unless the user clear()s): once when calculating occupancies and another to get processed track statistics. jumps_per_file depends on the processed track stats, so it doesn't agree with the posterior occs.
A solution could be to bundle StateArrayDataset._get_processed_track_statistics() and StateArrayDataset.calc_marginal_posterior_occs() into a bigger function. I guess the subsampled detections could also be cached on the SAD object, though that could take up a lot of memory.
Happy to try to fix this; let me know what you think is the best way forward.