From 61fcf32b4d485d1557f0fa95ee7812fe5807b291 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 13:57:17 -0800 Subject: [PATCH 01/23] format: frontend css --- src/cala/gui/templates/index.html | 1 + src/cala/main.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/cala/gui/templates/index.html b/src/cala/gui/templates/index.html index 35d9531c..18ed7b2d 100644 --- a/src/cala/gui/templates/index.html +++ b/src/cala/gui/templates/index.html @@ -5,6 +5,7 @@ + diff --git a/src/cala/main.py b/src/cala/main.py index cd7004b7..16251372 100644 --- a/src/cala/main.py +++ b/src/cala/main.py @@ -1,3 +1,4 @@ +from datetime import datetime from pathlib import Path from typing import Annotated @@ -23,6 +24,11 @@ def main(spec: str, gui: Annotated[bool, typer.Option()] = False) -> None: if gui: uvicorn.run("cala.main:app", reload=False, reload_dirs=[Path(__file__).parent]) else: + start = datetime.now() tube = Tube.from_specification(spec) runner = SynchronousRunner(tube=tube) - runner.run() + try: + runner.run() + finally: + end = datetime.now() + print(f"Finished in {end - start}") From 6fdcb713222144da93c836411818bc355d43ff56 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 16:09:31 -0800 Subject: [PATCH 02/23] feat: add downsample --- src/cala/nodes/prep/__init__.py | 2 + src/cala/nodes/prep/downsample.py | 30 +++ tests/data/pipelines/long_recording.yaml | 222 +++++++++++++++++++++++ 3 files changed, 254 insertions(+) create mode 100644 src/cala/nodes/prep/downsample.py create mode 100644 tests/data/pipelines/long_recording.yaml diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index f0753b87..178b95c1 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,5 +1,6 @@ from .background_removal import remove_background from .denoise import Restore, blur +from .downsample import downsample from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean @@ -19,4 +20,5 @@ "Restore", "package_frame", "counter", + "downsample", ] diff --git a/src/cala/nodes/prep/downsample.py b/src/cala/nodes/prep/downsample.py new file mode 100644 index 00000000..1b37f553 --- /dev/null +++ b/src/cala/nodes/prep/downsample.py @@ -0,0 +1,30 @@ +from typing import Annotated as A + +import numpy as np +from noob import Name + +from cala.assets import Frame +from cala.models import AXIS +from cala.nodes.prep import package_frame + + +def downsample( + frames: list[Frame], x_range: tuple[int, int], y_range: tuple[int, int], t_downsample: int = 1 +) -> A[Frame, Name("frame")]: + """ + Downsampling in time and cropping in space. Must be followed by gather node, and + t_downsample has to be same as gather's parameter n value. + + :param frames: + :param x_range: + :param y_range: + :param t_downsample: + :return: + """ + arrays = [] + for frame in frames: + arrays.append(frame.array[x_range[0] : x_range[1], y_range[0] : y_range[1]]) + + return package_frame( + np.mean(arrays, axis=0), arrays[-1][AXIS.frame_coord].item() // t_downsample + ) diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml new file mode 100644 index 00000000..53c51c8f --- /dev/null +++ b/tests/data/pipelines/long_recording.yaml @@ -0,0 +1,222 @@ +noob_id: long-recording +noob_model: noob.tube.TubeSpecification +noob_version: 0.1.1.dev118+g64d81b7 + +assets: + buffer: + type: cala.assets.Buffer + params: + size: 100 + scope: runner + footprints: + type: cala.assets.Footprints + params: + sparsify: True + scope: runner + traces: + type: cala.assets.Traces + scope: runner + params: + # zarr_path: traces/ + peek_size: 100 + pix_stats: + type: cala.assets.PixStats + scope: runner + comp_stats: + type: cala.assets.CompStats + scope: runner + overlaps: + type: cala.assets.Overlaps + scope: runner + residuals: + type: cala.assets.Buffer + params: + size: 100 + scope: runner + + +nodes: + source: + type: cala.nodes.io.stream + params: + files: + - long_recording/1.avi + - long_recording/2.avi + - long_recording/3.avi + - long_recording/4.avi + - long_recording/5.avi + - long_recording/6.avi + counter: + type: cala.nodes.prep.counter + frame: + type: cala.nodes.prep.package_frame + depends: + - frame: source.value + - index: counter.idx + + #PREPROCESS BEGINS + hotpix: + type: cala.nodes.prep.blur + params: + method: median + kwargs: + ksize: 3 + depends: + - frame: frame.value + flatten: + type: cala.nodes.prep.butter + params: + kwargs: + cutoff_frequency_ratio: 0.010 + depends: + - frame: hotpix.frame + lines: + type: cala.nodes.prep.remove_mean + params: + orient: both + depends: + - frame: flatten.frame + motion: + type: cala.nodes.prep.Anchor + depends: + - frame: lines.frame + glow: + type: cala.nodes.prep.GlowRemover + depends: + - frame: motion.frame + gather: + type: gather + params: + n: 4 + depends: + - glow.frame + downsample: # 20fps -> 5fps + type: cala.nodes.prep.downsample + params: + x_range: [ 250, 350 ] + y_range: [ 250, 350 ] + t_downsample: 4 + depends: + - frames: gather.value + size_est: + type: cala.nodes.prep.SizeEst + params: + hardset_radius: 8 + depends: + - frame: downsample.frame + cache: + type: cala.nodes.buffer.fill_buffer + depends: + - buffer: assets.buffer + - frame: downsample.frame + #PREPROCESS ENDS + + # FRAME UPDATE BEGINS + trace_frame: + type: cala.nodes.traces.Tracer + params: + tol: 0.001 + max_iter: 100 + depends: + - traces: assets.traces + - footprints: assets.footprints + - frame: downsample.frame + - overlaps: assets.overlaps + pix_frame: + type: cala.nodes.pixel_stats.ingest_frame + depends: + - pixel_stats: assets.pix_stats + - frame: downsample.frame + - new_traces: trace_frame.latest_trace + - footprints: assets.footprints + comp_frame: + type: cala.nodes.component_stats.ingest_frame + depends: + - component_stats: assets.comp_stats + - frame: downsample.frame + - new_traces: trace_frame.latest_trace + footprints_frame: + type: cala.nodes.footprints.Footprinter + params: + bep: 0 + tol: 0.0001 + max_iter: 5 + ratio_lb: 0.10 + depends: + - footprints: assets.footprints + - pixel_stats: pix_frame.value + - component_stats: comp_frame.value + + residual: + type: cala.nodes.residual.Residuer + depends: + - frame: downsample.frame + - footprints: footprints_frame.footprints + - traces: assets.traces + - residuals: assets.residuals + # FRAME UPDATE ENDS + + # DETECT BEGINS + nmf: + type: cala.nodes.detect.SliceNMF + params: + min_frames: 100 + detect_thresh: 10.0 + reprod_tol: 0.005 + depends: + - residuals: residual.movie + - energy: residual.std + - detect_radius: size_est.radius + catalog: + type: cala.nodes.detect.Cataloger + params: + age_limit: 100 + shape_smooth_kwargs: + ksize: [ 5, 5 ] + sigmaX: 0 + trace_smooth_kwargs: + sigma: 2 + merge_threshold: 0.9 + val_threshold: 0.5 + cnt_threshold: 5 + depends: + - new_fps: nmf.new_fps + - new_trs: nmf.new_trs + - existing_fp: assets.footprints + - existing_tr: assets.traces + detect_update: + type: cala.nodes.detect.update_assets + depends: + - new_footprints: catalog.new_footprints + - new_traces: catalog.new_traces + - footprints: assets.footprints + - traces: assets.traces + - pixel_stats: assets.pix_stats + - component_stats: assets.comp_stats + - overlaps: assets.overlaps + - buffer: assets.buffer + # DETECT ENDS + +# merge: +# type: cala.nodes.merge.merge_existing +# params: +# merge_interval: 500 +# merge_threshold: 0.95 +# smooth_kwargs: +# sigma: 2 +# depends: +# - shapes: assets.footprints +# - traces: assets.traces +# - overlaps: assets.overlaps +# - trigger: detect_update.footprints +# merge_update: +# type: cala.nodes.detect.update_assets +# depends: +# - new_footprints: merge.footprints +# - new_traces: merge.traces +# - footprints: assets.footprints +# - traces: assets.traces +# - pixel_stats: assets.pix_stats +# - component_stats: assets.comp_stats +# - overlaps: assets.overlaps +# - buffer: assets.buffer From c0440d4dc7e2bc7afa5eec9e687125e5c4615e8d Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 16:10:33 -0800 Subject: [PATCH 03/23] tests: benchmarking moved out of tests --- tests/test_prep/test_motion.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/tests/test_prep/test_motion.py b/tests/test_prep/test_motion.py index 6d47a7ea..bff48895 100644 --- a/tests/test_prep/test_motion.py +++ b/tests/test_prep/test_motion.py @@ -64,37 +64,3 @@ def test_motion_estimation(params) -> None: # Allow 1 pixel absolute tolerance np.testing.assert_allclose(estimate, expected[1:], atol=1.0) - - -# def test_motion_with_movie(): -# """ -# For testing how well the motion correction performs with real movie -# """ -# gen = stream( -# [ -# "cala/msCam1.avi", -# "cala/msCam2.avi", -# "cala/msCam3.avi", -# "cala/msCam4.avi", -# "cala/msCam5.avi", -# ] -# ) -# -# fourcc = cv2.VideoWriter_fourcc(*"mp4v") -# out = cv2.VideoWriter("mc_test.avi", fourcc, 24.0, (752, 960)) -# -# stab = Anchor() -# -# for idx, arr in enumerate(gen): -# frame = package_frame(arr, idx) -# frame = blur(frame, method="median", kwargs={"ksize": 3}) -# frame = butter(frame, {}) -# frame = remove_mean(frame, orient="both") -# matched = stab.stabilize(frame) -# -# combined = np.concat([frame.array.values, matched.array.values], axis=0) -# -# frame_bgr = cv2.cvtColor(combined.astype(np.uint8), cv2.COLOR_GRAY2BGR) -# out.write(frame_bgr) -# -# out.release() From fbc4065094b4d80f1884901a5a142f09edab9a68 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 16:47:02 -0800 Subject: [PATCH 04/23] debug: retain older traces than peek_size --- src/cala/nodes/prep/__init__.py | 4 ++-- src/cala/nodes/traces.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 178b95c1..00f5376f 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,12 +1,12 @@ from .background_removal import remove_background from .denoise import Restore, blur -from .downsample import downsample +from .downsample import downsample # noqa from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .wrap import counter, package_frame +from .wrap import counter, package_frame # isort:skip __all__ = [ "blur", diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index 520886ce..ac8ed617 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -81,7 +81,7 @@ def ingest_frame( ) traces.update(updated_tr, append_dim=AXIS.frames_dim) else: - traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) + traces.array = xr.concat([traces.full_array(), updated_traces], dim=AXIS.frames_dim) return PopSnap.from_array(updated_traces) From eb1ba0551816a0ad6fe6ee720367b65d6ad12491 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 19:17:10 -0800 Subject: [PATCH 05/23] feat: asset save / natsorting videos --- pdm.lock | 12 +++++- pyproject.toml | 1 + src/cala/assets.py | 24 ++++++------ src/cala/nodes/io.py | 50 ++++++++++++++++++++++-- tests/data/pipelines/long_recording.yaml | 14 +++++++ 5 files changed, 84 insertions(+), 17 deletions(-) diff --git a/pdm.lock b/pdm.lock index a3d86080..66aeef00 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "gui", "tests"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:68f2cd4f887a77af253a18c01590866b2fd460982b397836133b8fa09d808e17" +content_hash = "sha256:b791ad4b9d500b62299dfc14d6b10c6b75e16de030f5c33a1234c6b6c7372f5d" [[metadata.targets]] requires_python = ">=3.11" @@ -1715,6 +1715,16 @@ files = [ {file = "myst_parser-4.0.1.tar.gz", hash = "sha256:5cfea715e4f3574138aecbf7d54132296bfd72bb614d31168f48c477a830a7c4"}, ] +[[package]] +name = "natsort" +version = "8.4.0" +requires_python = ">=3.7" +summary = "Simple yet flexible natural sorting in Python." +files = [ + {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, + {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, +] + [[package]] name = "nbclient" version = "0.10.2" diff --git a/pyproject.toml b/pyproject.toml index dbdf603d..95c38602 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "typer>=0.15.3", "xarray-validate>=0.0.2", "noob @ git+https://github.com/miniscope/noob.git", + "natsort>=8.4.0", ] keywords = [ "pipeline", diff --git a/src/cala/assets.py b/src/cala/assets.py index 74092b21..2a8168c5 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -22,6 +22,8 @@ class Asset(BaseModel): validate_schema: bool = False array_: AssetType = None sparsify: ClassVar[bool] = False + zarr_path: Path | None = None + """relative to config.user_data_dir""" _entity: ClassVar[Entity] model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) @@ -61,6 +63,16 @@ def validate_array_schema(self) -> Self: return self + @field_validator("zarr_path", mode="after") + @classmethod + def validate_zarr_path(cls, value: Path | None) -> Path | None: + if value is None: + return value + zarr_dir = (config.user_dir / value).resolve() + zarr_dir.mkdir(parents=True, exist_ok=True) + clear_dir(zarr_dir) + return zarr_dir + class Footprint(Asset): _entity: ClassVar[Entity] = PrivateAttr( @@ -109,8 +121,6 @@ class Footprints(Asset): class Traces(Asset): - zarr_path: Path | None = None - """relative to config.user_data_dir""" peek_size: int | None = None """ Traces(array=array, path=path) -> saves to zarr (should be set in this asset, and leave @@ -182,16 +192,6 @@ def from_array( new_cls.array = array return new_cls - @field_validator("zarr_path", mode="after") - @classmethod - def validate_zarr_path(cls, value: Path | None) -> Path | None: - if value is None: - return value - zarr_dir = (config.user_dir / value).resolve() - zarr_dir.mkdir(parents=True, exist_ok=True) - clear_dir(zarr_dir) - return zarr_dir - @model_validator(mode="after") def check_zarr_setting(self) -> "Traces": if self.zarr_path: diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index a658e3a4..f55f999a 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -1,13 +1,17 @@ from abc import abstractmethod from collections.abc import Generator +from glob import glob from pathlib import Path -from typing import Protocol +from typing import Protocol, Literal import cv2 +from natsort import natsorted from numpy.typing import NDArray from skimage import io +from cala.assets import Asset from cala.config import config +from cala.util import clear_dir class Stream(Protocol): @@ -80,18 +84,30 @@ def stream_videos(video_paths: list[Path]) -> Generator[NDArray]: current_stream.close() -def stream(files: list[str | Path]) -> Generator[NDArray, None, None]: +def stream( + files: list[str | Path] = None, + subdir: str | Path = None, + extension: Literal[".avi"] = None, + prefix: str | None = None, +) -> Generator[NDArray, None, None]: """ Create a video stream from the provided video files. Args: files: List of file paths + subdir: Directory path. Ignored if files is populated. + extension: File extension. Ignored if files is populated. + prefix: File prefix. Ignored if files is populated. Returns: Stream: A stream that yields video frames """ - file_paths = [Path(f) if isinstance(f, str) else f for f in files] - suffix = {path.suffix.lower() for path in file_paths} + if files: + file_paths = [Path(f) if isinstance(f, str) else f for f in files] + suffix = {path.suffix.lower() for path in file_paths} + else: + files = natsort_paths(subdir, extension, prefix) + suffix = {extension} image_format = {".tif", ".tiff"} video_format = {".mp4", ".avi", ".webm"} @@ -102,3 +118,29 @@ def stream(files: list[str | Path]) -> Generator[NDArray, None, None]: yield from stream_images(files) else: raise ValueError(f"Unsupported file format: {suffix}") + + +def save_asset( + asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path | None = None +) -> None: + if target_epoch == curr_epoch: + if path: + zarr_dir = (config.user_dir / path).resolve() + else: + zarr_dir = config.user_dir / asset.zarr_path + + zarr_dir.mkdir(parents=True, exist_ok=True) + clear_dir(zarr_dir) + try: + asset.full_array().to_zarr(zarr_dir) # for Traces + except AttributeError: + asset.array.to_zarr(zarr_dir, mode="w") + return None + + +def natsort_paths( + subdir: str | Path, extension: Literal[".avi"], prefix: str | None = None +) -> list[str]: + prefix = prefix or "" + video_dir = config.video_dir / subdir + return natsorted(glob(f"{str(video_dir)}/{prefix}*{extension}")) diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 53c51c8f..1a1b47ff 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -197,6 +197,20 @@ nodes: - buffer: assets.buffer # DETECT ENDS + persist: + type: cala.nodes.io.save_asset + params: + target_epoch: 800 #5999 formula: 1000 * n_chunks - 1 + path: traces/ + depends: + - asset: assets.traces + - curr_epoch: counter.idx + + + return: + type: return + depends: downsample.frame + # merge: # type: cala.nodes.merge.merge_existing # params: From 72bebade6524136748ca8f3cedb00a0353c356f7 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 19:19:13 -0800 Subject: [PATCH 06/23] feat: long_recording.yaml --- tests/data/pipelines/long_recording.yaml | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 1a1b47ff..e9813798 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -39,13 +39,8 @@ nodes: source: type: cala.nodes.io.stream params: - files: - - long_recording/1.avi - - long_recording/2.avi - - long_recording/3.avi - - long_recording/4.avi - - long_recording/5.avi - - long_recording/6.avi + subdir: long_recording + extension: .avi counter: type: cala.nodes.prep.counter frame: @@ -161,7 +156,7 @@ nodes: type: cala.nodes.detect.SliceNMF params: min_frames: 100 - detect_thresh: 10.0 + detect_thresh: 8.0 reprod_tol: 0.005 depends: - residuals: residual.movie From f74853e2484e7ba430802044f8b6decd88b61c4f Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 19:20:41 -0800 Subject: [PATCH 07/23] init_order --- src/cala/nodes/prep/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 00f5376f..053af4cb 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,12 +1,12 @@ +from .wrap import counter, package_frame # isort:skip from .background_removal import remove_background from .denoise import Restore, blur -from .downsample import downsample # noqa from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .wrap import counter, package_frame # isort:skip +from .downsample import downsample # noqa __all__ = [ "blur", From dcc48b22369b7a8d3b1c7c42ef2d09cc94151232 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 19:47:16 -0800 Subject: [PATCH 08/23] format: ruff --- src/cala/nodes/io.py | 8 ++------ src/cala/nodes/prep/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index f55f999a..30e08993 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -2,7 +2,7 @@ from collections.abc import Generator from glob import glob from pathlib import Path -from typing import Protocol, Literal +from typing import Literal, Protocol import cv2 from natsort import natsorted @@ -124,11 +124,7 @@ def save_asset( asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path | None = None ) -> None: if target_epoch == curr_epoch: - if path: - zarr_dir = (config.user_dir / path).resolve() - else: - zarr_dir = config.user_dir / asset.zarr_path - + zarr_dir = (config.user_dir / path).resolve() if path else config.user_dir / asset.zarr_path zarr_dir.mkdir(parents=True, exist_ok=True) clear_dir(zarr_dir) try: diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 053af4cb..fe1b4188 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,12 +1,12 @@ -from .wrap import counter, package_frame # isort:skip from .background_removal import remove_background from .denoise import Restore, blur +from .downsample import downsample from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .downsample import downsample # noqa +from .wrap import counter, package_frame # noqa: I001 __all__ = [ "blur", From fa6301ad300e5ff1af2d1a197d522db81deee076 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 22:37:58 -0800 Subject: [PATCH 09/23] debug: handle sparse in zarr --- src/cala/nodes/io.py | 2 +- tests/data/pipelines/long_recording.yaml | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index 30e08993..09a60f8e 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -130,7 +130,7 @@ def save_asset( try: asset.full_array().to_zarr(zarr_dir) # for Traces except AttributeError: - asset.array.to_zarr(zarr_dir, mode="w") + asset.array.as_numpy().to_zarr(zarr_dir, mode="w") return None diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index e9813798..318cc810 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -192,15 +192,23 @@ nodes: - buffer: assets.buffer # DETECT ENDS - persist: + save_trace: type: cala.nodes.io.save_asset params: - target_epoch: 800 #5999 formula: 1000 * n_chunks - 1 + target_epoch: 999 #5999 formula: 1000 * n_chunks - 1 path: traces/ depends: - asset: assets.traces - curr_epoch: counter.idx + save_shape: + type: cala.nodes.io.save_asset + params: + target_epoch: 999 #5999 formula: 1000 * n_chunks - 1 + path: footprints/ + depends: + - asset: assets.footprints + - curr_epoch: counter.idx return: type: return From 352ff5138b7a4a0c34b40fa683c5d7cf8ba3fe1c Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 22:49:57 -0800 Subject: [PATCH 10/23] stop fucking switching the import order --- src/cala/nodes/prep/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index fe1b4188..ae7dafff 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,3 +1,4 @@ +from .wrap import counter, package_frame # noqa: I001 from .background_removal import remove_background from .denoise import Restore, blur from .downsample import downsample @@ -6,7 +7,6 @@ from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .wrap import counter, package_frame # noqa: I001 __all__ = [ "blur", From 1e1827cbc46dc91e8ce2c27292b5f6287cea1d32 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 03:27:59 -0800 Subject: [PATCH 11/23] feat: trace interval flushing to zarr functionality --- src/cala/assets.py | 180 +++++++++++++++++++++++++++---------------- tests/test_assets.py | 163 ++++++++++++++++++++++++++++----------- 2 files changed, 234 insertions(+), 109 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 2a8168c5..ac0dcc40 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -2,7 +2,7 @@ import shutil from copy import deepcopy from pathlib import Path -from typing import Any, ClassVar, Self, TypeVar +from typing import ClassVar, Self, TypeVar import numpy as np import xarray as xr @@ -48,6 +48,12 @@ def from_array(cls, array: xr.DataArray) -> Self: def reset(self) -> None: self.array_ = None + if self.zarr_path: + path = Path(self.zarr_path) + try: + shutil.rmtree(path) + except FileNotFoundError: + contextlib.suppress(FileNotFoundError) def __eq__(self, other: "Asset") -> bool: return self.array.equals(other.array) @@ -73,6 +79,22 @@ def validate_zarr_path(cls, value: Path | None) -> Path | None: clear_dir(zarr_dir) return zarr_dir + def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: + da = ( + xr.open_zarr(self.zarr_path) + .isel(isel_filter) + .sel(sel_filter) + .to_dataarray() + .drop_vars(["variable"]) + .isel(variable=0) + ) + return da.assign_coords( + { + AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str), + AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str), + } + ) + class Footprint(Asset): _entity: ClassVar[Entity] = PrivateAttr( @@ -121,92 +143,118 @@ class Footprints(Asset): class Traces(Asset): - peek_size: int | None = None + peek_size: int """ - Traces(array=array, path=path) -> saves to zarr (should be set in this asset, and leave - untouched in nodes.) - Traces.array -> loads from zarr + How many epochs to return when called. + """ + flush_interval: int | None = None + _entity: ClassVar[Entity] = PrivateAttr( + Group( + name="trace-group", + member=Trace.entity(), + group_by=Dims.component, + checks=[is_non_negative], + allow_extra_coords=False, + ) + ) + + @model_validator(mode="after") + def flush_conditions(self) -> Self: + assert (self.flush_interval and self.zarr_path) or ( + not self.flush_interval and not self.zarr_path + ), "zarr_path and flush_interval should either be both provided or neither." + if self.flush_interval: + assert self.flush_interval > self.peek_size, ( + f"flush_interval must be larger than peek_size. " + f"Provided: {self.flush_interval = }, {self.peek_size = }" + ) + return self + + @property + def sizes(self) -> dict[str, int]: + if self.zarr_path: + total_size = {} + for key, val in self.array_.sizes.items(): + if key == AXIS.frames_dim: + total_size[key] = val + self.load_zarr().sizes[key] + else: + total_size[key] = val + return total_size + else: + return self.array_.sizes @property def array(self) -> xr.DataArray: - peek_filter = {AXIS.frames_dim: slice(-self.peek_size, None)} if self.peek_size else None - return self.full_array(isel_filter=peek_filter) + return self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) @array.setter def array(self, array: xr.DataArray) -> None: + """ + In case zarr_path is defined, if array is larger than peek_size, + the epochs older than -peek_size gets flushed to zarr array. + + """ + if self.validate_schema: + array.validate.against_schema(self._entity.model) if self.zarr_path: - if self.validate_schema: - array.validate.against_schema(self._entity.model) - array.to_zarr(self.zarr_path, mode="w") # need to make sure it can overwrite + self.array_ = array.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, mode="w" + ) else: self.array_ = array - def reset(self) -> None: - self.array_ = None - if self.zarr_path: - path = Path(self.zarr_path) - try: - shutil.rmtree(path) - except FileNotFoundError: - contextlib.suppress(FileNotFoundError) + def zarr_append(self, array: xr.DataArray, dim: str) -> None: + """ + Since we cannot simply append to zarr array in memory using xarray syntax, + we provide a convenience method for appending to zarr array and in-memory array + in a streamlined manner. - def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: - if self.zarr_path: - try: - return self.load_zarr(isel_filter=isel_filter, sel_filter=sel_filter).compute() - except FileNotFoundError: - pass - return ( - self.array_.isel(isel_filter).sel(sel_filter) - if self.array_ is not None - else self.array_ - ) + Incoming arrays have to be 2-dimensional. - def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: - da = ( - xr.open_zarr(self.zarr_path) - .isel(isel_filter) - .sel(sel_filter) - .to_dataarray() - .drop_vars(["variable"]) - .isel(variable=0) - ) - return da.assign_coords( - { - AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str), - AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str), - } - ) + """ - def update(self, array: xr.DataArray, **kwargs: Any) -> None: - if self.validate_schema: - array.validate.against_schema(self._entity.model) - array.to_zarr(self.zarr_path, **kwargs) + if dim == AXIS.frames_dim: + self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) + if self.array_.sizes[AXIS.frames_dim] > self.flush_interval: + self._flush_zarr() + elif dim == AXIS.component_dim: + self.array_ = xr.concat( + [self.array_, array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})], dim=dim + ) + array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, append_dim=dim + ) + + def _flush_zarr(self) -> None: + """ + Flushes traces older than peek_size to zarr array. + + """ + self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, append_dim=AXIS.frames_dim + ) + self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) @classmethod - def from_array( - cls, array: xr.DataArray, zarr_path: Path | str | None = None, peek_size: int | None = None - ) -> "Traces": - new_cls = cls(zarr_path=zarr_path, peek_size=peek_size) + def from_array(cls, array: xr.DataArray) -> "Traces": + """ + This is only really used for typing / auto-validation purposes, + so we don't really have to worry about specifying the parameters. + + """ + new_cls = cls(peek_size=array.sizes[AXIS.frames_dim]) new_cls.array = array return new_cls - @model_validator(mode="after") - def check_zarr_setting(self) -> "Traces": + def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: if self.zarr_path: - assert self.peek_size, "peek_size must be set for zarr." - return self - - _entity: ClassVar[Entity] = PrivateAttr( - Group( - name="trace-group", - member=Trace.entity(), - group_by=Dims.component, - checks=[is_non_negative], - allow_extra_coords=False, - ) - ) + return xr.concat( + [self.load_zarr(isel_filter, sel_filter), self.array_], dim=AXIS.frames_dim + ).compute() + else: + return self.array_.isel(isel_filter).sel(sel_filter) class Movie(Asset): diff --git a/tests/test_assets.py b/tests/test_assets.py index 2897eb77..9f5ea726 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -1,6 +1,11 @@ -import os +""" +Since assets without zarr operations is very straightforward, +this test file is mainly focused on testing zarr-integrated assets, +with the exception of Buffer. + +""" + from datetime import datetime -from pathlib import Path import pytest import xarray as xr @@ -9,71 +14,143 @@ from cala.models import AXIS -@pytest.fixture -def path() -> Path: - return Path("assets") +@pytest.mark.parametrize("peek_size", [30, 49, 50, 51, 70]) +def test_array_assignment(tmp_path, four_connected_cells, peek_size): + """ + 1. array should get assigned to memory if smaller than or equal to peek_size + 2. array should get split to memory and zarr_path at peek_size if + array is larger than peek_size - -def test_assign_zarr(path, four_connected_cells): - zarr_traces = Traces(zarr_path=path, peek_size=100) + """ traces = four_connected_cells.traces.array + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) zarr_traces.array = traces - print(os.listdir(zarr_traces.zarr_path)) - assert zarr_traces.array_ is None # not in memory - assert zarr_traces.array.equals(traces) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == min(n_frames, peek_size) + assert zarr_traces.load_zarr().sizes[AXIS.frames_dim] == max(0, n_frames - peek_size) -def test_from_array(four_connected_cells, path): +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_array_peek(tmp_path, four_connected_cells, peek_size): + """ + .array property returns correctly for peek_size smaller, equal, larger + than the saved array. + + """ traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array(traces, path, peek_size=four_connected_cells.n_frames) - assert zarr_traces.array_ is None - assert zarr_traces.array.equals(traces) + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) + zarr_traces.array = traces + + assert zarr_traces.array.sizes[AXIS.frames_dim] == min(peek_size, n_frames) -@pytest.mark.parametrize("peek_shift", [-1, 0, 1]) -def test_peek(four_connected_cells, path, peek_shift): +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_flush_zarr(four_connected_cells, tmp_path, peek_size): + """ + _flush_zarr method flushes epochs older than peek_size to zarr_path, + and leaves only the newest epochs in memory. + + """ traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array( - traces, path, peek_size=four_connected_cells.n_frames + peek_shift + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) ) - if peek_shift >= 0: - assert zarr_traces.array.equals(traces) - else: - with pytest.raises(AssertionError): - assert zarr_traces.array.equals(traces) + + # need to initialize zarr file first to append with _flush_zarr + zarr_traces.array = traces[:, :0] + zarr_traces.array_ = traces # add 50 frames + + zarr_traces._flush_zarr() + # only peek_size left in memory + assert zarr_traces.array_.sizes[AXIS.frames_dim] == min(n_frames, peek_size) + # the rest is in zarr + assert zarr_traces.load_zarr().sizes[AXIS.frames_dim] == max(0, n_frames - peek_size) -def test_ingest_frame(path, four_connected_cells): +@pytest.mark.parametrize("peek_size, flush_interval", [(30, 70)]) +def test_zarr_append_frame(four_connected_cells, tmp_path, peek_size, flush_interval): + """ + Test that when in-memory array size hits flush_interval, old epochs + get flushed. + + """ traces = four_connected_cells.traces.array - old_traces = traces.isel({AXIS.frames_dim: slice(None, -1)}) - zarr_traces = Traces.from_array(old_traces, path, peek_size=four_connected_cells.n_frames) - new_traces = four_connected_cells.traces.array.isel({AXIS.frames_dim: [-1]}) + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames - zarr_traces.update(new_traces, append_dim=AXIS.frames_dim) - # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.frames_dim) + zarr_traces = Traces(zarr_path=tmp_path, peek_size=peek_size, flush_interval=flush_interval) + zarr_traces.array = traces[:, :0] # just initializing zarr + + # array smaller than flush_interval. does not flush. + zarr_traces.zarr_append(traces, dim=AXIS.frames_dim) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == n_frames + + # array larger than flush_interval. flushes down to peek_size. + zarr_traces.zarr_append(traces, dim=AXIS.frames_dim) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == peek_size - assert zarr_traces.array.equals(traces) +@pytest.mark.parametrize("flush_interval", [30]) +def test_zarr_append_component(four_connected_cells, tmp_path, flush_interval): + """ + Test that when adding components, + (a) it gets appropriately divided between in-memory and zarr arrays. + (b) in-memory and zarr arrays can be concatenated together afterward. -def test_ingest_component(four_connected_cells, path): + """ traces = four_connected_cells.traces.array - old_traces = traces.isel({AXIS.component_dim: slice(None, -1)}) - zarr_traces = Traces.from_array(old_traces, path, peek_size=four_connected_cells.n_frames) - new_traces = four_connected_cells.traces.array.isel({AXIS.component_dim: [-1]}) - zarr_traces.update(new_traces, append_dim=AXIS.component_dim) - # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.component_dim) + zarr_traces = Traces(zarr_path=tmp_path, peek_size=20, flush_interval=flush_interval) + zarr_traces.array = traces[:-1, :] # forgot the last component! Also, got flushed. + zarr_traces.zarr_append(traces[-1:, :], dim=AXIS.component_dim) # appendee needs to be 2D + assert zarr_traces.array_[AXIS.component_dim].equals(traces[AXIS.component_dim]) + assert zarr_traces.load_zarr()[AXIS.component_dim].equals(traces[AXIS.component_dim]) + result = xr.concat([zarr_traces.load_zarr(), zarr_traces.array_], dim=AXIS.frames_dim).compute() + assert result.equals(traces) + + +def test_from_array(four_connected_cells): + """ + .from_array method can correctly reproduce the array with .array + + """ + traces = four_connected_cells.traces.array + zarr_traces = Traces.from_array(traces) assert zarr_traces.array.equals(traces) -def test_overwrite(four_connected_cells, four_separate_cells, path): - conn_traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array(conn_traces, path, peek_size=four_connected_cells.n_frames) +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_sizes(four_connected_cells, tmp_path, peek_size): + """ + The sizes property of the asset combines the sizes + of the in-memory array and the zarr array. + + """ + traces = four_connected_cells.traces.array + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) + zarr_traces.array = traces + + assert zarr_traces.sizes == traces.sizes + + +@pytest.mark.xfail +def test_overwrite(four_connected_cells, four_separate_cells): + """ + test that zarr array can get overwritten. - sep_traces = four_separate_cells.traces.array - zarr_traces.array = sep_traces - assert zarr_traces.array.equals(sep_traces) + """ # two cases of init: From b4eb5ad65720de51ce1a83168850823b352a32f3 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 05:45:02 -0800 Subject: [PATCH 12/23] feat: trace mostly supporting new grammar (except zarr component update) --- src/cala/assets.py | 48 ++++++++++++++++++------ src/cala/nodes/traces.py | 67 ++++++++++++++++++---------------- tests/test_assets.py | 6 +-- tests/test_iter/test_traces.py | 45 ++++++++++++++++++++--- 4 files changed, 113 insertions(+), 53 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index ac0dcc40..e8346e27 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -144,11 +144,22 @@ class Footprints(Asset): class Traces(Asset): peek_size: int + """How many epochs to return when called.""" + flush_interval: int | None = None + """How many epochs to wait until next flush""" + + _deprecated: list = PrivateAttr(default_factory=list) """ - How many epochs to return when called. - + Deprecated, or replaced component idx. + Since zarr does not support efficiently removing rows and columns, + there's no easy way to remove a column when a component has been + removed or replaced. Instead, we "mask" it with this "deprecated" + flag. + + When arrays are called, these are filtered out. When new epochs are + added, these are added in with nan values. """ - flush_interval: int | None = None + _entity: ClassVar[Entity] = PrivateAttr( Group( name="trace-group", @@ -205,7 +216,7 @@ def array(self, array: xr.DataArray) -> None: else: self.array_ = array - def zarr_append(self, array: xr.DataArray, dim: str) -> None: + def append(self, array: xr.DataArray, dim: str) -> None: """ Since we cannot simply append to zarr array in memory using xarray syntax, we provide a convenience method for appending to zarr array and in-memory array @@ -217,15 +228,22 @@ def zarr_append(self, array: xr.DataArray, dim: str) -> None: if dim == AXIS.frames_dim: self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) - if self.array_.sizes[AXIS.frames_dim] > self.flush_interval: - self._flush_zarr() + + if self.zarr_path: + if self.array_.sizes[AXIS.frames_dim] > self.flush_interval: + self._flush_zarr() + elif dim == AXIS.component_dim: - self.array_ = xr.concat( - [self.array_, array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})], dim=dim - ) - array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( - self.zarr_path, append_dim=dim - ) + if self.zarr_path: + self.array_ = xr.concat( + [self.array_, array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})], + dim=dim, + ) + array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, append_dim=dim + ) + else: + self.array_ = xr.concat([self.array_, array], dim=dim, combine_attrs="drop") def _flush_zarr(self) -> None: """ @@ -237,6 +255,12 @@ def _flush_zarr(self) -> None: ) self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + def deprecate(self, del_mask: np.ndarray) -> None: + if self.zarr_path: + ... + else: + self.array_ = self.array_[~del_mask] + @classmethod def from_array(cls, array: xr.DataArray) -> "Traces": """ diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index ac8ed617..c115ab6f 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -79,9 +79,9 @@ def ingest_frame( updated_tr = updated_traces.volumize.dim_with_coords( dim=AXIS.frames_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] ) - traces.update(updated_tr, append_dim=AXIS.frames_dim) + traces.append(updated_tr, dim=AXIS.frames_dim) else: - traces.array = xr.concat([traces.full_array(), updated_traces], dim=AXIS.frames_dim) + traces.append(updated_traces, dim=AXIS.frames_dim) return PopSnap.from_array(updated_traces) @@ -204,44 +204,47 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: :param new_traces: Can be either a newly registered trace or an updated existing one. """ - c = traces.full_array() - c_det = new_traces.array + c_new = new_traces.array - if c_det is None: + if c_new is None: return traces - if c is None: - traces.array = c_det + if traces.array is None: + traces.array = c_new return traces - if c.sizes[AXIS.frames_dim] > c_det.sizes[AXIS.frames_dim]: - # if newly detected cells are truncated, pad with np.nans - c_new = xr.DataArray( - np.full((c_det.sizes[AXIS.component_dim], c.sizes[AXIS.frames_dim]), np.nan), - dims=[AXIS.component_dim, AXIS.frames_dim], - coords=c.isel({AXIS.component_dim: 0}).coords, - ) - c_new[AXIS.id_coord] = c_det[AXIS.id_coord] - c_new[AXIS.detect_coord] = c_det[AXIS.detect_coord] - - c_new.loc[ - {AXIS.frames_dim: slice(c.sizes[AXIS.frames_dim] - c_det.sizes[AXIS.frames_dim], None)} - ] = c_det - else: - c_new = c_det + total_frames = traces.sizes[AXIS.frames_dim] + new_n_frames = c_new.sizes[AXIS.frames_dim] - merged_ids = c_det.attrs.get("replaces") + merged_ids = c_new.attrs.get("replaces") if merged_ids: - if traces.zarr_path: - invalid = c[AXIS.id_coord].isin(merged_ids) - traces.array = c.where(~invalid.compute(), drop=True).compute() - else: - intact_mask = ~np.isin(c[AXIS.id_coord].values, merged_ids) - c = c[intact_mask] + del_mask = np.isin(traces.array[AXIS.id_coord].values, merged_ids) + traces.deprecate(del_mask) - if traces.zarr_path: - traces.update(c_new, append_dim=AXIS.component_dim) + if total_frames > new_n_frames: + # if newly detected cells are truncated, pad with np.nans + c_pad = _pad_history(c_new, total_frames, np.nan) else: - traces.array = xr.concat([c, c_new], dim=AXIS.component_dim, combine_attrs="drop") + c_pad = c_new + + traces.append(c_pad, dim=AXIS.component_dim) return traces + + +def _pad_history(traces: xr.DataArray, total_nframes: int, value: float = np.nan) -> xr.DataArray: + """ + Pad unknown historical epochs with values... + + """ + new_nframes = traces.sizes[AXIS.frames_dim] + + c_new = xr.DataArray( + np.full((traces.sizes[AXIS.component_dim], total_nframes), value), + dims=[AXIS.component_dim, AXIS.frames_dim], + coords=traces[AXIS.component_dim].coords, + ) + + c_new.loc[{AXIS.frames_dim: slice(total_nframes - new_nframes, None)}] = traces + + return c_new diff --git a/tests/test_assets.py b/tests/test_assets.py index 9f5ea726..04a8f97a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -90,11 +90,11 @@ def test_zarr_append_frame(four_connected_cells, tmp_path, peek_size, flush_inte zarr_traces.array = traces[:, :0] # just initializing zarr # array smaller than flush_interval. does not flush. - zarr_traces.zarr_append(traces, dim=AXIS.frames_dim) + zarr_traces.append(traces, dim=AXIS.frames_dim) assert zarr_traces.array_.sizes[AXIS.frames_dim] == n_frames # array larger than flush_interval. flushes down to peek_size. - zarr_traces.zarr_append(traces, dim=AXIS.frames_dim) + zarr_traces.append(traces, dim=AXIS.frames_dim) assert zarr_traces.array_.sizes[AXIS.frames_dim] == peek_size @@ -110,7 +110,7 @@ def test_zarr_append_component(four_connected_cells, tmp_path, flush_interval): zarr_traces = Traces(zarr_path=tmp_path, peek_size=20, flush_interval=flush_interval) zarr_traces.array = traces[:-1, :] # forgot the last component! Also, got flushed. - zarr_traces.zarr_append(traces[-1:, :], dim=AXIS.component_dim) # appendee needs to be 2D + zarr_traces.append(traces[-1:, :], dim=AXIS.component_dim) # appendee needs to be 2D assert zarr_traces.array_[AXIS.component_dim].equals(traces[AXIS.component_dim]) assert zarr_traces.load_zarr()[AXIS.component_dim].equals(traces[AXIS.component_dim]) diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index bd5057b6..9139f89e 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -18,15 +18,29 @@ def frame_update() -> Node: ) +@pytest.mark.parametrize( + "zarr_setup", + [ + {"zarr_path": None, "peek_size": 50, "flush_interval": None}, + {"zarr_path": "tmp", "peek_size": 30, "flush_interval": 40}, + ], +) @pytest.mark.parametrize("toy", ["four_separate_cells", "four_connected_cells"]) -def test_update_traces(frame_update, toy, request) -> None: +def test_ingest_frame(frame_update, toy, zarr_setup, request, tmp_path) -> None: + """ + Frame ingestion step adds new traces to the Traces instance + that matches true brightness, overlapping and non-overlapping alike. + + """ toy = request.getfixturevalue(toy) + zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None xray = Node.from_specification( spec=NodeSpecification(id="test", type="cala.nodes.overlap.initialize") ) - traces = Traces.from_array(toy.traces.array.isel({AXIS.frames_dim: slice(None, -1)})) + traces = Traces(array_=None, **zarr_setup) + traces.array = toy.traces.array.isel({AXIS.frames_dim: slice(None, -1)}) frame = Frame.from_array(toy.make_movie().array.isel({AXIS.frames_dim: -1})) overlap = xray.process(overlaps=Overlaps(), footprints=toy.footprints) @@ -46,14 +60,33 @@ def comp_update() -> Node: ) +@pytest.mark.parametrize( + "zarr_setup", + [ + {"zarr_path": None, "peek_size": 40, "flush_interval": None}, + {"zarr_path": "tmp", "peek_size": 30, "flush_interval": 40}, + ], +) @pytest.mark.parametrize("toy", ["four_separate_cells"]) -def test_ingest_component(comp_update, toy, request) -> None: +def test_ingest_component(comp_update, toy, request, zarr_setup, tmp_path) -> None: + """ + *New component is always the same length as peek_size. + + I can add components that... + 1. is single and new + 2. is single and replacing + 3. is multiple + 4. is multiple and replacing + + """ toy = request.getfixturevalue(toy) + zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None - traces = Traces.from_array(toy.traces.array.isel({AXIS.component_dim: slice(None, -1)})) + traces = Traces(array_=None, **zarr_setup) + traces.array = toy.traces.array.isel({AXIS.component_dim: slice(None, -1)}) new_traces = toy.traces.array.isel( - {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-40, None)} + {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-zarr_setup["peek_size"], None)} ) new_traces.attrs["replaces"] = ["cell_0"] @@ -63,4 +96,4 @@ def test_ingest_component(comp_update, toy, request) -> None: expected = toy.traces.array.drop_sel({AXIS.component_dim: 0}) expected.loc[{AXIS.component_dim: -1, AXIS.frames_dim: slice(None, 10)}] = np.nan - assert result.array.equals(expected) + assert result.array_.equals(expected) From 31d36384c0ceab8579f3fbad3bb7a42844c9a15b Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 11:28:46 -0800 Subject: [PATCH 13/23] feat: implement zarr and in-memory caching in traces --- src/cala/assets.py | 51 +++++++++++++++++++++++++--------- src/cala/nodes/traces.py | 4 +-- tests/data/pipelines/odl.yaml | 4 +++ tests/test_assets.py | 18 ++++++++++++ tests/test_iter/test_traces.py | 11 ++++---- 5 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index e8346e27..e423cf72 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -143,12 +143,12 @@ class Footprints(Asset): class Traces(Asset): - peek_size: int + peek_size: int = None """How many epochs to return when called.""" flush_interval: int | None = None """How many epochs to wait until next flush""" - _deprecated: list = PrivateAttr(default_factory=list) + _deprecated: list[str] = PrivateAttr(default_factory=list) """ Deprecated, or replaced component idx. Since zarr does not support efficiently removing rows and columns, @@ -197,7 +197,11 @@ def sizes(self) -> dict[str, int]: @property def array(self) -> xr.DataArray: - return self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + return ( + self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + if self.array_ is not None + else self.array_ + ) @array.setter def array(self, array: xr.DataArray) -> None: @@ -248,18 +252,37 @@ def append(self, array: xr.DataArray, dim: str) -> None: def _flush_zarr(self) -> None: """ Flushes traces older than peek_size to zarr array. + Needs to append nans to deprecated components, since they get deleted + in in-memory array, but persist in zarr array. + + Could do this much more elegantly by pre-allocating .array_ """ - self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( - self.zarr_path, append_dim=AXIS.frames_dim - ) + raw_zarr = self.load_zarr() + to_flush = self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)}) + if self._deprecated: + zarr_ids = raw_zarr[AXIS.id_coord].values + zarr_detects = raw_zarr[AXIS.detect_coord].values + intact_mask = ~np.isin(zarr_ids, self._deprecated) + n_flush = to_flush.sizes[AXIS.frames_dim] + prealloc = xr.DataArray( + np.full((raw_zarr.sizes[AXIS.component_dim], n_flush), np.nan), + dims=[AXIS.component_dim, AXIS.frames_dim], + coords={ + AXIS.id_coord: (AXIS.component_dim, zarr_ids), + AXIS.detect_coord: (AXIS.component_dim, zarr_detects), + }, + ).assign_coords(to_flush[AXIS.frames_dim].coords) + prealloc.loc[intact_mask] = to_flush + prealloc.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim) + else: + to_flush.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim) self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) - def deprecate(self, del_mask: np.ndarray) -> None: + def keep(self, intact_mask: np.ndarray) -> None: if self.zarr_path: - ... - else: - self.array_ = self.array_[~del_mask] + self._deprecated.extend(self.array_[AXIS.id_coord].values[~intact_mask]) + self.array_ = self.array_[intact_mask] @classmethod def from_array(cls, array: xr.DataArray) -> "Traces": @@ -274,9 +297,11 @@ def from_array(cls, array: xr.DataArray) -> "Traces": def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: if self.zarr_path: - return xr.concat( - [self.load_zarr(isel_filter, sel_filter), self.array_], dim=AXIS.frames_dim - ).compute() + raw_zarr = self.load_zarr(isel_filter, sel_filter) + zarr_ids = raw_zarr[AXIS.id_coord].values + intact_mask = ~np.isin(zarr_ids, self._deprecated) + + return xr.concat([raw_zarr[intact_mask], self.array_], dim=AXIS.frames_dim).compute() else: return self.array_.isel(isel_filter).sel(sel_filter) diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index c115ab6f..e34f3bf8 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -218,8 +218,8 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: merged_ids = c_new.attrs.get("replaces") if merged_ids: - del_mask = np.isin(traces.array[AXIS.id_coord].values, merged_ids) - traces.deprecate(del_mask) + intact_mask = ~np.isin(traces.array[AXIS.id_coord].values, merged_ids) + traces.keep(intact_mask) if total_frames > new_n_frames: # if newly detected cells are truncated, pad with np.nans diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index b980e57a..a07a1cd3 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -15,6 +15,10 @@ assets: scope: runner traces: type: cala.assets.Traces + params: + zarr_path: traces/ + peek_size: 100 + flush_interval: 200 scope: runner pix_stats: type: cala.assets.PixStats diff --git a/tests/test_assets.py b/tests/test_assets.py index 04a8f97a..7b6b9b4a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -7,6 +7,7 @@ from datetime import datetime +import numpy as np import pytest import xarray as xr @@ -118,6 +119,23 @@ def test_zarr_append_component(four_connected_cells, tmp_path, flush_interval): assert result.equals(traces) +@pytest.mark.parametrize("flush_interval", [30]) +def test_flush_after_deprecated(four_connected_cells, tmp_path, flush_interval) -> None: + traces = four_connected_cells.traces.array + peek_size = 20 + zarr_traces = Traces(zarr_path=tmp_path, peek_size=peek_size, flush_interval=flush_interval) + zarr_traces.array = traces + + merged_ids = zarr_traces.array[AXIS.id_coord].values[0] + intact_mask = ~np.isin(zarr_traces.array[AXIS.id_coord].values, merged_ids) + zarr_traces.keep(intact_mask) + zarr_traces.append(traces[intact_mask], dim=AXIS.frames_dim) + + assert zarr_traces.full_array().equals( + xr.concat([traces] * 2, dim=AXIS.frames_dim)[intact_mask] + ) + + def test_from_array(four_connected_cells): """ .from_array method can correctly reproduce the array with .array diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index 9139f89e..a84c1aa5 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -83,17 +83,18 @@ def test_ingest_component(comp_update, toy, request, zarr_setup, tmp_path) -> No zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None traces = Traces(array_=None, **zarr_setup) - traces.array = toy.traces.array.isel({AXIS.component_dim: slice(None, -1)}) + traces.array = toy.traces.array_.isel({AXIS.component_dim: slice(None, -1)}) - new_traces = toy.traces.array.isel( + new_traces = toy.traces.array_.isel( {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-zarr_setup["peek_size"], None)} ) new_traces.attrs["replaces"] = ["cell_0"] - result = comp_update.process(traces, Traces.from_array(new_traces)) expected = toy.traces.array.drop_sel({AXIS.component_dim: 0}) - expected.loc[{AXIS.component_dim: -1, AXIS.frames_dim: slice(None, 10)}] = np.nan + expected.loc[ + {AXIS.component_dim: -1, AXIS.frames_dim: slice(None, -zarr_setup["peek_size"])} + ] = np.nan - assert result.array_.equals(expected) + assert result.full_array().equals(expected) From c13a3b6d3b7db1971180dcba28f7970394f6d4d9 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 11:36:30 -0800 Subject: [PATCH 14/23] format: ruff --- src/cala/assets.py | 6 +++--- src/cala/nodes/traces.py | 6 +----- tests/data/pipelines/long_recording.yaml | 3 ++- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index e423cf72..89ea5d38 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -233,9 +233,9 @@ def append(self, array: xr.DataArray, dim: str) -> None: if dim == AXIS.frames_dim: self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) - if self.zarr_path: - if self.array_.sizes[AXIS.frames_dim] > self.flush_interval: - self._flush_zarr() + is_overflowing = self.array_.sizes[AXIS.frames_dim] > self.flush_interval + if self.zarr_path and is_overflowing: + self._flush_zarr() elif dim == AXIS.component_dim: if self.zarr_path: diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index e34f3bf8..f8bcafcc 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -221,11 +221,7 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: intact_mask = ~np.isin(traces.array[AXIS.id_coord].values, merged_ids) traces.keep(intact_mask) - if total_frames > new_n_frames: - # if newly detected cells are truncated, pad with np.nans - c_pad = _pad_history(c_new, total_frames, np.nan) - else: - c_pad = c_new + c_pad = _pad_history(c_new, total_frames, np.nan) if total_frames > new_n_frames else c_new traces.append(c_pad, dim=AXIS.component_dim) diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 318cc810..5ee8152e 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -17,8 +17,9 @@ assets: type: cala.assets.Traces scope: runner params: - # zarr_path: traces/ + zarr_path: traces/ peek_size: 100 + flush_interval: 1000 pix_stats: type: cala.assets.PixStats scope: runner From acd7cb4e8c3e040060597d9f4b5f53922837ac26 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 11:43:37 -0800 Subject: [PATCH 15/23] debug: new component concat --- src/cala/assets.py | 5 ++-- .../pipelines/{with_src.yaml => minian.yaml} | 23 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) rename tests/data/pipelines/{with_src.yaml => minian.yaml} (93%) diff --git a/src/cala/assets.py b/src/cala/assets.py index 89ea5d38..62f0df83 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -239,11 +239,12 @@ def append(self, array: xr.DataArray, dim: str) -> None: elif dim == AXIS.component_dim: if self.zarr_path: + n_in_memory = self.array_.sizes[AXIS.frames_dim] self.array_ = xr.concat( - [self.array_, array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})], + [self.array_, array.isel({AXIS.frames_dim: slice(-n_in_memory, None)})], dim=dim, ) - array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + array.isel({AXIS.frames_dim: slice(None, -n_in_memory)}).to_zarr( self.zarr_path, append_dim=dim ) else: diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/minian.yaml similarity index 93% rename from tests/data/pipelines/with_src.yaml rename to tests/data/pipelines/minian.yaml index 3595697f..d66895d9 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/minian.yaml @@ -1,4 +1,4 @@ -noob_id: cala-with-movie +noob_id: with-minian noob_model: noob.tube.TubeSpecification noob_version: 0.1.1.dev118+g64d81b7 @@ -17,8 +17,9 @@ assets: type: cala.assets.Traces scope: runner params: - # zarr_path: traces/ + zarr_path: traces/ peek_size: 100 + flush_interval: 1000 pix_stats: type: cala.assets.PixStats scope: runner @@ -41,15 +42,15 @@ nodes: params: files: - minian/msCam1.avi - - minian/msCam2.avi - - minian/msCam3.avi - - minian/msCam4.avi - - minian/msCam5.avi - - minian/msCam6.avi - - minian/msCam7.avi - - minian/msCam8.avi - - minian/msCam9.avi - - minian/msCam10.avi + # - minian/msCam2.avi + # - minian/msCam3.avi + # - minian/msCam4.avi + # - minian/msCam5.avi + # - minian/msCam6.avi + # - minian/msCam7.avi + # - minian/msCam8.avi + # - minian/msCam9.avi + # - minian/msCam10.avi counter: type: cala.nodes.prep.counter frame: From f2c0c0b13f1010728a14465b833bc7dbe1b00333 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 11:46:51 -0800 Subject: [PATCH 16/23] debug: new component concat --- tests/data/pipelines/minian.yaml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/data/pipelines/minian.yaml b/tests/data/pipelines/minian.yaml index d66895d9..32eb17b4 100644 --- a/tests/data/pipelines/minian.yaml +++ b/tests/data/pipelines/minian.yaml @@ -42,15 +42,15 @@ nodes: params: files: - minian/msCam1.avi - # - minian/msCam2.avi - # - minian/msCam3.avi - # - minian/msCam4.avi - # - minian/msCam5.avi - # - minian/msCam6.avi - # - minian/msCam7.avi - # - minian/msCam8.avi - # - minian/msCam9.avi - # - minian/msCam10.avi + - minian/msCam2.avi + - minian/msCam3.avi + - minian/msCam4.avi + - minian/msCam5.avi + - minian/msCam6.avi + - minian/msCam7.avi + - minian/msCam8.avi + - minian/msCam9.avi + - minian/msCam10.avi counter: type: cala.nodes.prep.counter frame: From c4b8b7bb6633ba45f00a2689bf8ae948ce5956d7 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 14:52:33 -0800 Subject: [PATCH 17/23] debug: update noob for gather compatibility --- pdm.lock | 7 ++++--- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pdm.lock b/pdm.lock index 66aeef00..16127fbd 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "gui", "tests"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:b791ad4b9d500b62299dfc14d6b10c6b75e16de030f5c33a1234c6b6c7372f5d" +content_hash = "sha256:d02754f2363ec34c50db0fe7d4a94b017e21816863471a002e007ed0fd85e128" [[metadata.targets]] requires_python = ">=3.11" @@ -1821,10 +1821,11 @@ files = [ [[package]] name = "noob" -version = "0.1.1.dev208" +version = "0.1.1.dev209" requires_python = ">=3.11" git = "https://github.com/miniscope/noob.git" -revision = "6f9512e57e10e142335ccddf64a599d85edc73d6" +ref = "scheduler-optimize" +revision = "1cafe617372795c5bfe9c6953416b309572e1676" summary = "Default template for PDM package" dependencies = [ "PyYAML>=6.0.2", diff --git a/pyproject.toml b/pyproject.toml index 95c38602..ebd58be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "pyyaml>=6.0.2", "typer>=0.15.3", "xarray-validate>=0.0.2", - "noob @ git+https://github.com/miniscope/noob.git", + "noob @ git+https://github.com/miniscope/noob.git@scheduler-optimize", "natsort>=8.4.0", ] keywords = [ From 2c0e10ce260ec7fd45be2e73ad44d523ac88ba77 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 14:52:47 -0800 Subject: [PATCH 18/23] debug: asset zarr saving --- src/cala/nodes/io.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index 09a60f8e..7b35cbd3 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -11,7 +11,6 @@ from cala.assets import Asset from cala.config import config -from cala.util import clear_dir class Stream(Protocol): @@ -120,18 +119,14 @@ def stream( raise ValueError(f"Unsupported file format: {suffix}") -def save_asset( - asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path | None = None -) -> None: +def save_asset(asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path) -> Asset: if target_epoch == curr_epoch: - zarr_dir = (config.user_dir / path).resolve() if path else config.user_dir / asset.zarr_path - zarr_dir.mkdir(parents=True, exist_ok=True) - clear_dir(zarr_dir) + zarr_dir = config.user_dir try: - asset.full_array().to_zarr(zarr_dir) # for Traces + asset.full_array().to_zarr(zarr_dir / path, mode="w") # for Traces except AttributeError: - asset.array.as_numpy().to_zarr(zarr_dir, mode="w") - return None + asset.array.as_numpy().to_zarr(zarr_dir / path, mode="w") + return asset def natsort_paths( From f06ce84fcce829c0ca434ae09f16f234285ce330 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 20:08:20 -0800 Subject: [PATCH 19/23] test: motion correction crisp score --- src/cala/testing/util.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 02f87b37..1d8b6d70 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -110,3 +110,26 @@ def expand_boundary(footprints: xr.DataArray) -> xr.DataArray: vectorize=True, dask="parallelized", ) + + +def total_gradient_magnitude(I: np.ndarray) -> float: + """ + Calculates the total gradient magnitude c(I) = || |nabla I| ||_F for a 2D array I. + + This function computes the Frobenius norm of the pixel-wise gradient magnitude map, + which is a common measure of total image variation or signal energy. + + Args: + I (np.ndarray): The 2D input array (e.g., an image). + + Returns: + float: The Frobenius norm of the gradient magnitude map. + """ + if I.ndim != 2: + raise ValueError("Input array I must be 2-dimensional.") + + grad_y, grad_x = np.gradient(I) + grad_magnitude_map = np.sqrt(grad_x**2 + grad_y**2) + c_I = np.linalg.norm(grad_magnitude_map, ord="fro") + + return c_I From 4f9b08eb784f061c2e4877819c0a7e1cf753264b Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 12 Nov 2025 10:49:02 -0800 Subject: [PATCH 20/23] debug: allow non-zarr --- src/cala/assets.py | 3 +-- tests/data/pipelines/long_recording.yaml | 12 ++++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 62f0df83..1e1f2663 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -233,8 +233,7 @@ def append(self, array: xr.DataArray, dim: str) -> None: if dim == AXIS.frames_dim: self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) - is_overflowing = self.array_.sizes[AXIS.frames_dim] > self.flush_interval - if self.zarr_path and is_overflowing: + if self.zarr_path and self.array_.sizes[AXIS.frames_dim] > self.flush_interval: self._flush_zarr() elif dim == AXIS.component_dim: diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 5ee8152e..4e3a3970 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -196,8 +196,8 @@ nodes: save_trace: type: cala.nodes.io.save_asset params: - target_epoch: 999 #5999 formula: 1000 * n_chunks - 1 - path: traces/ + target_epoch: 855999 # formula: 1000 * n_chunks - 1 + path: traces_fin/ depends: - asset: assets.traces - curr_epoch: counter.idx @@ -205,16 +205,12 @@ nodes: save_shape: type: cala.nodes.io.save_asset params: - target_epoch: 999 #5999 formula: 1000 * n_chunks - 1 - path: footprints/ + target_epoch: 855999 # formula: 1000 * n_chunks - 1 + path: footprints_fin/ depends: - asset: assets.footprints - curr_epoch: counter.idx - return: - type: return - depends: downsample.frame - # merge: # type: cala.nodes.merge.merge_existing # params: From 39bf41dc3c2a0ae8af9f9313b41dd536c90cb6a8 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 12 Nov 2025 11:37:00 -0800 Subject: [PATCH 21/23] feat: continue for gui failure --- src/cala/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cala/main.py b/src/cala/main.py index 16251372..7038d2bd 100644 --- a/src/cala/main.py +++ b/src/cala/main.py @@ -15,7 +15,7 @@ try: app = get_app() -except TypeError as e: +except (TypeError, RuntimeError) as e: logger.warning(f"Failed to load gui app: {e}") From 17b2ccfaacda28fbfb169b13f4121f20b4553730 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 21 Nov 2025 11:56:20 -0800 Subject: [PATCH 22/23] format: ruff --- src/cala/testing/util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 1d8b6d70..7e73069e 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -112,7 +112,7 @@ def expand_boundary(footprints: xr.DataArray) -> xr.DataArray: ) -def total_gradient_magnitude(I: np.ndarray) -> float: +def total_gradient_magnitude(image: np.ndarray) -> float: """ Calculates the total gradient magnitude c(I) = || |nabla I| ||_F for a 2D array I. @@ -120,15 +120,15 @@ def total_gradient_magnitude(I: np.ndarray) -> float: which is a common measure of total image variation or signal energy. Args: - I (np.ndarray): The 2D input array (e.g., an image). + image (np.ndarray): The 2D input array (e.g., an image). Returns: float: The Frobenius norm of the gradient magnitude map. """ - if I.ndim != 2: + if image.ndim != 2: raise ValueError("Input array I must be 2-dimensional.") - grad_y, grad_x = np.gradient(I) + grad_y, grad_x = np.gradient(image) grad_magnitude_map = np.sqrt(grad_x**2 + grad_y**2) c_I = np.linalg.norm(grad_magnitude_map, ord="fro") From 0724bd1ec765636f0903338b0f562e17fb5642be Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 21 Nov 2025 13:01:28 -0800 Subject: [PATCH 23/23] rename: detect to segment --- .pre-commit-config.yaml | 2 +- src/cala/nodes/cleanup.py | 16 ---------------- src/cala/nodes/merge.py | 2 +- src/cala/nodes/{detect => segment}/__init__.py | 0 src/cala/nodes/{detect => segment}/catalog.py | 0 src/cala/nodes/{detect => segment}/slice_nmf.py | 0 src/cala/nodes/{detect => segment}/update.py | 0 tests/data/pipelines/long_recording.yaml | 8 ++++---- tests/data/pipelines/minian.yaml | 8 ++++---- tests/data/pipelines/odl.yaml | 8 ++++---- tests/test_iter/test_catalog.py | 8 ++++---- tests/test_iter/test_slice_nmf.py | 12 ++++++------ 12 files changed, 24 insertions(+), 40 deletions(-) rename src/cala/nodes/{detect => segment}/__init__.py (100%) rename src/cala/nodes/{detect => segment}/catalog.py (100%) rename src/cala/nodes/{detect => segment}/slice_nmf.py (100%) rename src/cala/nodes/{detect => segment}/update.py (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b1171cc..59ac6b0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: name: pdm-format entry: pdm format language: system - types: [python] + types: [ python ] pass_filenames: false always_run: true # - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 7319946d..fe52da99 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -169,19 +169,3 @@ def _filter_redundant( keep_ids.append(a[AXIS.id_coord].item()) return keep_ids - - -def merge_components( - footprints: Footprints, - traces: Traces, -) -> A[Footprints, Name("footprints")]: - """ - Merge existing components - - 1. dilate footprints (to consider adjacent components) - 2. find overlaps - 3. send to cataloger?? - 4. then send to all component ingestion - - nvm, let's just do it in cataloger. - """ diff --git a/src/cala/nodes/merge.py b/src/cala/nodes/merge.py index ff8adf48..737e0f25 100644 --- a/src/cala/nodes/merge.py +++ b/src/cala/nodes/merge.py @@ -8,7 +8,7 @@ from cala.assets import Footprints, Overlaps, Traces from cala.models import AXIS -from cala.nodes.detect.catalog import _recompose, _register +from cala.nodes.segment.catalog import _recompose, _register from cala.util import combine_attr_replaces diff --git a/src/cala/nodes/detect/__init__.py b/src/cala/nodes/segment/__init__.py similarity index 100% rename from src/cala/nodes/detect/__init__.py rename to src/cala/nodes/segment/__init__.py diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/segment/catalog.py similarity index 100% rename from src/cala/nodes/detect/catalog.py rename to src/cala/nodes/segment/catalog.py diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/segment/slice_nmf.py similarity index 100% rename from src/cala/nodes/detect/slice_nmf.py rename to src/cala/nodes/segment/slice_nmf.py diff --git a/src/cala/nodes/detect/update.py b/src/cala/nodes/segment/update.py similarity index 100% rename from src/cala/nodes/detect/update.py rename to src/cala/nodes/segment/update.py diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 4e3a3970..fd7ffe66 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -154,7 +154,7 @@ nodes: # DETECT BEGINS nmf: - type: cala.nodes.detect.SliceNMF + type: cala.nodes.segment.SliceNMF params: min_frames: 100 detect_thresh: 8.0 @@ -164,7 +164,7 @@ nodes: - energy: residual.std - detect_radius: size_est.radius catalog: - type: cala.nodes.detect.Cataloger + type: cala.nodes.segment.Cataloger params: age_limit: 100 shape_smooth_kwargs: @@ -181,7 +181,7 @@ nodes: - existing_fp: assets.footprints - existing_tr: assets.traces detect_update: - type: cala.nodes.detect.update_assets + type: cala.nodes.segment.update_assets depends: - new_footprints: catalog.new_footprints - new_traces: catalog.new_traces @@ -224,7 +224,7 @@ nodes: # - overlaps: assets.overlaps # - trigger: detect_update.footprints # merge_update: -# type: cala.nodes.detect.update_assets +# type: cala.nodes.segment.update_assets # depends: # - new_footprints: merge.footprints # - new_traces: merge.traces diff --git a/tests/data/pipelines/minian.yaml b/tests/data/pipelines/minian.yaml index 32eb17b4..48d1a779 100644 --- a/tests/data/pipelines/minian.yaml +++ b/tests/data/pipelines/minian.yaml @@ -157,7 +157,7 @@ nodes: # DETECT BEGINS nmf: - type: cala.nodes.detect.SliceNMF + type: cala.nodes.segment.SliceNMF params: min_frames: 100 detect_thresh: 3.0 @@ -167,7 +167,7 @@ nodes: - energy: residual.std - detect_radius: size_est.radius catalog: - type: cala.nodes.detect.Cataloger + type: cala.nodes.segment.Cataloger params: age_limit: 100 shape_smooth_kwargs: @@ -184,7 +184,7 @@ nodes: - existing_fp: assets.footprints - existing_tr: assets.traces detect_update: - type: cala.nodes.detect.update_assets + type: cala.nodes.segment.update_assets depends: - new_footprints: catalog.new_footprints - new_traces: catalog.new_traces @@ -209,7 +209,7 @@ nodes: # - overlaps: assets.overlaps # - trigger: detect_update.footprints # merge_update: -# type: cala.nodes.detect.update_assets +# type: cala.nodes.segment.update_assets # depends: # - new_footprints: merge.footprints # - new_traces: merge.traces diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index a07a1cd3..79308e3b 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -111,7 +111,7 @@ nodes: # DETECT BEGINS nmf: - type: cala.nodes.detect.SliceNMF + type: cala.nodes.segment.SliceNMF params: min_frames: 10 detect_thresh: 1.0 @@ -121,7 +121,7 @@ nodes: - energy: residual.std - detect_radius: size_est.radius catalog: - type: cala.nodes.detect.Cataloger + type: cala.nodes.segment.Cataloger params: age_limit: 100 shape_smooth_kwargs: @@ -138,7 +138,7 @@ nodes: - existing_fp: assets.footprints - existing_tr: assets.traces detect_update: - type: cala.nodes.detect.update_assets + type: cala.nodes.segment.update_assets depends: - new_footprints: catalog.new_footprints - new_traces: catalog.new_traces @@ -163,7 +163,7 @@ nodes: # - overlaps: assets.overlaps # - trigger: detect_update.footprints # merge_update: - # type: cala.nodes.detect.update_assets + # type: cala.nodes.segment.update_assets # depends: # - new_footprints: merge.footprints # - new_traces: merge.traces diff --git a/tests/test_iter/test_catalog.py b/tests/test_iter/test_catalog.py index b8d045b1..9143b9bb 100644 --- a/tests/test_iter/test_catalog.py +++ b/tests/test_iter/test_catalog.py @@ -4,8 +4,8 @@ from noob.node import NodeSpecification from cala.assets import AXIS, Buffer, Footprints, Traces -from cala.nodes.detect import Cataloger, SliceNMF -from cala.nodes.detect.catalog import _register +from cala.nodes.segment import Cataloger, SliceNMF +from cala.nodes.segment.catalog import _register from cala.testing.catalog_depr import CatalogerDepr from cala.testing.util import expand_boundary, split_footprint @@ -15,7 +15,7 @@ def slice_nmf(): return SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -26,7 +26,7 @@ def cataloger(): return Cataloger.from_specification( spec=NodeSpecification( id="test", - type="cala.nodes.detect.Cataloger", + type="cala.nodes.segment.Cataloger", params={ "age_limit": 100, "trace_smooth_kwargs": {"sigma": 2}, diff --git a/tests/test_iter/test_slice_nmf.py b/tests/test_iter/test_slice_nmf.py index 5f3f773b..10a38b55 100644 --- a/tests/test_iter/test_slice_nmf.py +++ b/tests/test_iter/test_slice_nmf.py @@ -5,8 +5,8 @@ from sklearn.decomposition import NMF from cala.assets import AXIS, Buffer -from cala.nodes.detect import SliceNMF -from cala.nodes.detect.slice_nmf import rank1nmf +from cala.nodes.segment import SliceNMF +from cala.nodes.segment.slice_nmf import rank1nmf from cala.testing.util import assert_scalar_multiple_arrays @@ -15,7 +15,7 @@ def slice_nmf(): return SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -30,7 +30,7 @@ def test_process(slice_nmf, single_cell): if new_component: new_fp, new_tr = new_component else: - raise AssertionError("Failed to detect a new component") + raise AssertionError("Failed to segment a new component") for new, old in zip([new_fp[0], new_tr[0]], [single_cell.footprints, single_cell.traces]): assert_scalar_multiple_arrays(new.array.as_numpy(), old.array.as_numpy()) @@ -40,7 +40,7 @@ def test_chunks(single_cell): nmf = SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -50,7 +50,7 @@ def test_chunks(single_cell): detect_radius=10, ) if not fpts or not trcs: - raise AssertionError("Failed to detect a new component") + raise AssertionError("Failed to segment a new component") factors = [trc.array.data.max() for trc in trcs] fpt_arr = xr.concat([f.array * m for f, m in zip(fpts, factors)], dim="component")