From b067e21bbe49dcafb9bfe07f54f8941d238ff146 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 13:57:17 -0800 Subject: [PATCH 01/33] format: frontend css --- src/cala/gui/templates/index.html | 1 + src/cala/main.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/cala/gui/templates/index.html b/src/cala/gui/templates/index.html index 35d9531c..18ed7b2d 100644 --- a/src/cala/gui/templates/index.html +++ b/src/cala/gui/templates/index.html @@ -5,6 +5,7 @@ + diff --git a/src/cala/main.py b/src/cala/main.py index cd7004b7..16251372 100644 --- a/src/cala/main.py +++ b/src/cala/main.py @@ -1,3 +1,4 @@ +from datetime import datetime from pathlib import Path from typing import Annotated @@ -23,6 +24,11 @@ def main(spec: str, gui: Annotated[bool, typer.Option()] = False) -> None: if gui: uvicorn.run("cala.main:app", reload=False, reload_dirs=[Path(__file__).parent]) else: + start = datetime.now() tube = Tube.from_specification(spec) runner = SynchronousRunner(tube=tube) - runner.run() + try: + runner.run() + finally: + end = datetime.now() + print(f"Finished in {end - start}") From 4d3c32043c5e568723222aa00dfd53277fee2800 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 16:09:31 -0800 Subject: [PATCH 02/33] feat: add downsample --- src/cala/nodes/prep/__init__.py | 2 + src/cala/nodes/prep/downsample.py | 30 +++ tests/data/pipelines/long_recording.yaml | 222 +++++++++++++++++++++++ 3 files changed, 254 insertions(+) create mode 100644 src/cala/nodes/prep/downsample.py create mode 100644 tests/data/pipelines/long_recording.yaml diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index f0753b87..178b95c1 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,5 +1,6 @@ from .background_removal import remove_background from .denoise import Restore, blur +from .downsample import downsample from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean @@ -19,4 +20,5 @@ "Restore", "package_frame", "counter", + "downsample", ] diff --git a/src/cala/nodes/prep/downsample.py b/src/cala/nodes/prep/downsample.py new file mode 100644 index 00000000..1b37f553 --- /dev/null +++ b/src/cala/nodes/prep/downsample.py @@ -0,0 +1,30 @@ +from typing import Annotated as A + +import numpy as np +from noob import Name + +from cala.assets import Frame +from cala.models import AXIS +from cala.nodes.prep import package_frame + + +def downsample( + frames: list[Frame], x_range: tuple[int, int], y_range: tuple[int, int], t_downsample: int = 1 +) -> A[Frame, Name("frame")]: + """ + Downsampling in time and cropping in space. Must be followed by gather node, and + t_downsample has to be same as gather's parameter n value. + + :param frames: + :param x_range: + :param y_range: + :param t_downsample: + :return: + """ + arrays = [] + for frame in frames: + arrays.append(frame.array[x_range[0] : x_range[1], y_range[0] : y_range[1]]) + + return package_frame( + np.mean(arrays, axis=0), arrays[-1][AXIS.frame_coord].item() // t_downsample + ) diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml new file mode 100644 index 00000000..53c51c8f --- /dev/null +++ b/tests/data/pipelines/long_recording.yaml @@ -0,0 +1,222 @@ +noob_id: long-recording +noob_model: noob.tube.TubeSpecification +noob_version: 0.1.1.dev118+g64d81b7 + +assets: + buffer: + type: cala.assets.Buffer + params: + size: 100 + scope: runner + footprints: + type: cala.assets.Footprints + params: + sparsify: True + scope: runner + traces: + type: cala.assets.Traces + scope: runner + params: + # zarr_path: traces/ + peek_size: 100 + pix_stats: + type: cala.assets.PixStats + scope: runner + comp_stats: + type: cala.assets.CompStats + scope: runner + overlaps: + type: cala.assets.Overlaps + scope: runner + residuals: + type: cala.assets.Buffer + params: + size: 100 + scope: runner + + +nodes: + source: + type: cala.nodes.io.stream + params: + files: + - long_recording/1.avi + - long_recording/2.avi + - long_recording/3.avi + - long_recording/4.avi + - long_recording/5.avi + - long_recording/6.avi + counter: + type: cala.nodes.prep.counter + frame: + type: cala.nodes.prep.package_frame + depends: + - frame: source.value + - index: counter.idx + + #PREPROCESS BEGINS + hotpix: + type: cala.nodes.prep.blur + params: + method: median + kwargs: + ksize: 3 + depends: + - frame: frame.value + flatten: + type: cala.nodes.prep.butter + params: + kwargs: + cutoff_frequency_ratio: 0.010 + depends: + - frame: hotpix.frame + lines: + type: cala.nodes.prep.remove_mean + params: + orient: both + depends: + - frame: flatten.frame + motion: + type: cala.nodes.prep.Anchor + depends: + - frame: lines.frame + glow: + type: cala.nodes.prep.GlowRemover + depends: + - frame: motion.frame + gather: + type: gather + params: + n: 4 + depends: + - glow.frame + downsample: # 20fps -> 5fps + type: cala.nodes.prep.downsample + params: + x_range: [ 250, 350 ] + y_range: [ 250, 350 ] + t_downsample: 4 + depends: + - frames: gather.value + size_est: + type: cala.nodes.prep.SizeEst + params: + hardset_radius: 8 + depends: + - frame: downsample.frame + cache: + type: cala.nodes.buffer.fill_buffer + depends: + - buffer: assets.buffer + - frame: downsample.frame + #PREPROCESS ENDS + + # FRAME UPDATE BEGINS + trace_frame: + type: cala.nodes.traces.Tracer + params: + tol: 0.001 + max_iter: 100 + depends: + - traces: assets.traces + - footprints: assets.footprints + - frame: downsample.frame + - overlaps: assets.overlaps + pix_frame: + type: cala.nodes.pixel_stats.ingest_frame + depends: + - pixel_stats: assets.pix_stats + - frame: downsample.frame + - new_traces: trace_frame.latest_trace + - footprints: assets.footprints + comp_frame: + type: cala.nodes.component_stats.ingest_frame + depends: + - component_stats: assets.comp_stats + - frame: downsample.frame + - new_traces: trace_frame.latest_trace + footprints_frame: + type: cala.nodes.footprints.Footprinter + params: + bep: 0 + tol: 0.0001 + max_iter: 5 + ratio_lb: 0.10 + depends: + - footprints: assets.footprints + - pixel_stats: pix_frame.value + - component_stats: comp_frame.value + + residual: + type: cala.nodes.residual.Residuer + depends: + - frame: downsample.frame + - footprints: footprints_frame.footprints + - traces: assets.traces + - residuals: assets.residuals + # FRAME UPDATE ENDS + + # DETECT BEGINS + nmf: + type: cala.nodes.detect.SliceNMF + params: + min_frames: 100 + detect_thresh: 10.0 + reprod_tol: 0.005 + depends: + - residuals: residual.movie + - energy: residual.std + - detect_radius: size_est.radius + catalog: + type: cala.nodes.detect.Cataloger + params: + age_limit: 100 + shape_smooth_kwargs: + ksize: [ 5, 5 ] + sigmaX: 0 + trace_smooth_kwargs: + sigma: 2 + merge_threshold: 0.9 + val_threshold: 0.5 + cnt_threshold: 5 + depends: + - new_fps: nmf.new_fps + - new_trs: nmf.new_trs + - existing_fp: assets.footprints + - existing_tr: assets.traces + detect_update: + type: cala.nodes.detect.update_assets + depends: + - new_footprints: catalog.new_footprints + - new_traces: catalog.new_traces + - footprints: assets.footprints + - traces: assets.traces + - pixel_stats: assets.pix_stats + - component_stats: assets.comp_stats + - overlaps: assets.overlaps + - buffer: assets.buffer + # DETECT ENDS + +# merge: +# type: cala.nodes.merge.merge_existing +# params: +# merge_interval: 500 +# merge_threshold: 0.95 +# smooth_kwargs: +# sigma: 2 +# depends: +# - shapes: assets.footprints +# - traces: assets.traces +# - overlaps: assets.overlaps +# - trigger: detect_update.footprints +# merge_update: +# type: cala.nodes.detect.update_assets +# depends: +# - new_footprints: merge.footprints +# - new_traces: merge.traces +# - footprints: assets.footprints +# - traces: assets.traces +# - pixel_stats: assets.pix_stats +# - component_stats: assets.comp_stats +# - overlaps: assets.overlaps +# - buffer: assets.buffer From 9c35757a8a9e21f991253466f199a43c9944d229 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 16:10:33 -0800 Subject: [PATCH 03/33] tests: benchmarking moved out of tests --- tests/test_prep/test_motion.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/tests/test_prep/test_motion.py b/tests/test_prep/test_motion.py index 6d47a7ea..bff48895 100644 --- a/tests/test_prep/test_motion.py +++ b/tests/test_prep/test_motion.py @@ -64,37 +64,3 @@ def test_motion_estimation(params) -> None: # Allow 1 pixel absolute tolerance np.testing.assert_allclose(estimate, expected[1:], atol=1.0) - - -# def test_motion_with_movie(): -# """ -# For testing how well the motion correction performs with real movie -# """ -# gen = stream( -# [ -# "cala/msCam1.avi", -# "cala/msCam2.avi", -# "cala/msCam3.avi", -# "cala/msCam4.avi", -# "cala/msCam5.avi", -# ] -# ) -# -# fourcc = cv2.VideoWriter_fourcc(*"mp4v") -# out = cv2.VideoWriter("mc_test.avi", fourcc, 24.0, (752, 960)) -# -# stab = Anchor() -# -# for idx, arr in enumerate(gen): -# frame = package_frame(arr, idx) -# frame = blur(frame, method="median", kwargs={"ksize": 3}) -# frame = butter(frame, {}) -# frame = remove_mean(frame, orient="both") -# matched = stab.stabilize(frame) -# -# combined = np.concat([frame.array.values, matched.array.values], axis=0) -# -# frame_bgr = cv2.cvtColor(combined.astype(np.uint8), cv2.COLOR_GRAY2BGR) -# out.write(frame_bgr) -# -# out.release() From c988dc7f831fcf230827f50b68fd1147adc56cc1 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 16:47:02 -0800 Subject: [PATCH 04/33] debug: retain older traces than peek_size --- src/cala/nodes/prep/__init__.py | 4 ++-- src/cala/nodes/traces.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 178b95c1..00f5376f 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,12 +1,12 @@ from .background_removal import remove_background from .denoise import Restore, blur -from .downsample import downsample +from .downsample import downsample # noqa from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .wrap import counter, package_frame +from .wrap import counter, package_frame # isort:skip __all__ = [ "blur", diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index 520886ce..ac8ed617 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -81,7 +81,7 @@ def ingest_frame( ) traces.update(updated_tr, append_dim=AXIS.frames_dim) else: - traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) + traces.array = xr.concat([traces.full_array(), updated_traces], dim=AXIS.frames_dim) return PopSnap.from_array(updated_traces) From 6bdd3d332bf5248eb573e36bc2ec4790d184542e Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 19:17:10 -0800 Subject: [PATCH 05/33] feat: asset save / natsorting videos --- pdm.lock | 12 +++++- pyproject.toml | 1 + src/cala/assets.py | 24 ++++++------ src/cala/nodes/io.py | 50 ++++++++++++++++++++++-- tests/data/pipelines/long_recording.yaml | 14 +++++++ 5 files changed, 84 insertions(+), 17 deletions(-) diff --git a/pdm.lock b/pdm.lock index a3d86080..66aeef00 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "gui", "tests"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:68f2cd4f887a77af253a18c01590866b2fd460982b397836133b8fa09d808e17" +content_hash = "sha256:b791ad4b9d500b62299dfc14d6b10c6b75e16de030f5c33a1234c6b6c7372f5d" [[metadata.targets]] requires_python = ">=3.11" @@ -1715,6 +1715,16 @@ files = [ {file = "myst_parser-4.0.1.tar.gz", hash = "sha256:5cfea715e4f3574138aecbf7d54132296bfd72bb614d31168f48c477a830a7c4"}, ] +[[package]] +name = "natsort" +version = "8.4.0" +requires_python = ">=3.7" +summary = "Simple yet flexible natural sorting in Python." +files = [ + {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, + {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, +] + [[package]] name = "nbclient" version = "0.10.2" diff --git a/pyproject.toml b/pyproject.toml index dbdf603d..95c38602 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "typer>=0.15.3", "xarray-validate>=0.0.2", "noob @ git+https://github.com/miniscope/noob.git", + "natsort>=8.4.0", ] keywords = [ "pipeline", diff --git a/src/cala/assets.py b/src/cala/assets.py index 74092b21..2a8168c5 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -22,6 +22,8 @@ class Asset(BaseModel): validate_schema: bool = False array_: AssetType = None sparsify: ClassVar[bool] = False + zarr_path: Path | None = None + """relative to config.user_data_dir""" _entity: ClassVar[Entity] model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) @@ -61,6 +63,16 @@ def validate_array_schema(self) -> Self: return self + @field_validator("zarr_path", mode="after") + @classmethod + def validate_zarr_path(cls, value: Path | None) -> Path | None: + if value is None: + return value + zarr_dir = (config.user_dir / value).resolve() + zarr_dir.mkdir(parents=True, exist_ok=True) + clear_dir(zarr_dir) + return zarr_dir + class Footprint(Asset): _entity: ClassVar[Entity] = PrivateAttr( @@ -109,8 +121,6 @@ class Footprints(Asset): class Traces(Asset): - zarr_path: Path | None = None - """relative to config.user_data_dir""" peek_size: int | None = None """ Traces(array=array, path=path) -> saves to zarr (should be set in this asset, and leave @@ -182,16 +192,6 @@ def from_array( new_cls.array = array return new_cls - @field_validator("zarr_path", mode="after") - @classmethod - def validate_zarr_path(cls, value: Path | None) -> Path | None: - if value is None: - return value - zarr_dir = (config.user_dir / value).resolve() - zarr_dir.mkdir(parents=True, exist_ok=True) - clear_dir(zarr_dir) - return zarr_dir - @model_validator(mode="after") def check_zarr_setting(self) -> "Traces": if self.zarr_path: diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index a658e3a4..f55f999a 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -1,13 +1,17 @@ from abc import abstractmethod from collections.abc import Generator +from glob import glob from pathlib import Path -from typing import Protocol +from typing import Protocol, Literal import cv2 +from natsort import natsorted from numpy.typing import NDArray from skimage import io +from cala.assets import Asset from cala.config import config +from cala.util import clear_dir class Stream(Protocol): @@ -80,18 +84,30 @@ def stream_videos(video_paths: list[Path]) -> Generator[NDArray]: current_stream.close() -def stream(files: list[str | Path]) -> Generator[NDArray, None, None]: +def stream( + files: list[str | Path] = None, + subdir: str | Path = None, + extension: Literal[".avi"] = None, + prefix: str | None = None, +) -> Generator[NDArray, None, None]: """ Create a video stream from the provided video files. Args: files: List of file paths + subdir: Directory path. Ignored if files is populated. + extension: File extension. Ignored if files is populated. + prefix: File prefix. Ignored if files is populated. Returns: Stream: A stream that yields video frames """ - file_paths = [Path(f) if isinstance(f, str) else f for f in files] - suffix = {path.suffix.lower() for path in file_paths} + if files: + file_paths = [Path(f) if isinstance(f, str) else f for f in files] + suffix = {path.suffix.lower() for path in file_paths} + else: + files = natsort_paths(subdir, extension, prefix) + suffix = {extension} image_format = {".tif", ".tiff"} video_format = {".mp4", ".avi", ".webm"} @@ -102,3 +118,29 @@ def stream(files: list[str | Path]) -> Generator[NDArray, None, None]: yield from stream_images(files) else: raise ValueError(f"Unsupported file format: {suffix}") + + +def save_asset( + asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path | None = None +) -> None: + if target_epoch == curr_epoch: + if path: + zarr_dir = (config.user_dir / path).resolve() + else: + zarr_dir = config.user_dir / asset.zarr_path + + zarr_dir.mkdir(parents=True, exist_ok=True) + clear_dir(zarr_dir) + try: + asset.full_array().to_zarr(zarr_dir) # for Traces + except AttributeError: + asset.array.to_zarr(zarr_dir, mode="w") + return None + + +def natsort_paths( + subdir: str | Path, extension: Literal[".avi"], prefix: str | None = None +) -> list[str]: + prefix = prefix or "" + video_dir = config.video_dir / subdir + return natsorted(glob(f"{str(video_dir)}/{prefix}*{extension}")) diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 53c51c8f..1a1b47ff 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -197,6 +197,20 @@ nodes: - buffer: assets.buffer # DETECT ENDS + persist: + type: cala.nodes.io.save_asset + params: + target_epoch: 800 #5999 formula: 1000 * n_chunks - 1 + path: traces/ + depends: + - asset: assets.traces + - curr_epoch: counter.idx + + + return: + type: return + depends: downsample.frame + # merge: # type: cala.nodes.merge.merge_existing # params: From 76a60bc62d1fdcabbc31efb7285810ec417c8523 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 19:19:13 -0800 Subject: [PATCH 06/33] feat: long_recording.yaml --- tests/data/pipelines/long_recording.yaml | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 1a1b47ff..e9813798 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -39,13 +39,8 @@ nodes: source: type: cala.nodes.io.stream params: - files: - - long_recording/1.avi - - long_recording/2.avi - - long_recording/3.avi - - long_recording/4.avi - - long_recording/5.avi - - long_recording/6.avi + subdir: long_recording + extension: .avi counter: type: cala.nodes.prep.counter frame: @@ -161,7 +156,7 @@ nodes: type: cala.nodes.detect.SliceNMF params: min_frames: 100 - detect_thresh: 10.0 + detect_thresh: 8.0 reprod_tol: 0.005 depends: - residuals: residual.movie From 7d648e451a101583572c460796db6caed3a83e7b Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 19:20:41 -0800 Subject: [PATCH 07/33] init_order --- src/cala/nodes/prep/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 00f5376f..053af4cb 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,12 +1,12 @@ +from .wrap import counter, package_frame # isort:skip from .background_removal import remove_background from .denoise import Restore, blur -from .downsample import downsample # noqa from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .wrap import counter, package_frame # isort:skip +from .downsample import downsample # noqa __all__ = [ "blur", From 730b934f89a22a974de42efa5d526296df35f3aa Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 19:47:16 -0800 Subject: [PATCH 08/33] format: ruff --- src/cala/nodes/io.py | 8 ++------ src/cala/nodes/prep/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index f55f999a..30e08993 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -2,7 +2,7 @@ from collections.abc import Generator from glob import glob from pathlib import Path -from typing import Protocol, Literal +from typing import Literal, Protocol import cv2 from natsort import natsorted @@ -124,11 +124,7 @@ def save_asset( asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path | None = None ) -> None: if target_epoch == curr_epoch: - if path: - zarr_dir = (config.user_dir / path).resolve() - else: - zarr_dir = config.user_dir / asset.zarr_path - + zarr_dir = (config.user_dir / path).resolve() if path else config.user_dir / asset.zarr_path zarr_dir.mkdir(parents=True, exist_ok=True) clear_dir(zarr_dir) try: diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 053af4cb..fe1b4188 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,12 +1,12 @@ -from .wrap import counter, package_frame # isort:skip from .background_removal import remove_background from .denoise import Restore, blur +from .downsample import downsample from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .downsample import downsample # noqa +from .wrap import counter, package_frame # noqa: I001 __all__ = [ "blur", From 22e883849c0c979f17f4fe55695d7d02bc38f284 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 22:37:58 -0800 Subject: [PATCH 09/33] debug: handle sparse in zarr --- src/cala/nodes/io.py | 2 +- tests/data/pipelines/long_recording.yaml | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index 30e08993..09a60f8e 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -130,7 +130,7 @@ def save_asset( try: asset.full_array().to_zarr(zarr_dir) # for Traces except AttributeError: - asset.array.to_zarr(zarr_dir, mode="w") + asset.array.as_numpy().to_zarr(zarr_dir, mode="w") return None diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index e9813798..318cc810 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -192,15 +192,23 @@ nodes: - buffer: assets.buffer # DETECT ENDS - persist: + save_trace: type: cala.nodes.io.save_asset params: - target_epoch: 800 #5999 formula: 1000 * n_chunks - 1 + target_epoch: 999 #5999 formula: 1000 * n_chunks - 1 path: traces/ depends: - asset: assets.traces - curr_epoch: counter.idx + save_shape: + type: cala.nodes.io.save_asset + params: + target_epoch: 999 #5999 formula: 1000 * n_chunks - 1 + path: footprints/ + depends: + - asset: assets.footprints + - curr_epoch: counter.idx return: type: return From 46caa9aea0fa425f82937f81ac95665a6c2fdc07 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 10 Nov 2025 22:49:57 -0800 Subject: [PATCH 10/33] stop fucking switching the import order --- src/cala/nodes/prep/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index fe1b4188..ae7dafff 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,3 +1,4 @@ +from .wrap import counter, package_frame # noqa: I001 from .background_removal import remove_background from .denoise import Restore, blur from .downsample import downsample @@ -6,7 +7,6 @@ from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .wrap import counter, package_frame # noqa: I001 __all__ = [ "blur", From c351f76603d401787dadc54355dc3cf4e5ffd8f1 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 03:27:59 -0800 Subject: [PATCH 11/33] feat: trace interval flushing to zarr functionality --- src/cala/assets.py | 180 +++++++++++++++++++++++++++---------------- tests/test_assets.py | 163 ++++++++++++++++++++++++++++----------- 2 files changed, 234 insertions(+), 109 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 2a8168c5..ac0dcc40 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -2,7 +2,7 @@ import shutil from copy import deepcopy from pathlib import Path -from typing import Any, ClassVar, Self, TypeVar +from typing import ClassVar, Self, TypeVar import numpy as np import xarray as xr @@ -48,6 +48,12 @@ def from_array(cls, array: xr.DataArray) -> Self: def reset(self) -> None: self.array_ = None + if self.zarr_path: + path = Path(self.zarr_path) + try: + shutil.rmtree(path) + except FileNotFoundError: + contextlib.suppress(FileNotFoundError) def __eq__(self, other: "Asset") -> bool: return self.array.equals(other.array) @@ -73,6 +79,22 @@ def validate_zarr_path(cls, value: Path | None) -> Path | None: clear_dir(zarr_dir) return zarr_dir + def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: + da = ( + xr.open_zarr(self.zarr_path) + .isel(isel_filter) + .sel(sel_filter) + .to_dataarray() + .drop_vars(["variable"]) + .isel(variable=0) + ) + return da.assign_coords( + { + AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str), + AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str), + } + ) + class Footprint(Asset): _entity: ClassVar[Entity] = PrivateAttr( @@ -121,92 +143,118 @@ class Footprints(Asset): class Traces(Asset): - peek_size: int | None = None + peek_size: int """ - Traces(array=array, path=path) -> saves to zarr (should be set in this asset, and leave - untouched in nodes.) - Traces.array -> loads from zarr + How many epochs to return when called. + """ + flush_interval: int | None = None + _entity: ClassVar[Entity] = PrivateAttr( + Group( + name="trace-group", + member=Trace.entity(), + group_by=Dims.component, + checks=[is_non_negative], + allow_extra_coords=False, + ) + ) + + @model_validator(mode="after") + def flush_conditions(self) -> Self: + assert (self.flush_interval and self.zarr_path) or ( + not self.flush_interval and not self.zarr_path + ), "zarr_path and flush_interval should either be both provided or neither." + if self.flush_interval: + assert self.flush_interval > self.peek_size, ( + f"flush_interval must be larger than peek_size. " + f"Provided: {self.flush_interval = }, {self.peek_size = }" + ) + return self + + @property + def sizes(self) -> dict[str, int]: + if self.zarr_path: + total_size = {} + for key, val in self.array_.sizes.items(): + if key == AXIS.frames_dim: + total_size[key] = val + self.load_zarr().sizes[key] + else: + total_size[key] = val + return total_size + else: + return self.array_.sizes @property def array(self) -> xr.DataArray: - peek_filter = {AXIS.frames_dim: slice(-self.peek_size, None)} if self.peek_size else None - return self.full_array(isel_filter=peek_filter) + return self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) @array.setter def array(self, array: xr.DataArray) -> None: + """ + In case zarr_path is defined, if array is larger than peek_size, + the epochs older than -peek_size gets flushed to zarr array. + + """ + if self.validate_schema: + array.validate.against_schema(self._entity.model) if self.zarr_path: - if self.validate_schema: - array.validate.against_schema(self._entity.model) - array.to_zarr(self.zarr_path, mode="w") # need to make sure it can overwrite + self.array_ = array.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, mode="w" + ) else: self.array_ = array - def reset(self) -> None: - self.array_ = None - if self.zarr_path: - path = Path(self.zarr_path) - try: - shutil.rmtree(path) - except FileNotFoundError: - contextlib.suppress(FileNotFoundError) + def zarr_append(self, array: xr.DataArray, dim: str) -> None: + """ + Since we cannot simply append to zarr array in memory using xarray syntax, + we provide a convenience method for appending to zarr array and in-memory array + in a streamlined manner. - def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: - if self.zarr_path: - try: - return self.load_zarr(isel_filter=isel_filter, sel_filter=sel_filter).compute() - except FileNotFoundError: - pass - return ( - self.array_.isel(isel_filter).sel(sel_filter) - if self.array_ is not None - else self.array_ - ) + Incoming arrays have to be 2-dimensional. - def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: - da = ( - xr.open_zarr(self.zarr_path) - .isel(isel_filter) - .sel(sel_filter) - .to_dataarray() - .drop_vars(["variable"]) - .isel(variable=0) - ) - return da.assign_coords( - { - AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str), - AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str), - } - ) + """ - def update(self, array: xr.DataArray, **kwargs: Any) -> None: - if self.validate_schema: - array.validate.against_schema(self._entity.model) - array.to_zarr(self.zarr_path, **kwargs) + if dim == AXIS.frames_dim: + self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) + if self.array_.sizes[AXIS.frames_dim] > self.flush_interval: + self._flush_zarr() + elif dim == AXIS.component_dim: + self.array_ = xr.concat( + [self.array_, array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})], dim=dim + ) + array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, append_dim=dim + ) + + def _flush_zarr(self) -> None: + """ + Flushes traces older than peek_size to zarr array. + + """ + self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, append_dim=AXIS.frames_dim + ) + self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) @classmethod - def from_array( - cls, array: xr.DataArray, zarr_path: Path | str | None = None, peek_size: int | None = None - ) -> "Traces": - new_cls = cls(zarr_path=zarr_path, peek_size=peek_size) + def from_array(cls, array: xr.DataArray) -> "Traces": + """ + This is only really used for typing / auto-validation purposes, + so we don't really have to worry about specifying the parameters. + + """ + new_cls = cls(peek_size=array.sizes[AXIS.frames_dim]) new_cls.array = array return new_cls - @model_validator(mode="after") - def check_zarr_setting(self) -> "Traces": + def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: if self.zarr_path: - assert self.peek_size, "peek_size must be set for zarr." - return self - - _entity: ClassVar[Entity] = PrivateAttr( - Group( - name="trace-group", - member=Trace.entity(), - group_by=Dims.component, - checks=[is_non_negative], - allow_extra_coords=False, - ) - ) + return xr.concat( + [self.load_zarr(isel_filter, sel_filter), self.array_], dim=AXIS.frames_dim + ).compute() + else: + return self.array_.isel(isel_filter).sel(sel_filter) class Movie(Asset): diff --git a/tests/test_assets.py b/tests/test_assets.py index 2897eb77..9f5ea726 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -1,6 +1,11 @@ -import os +""" +Since assets without zarr operations is very straightforward, +this test file is mainly focused on testing zarr-integrated assets, +with the exception of Buffer. + +""" + from datetime import datetime -from pathlib import Path import pytest import xarray as xr @@ -9,71 +14,143 @@ from cala.models import AXIS -@pytest.fixture -def path() -> Path: - return Path("assets") +@pytest.mark.parametrize("peek_size", [30, 49, 50, 51, 70]) +def test_array_assignment(tmp_path, four_connected_cells, peek_size): + """ + 1. array should get assigned to memory if smaller than or equal to peek_size + 2. array should get split to memory and zarr_path at peek_size if + array is larger than peek_size - -def test_assign_zarr(path, four_connected_cells): - zarr_traces = Traces(zarr_path=path, peek_size=100) + """ traces = four_connected_cells.traces.array + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) zarr_traces.array = traces - print(os.listdir(zarr_traces.zarr_path)) - assert zarr_traces.array_ is None # not in memory - assert zarr_traces.array.equals(traces) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == min(n_frames, peek_size) + assert zarr_traces.load_zarr().sizes[AXIS.frames_dim] == max(0, n_frames - peek_size) -def test_from_array(four_connected_cells, path): +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_array_peek(tmp_path, four_connected_cells, peek_size): + """ + .array property returns correctly for peek_size smaller, equal, larger + than the saved array. + + """ traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array(traces, path, peek_size=four_connected_cells.n_frames) - assert zarr_traces.array_ is None - assert zarr_traces.array.equals(traces) + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) + zarr_traces.array = traces + + assert zarr_traces.array.sizes[AXIS.frames_dim] == min(peek_size, n_frames) -@pytest.mark.parametrize("peek_shift", [-1, 0, 1]) -def test_peek(four_connected_cells, path, peek_shift): +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_flush_zarr(four_connected_cells, tmp_path, peek_size): + """ + _flush_zarr method flushes epochs older than peek_size to zarr_path, + and leaves only the newest epochs in memory. + + """ traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array( - traces, path, peek_size=four_connected_cells.n_frames + peek_shift + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) ) - if peek_shift >= 0: - assert zarr_traces.array.equals(traces) - else: - with pytest.raises(AssertionError): - assert zarr_traces.array.equals(traces) + + # need to initialize zarr file first to append with _flush_zarr + zarr_traces.array = traces[:, :0] + zarr_traces.array_ = traces # add 50 frames + + zarr_traces._flush_zarr() + # only peek_size left in memory + assert zarr_traces.array_.sizes[AXIS.frames_dim] == min(n_frames, peek_size) + # the rest is in zarr + assert zarr_traces.load_zarr().sizes[AXIS.frames_dim] == max(0, n_frames - peek_size) -def test_ingest_frame(path, four_connected_cells): +@pytest.mark.parametrize("peek_size, flush_interval", [(30, 70)]) +def test_zarr_append_frame(four_connected_cells, tmp_path, peek_size, flush_interval): + """ + Test that when in-memory array size hits flush_interval, old epochs + get flushed. + + """ traces = four_connected_cells.traces.array - old_traces = traces.isel({AXIS.frames_dim: slice(None, -1)}) - zarr_traces = Traces.from_array(old_traces, path, peek_size=four_connected_cells.n_frames) - new_traces = four_connected_cells.traces.array.isel({AXIS.frames_dim: [-1]}) + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames - zarr_traces.update(new_traces, append_dim=AXIS.frames_dim) - # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.frames_dim) + zarr_traces = Traces(zarr_path=tmp_path, peek_size=peek_size, flush_interval=flush_interval) + zarr_traces.array = traces[:, :0] # just initializing zarr + + # array smaller than flush_interval. does not flush. + zarr_traces.zarr_append(traces, dim=AXIS.frames_dim) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == n_frames + + # array larger than flush_interval. flushes down to peek_size. + zarr_traces.zarr_append(traces, dim=AXIS.frames_dim) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == peek_size - assert zarr_traces.array.equals(traces) +@pytest.mark.parametrize("flush_interval", [30]) +def test_zarr_append_component(four_connected_cells, tmp_path, flush_interval): + """ + Test that when adding components, + (a) it gets appropriately divided between in-memory and zarr arrays. + (b) in-memory and zarr arrays can be concatenated together afterward. -def test_ingest_component(four_connected_cells, path): + """ traces = four_connected_cells.traces.array - old_traces = traces.isel({AXIS.component_dim: slice(None, -1)}) - zarr_traces = Traces.from_array(old_traces, path, peek_size=four_connected_cells.n_frames) - new_traces = four_connected_cells.traces.array.isel({AXIS.component_dim: [-1]}) - zarr_traces.update(new_traces, append_dim=AXIS.component_dim) - # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.component_dim) + zarr_traces = Traces(zarr_path=tmp_path, peek_size=20, flush_interval=flush_interval) + zarr_traces.array = traces[:-1, :] # forgot the last component! Also, got flushed. + zarr_traces.zarr_append(traces[-1:, :], dim=AXIS.component_dim) # appendee needs to be 2D + assert zarr_traces.array_[AXIS.component_dim].equals(traces[AXIS.component_dim]) + assert zarr_traces.load_zarr()[AXIS.component_dim].equals(traces[AXIS.component_dim]) + result = xr.concat([zarr_traces.load_zarr(), zarr_traces.array_], dim=AXIS.frames_dim).compute() + assert result.equals(traces) + + +def test_from_array(four_connected_cells): + """ + .from_array method can correctly reproduce the array with .array + + """ + traces = four_connected_cells.traces.array + zarr_traces = Traces.from_array(traces) assert zarr_traces.array.equals(traces) -def test_overwrite(four_connected_cells, four_separate_cells, path): - conn_traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array(conn_traces, path, peek_size=four_connected_cells.n_frames) +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_sizes(four_connected_cells, tmp_path, peek_size): + """ + The sizes property of the asset combines the sizes + of the in-memory array and the zarr array. + + """ + traces = four_connected_cells.traces.array + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) + zarr_traces.array = traces + + assert zarr_traces.sizes == traces.sizes + + +@pytest.mark.xfail +def test_overwrite(four_connected_cells, four_separate_cells): + """ + test that zarr array can get overwritten. - sep_traces = four_separate_cells.traces.array - zarr_traces.array = sep_traces - assert zarr_traces.array.equals(sep_traces) + """ # two cases of init: From 1a148f54d6ce69e6d0dc409f59cb107087d6f52d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 05:45:02 -0800 Subject: [PATCH 12/33] feat: trace mostly supporting new grammar (except zarr component update) --- src/cala/assets.py | 48 ++++++++++++++++++------ src/cala/nodes/traces.py | 67 ++++++++++++++++++---------------- tests/test_assets.py | 6 +-- tests/test_iter/test_traces.py | 45 ++++++++++++++++++++--- 4 files changed, 113 insertions(+), 53 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index ac0dcc40..e8346e27 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -144,11 +144,22 @@ class Footprints(Asset): class Traces(Asset): peek_size: int + """How many epochs to return when called.""" + flush_interval: int | None = None + """How many epochs to wait until next flush""" + + _deprecated: list = PrivateAttr(default_factory=list) """ - How many epochs to return when called. - + Deprecated, or replaced component idx. + Since zarr does not support efficiently removing rows and columns, + there's no easy way to remove a column when a component has been + removed or replaced. Instead, we "mask" it with this "deprecated" + flag. + + When arrays are called, these are filtered out. When new epochs are + added, these are added in with nan values. """ - flush_interval: int | None = None + _entity: ClassVar[Entity] = PrivateAttr( Group( name="trace-group", @@ -205,7 +216,7 @@ def array(self, array: xr.DataArray) -> None: else: self.array_ = array - def zarr_append(self, array: xr.DataArray, dim: str) -> None: + def append(self, array: xr.DataArray, dim: str) -> None: """ Since we cannot simply append to zarr array in memory using xarray syntax, we provide a convenience method for appending to zarr array and in-memory array @@ -217,15 +228,22 @@ def zarr_append(self, array: xr.DataArray, dim: str) -> None: if dim == AXIS.frames_dim: self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) - if self.array_.sizes[AXIS.frames_dim] > self.flush_interval: - self._flush_zarr() + + if self.zarr_path: + if self.array_.sizes[AXIS.frames_dim] > self.flush_interval: + self._flush_zarr() + elif dim == AXIS.component_dim: - self.array_ = xr.concat( - [self.array_, array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})], dim=dim - ) - array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( - self.zarr_path, append_dim=dim - ) + if self.zarr_path: + self.array_ = xr.concat( + [self.array_, array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})], + dim=dim, + ) + array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, append_dim=dim + ) + else: + self.array_ = xr.concat([self.array_, array], dim=dim, combine_attrs="drop") def _flush_zarr(self) -> None: """ @@ -237,6 +255,12 @@ def _flush_zarr(self) -> None: ) self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + def deprecate(self, del_mask: np.ndarray) -> None: + if self.zarr_path: + ... + else: + self.array_ = self.array_[~del_mask] + @classmethod def from_array(cls, array: xr.DataArray) -> "Traces": """ diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index ac8ed617..c115ab6f 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -79,9 +79,9 @@ def ingest_frame( updated_tr = updated_traces.volumize.dim_with_coords( dim=AXIS.frames_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] ) - traces.update(updated_tr, append_dim=AXIS.frames_dim) + traces.append(updated_tr, dim=AXIS.frames_dim) else: - traces.array = xr.concat([traces.full_array(), updated_traces], dim=AXIS.frames_dim) + traces.append(updated_traces, dim=AXIS.frames_dim) return PopSnap.from_array(updated_traces) @@ -204,44 +204,47 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: :param new_traces: Can be either a newly registered trace or an updated existing one. """ - c = traces.full_array() - c_det = new_traces.array + c_new = new_traces.array - if c_det is None: + if c_new is None: return traces - if c is None: - traces.array = c_det + if traces.array is None: + traces.array = c_new return traces - if c.sizes[AXIS.frames_dim] > c_det.sizes[AXIS.frames_dim]: - # if newly detected cells are truncated, pad with np.nans - c_new = xr.DataArray( - np.full((c_det.sizes[AXIS.component_dim], c.sizes[AXIS.frames_dim]), np.nan), - dims=[AXIS.component_dim, AXIS.frames_dim], - coords=c.isel({AXIS.component_dim: 0}).coords, - ) - c_new[AXIS.id_coord] = c_det[AXIS.id_coord] - c_new[AXIS.detect_coord] = c_det[AXIS.detect_coord] - - c_new.loc[ - {AXIS.frames_dim: slice(c.sizes[AXIS.frames_dim] - c_det.sizes[AXIS.frames_dim], None)} - ] = c_det - else: - c_new = c_det + total_frames = traces.sizes[AXIS.frames_dim] + new_n_frames = c_new.sizes[AXIS.frames_dim] - merged_ids = c_det.attrs.get("replaces") + merged_ids = c_new.attrs.get("replaces") if merged_ids: - if traces.zarr_path: - invalid = c[AXIS.id_coord].isin(merged_ids) - traces.array = c.where(~invalid.compute(), drop=True).compute() - else: - intact_mask = ~np.isin(c[AXIS.id_coord].values, merged_ids) - c = c[intact_mask] + del_mask = np.isin(traces.array[AXIS.id_coord].values, merged_ids) + traces.deprecate(del_mask) - if traces.zarr_path: - traces.update(c_new, append_dim=AXIS.component_dim) + if total_frames > new_n_frames: + # if newly detected cells are truncated, pad with np.nans + c_pad = _pad_history(c_new, total_frames, np.nan) else: - traces.array = xr.concat([c, c_new], dim=AXIS.component_dim, combine_attrs="drop") + c_pad = c_new + + traces.append(c_pad, dim=AXIS.component_dim) return traces + + +def _pad_history(traces: xr.DataArray, total_nframes: int, value: float = np.nan) -> xr.DataArray: + """ + Pad unknown historical epochs with values... + + """ + new_nframes = traces.sizes[AXIS.frames_dim] + + c_new = xr.DataArray( + np.full((traces.sizes[AXIS.component_dim], total_nframes), value), + dims=[AXIS.component_dim, AXIS.frames_dim], + coords=traces[AXIS.component_dim].coords, + ) + + c_new.loc[{AXIS.frames_dim: slice(total_nframes - new_nframes, None)}] = traces + + return c_new diff --git a/tests/test_assets.py b/tests/test_assets.py index 9f5ea726..04a8f97a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -90,11 +90,11 @@ def test_zarr_append_frame(four_connected_cells, tmp_path, peek_size, flush_inte zarr_traces.array = traces[:, :0] # just initializing zarr # array smaller than flush_interval. does not flush. - zarr_traces.zarr_append(traces, dim=AXIS.frames_dim) + zarr_traces.append(traces, dim=AXIS.frames_dim) assert zarr_traces.array_.sizes[AXIS.frames_dim] == n_frames # array larger than flush_interval. flushes down to peek_size. - zarr_traces.zarr_append(traces, dim=AXIS.frames_dim) + zarr_traces.append(traces, dim=AXIS.frames_dim) assert zarr_traces.array_.sizes[AXIS.frames_dim] == peek_size @@ -110,7 +110,7 @@ def test_zarr_append_component(four_connected_cells, tmp_path, flush_interval): zarr_traces = Traces(zarr_path=tmp_path, peek_size=20, flush_interval=flush_interval) zarr_traces.array = traces[:-1, :] # forgot the last component! Also, got flushed. - zarr_traces.zarr_append(traces[-1:, :], dim=AXIS.component_dim) # appendee needs to be 2D + zarr_traces.append(traces[-1:, :], dim=AXIS.component_dim) # appendee needs to be 2D assert zarr_traces.array_[AXIS.component_dim].equals(traces[AXIS.component_dim]) assert zarr_traces.load_zarr()[AXIS.component_dim].equals(traces[AXIS.component_dim]) diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index bd5057b6..9139f89e 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -18,15 +18,29 @@ def frame_update() -> Node: ) +@pytest.mark.parametrize( + "zarr_setup", + [ + {"zarr_path": None, "peek_size": 50, "flush_interval": None}, + {"zarr_path": "tmp", "peek_size": 30, "flush_interval": 40}, + ], +) @pytest.mark.parametrize("toy", ["four_separate_cells", "four_connected_cells"]) -def test_update_traces(frame_update, toy, request) -> None: +def test_ingest_frame(frame_update, toy, zarr_setup, request, tmp_path) -> None: + """ + Frame ingestion step adds new traces to the Traces instance + that matches true brightness, overlapping and non-overlapping alike. + + """ toy = request.getfixturevalue(toy) + zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None xray = Node.from_specification( spec=NodeSpecification(id="test", type="cala.nodes.overlap.initialize") ) - traces = Traces.from_array(toy.traces.array.isel({AXIS.frames_dim: slice(None, -1)})) + traces = Traces(array_=None, **zarr_setup) + traces.array = toy.traces.array.isel({AXIS.frames_dim: slice(None, -1)}) frame = Frame.from_array(toy.make_movie().array.isel({AXIS.frames_dim: -1})) overlap = xray.process(overlaps=Overlaps(), footprints=toy.footprints) @@ -46,14 +60,33 @@ def comp_update() -> Node: ) +@pytest.mark.parametrize( + "zarr_setup", + [ + {"zarr_path": None, "peek_size": 40, "flush_interval": None}, + {"zarr_path": "tmp", "peek_size": 30, "flush_interval": 40}, + ], +) @pytest.mark.parametrize("toy", ["four_separate_cells"]) -def test_ingest_component(comp_update, toy, request) -> None: +def test_ingest_component(comp_update, toy, request, zarr_setup, tmp_path) -> None: + """ + *New component is always the same length as peek_size. + + I can add components that... + 1. is single and new + 2. is single and replacing + 3. is multiple + 4. is multiple and replacing + + """ toy = request.getfixturevalue(toy) + zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None - traces = Traces.from_array(toy.traces.array.isel({AXIS.component_dim: slice(None, -1)})) + traces = Traces(array_=None, **zarr_setup) + traces.array = toy.traces.array.isel({AXIS.component_dim: slice(None, -1)}) new_traces = toy.traces.array.isel( - {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-40, None)} + {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-zarr_setup["peek_size"], None)} ) new_traces.attrs["replaces"] = ["cell_0"] @@ -63,4 +96,4 @@ def test_ingest_component(comp_update, toy, request) -> None: expected = toy.traces.array.drop_sel({AXIS.component_dim: 0}) expected.loc[{AXIS.component_dim: -1, AXIS.frames_dim: slice(None, 10)}] = np.nan - assert result.array.equals(expected) + assert result.array_.equals(expected) From 4d0e5f42058f07c6b6e4f339661c2041508b49a2 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 11:28:46 -0800 Subject: [PATCH 13/33] feat: implement zarr and in-memory caching in traces --- src/cala/assets.py | 51 +++++++++++++++++++++++++--------- src/cala/nodes/traces.py | 4 +-- tests/data/pipelines/odl.yaml | 4 +++ tests/test_assets.py | 18 ++++++++++++ tests/test_iter/test_traces.py | 11 ++++---- 5 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index e8346e27..e423cf72 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -143,12 +143,12 @@ class Footprints(Asset): class Traces(Asset): - peek_size: int + peek_size: int = None """How many epochs to return when called.""" flush_interval: int | None = None """How many epochs to wait until next flush""" - _deprecated: list = PrivateAttr(default_factory=list) + _deprecated: list[str] = PrivateAttr(default_factory=list) """ Deprecated, or replaced component idx. Since zarr does not support efficiently removing rows and columns, @@ -197,7 +197,11 @@ def sizes(self) -> dict[str, int]: @property def array(self) -> xr.DataArray: - return self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + return ( + self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + if self.array_ is not None + else self.array_ + ) @array.setter def array(self, array: xr.DataArray) -> None: @@ -248,18 +252,37 @@ def append(self, array: xr.DataArray, dim: str) -> None: def _flush_zarr(self) -> None: """ Flushes traces older than peek_size to zarr array. + Needs to append nans to deprecated components, since they get deleted + in in-memory array, but persist in zarr array. + + Could do this much more elegantly by pre-allocating .array_ """ - self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( - self.zarr_path, append_dim=AXIS.frames_dim - ) + raw_zarr = self.load_zarr() + to_flush = self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)}) + if self._deprecated: + zarr_ids = raw_zarr[AXIS.id_coord].values + zarr_detects = raw_zarr[AXIS.detect_coord].values + intact_mask = ~np.isin(zarr_ids, self._deprecated) + n_flush = to_flush.sizes[AXIS.frames_dim] + prealloc = xr.DataArray( + np.full((raw_zarr.sizes[AXIS.component_dim], n_flush), np.nan), + dims=[AXIS.component_dim, AXIS.frames_dim], + coords={ + AXIS.id_coord: (AXIS.component_dim, zarr_ids), + AXIS.detect_coord: (AXIS.component_dim, zarr_detects), + }, + ).assign_coords(to_flush[AXIS.frames_dim].coords) + prealloc.loc[intact_mask] = to_flush + prealloc.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim) + else: + to_flush.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim) self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) - def deprecate(self, del_mask: np.ndarray) -> None: + def keep(self, intact_mask: np.ndarray) -> None: if self.zarr_path: - ... - else: - self.array_ = self.array_[~del_mask] + self._deprecated.extend(self.array_[AXIS.id_coord].values[~intact_mask]) + self.array_ = self.array_[intact_mask] @classmethod def from_array(cls, array: xr.DataArray) -> "Traces": @@ -274,9 +297,11 @@ def from_array(cls, array: xr.DataArray) -> "Traces": def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: if self.zarr_path: - return xr.concat( - [self.load_zarr(isel_filter, sel_filter), self.array_], dim=AXIS.frames_dim - ).compute() + raw_zarr = self.load_zarr(isel_filter, sel_filter) + zarr_ids = raw_zarr[AXIS.id_coord].values + intact_mask = ~np.isin(zarr_ids, self._deprecated) + + return xr.concat([raw_zarr[intact_mask], self.array_], dim=AXIS.frames_dim).compute() else: return self.array_.isel(isel_filter).sel(sel_filter) diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index c115ab6f..e34f3bf8 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -218,8 +218,8 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: merged_ids = c_new.attrs.get("replaces") if merged_ids: - del_mask = np.isin(traces.array[AXIS.id_coord].values, merged_ids) - traces.deprecate(del_mask) + intact_mask = ~np.isin(traces.array[AXIS.id_coord].values, merged_ids) + traces.keep(intact_mask) if total_frames > new_n_frames: # if newly detected cells are truncated, pad with np.nans diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index b980e57a..a07a1cd3 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -15,6 +15,10 @@ assets: scope: runner traces: type: cala.assets.Traces + params: + zarr_path: traces/ + peek_size: 100 + flush_interval: 200 scope: runner pix_stats: type: cala.assets.PixStats diff --git a/tests/test_assets.py b/tests/test_assets.py index 04a8f97a..7b6b9b4a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -7,6 +7,7 @@ from datetime import datetime +import numpy as np import pytest import xarray as xr @@ -118,6 +119,23 @@ def test_zarr_append_component(four_connected_cells, tmp_path, flush_interval): assert result.equals(traces) +@pytest.mark.parametrize("flush_interval", [30]) +def test_flush_after_deprecated(four_connected_cells, tmp_path, flush_interval) -> None: + traces = four_connected_cells.traces.array + peek_size = 20 + zarr_traces = Traces(zarr_path=tmp_path, peek_size=peek_size, flush_interval=flush_interval) + zarr_traces.array = traces + + merged_ids = zarr_traces.array[AXIS.id_coord].values[0] + intact_mask = ~np.isin(zarr_traces.array[AXIS.id_coord].values, merged_ids) + zarr_traces.keep(intact_mask) + zarr_traces.append(traces[intact_mask], dim=AXIS.frames_dim) + + assert zarr_traces.full_array().equals( + xr.concat([traces] * 2, dim=AXIS.frames_dim)[intact_mask] + ) + + def test_from_array(four_connected_cells): """ .from_array method can correctly reproduce the array with .array diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index 9139f89e..a84c1aa5 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -83,17 +83,18 @@ def test_ingest_component(comp_update, toy, request, zarr_setup, tmp_path) -> No zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None traces = Traces(array_=None, **zarr_setup) - traces.array = toy.traces.array.isel({AXIS.component_dim: slice(None, -1)}) + traces.array = toy.traces.array_.isel({AXIS.component_dim: slice(None, -1)}) - new_traces = toy.traces.array.isel( + new_traces = toy.traces.array_.isel( {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-zarr_setup["peek_size"], None)} ) new_traces.attrs["replaces"] = ["cell_0"] - result = comp_update.process(traces, Traces.from_array(new_traces)) expected = toy.traces.array.drop_sel({AXIS.component_dim: 0}) - expected.loc[{AXIS.component_dim: -1, AXIS.frames_dim: slice(None, 10)}] = np.nan + expected.loc[ + {AXIS.component_dim: -1, AXIS.frames_dim: slice(None, -zarr_setup["peek_size"])} + ] = np.nan - assert result.array_.equals(expected) + assert result.full_array().equals(expected) From 8edf5c4f8437291f8968b544ec69e84900f88eac Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 11:36:30 -0800 Subject: [PATCH 14/33] format: ruff --- src/cala/assets.py | 6 +++--- src/cala/nodes/traces.py | 6 +----- tests/data/pipelines/long_recording.yaml | 3 ++- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index e423cf72..89ea5d38 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -233,9 +233,9 @@ def append(self, array: xr.DataArray, dim: str) -> None: if dim == AXIS.frames_dim: self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) - if self.zarr_path: - if self.array_.sizes[AXIS.frames_dim] > self.flush_interval: - self._flush_zarr() + is_overflowing = self.array_.sizes[AXIS.frames_dim] > self.flush_interval + if self.zarr_path and is_overflowing: + self._flush_zarr() elif dim == AXIS.component_dim: if self.zarr_path: diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index e34f3bf8..f8bcafcc 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -221,11 +221,7 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: intact_mask = ~np.isin(traces.array[AXIS.id_coord].values, merged_ids) traces.keep(intact_mask) - if total_frames > new_n_frames: - # if newly detected cells are truncated, pad with np.nans - c_pad = _pad_history(c_new, total_frames, np.nan) - else: - c_pad = c_new + c_pad = _pad_history(c_new, total_frames, np.nan) if total_frames > new_n_frames else c_new traces.append(c_pad, dim=AXIS.component_dim) diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 318cc810..5ee8152e 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -17,8 +17,9 @@ assets: type: cala.assets.Traces scope: runner params: - # zarr_path: traces/ + zarr_path: traces/ peek_size: 100 + flush_interval: 1000 pix_stats: type: cala.assets.PixStats scope: runner From 1555a2bd9e24c9e23c96aa33b371898a17b8097d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 11:43:37 -0800 Subject: [PATCH 15/33] debug: new component concat --- src/cala/assets.py | 5 ++-- .../pipelines/{with_src.yaml => minian.yaml} | 23 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) rename tests/data/pipelines/{with_src.yaml => minian.yaml} (93%) diff --git a/src/cala/assets.py b/src/cala/assets.py index 89ea5d38..62f0df83 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -239,11 +239,12 @@ def append(self, array: xr.DataArray, dim: str) -> None: elif dim == AXIS.component_dim: if self.zarr_path: + n_in_memory = self.array_.sizes[AXIS.frames_dim] self.array_ = xr.concat( - [self.array_, array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})], + [self.array_, array.isel({AXIS.frames_dim: slice(-n_in_memory, None)})], dim=dim, ) - array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + array.isel({AXIS.frames_dim: slice(None, -n_in_memory)}).to_zarr( self.zarr_path, append_dim=dim ) else: diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/minian.yaml similarity index 93% rename from tests/data/pipelines/with_src.yaml rename to tests/data/pipelines/minian.yaml index 3595697f..d66895d9 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/minian.yaml @@ -1,4 +1,4 @@ -noob_id: cala-with-movie +noob_id: with-minian noob_model: noob.tube.TubeSpecification noob_version: 0.1.1.dev118+g64d81b7 @@ -17,8 +17,9 @@ assets: type: cala.assets.Traces scope: runner params: - # zarr_path: traces/ + zarr_path: traces/ peek_size: 100 + flush_interval: 1000 pix_stats: type: cala.assets.PixStats scope: runner @@ -41,15 +42,15 @@ nodes: params: files: - minian/msCam1.avi - - minian/msCam2.avi - - minian/msCam3.avi - - minian/msCam4.avi - - minian/msCam5.avi - - minian/msCam6.avi - - minian/msCam7.avi - - minian/msCam8.avi - - minian/msCam9.avi - - minian/msCam10.avi + # - minian/msCam2.avi + # - minian/msCam3.avi + # - minian/msCam4.avi + # - minian/msCam5.avi + # - minian/msCam6.avi + # - minian/msCam7.avi + # - minian/msCam8.avi + # - minian/msCam9.avi + # - minian/msCam10.avi counter: type: cala.nodes.prep.counter frame: From d7dffc9b96328e9e4cc610ff0d4a04e7dc237d0d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 11:46:51 -0800 Subject: [PATCH 16/33] debug: new component concat --- tests/data/pipelines/minian.yaml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/data/pipelines/minian.yaml b/tests/data/pipelines/minian.yaml index d66895d9..32eb17b4 100644 --- a/tests/data/pipelines/minian.yaml +++ b/tests/data/pipelines/minian.yaml @@ -42,15 +42,15 @@ nodes: params: files: - minian/msCam1.avi - # - minian/msCam2.avi - # - minian/msCam3.avi - # - minian/msCam4.avi - # - minian/msCam5.avi - # - minian/msCam6.avi - # - minian/msCam7.avi - # - minian/msCam8.avi - # - minian/msCam9.avi - # - minian/msCam10.avi + - minian/msCam2.avi + - minian/msCam3.avi + - minian/msCam4.avi + - minian/msCam5.avi + - minian/msCam6.avi + - minian/msCam7.avi + - minian/msCam8.avi + - minian/msCam9.avi + - minian/msCam10.avi counter: type: cala.nodes.prep.counter frame: From 93653aa793567ec05213a523a30bb97f3239b642 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 14:52:33 -0800 Subject: [PATCH 17/33] debug: update noob for gather compatibility --- pdm.lock | 7 ++++--- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pdm.lock b/pdm.lock index 66aeef00..16127fbd 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "gui", "tests"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:b791ad4b9d500b62299dfc14d6b10c6b75e16de030f5c33a1234c6b6c7372f5d" +content_hash = "sha256:d02754f2363ec34c50db0fe7d4a94b017e21816863471a002e007ed0fd85e128" [[metadata.targets]] requires_python = ">=3.11" @@ -1821,10 +1821,11 @@ files = [ [[package]] name = "noob" -version = "0.1.1.dev208" +version = "0.1.1.dev209" requires_python = ">=3.11" git = "https://github.com/miniscope/noob.git" -revision = "6f9512e57e10e142335ccddf64a599d85edc73d6" +ref = "scheduler-optimize" +revision = "1cafe617372795c5bfe9c6953416b309572e1676" summary = "Default template for PDM package" dependencies = [ "PyYAML>=6.0.2", diff --git a/pyproject.toml b/pyproject.toml index 95c38602..ebd58be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "pyyaml>=6.0.2", "typer>=0.15.3", "xarray-validate>=0.0.2", - "noob @ git+https://github.com/miniscope/noob.git", + "noob @ git+https://github.com/miniscope/noob.git@scheduler-optimize", "natsort>=8.4.0", ] keywords = [ From a369cd1392ae7ded52a56e3c3491fc282eec1dbd Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 14:52:47 -0800 Subject: [PATCH 18/33] debug: asset zarr saving --- src/cala/nodes/io.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index 09a60f8e..7b35cbd3 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -11,7 +11,6 @@ from cala.assets import Asset from cala.config import config -from cala.util import clear_dir class Stream(Protocol): @@ -120,18 +119,14 @@ def stream( raise ValueError(f"Unsupported file format: {suffix}") -def save_asset( - asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path | None = None -) -> None: +def save_asset(asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path) -> Asset: if target_epoch == curr_epoch: - zarr_dir = (config.user_dir / path).resolve() if path else config.user_dir / asset.zarr_path - zarr_dir.mkdir(parents=True, exist_ok=True) - clear_dir(zarr_dir) + zarr_dir = config.user_dir try: - asset.full_array().to_zarr(zarr_dir) # for Traces + asset.full_array().to_zarr(zarr_dir / path, mode="w") # for Traces except AttributeError: - asset.array.as_numpy().to_zarr(zarr_dir, mode="w") - return None + asset.array.as_numpy().to_zarr(zarr_dir / path, mode="w") + return asset def natsort_paths( From f9f89d19dbec249315664212976281bb41b50792 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 11 Nov 2025 20:08:20 -0800 Subject: [PATCH 19/33] test: motion correction crisp score --- src/cala/testing/util.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 02f87b37..1d8b6d70 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -110,3 +110,26 @@ def expand_boundary(footprints: xr.DataArray) -> xr.DataArray: vectorize=True, dask="parallelized", ) + + +def total_gradient_magnitude(I: np.ndarray) -> float: + """ + Calculates the total gradient magnitude c(I) = || |nabla I| ||_F for a 2D array I. + + This function computes the Frobenius norm of the pixel-wise gradient magnitude map, + which is a common measure of total image variation or signal energy. + + Args: + I (np.ndarray): The 2D input array (e.g., an image). + + Returns: + float: The Frobenius norm of the gradient magnitude map. + """ + if I.ndim != 2: + raise ValueError("Input array I must be 2-dimensional.") + + grad_y, grad_x = np.gradient(I) + grad_magnitude_map = np.sqrt(grad_x**2 + grad_y**2) + c_I = np.linalg.norm(grad_magnitude_map, ord="fro") + + return c_I From bea740697cd69ba002b4f11c1f63f016470e2a4e Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 12 Nov 2025 10:49:02 -0800 Subject: [PATCH 20/33] debug: allow non-zarr --- src/cala/assets.py | 3 +-- tests/data/pipelines/long_recording.yaml | 12 ++++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 62f0df83..1e1f2663 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -233,8 +233,7 @@ def append(self, array: xr.DataArray, dim: str) -> None: if dim == AXIS.frames_dim: self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) - is_overflowing = self.array_.sizes[AXIS.frames_dim] > self.flush_interval - if self.zarr_path and is_overflowing: + if self.zarr_path and self.array_.sizes[AXIS.frames_dim] > self.flush_interval: self._flush_zarr() elif dim == AXIS.component_dim: diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 5ee8152e..4e3a3970 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -196,8 +196,8 @@ nodes: save_trace: type: cala.nodes.io.save_asset params: - target_epoch: 999 #5999 formula: 1000 * n_chunks - 1 - path: traces/ + target_epoch: 855999 # formula: 1000 * n_chunks - 1 + path: traces_fin/ depends: - asset: assets.traces - curr_epoch: counter.idx @@ -205,16 +205,12 @@ nodes: save_shape: type: cala.nodes.io.save_asset params: - target_epoch: 999 #5999 formula: 1000 * n_chunks - 1 - path: footprints/ + target_epoch: 855999 # formula: 1000 * n_chunks - 1 + path: footprints_fin/ depends: - asset: assets.footprints - curr_epoch: counter.idx - return: - type: return - depends: downsample.frame - # merge: # type: cala.nodes.merge.merge_existing # params: From 59c708eb883340151d5cf7da4b9f98d10c35e185 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 12 Nov 2025 11:37:00 -0800 Subject: [PATCH 21/33] feat: continue for gui failure --- src/cala/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cala/main.py b/src/cala/main.py index 16251372..7038d2bd 100644 --- a/src/cala/main.py +++ b/src/cala/main.py @@ -15,7 +15,7 @@ try: app = get_app() -except TypeError as e: +except (TypeError, RuntimeError) as e: logger.warning(f"Failed to load gui app: {e}") From d903419ed0be3292d2913f870c9b4a112875c927 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 21 Nov 2025 11:56:20 -0800 Subject: [PATCH 22/33] format: ruff --- src/cala/testing/util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 1d8b6d70..7e73069e 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -112,7 +112,7 @@ def expand_boundary(footprints: xr.DataArray) -> xr.DataArray: ) -def total_gradient_magnitude(I: np.ndarray) -> float: +def total_gradient_magnitude(image: np.ndarray) -> float: """ Calculates the total gradient magnitude c(I) = || |nabla I| ||_F for a 2D array I. @@ -120,15 +120,15 @@ def total_gradient_magnitude(I: np.ndarray) -> float: which is a common measure of total image variation or signal energy. Args: - I (np.ndarray): The 2D input array (e.g., an image). + image (np.ndarray): The 2D input array (e.g., an image). Returns: float: The Frobenius norm of the gradient magnitude map. """ - if I.ndim != 2: + if image.ndim != 2: raise ValueError("Input array I must be 2-dimensional.") - grad_y, grad_x = np.gradient(I) + grad_y, grad_x = np.gradient(image) grad_magnitude_map = np.sqrt(grad_x**2 + grad_y**2) c_I = np.linalg.norm(grad_magnitude_map, ord="fro") From dd306e4e4bcf725dd483804646576cfcfe14ac54 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 21 Nov 2025 13:01:28 -0800 Subject: [PATCH 23/33] rename: detect to segment --- .pre-commit-config.yaml | 2 +- src/cala/nodes/cleanup.py | 16 ---------------- src/cala/nodes/merge.py | 2 +- src/cala/nodes/{detect => segment}/__init__.py | 0 src/cala/nodes/{detect => segment}/catalog.py | 0 src/cala/nodes/{detect => segment}/slice_nmf.py | 0 src/cala/nodes/{detect => segment}/update.py | 0 tests/data/pipelines/long_recording.yaml | 8 ++++---- tests/data/pipelines/minian.yaml | 8 ++++---- tests/data/pipelines/odl.yaml | 8 ++++---- tests/test_iter/test_catalog.py | 8 ++++---- tests/test_iter/test_slice_nmf.py | 12 ++++++------ 12 files changed, 24 insertions(+), 40 deletions(-) rename src/cala/nodes/{detect => segment}/__init__.py (100%) rename src/cala/nodes/{detect => segment}/catalog.py (100%) rename src/cala/nodes/{detect => segment}/slice_nmf.py (100%) rename src/cala/nodes/{detect => segment}/update.py (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b1171cc..59ac6b0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: name: pdm-format entry: pdm format language: system - types: [python] + types: [ python ] pass_filenames: false always_run: true # - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 7319946d..fe52da99 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -169,19 +169,3 @@ def _filter_redundant( keep_ids.append(a[AXIS.id_coord].item()) return keep_ids - - -def merge_components( - footprints: Footprints, - traces: Traces, -) -> A[Footprints, Name("footprints")]: - """ - Merge existing components - - 1. dilate footprints (to consider adjacent components) - 2. find overlaps - 3. send to cataloger?? - 4. then send to all component ingestion - - nvm, let's just do it in cataloger. - """ diff --git a/src/cala/nodes/merge.py b/src/cala/nodes/merge.py index ff8adf48..737e0f25 100644 --- a/src/cala/nodes/merge.py +++ b/src/cala/nodes/merge.py @@ -8,7 +8,7 @@ from cala.assets import Footprints, Overlaps, Traces from cala.models import AXIS -from cala.nodes.detect.catalog import _recompose, _register +from cala.nodes.segment.catalog import _recompose, _register from cala.util import combine_attr_replaces diff --git a/src/cala/nodes/detect/__init__.py b/src/cala/nodes/segment/__init__.py similarity index 100% rename from src/cala/nodes/detect/__init__.py rename to src/cala/nodes/segment/__init__.py diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/segment/catalog.py similarity index 100% rename from src/cala/nodes/detect/catalog.py rename to src/cala/nodes/segment/catalog.py diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/segment/slice_nmf.py similarity index 100% rename from src/cala/nodes/detect/slice_nmf.py rename to src/cala/nodes/segment/slice_nmf.py diff --git a/src/cala/nodes/detect/update.py b/src/cala/nodes/segment/update.py similarity index 100% rename from src/cala/nodes/detect/update.py rename to src/cala/nodes/segment/update.py diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 4e3a3970..fd7ffe66 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -154,7 +154,7 @@ nodes: # DETECT BEGINS nmf: - type: cala.nodes.detect.SliceNMF + type: cala.nodes.segment.SliceNMF params: min_frames: 100 detect_thresh: 8.0 @@ -164,7 +164,7 @@ nodes: - energy: residual.std - detect_radius: size_est.radius catalog: - type: cala.nodes.detect.Cataloger + type: cala.nodes.segment.Cataloger params: age_limit: 100 shape_smooth_kwargs: @@ -181,7 +181,7 @@ nodes: - existing_fp: assets.footprints - existing_tr: assets.traces detect_update: - type: cala.nodes.detect.update_assets + type: cala.nodes.segment.update_assets depends: - new_footprints: catalog.new_footprints - new_traces: catalog.new_traces @@ -224,7 +224,7 @@ nodes: # - overlaps: assets.overlaps # - trigger: detect_update.footprints # merge_update: -# type: cala.nodes.detect.update_assets +# type: cala.nodes.segment.update_assets # depends: # - new_footprints: merge.footprints # - new_traces: merge.traces diff --git a/tests/data/pipelines/minian.yaml b/tests/data/pipelines/minian.yaml index 32eb17b4..48d1a779 100644 --- a/tests/data/pipelines/minian.yaml +++ b/tests/data/pipelines/minian.yaml @@ -157,7 +157,7 @@ nodes: # DETECT BEGINS nmf: - type: cala.nodes.detect.SliceNMF + type: cala.nodes.segment.SliceNMF params: min_frames: 100 detect_thresh: 3.0 @@ -167,7 +167,7 @@ nodes: - energy: residual.std - detect_radius: size_est.radius catalog: - type: cala.nodes.detect.Cataloger + type: cala.nodes.segment.Cataloger params: age_limit: 100 shape_smooth_kwargs: @@ -184,7 +184,7 @@ nodes: - existing_fp: assets.footprints - existing_tr: assets.traces detect_update: - type: cala.nodes.detect.update_assets + type: cala.nodes.segment.update_assets depends: - new_footprints: catalog.new_footprints - new_traces: catalog.new_traces @@ -209,7 +209,7 @@ nodes: # - overlaps: assets.overlaps # - trigger: detect_update.footprints # merge_update: -# type: cala.nodes.detect.update_assets +# type: cala.nodes.segment.update_assets # depends: # - new_footprints: merge.footprints # - new_traces: merge.traces diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index a07a1cd3..79308e3b 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -111,7 +111,7 @@ nodes: # DETECT BEGINS nmf: - type: cala.nodes.detect.SliceNMF + type: cala.nodes.segment.SliceNMF params: min_frames: 10 detect_thresh: 1.0 @@ -121,7 +121,7 @@ nodes: - energy: residual.std - detect_radius: size_est.radius catalog: - type: cala.nodes.detect.Cataloger + type: cala.nodes.segment.Cataloger params: age_limit: 100 shape_smooth_kwargs: @@ -138,7 +138,7 @@ nodes: - existing_fp: assets.footprints - existing_tr: assets.traces detect_update: - type: cala.nodes.detect.update_assets + type: cala.nodes.segment.update_assets depends: - new_footprints: catalog.new_footprints - new_traces: catalog.new_traces @@ -163,7 +163,7 @@ nodes: # - overlaps: assets.overlaps # - trigger: detect_update.footprints # merge_update: - # type: cala.nodes.detect.update_assets + # type: cala.nodes.segment.update_assets # depends: # - new_footprints: merge.footprints # - new_traces: merge.traces diff --git a/tests/test_iter/test_catalog.py b/tests/test_iter/test_catalog.py index b8d045b1..9143b9bb 100644 --- a/tests/test_iter/test_catalog.py +++ b/tests/test_iter/test_catalog.py @@ -4,8 +4,8 @@ from noob.node import NodeSpecification from cala.assets import AXIS, Buffer, Footprints, Traces -from cala.nodes.detect import Cataloger, SliceNMF -from cala.nodes.detect.catalog import _register +from cala.nodes.segment import Cataloger, SliceNMF +from cala.nodes.segment.catalog import _register from cala.testing.catalog_depr import CatalogerDepr from cala.testing.util import expand_boundary, split_footprint @@ -15,7 +15,7 @@ def slice_nmf(): return SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -26,7 +26,7 @@ def cataloger(): return Cataloger.from_specification( spec=NodeSpecification( id="test", - type="cala.nodes.detect.Cataloger", + type="cala.nodes.segment.Cataloger", params={ "age_limit": 100, "trace_smooth_kwargs": {"sigma": 2}, diff --git a/tests/test_iter/test_slice_nmf.py b/tests/test_iter/test_slice_nmf.py index 5f3f773b..10a38b55 100644 --- a/tests/test_iter/test_slice_nmf.py +++ b/tests/test_iter/test_slice_nmf.py @@ -5,8 +5,8 @@ from sklearn.decomposition import NMF from cala.assets import AXIS, Buffer -from cala.nodes.detect import SliceNMF -from cala.nodes.detect.slice_nmf import rank1nmf +from cala.nodes.segment import SliceNMF +from cala.nodes.segment.slice_nmf import rank1nmf from cala.testing.util import assert_scalar_multiple_arrays @@ -15,7 +15,7 @@ def slice_nmf(): return SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -30,7 +30,7 @@ def test_process(slice_nmf, single_cell): if new_component: new_fp, new_tr = new_component else: - raise AssertionError("Failed to detect a new component") + raise AssertionError("Failed to segment a new component") for new, old in zip([new_fp[0], new_tr[0]], [single_cell.footprints, single_cell.traces]): assert_scalar_multiple_arrays(new.array.as_numpy(), old.array.as_numpy()) @@ -40,7 +40,7 @@ def test_chunks(single_cell): nmf = SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -50,7 +50,7 @@ def test_chunks(single_cell): detect_radius=10, ) if not fpts or not trcs: - raise AssertionError("Failed to detect a new component") + raise AssertionError("Failed to segment a new component") factors = [trc.array.data.max() for trc in trcs] fpt_arr = xr.concat([f.array * m for f, m in zip(fpts, factors)], dim="component") From eb077b91da56878df2589b6e40c86423c212fbce Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 21 Nov 2025 17:27:08 -0800 Subject: [PATCH 24/33] refactor: restructure filetree to group omf --- src/cala/nodes/buffer.py | 17 -- src/cala/nodes/cleanup.py | 171 ------------------ src/cala/nodes/omf/__init__.py | 0 src/cala/nodes/{ => omf}/component_stats.py | 0 src/cala/nodes/{ => omf}/footprints.py | 0 src/cala/nodes/{ => omf}/overlap.py | 0 src/cala/nodes/{ => omf}/pixel_stats.py | 16 +- src/cala/nodes/{ => omf}/residual.py | 0 src/cala/nodes/{ => omf}/traces.py | 0 src/cala/nodes/segment/__init__.py | 2 +- src/cala/nodes/segment/cleanup.py | 72 ++++++++ src/cala/nodes/{ => segment}/merge.py | 0 .../nodes/segment/{update.py => persist.py} | 10 +- tests/data/pipelines/long_recording.yaml | 12 +- tests/data/pipelines/minian.yaml | 12 +- tests/data/pipelines/odl.yaml | 12 +- tests/test_iter/test_cleanup.py | 33 +--- tests/test_iter/test_component_stats.py | 8 +- tests/test_iter/test_footprints.py | 8 +- tests/test_iter/test_overlaps.py | 4 +- tests/test_iter/test_pixel_stats.py | 8 +- tests/test_iter/test_residual.py | 4 +- tests/test_iter/test_traces.py | 6 +- 23 files changed, 133 insertions(+), 262 deletions(-) delete mode 100644 src/cala/nodes/buffer.py delete mode 100644 src/cala/nodes/cleanup.py create mode 100644 src/cala/nodes/omf/__init__.py rename src/cala/nodes/{ => omf}/component_stats.py (100%) rename src/cala/nodes/{ => omf}/footprints.py (100%) rename src/cala/nodes/{ => omf}/overlap.py (100%) rename src/cala/nodes/{ => omf}/pixel_stats.py (92%) rename src/cala/nodes/{ => omf}/residual.py (100%) rename src/cala/nodes/{ => omf}/traces.py (100%) create mode 100644 src/cala/nodes/segment/cleanup.py rename src/cala/nodes/{ => segment}/merge.py (100%) rename src/cala/nodes/segment/{update.py => persist.py} (79%) diff --git a/src/cala/nodes/buffer.py b/src/cala/nodes/buffer.py deleted file mode 100644 index 8bb9a9d5..00000000 --- a/src/cala/nodes/buffer.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Annotated as A - -from noob import Name - -from cala.assets import Buffer, Frame -from cala.models import AXIS - - -def fill_buffer(buffer: Buffer, frame: Frame) -> A[Buffer, Name("buffer")]: - if buffer.array is None: - buffer.array = frame.array.volumize.dim_with_coords( - dim=AXIS.frames_dim, coords=[AXIS.timestamp_coord] - ) - return buffer - - buffer.append(frame.array) - return buffer diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py deleted file mode 100644 index fe52da99..00000000 --- a/src/cala/nodes/cleanup.py +++ /dev/null @@ -1,171 +0,0 @@ -from typing import Annotated as A - -import cv2 -import numpy as np -import xarray as xr -from noob import Name - -from cala.assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces -from cala.models import AXIS - - -def clear_overestimates( - footprints: Footprints, residuals: Buffer, nmf_error: float -) -> A[Footprints, Name("footprints")]: - """ - Remove all sections of the footprints that cause negative residuals. - - This occurs by: - 1. find "significant" negative residual spots that is more than a noise level, and thus - cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?) - 2. all footprint values at these spots go to zero. - """ - if residuals.array is None: - return footprints - R_min = residuals.array.isel({AXIS.frames_dim: -1}).reset_coords( - [AXIS.frame_coord, AXIS.timestamp_coord], drop=True - ) - tuned_fp = footprints.array.where(R_min > -nmf_error, 0, drop=False) - - return tuned_fp - - -def purge_razed_components( - footprints: Footprints, - traces: Traces, - pix_stats: PixStats, - comp_stats: CompStats, - overlaps: Overlaps, - min_thicc: int, - trigger: bool, -) -> tuple[ - A[Footprints, Name("footprints")], - A[Traces, Name("traces")], - A[PixStats, Name("pix_stats")], - A[CompStats, Name("comp_stats")], - A[Overlaps, Name("overlaps")], -]: - keep_ids = _filter_razed_ids(footprints=footprints, min_thicc=min_thicc) - return _filter_components( - footprints=footprints, - traces=traces, - pix_stats=pix_stats, - comp_stats=comp_stats, - overlaps=overlaps, - keep_ids=keep_ids, - ) - - -def _filter_razed_ids(footprints: Footprints, min_thicc: int) -> A[list[str], Name("keep_ids")]: - """ - :param min_thicc: minimum number of pixel thickness to keep the cell - :return: - """ - A = footprints.array - - if A is None: - return [] - - kernel = np.ones((min_thicc, min_thicc), np.uint8) - - eroded = xr.apply_ufunc( - cv2.erode, - (A > 0).as_numpy().astype(np.uint8), - kwargs={"kernel": kernel}, - vectorize=True, - input_core_dims=[AXIS.spatial_dims], - output_core_dims=[AXIS.spatial_dims], - ) - - keep_idx = np.where(eroded.sum(dim=AXIS.spatial_dims).values.tolist())[0] - return A.isel({AXIS.component_dim: keep_idx})[AXIS.id_coord].values.tolist() - - -def _filter_components( - footprints: Footprints, - traces: Traces, - pix_stats: PixStats, - comp_stats: CompStats, - overlaps: Overlaps, - keep_ids: list[str], -) -> tuple[ - A[Footprints, Name("footprints")], - A[Traces, Name("traces")], - A[PixStats, Name("pix_stats")], - A[CompStats, Name("comp_stats")], - A[Overlaps, Name("overlaps")], -]: - if len(keep_ids) == 0 or footprints.array is None: - footprints.reset() - traces.reset() - pix_stats.reset() - comp_stats.reset() - overlaps.reset() - - elif footprints.array[AXIS.id_coord].values.tolist() != keep_ids: - footprints.array = ( - footprints.array.set_xindex(AXIS.id_coord) - .sel({AXIS.id_coord: keep_ids}) - .reset_index(AXIS.id_coord) - ) - traces.array = ( - traces.array.set_xindex(AXIS.id_coord) - .sel({AXIS.id_coord: keep_ids}) - .reset_index(AXIS.id_coord) - ) - pix_stats.array = ( - pix_stats.array.set_xindex(AXIS.id_coord) - .sel({AXIS.id_coord: keep_ids}) - .reset_index(AXIS.id_coord) - ) - comp_stats.array = ( - comp_stats.array.set_xindex(AXIS.id_coord) - .set_xindex(f"{AXIS.id_coord}'") - .sel({AXIS.id_coord: keep_ids, f"{AXIS.id_coord}'": keep_ids}) - .reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"]) - ) - overlaps.array = ( - overlaps.array.set_xindex(AXIS.id_coord) - .set_xindex(f"{AXIS.id_coord}'") - .sel({AXIS.id_coord: keep_ids, f"{AXIS.id_coord}'": keep_ids}) - .reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"]) - ) - - return footprints, traces, pix_stats, comp_stats, overlaps - - -def _filter_redundant( - footprints: Footprints, - traces: Traces, - min_life_in_frames: int, - quantile: float = 0.8, - rel_threshold: float = 0.9, - abs_threshold: float = 1.0, -) -> list[str]: - """ - Remove redundant components - Tested with SplitOffSource - - 1. should have been some time since discovery - 2. max of residual over the last however many frames is very similar to y / trending up % wise - - :param quantile: the higher, the more stringent - """ - A = footprints.array.as_numpy() - c_t = traces.array.isel({AXIS.frames_dim: -1}) - y_t = A @ c_t - # not sure whether to use y_t or reconstructed. recon probably makes more sense. - # y_t = frame.array - - keep_ids = [] - for a, c in zip(A.transpose(AXIS.component_dim, ...), c_t.transpose(AXIS.component_dim, ...)): - if y_t[AXIS.frame_coord] - a[AXIS.detect_coord] < min_life_in_frames: - keep_ids.append(a[AXIS.id_coord].item()) - - ratio = (a @ c / y_t).where(a, np.nan, drop=True).quantile(1 - quantile) - diff = np.abs((a @ c - y_t).where(a, np.nan, drop=True)).quantile(quantile) - - if ratio > rel_threshold and diff < abs_threshold: - keep_ids.append(a[AXIS.id_coord].item()) - - return keep_ids diff --git a/src/cala/nodes/omf/__init__.py b/src/cala/nodes/omf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cala/nodes/component_stats.py b/src/cala/nodes/omf/component_stats.py similarity index 100% rename from src/cala/nodes/component_stats.py rename to src/cala/nodes/omf/component_stats.py diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/omf/footprints.py similarity index 100% rename from src/cala/nodes/footprints.py rename to src/cala/nodes/omf/footprints.py diff --git a/src/cala/nodes/overlap.py b/src/cala/nodes/omf/overlap.py similarity index 100% rename from src/cala/nodes/overlap.py rename to src/cala/nodes/omf/overlap.py diff --git a/src/cala/nodes/pixel_stats.py b/src/cala/nodes/omf/pixel_stats.py similarity index 92% rename from src/cala/nodes/pixel_stats.py rename to src/cala/nodes/omf/pixel_stats.py index ad4b333d..df70cdeb 100644 --- a/src/cala/nodes/pixel_stats.py +++ b/src/cala/nodes/omf/pixel_stats.py @@ -1,8 +1,11 @@ +from typing import Annotated as A + import numpy as np import xarray as xr +from noob import Name from scipy.sparse import csr_matrix -from cala.assets import Footprints, Frame, Movie, PixStats, PopSnap, Traces +from cala.assets import Buffer, Footprints, Frame, Movie, PixStats, PopSnap, Traces from cala.models import AXIS @@ -118,6 +121,17 @@ def ingest_component( return pixel_stats +def fill_buffer(buffer: Buffer, frame: Frame) -> A[Buffer, Name("buffer")]: + if buffer.array is None: + buffer.array = frame.array.volumize.dim_with_coords( + dim=AXIS.frames_dim, coords=[AXIS.timestamp_coord] + ) + return buffer + + buffer.append(frame.array) + return buffer + + def initialize( traces: xr.DataArray, frames: xr.DataArray, footprints: xr.DataArray ) -> xr.DataArray: diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/omf/residual.py similarity index 100% rename from src/cala/nodes/residual.py rename to src/cala/nodes/omf/residual.py diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/omf/traces.py similarity index 100% rename from src/cala/nodes/traces.py rename to src/cala/nodes/omf/traces.py diff --git a/src/cala/nodes/segment/__init__.py b/src/cala/nodes/segment/__init__.py index cbb17884..b91a89be 100644 --- a/src/cala/nodes/segment/__init__.py +++ b/src/cala/nodes/segment/__init__.py @@ -1,5 +1,5 @@ from .catalog import Cataloger +from .persist import update_assets from .slice_nmf import SliceNMF -from .update import update_assets __all__ = ["SliceNMF", "Cataloger", "update_assets"] diff --git a/src/cala/nodes/segment/cleanup.py b/src/cala/nodes/segment/cleanup.py new file mode 100644 index 00000000..d1a99fcc --- /dev/null +++ b/src/cala/nodes/segment/cleanup.py @@ -0,0 +1,72 @@ +from typing import Annotated as A + +import numpy as np +from noob import Name + +from cala.assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces +from cala.models import AXIS + + +def clear_overestimates( + footprints: Footprints, residuals: Buffer, nmf_error: float +) -> A[Footprints, Name("footprints")]: + """ + Remove all sections of the footprints that cause negative residuals. + + This occurs by: + 1. find "significant" negative residual spots that is more than a noise level, and thus + cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?) + 2. all footprint values at these spots go to zero. + """ + if residuals.array is None: + return footprints + R_min = residuals.array.isel({AXIS.frames_dim: -1}).reset_coords( + [AXIS.frame_coord, AXIS.timestamp_coord], drop=True + ) + tuned_fp = footprints.array.where(R_min > -nmf_error, 0, drop=False) + + return tuned_fp + + +def deprecate_components( + footprints: Footprints, + traces: Traces, + pix_stats: PixStats, + comp_stats: CompStats, + overlaps: Overlaps, + remove_ids: list[str], +) -> tuple[ + A[Footprints, Name("footprints")], + A[Traces, Name("traces")], + A[PixStats, Name("pix_stats")], + A[CompStats, Name("comp_stats")], + A[Overlaps, Name("overlaps")], +]: + """ + Deprecate a list of components from all assets involved in omf. + + """ + keep_mask = ~np.isin(traces.array[AXIS.id_coord].values, remove_ids) + + traces.keep(keep_mask) + # the line below compiles numba. gotta do it like in footprints.ingest_component + # but then i need to redundantly convert COO -> csr -> COO -> csr -> COO + footprints.array = footprints.array[keep_mask] + pix_stats.array = pix_stats.array[keep_mask] + comp_stats.array = comp_stats.array[keep_mask].T[keep_mask] + overlaps.array = overlaps.array[keep_mask].T[keep_mask] + + return footprints, traces, pix_stats, comp_stats, overlaps + + +def find_inactive() -> list[str]: + """ + Deprecate inactive components + Component is deemed inactive if its own brightness contribution across + all of its footprint is below threshold. + + 1. has been some time since discovery + 2. within its own footprint, its brightness contribution is lower than + some % of the minimum of the total brightness contributions from all components? + - but what if the component is completely occluded sometimes? + """ diff --git a/src/cala/nodes/merge.py b/src/cala/nodes/segment/merge.py similarity index 100% rename from src/cala/nodes/merge.py rename to src/cala/nodes/segment/merge.py diff --git a/src/cala/nodes/segment/update.py b/src/cala/nodes/segment/persist.py similarity index 79% rename from src/cala/nodes/segment/update.py rename to src/cala/nodes/segment/persist.py index 98973a59..7a707057 100644 --- a/src/cala/nodes/segment/update.py +++ b/src/cala/nodes/segment/persist.py @@ -3,11 +3,11 @@ from noob import Name from cala.assets import CompStats, Footprints, Movie, Overlaps, PixStats, Traces -from cala.nodes.component_stats import ingest_component as update_component_stats -from cala.nodes.footprints import ingest_component as update_footprints -from cala.nodes.overlap import ingest_component as update_overlap -from cala.nodes.pixel_stats import ingest_component as update_pixel_stats -from cala.nodes.traces import ingest_component as update_traces +from cala.nodes.omf.component_stats import ingest_component as update_component_stats +from cala.nodes.omf.footprints import ingest_component as update_footprints +from cala.nodes.omf.overlap import ingest_component as update_overlap +from cala.nodes.omf.pixel_stats import ingest_component as update_pixel_stats +from cala.nodes.omf.traces import ingest_component as update_traces def update_assets( diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index fd7ffe66..3418ef82 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -101,7 +101,7 @@ nodes: depends: - frame: downsample.frame cache: - type: cala.nodes.buffer.fill_buffer + type: cala.nodes.omf.pixel_stats.fill_buffer depends: - buffer: assets.buffer - frame: downsample.frame @@ -109,7 +109,7 @@ nodes: # FRAME UPDATE BEGINS trace_frame: - type: cala.nodes.traces.Tracer + type: cala.nodes.omf.traces.Tracer params: tol: 0.001 max_iter: 100 @@ -119,20 +119,20 @@ nodes: - frame: downsample.frame - overlaps: assets.overlaps pix_frame: - type: cala.nodes.pixel_stats.ingest_frame + type: cala.nodes.omf.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - frame: downsample.frame - new_traces: trace_frame.latest_trace - footprints: assets.footprints comp_frame: - type: cala.nodes.component_stats.ingest_frame + type: cala.nodes.omf.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - frame: downsample.frame - new_traces: trace_frame.latest_trace footprints_frame: - type: cala.nodes.footprints.Footprinter + type: cala.nodes.omf.footprints.Footprinter params: bep: 0 tol: 0.0001 @@ -144,7 +144,7 @@ nodes: - component_stats: comp_frame.value residual: - type: cala.nodes.residual.Residuer + type: cala.nodes.omf.residual.Residuer depends: - frame: downsample.frame - footprints: footprints_frame.footprints diff --git a/tests/data/pipelines/minian.yaml b/tests/data/pipelines/minian.yaml index 48d1a779..09aa0602 100644 --- a/tests/data/pipelines/minian.yaml +++ b/tests/data/pipelines/minian.yaml @@ -104,7 +104,7 @@ nodes: depends: - frame: denoise.frame cache: - type: cala.nodes.buffer.fill_buffer + type: cala.nodes.omf.pixel_stats.fill_buffer depends: - buffer: assets.buffer - frame: glow.frame @@ -112,7 +112,7 @@ nodes: # FRAME UPDATE BEGINS trace_frame: - type: cala.nodes.traces.Tracer + type: cala.nodes.omf.traces.Tracer params: tol: 0.001 max_iter: 100 @@ -122,20 +122,20 @@ nodes: - frame: glow.frame - overlaps: assets.overlaps pix_frame: - type: cala.nodes.pixel_stats.ingest_frame + type: cala.nodes.omf.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - frame: glow.frame - new_traces: trace_frame.latest_trace - footprints: assets.footprints comp_frame: - type: cala.nodes.component_stats.ingest_frame + type: cala.nodes.omf.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - frame: glow.frame - new_traces: trace_frame.latest_trace footprints_frame: - type: cala.nodes.footprints.Footprinter + type: cala.nodes.omf.footprints.Footprinter params: bep: 0 tol: 0.0001 @@ -147,7 +147,7 @@ nodes: - component_stats: comp_frame.value residual: - type: cala.nodes.residual.Residuer + type: cala.nodes.omf.residual.Residuer depends: - frame: glow.frame - footprints: footprints_frame.footprints diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 79308e3b..3362825c 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -58,7 +58,7 @@ nodes: depends: - frame: glow.frame cache: - type: cala.nodes.buffer.fill_buffer + type: cala.nodes.omf.pixel_stats.fill_buffer depends: - buffer: assets.buffer - frame: glow.frame @@ -66,7 +66,7 @@ nodes: # FRAME UPDATE BEGINS trace_frame: - type: cala.nodes.traces.Tracer + type: cala.nodes.omf.traces.Tracer params: tol: 0.001 max_iter: 100 @@ -76,20 +76,20 @@ nodes: - frame: glow.frame - overlaps: assets.overlaps pix_frame: - type: cala.nodes.pixel_stats.ingest_frame + type: cala.nodes.omf.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - frame: glow.frame - new_traces: trace_frame.latest_trace - footprints: assets.footprints comp_frame: - type: cala.nodes.component_stats.ingest_frame + type: cala.nodes.omf.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - frame: glow.frame - new_traces: trace_frame.latest_trace footprints_frame: - type: cala.nodes.footprints.Footprinter + type: cala.nodes.omf.footprints.Footprinter params: bep: 0 tol: 0.0001 @@ -101,7 +101,7 @@ nodes: - component_stats: comp_frame.value residual: - type: cala.nodes.residual.Residuer + type: cala.nodes.omf.residual.Residuer depends: - frame: glow.frame - footprints: footprints_frame.footprints diff --git a/tests/test_iter/test_cleanup.py b/tests/test_iter/test_cleanup.py index 2f64e1d0..ed01bda1 100644 --- a/tests/test_iter/test_cleanup.py +++ b/tests/test_iter/test_cleanup.py @@ -1,8 +1,6 @@ -import xarray as xr - from cala.assets import Buffer from cala.models import AXIS -from cala.nodes.cleanup import _filter_redundant, clear_overestimates +from cala.nodes.segment.cleanup import clear_overestimates def test_clear_overestimates(single_cell) -> None: @@ -16,32 +14,3 @@ def test_clear_overestimates(single_cell) -> None: expected.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] = 0 assert result.equals(expected) - - -def test_erase_redundant(splitoff_cells) -> None: - footprints = splitoff_cells.footprints - dead_footprint = xr.DataArray( - footprints.array.sum(dim=AXIS.component_dim) / 100, - dims=footprints.array.isel({AXIS.component_dim: 1}).dims, - coords=footprints.array.isel({AXIS.component_dim: 1}).coords, - ).assign_coords({AXIS.id_coord: "cell_2", AXIS.detect_coord: 77}) - footprints.array = xr.concat([footprints.array, dead_footprint], dim=AXIS.component_dim) - - traces = splitoff_cells.traces - dead_trace = xr.DataArray( - [0.1] * traces.array.sizes[AXIS.frames_dim], - dims=traces.array.isel({AXIS.component_dim: 1}).dims, - coords=traces.array.isel({AXIS.component_dim: 1}).coords, - ).assign_coords({AXIS.id_coord: "cell_2", AXIS.detect_coord: 77}) - traces.array = xr.concat([traces.array, dead_trace], dim=AXIS.component_dim) - - result = _filter_redundant( - footprints=footprints, traces=traces, min_life_in_frames=10, quantile=0.9 - ) - - expected = splitoff_cells.footprints.array[AXIS.id_coord].values.tolist() - - assert set(result) == set(expected) - - -def test_merge_components(splitoff_cells) -> None: ... diff --git a/tests/test_iter/test_component_stats.py b/tests/test_iter/test_component_stats.py index 005458cc..d32b7aab 100644 --- a/tests/test_iter/test_component_stats.py +++ b/tests/test_iter/test_component_stats.py @@ -10,7 +10,7 @@ @pytest.fixture def init() -> Node: return Node.from_specification( - spec=NodeSpecification(id="cs_init_test", type="cala.nodes.component_stats.initialize") + spec=NodeSpecification(id="cs_init_test", type="cala.nodes.omf.component_stats.initialize") ) @@ -42,7 +42,9 @@ def test_init(init, four_separate_cells) -> None: @pytest.fixture def frame_update() -> Node: return Node.from_specification( - spec=NodeSpecification(id="cs_frame_test", type="cala.nodes.component_stats.ingest_frame") + spec=NodeSpecification( + id="cs_frame_test", type="cala.nodes.omf.component_stats.ingest_frame" + ) ) @@ -65,7 +67,7 @@ def test_ingest_frame(init, frame_update, four_separate_cells) -> None: def comp_update() -> Node: return Node.from_specification( spec=NodeSpecification( - id="cs_comp_test", type="cala.nodes.component_stats.ingest_component" + id="cs_comp_test", type="cala.nodes.omf.component_stats.ingest_component" ) ) diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index 1258279d..6439563b 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -5,7 +5,7 @@ from cala.assets import CompStats, Footprints, PixStats from cala.models import AXIS -from cala.nodes.footprints import ingest_component +from cala.nodes.omf.footprints import ingest_component from cala.testing.toy import FrameDims, Position, Toy @@ -60,7 +60,7 @@ def fpter() -> Node: return Node.from_specification( NodeSpecification( id="test_footprinter", - type="cala.nodes.footprints.Footprinter", + type="cala.nodes.omf.footprints.Footprinter", params={"bep": 0, "tol": 1e-7}, ) ) @@ -71,12 +71,12 @@ def test_ingest_frame(fpter, toy, request): toy = request.getfixturevalue(toy) pixstats = Node.from_specification( - NodeSpecification(id="test_pixstats", type="cala.nodes.pixel_stats.initialize") + NodeSpecification(id="test_pixstats", type="cala.nodes.omf.pixel_stats.initialize") ).process( traces=toy.traces.array, frames=toy.make_movie().array, footprints=toy.footprints.array ) compstats = Node.from_specification( - NodeSpecification(id="test_compstats", type="cala.nodes.component_stats.initialize") + NodeSpecification(id="test_compstats", type="cala.nodes.omf.component_stats.initialize") ).process(traces=toy.traces.array) result = fpter.process( diff --git a/tests/test_iter/test_overlaps.py b/tests/test_iter/test_overlaps.py index 84aa1ada..9662030e 100644 --- a/tests/test_iter/test_overlaps.py +++ b/tests/test_iter/test_overlaps.py @@ -9,7 +9,7 @@ @pytest.fixture(scope="function") def init() -> Node: return Node.from_specification( - spec=NodeSpecification(id="ov_init_test", type="cala.nodes.overlap.initialize") + spec=NodeSpecification(id="ov_init_test", type="cala.nodes.omf.overlap.initialize") ) @@ -29,7 +29,7 @@ def test_init(init, four_separate_cells, four_connected_cells) -> None: @pytest.fixture(scope="function") def comp_update() -> Node: return Node.from_specification( - spec=NodeSpecification(id="ov_init_test", type="cala.nodes.overlap.ingest_component") + spec=NodeSpecification(id="ov_init_test", type="cala.nodes.omf.overlap.ingest_component") ) diff --git a/tests/test_iter/test_pixel_stats.py b/tests/test_iter/test_pixel_stats.py index aab20fa7..1f84621b 100644 --- a/tests/test_iter/test_pixel_stats.py +++ b/tests/test_iter/test_pixel_stats.py @@ -9,7 +9,7 @@ @pytest.fixture(scope="function") def init() -> Node: return Node.from_specification( - spec=NodeSpecification(id="ps_init_test", type="cala.nodes.pixel_stats.initialize") + spec=NodeSpecification(id="ps_init_test", type="cala.nodes.omf.pixel_stats.initialize") ) @@ -39,7 +39,7 @@ def test_init(init, four_separate_cells) -> None: @pytest.fixture(scope="function") def frame_update() -> Node: return Node.from_specification( - spec=NodeSpecification(id="ps_frame_test", type="cala.nodes.pixel_stats.ingest_frame") + spec=NodeSpecification(id="ps_frame_test", type="cala.nodes.omf.pixel_stats.ingest_frame") ) @@ -68,7 +68,9 @@ def test_ingest_frame(init, frame_update, four_separate_cells) -> None: @pytest.fixture(scope="function") def comp_update() -> Node: return Node.from_specification( - spec=NodeSpecification(id="ps_comp_test", type="cala.nodes.pixel_stats.ingest_component") + spec=NodeSpecification( + id="ps_comp_test", type="cala.nodes.omf.pixel_stats.ingest_component" + ) ) diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 4ca1e06c..b2f2bf0c 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -5,7 +5,7 @@ from cala.assets import Buffer, Footprints, Frame, Traces from cala.models.axis import AXIS -from cala.nodes.residual import _align_overestimates, _find_unlayered_footprints +from cala.nodes.omf.residual import _align_overestimates, _find_unlayered_footprints from cala.testing.toy import FrameDims, Position, Toy @@ -38,7 +38,7 @@ def init() -> Node: return Node.from_specification( spec=NodeSpecification( id="res_init_test", - type="cala.nodes.residual.Residuer", + type="cala.nodes.omf.residual.Residuer", ) ) diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index a84c1aa5..d141c644 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -12,7 +12,7 @@ def frame_update() -> Node: return Node.from_specification( spec=NodeSpecification( id="frame_test", - type="cala.nodes.traces.Tracer", + type="cala.nodes.omf.traces.Tracer", params={"max_iter": 100, "tol": 1e-4}, ) ) @@ -36,7 +36,7 @@ def test_ingest_frame(frame_update, toy, zarr_setup, request, tmp_path) -> None: zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None xray = Node.from_specification( - spec=NodeSpecification(id="test", type="cala.nodes.overlap.initialize") + spec=NodeSpecification(id="test", type="cala.nodes.omf.overlap.initialize") ) traces = Traces(array_=None, **zarr_setup) @@ -56,7 +56,7 @@ def test_ingest_frame(frame_update, toy, zarr_setup, request, tmp_path) -> None: @pytest.fixture def comp_update() -> Node: return Node.from_specification( - NodeSpecification(id="comp_test", type="cala.nodes.traces.ingest_component") + NodeSpecification(id="comp_test", type="cala.nodes.omf.traces.ingest_component") ) From f7b301f5430230fada9a3ea3f50339e9fc7dcc0d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 25 Nov 2025 19:07:06 -0800 Subject: [PATCH 25/33] chore: cleanup unused codes --- src/cala/nodes/segment/cleanup.py | 42 ++--- src/cala/nodes/segment/merge.py | 153 +--------------- src/cala/testing/catalog_depr.py | 293 ------------------------------ 3 files changed, 29 insertions(+), 459 deletions(-) diff --git a/src/cala/nodes/segment/cleanup.py b/src/cala/nodes/segment/cleanup.py index d1a99fcc..81cb5643 100644 --- a/src/cala/nodes/segment/cleanup.py +++ b/src/cala/nodes/segment/cleanup.py @@ -7,27 +7,6 @@ from cala.models import AXIS -def clear_overestimates( - footprints: Footprints, residuals: Buffer, nmf_error: float -) -> A[Footprints, Name("footprints")]: - """ - Remove all sections of the footprints that cause negative residuals. - - This occurs by: - 1. find "significant" negative residual spots that is more than a noise level, and thus - cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?) - 2. all footprint values at these spots go to zero. - """ - if residuals.array is None: - return footprints - R_min = residuals.array.isel({AXIS.frames_dim: -1}).reset_coords( - [AXIS.frame_coord, AXIS.timestamp_coord], drop=True - ) - tuned_fp = footprints.array.where(R_min > -nmf_error, 0, drop=False) - - return tuned_fp - - def deprecate_components( footprints: Footprints, traces: Traces, @@ -70,3 +49,24 @@ def find_inactive() -> list[str]: some % of the minimum of the total brightness contributions from all components? - but what if the component is completely occluded sometimes? """ + + +def clear_overestimates( + footprints: Footprints, residuals: Buffer, nmf_error: float +) -> A[Footprints, Name("footprints")]: + """ + Remove all sections of the footprints that cause negative residuals. + + This occurs by: + 1. find "significant" negative residual spots that is more than a noise level, and thus + cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?) + 2. all footprint values at these spots go to zero. + """ + if residuals.array is None: + return footprints + R_min = residuals.array.isel({AXIS.frames_dim: -1}).reset_coords( + [AXIS.frame_coord, AXIS.timestamp_coord], drop=True + ) + tuned_fp = footprints.array.where(R_min > -nmf_error, 0, drop=False) + + return tuned_fp diff --git a/src/cala/nodes/segment/merge.py b/src/cala/nodes/segment/merge.py index 737e0f25..c3263ffc 100644 --- a/src/cala/nodes/segment/merge.py +++ b/src/cala/nodes/segment/merge.py @@ -1,147 +1,10 @@ -from typing import Annotated as A +""" +Merge components that have already been registered with each other, +if they spatially overlap and temporally correlate significantly. -import numpy as np -import xarray as xr -from noob import Name -from scipy.ndimage import gaussian_filter1d -from scipy.sparse.csgraph import connected_components +This step is complementary to the catalog node, for the cases in which +two components should have been merged, but the buffered data's SNR was +temporarily too low to build a reliable merge matrix, causing the two +components to remain separated. -from cala.assets import Footprints, Overlaps, Traces -from cala.models import AXIS -from cala.nodes.segment.catalog import _recompose, _register -from cala.util import combine_attr_replaces - - -def merge_existing( - shapes: Footprints, - traces: Traces, - overlaps: Overlaps, - merge_interval: int, - merge_threshold: float, - smooth_kwargs: dict, - trigger: bool = None, -) -> tuple[A[Footprints, Name("footprints")], A[Traces, Name("traces")]]: - """ - merge old ass components with each other, if they missed their - natal window chance to get merged with their siamese twin - """ - if overlaps.array is None: - return Footprints(), Traces() - - idx = traces.array[AXIS.frame_coord].max().item() - - if idx % merge_interval != 0: - return Footprints(), Traces() - - # only merge old components - targets = traces.array[AXIS.detect_coord] <= ( - traces.array[AXIS.frame_coord].max() - merge_interval - ) - - if not any(targets): - return Footprints(), Traces() - - target_ids = targets.where(targets, drop=True)[AXIS.id_coord].values - - target_fps, target_trs, target_ovs = _filter_targets( - target_ids=target_ids, - shapes=shapes, - traces=traces, - overlaps=overlaps, - n_frames=merge_interval, - ) - - merge_mat = _merge_matrix( - traces=target_trs, - overlaps=target_ovs, - smooth_kwargs=smooth_kwargs, - threshold=merge_threshold, - ) - - num, label = connected_components(merge_mat.data) - combined_fps = [] - combined_trs = [] - - for lbl in set(label): - group = np.where(label == lbl)[0] - if len(group) <= 1: - continue - fps = target_fps.isel({AXIS.component_dim: group}) - trs = target_trs.isel({AXIS.component_dim: group}) - res = xr.DataArray( - np.matmul(fps.transpose(*AXIS.spatial_dims, ...).data, trs.data), - dims=[*AXIS.spatial_dims, AXIS.frames_dim], - ) - a_new, c_new = _recompose(res, target_fps[0].coords, target_trs[0].coords) - - a_new.attrs["replaces"] = fps[AXIS.id_coord].values.tolist() - c_new.attrs["replaces"] = trs[AXIS.id_coord].values.tolist() - - combined_fps.append(a_new) - combined_trs.append(c_new) - - if len(combined_fps) == 0: - return Footprints(), Traces() - - new_fps = xr.concat( - combined_fps, - dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.detect_coord], - combine_attrs=combine_attr_replaces, - ) - new_trs = xr.concat( - combined_trs, - dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.detect_coord], - combine_attrs=combine_attr_replaces, - ) - new_fps, new_trs = _register(new_fps, new_trs) - - return Footprints.from_array(new_fps), Traces.from_array(new_trs) - - -def _filter_targets( - target_ids: np.ndarray, shapes: Footprints, traces: Traces, overlaps: Overlaps, n_frames: int -) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray]: - """ - filter assets for old enough components - """ - target_ovs = ( - overlaps.array.set_xindex([AXIS.id_coord]) - .set_xindex(f"{AXIS.id_coord}'") - .sel({AXIS.id_coord: target_ids, f"{AXIS.id_coord}'": target_ids}) - .reset_index(AXIS.id_coord) - .reset_index(f"{AXIS.id_coord}'") - ) - - target_trs = ( - traces.full_array(isel_filter={AXIS.frames_dim: slice(-n_frames, None)}) - .set_xindex(AXIS.id_coord) - .sel({AXIS.id_coord: target_ids}) - .reset_index(AXIS.id_coord) - ).transpose(AXIS.component_dim, ...) - - target_fps = ( - shapes.array.set_xindex(AXIS.id_coord) - .sel({AXIS.id_coord: target_ids}) - .reset_index(AXIS.id_coord) - ) - - return target_fps, target_trs, target_ovs - - -def _merge_matrix( - traces: xr.DataArray, overlaps: xr.DataArray, smooth_kwargs: dict, threshold: float -) -> xr.DataArray: - """ - get merge matrix to use for connected_components - """ - traces = xr.DataArray( - gaussian_filter1d(traces.transpose(AXIS.component_dim, ...), **smooth_kwargs), - dims=traces.dims, - coords=traces.coords, - ) - traces_base = traces.rename(AXIS.component_rename) - - corr = xr.corr(traces, traces_base, dim=AXIS.frames_dim) - return overlaps * corr > threshold +""" diff --git a/src/cala/testing/catalog_depr.py b/src/cala/testing/catalog_depr.py index 83ac35d3..ff6ebee4 100644 --- a/src/cala/testing/catalog_depr.py +++ b/src/cala/testing/catalog_depr.py @@ -3,23 +3,13 @@ """ -from collections.abc import Hashable, Iterable -from itertools import compress -from typing import Annotated as A - import numpy as np import xarray as xr -from noob import Name from noob.node import Node from pydantic import Field from scipy.ndimage import gaussian_filter1d -from scipy.sparse.csgraph import connected_components -from skimage.measure import label -from xarray import Coordinates -from cala.assets import Footprint, Footprints, Trace, Traces from cala.models import AXIS -from cala.util import combine_attr_replaces, create_id, rank1nmf class CatalogerDepr(Node): @@ -31,31 +21,6 @@ class CatalogerDepr(Node): cnt_threshold: int = Field(gt=0) """must have cnt-number of pixels that are above the val-value""" - def process( - self, - new_fps: list[Footprint], - new_trs: list[Trace], - existing_fp: Footprints | None = None, - existing_tr: Traces | None = None, - ) -> tuple[A[Footprints, Name("new_footprints")], A[Traces, Name("new_traces")]]: - - if not new_fps or not new_trs: - return Footprints(), Traces() - - new_fps = xr.concat([fp.array for fp in new_fps], dim=AXIS.component_dim) - new_trs = xr.concat([tr.array for tr in new_trs], dim=AXIS.component_dim) - merge_mat = self._merge_matrix(new_fps, new_trs) - new_fps, new_trs = _merge(new_fps, new_trs, merge_mat) - - known_fp, known_tr = _get_absorption_targets(existing_fp, existing_tr, self.age_limit) - merge_mat = self._merge_matrix(new_fps, new_trs, known_fp, known_tr) - footprints, traces = self._absorb(new_fps, new_trs, known_fp, known_tr, merge_mat) - # footprints = self._smooth(shapes) - - return Footprints.from_array(footprints), Traces.from_array(traces) - - def _smooth(self, shapes: xr.DataArray) -> xr.DataArray: ... - def _merge_matrix( self, fps: xr.DataArray, @@ -88,261 +53,3 @@ def _merge_matrix( # corr is fast. (~1ms to 4ms) corrs = xr.corr(trs, trs_base, dim=AXIS.frames_dim) > self.merge_threshold return xr.DataArray(overlaps * corrs.values, dims=corrs.dims, coords=corrs.coords) - - def _absorb( - self, - new_fps: xr.DataArray, - new_trs: xr.DataArray, - known_fps: xr.DataArray, - known_trs: xr.DataArray, - merge_matrix: xr.DataArray, - ) -> tuple[xr.DataArray | None, xr.DataArray | None]: - footprints = [] - traces = [] - - merge_matrix.data = label(merge_matrix.to_numpy(), background=0, connectivity=1) - merge_matrix = merge_matrix.assign_coords( - {AXIS.component_dim: range(merge_matrix.sizes[AXIS.component_dim])} - ).reset_index(AXIS.component_dim) - indep_idxs = ( - merge_matrix.where(merge_matrix.sum(f"{AXIS.component_dim}'") == 0, drop=True)[ - AXIS.component_dim - ].values - if known_fps is not None - else np.array(range(len(merge_matrix))) - ) - if indep_idxs.size > 0: - fps, trs = _register_batch( - new_fps=new_fps.isel({AXIS.component_dim: indep_idxs}), - new_trs=new_trs.isel({AXIS.component_dim: indep_idxs}), - ) - footprints.append(fps) - traces.append(trs) - - num = merge_matrix.max().item() - if num > 0 and known_fps is not None: - for lbl in range(1, num + 1): - new_idxs, _known_idxs = np.where(merge_matrix == lbl) - known_ids = merge_matrix.where(merge_matrix == lbl, drop=True)[ - f"{AXIS.id_coord}'" - ].values - fp = new_fps.sel({AXIS.component_dim: list(set(new_idxs))}) - tr = new_trs.sel({AXIS.component_dim: list(set(new_idxs))}) - footprint, trace = _merge_with(fp, tr, known_fps, known_trs, known_ids) - - footprints.append(footprint) - traces.append(trace) - - mask = [np.sum(fp.data > self.val_threshold) > self.cnt_threshold for fp in footprints] - footprints = list(compress(footprints, mask)) - traces = list(compress(traces, mask)) - - if not footprints: - return None, None - - footprints = xr.concat( - footprints, - dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.detect_coord], - combine_attrs=combine_attr_replaces, - ) - traces = xr.concat( - traces, - dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.detect_coord], - combine_attrs=combine_attr_replaces, - ) - - return footprints, traces - - -def _get_absorption_targets( - existing_fp: Footprints, existing_tr: Traces, age_limit: int -) -> tuple[xr.DataArray, xr.DataArray]: - if existing_fp.array is not None: - targets = existing_tr.array[AXIS.detect_coord] > ( - existing_tr.array[AXIS.frame_coord].max() - age_limit - ) - known_fp = existing_fp.array.where(targets, drop=True) - known_tr = existing_tr.array.where(targets, drop=True) - else: - known_fp = existing_fp.array - known_tr = existing_tr.array - return known_fp, known_tr - - -def _register(new_fp: xr.DataArray, new_tr: xr.DataArray) -> tuple[xr.DataArray, xr.DataArray]: - - new_id = create_id() - - footprint = ( - new_fp.expand_dims(AXIS.component_dim) - .assign_coords( - { - AXIS.id_coord: (AXIS.component_dim, [new_id]), - AXIS.detect_coord: ( - AXIS.component_dim, - [new_tr[AXIS.frame_coord].max().item()], - ), - } - ) - .isel({AXIS.component_dim: 0}) - ) - trace = ( - new_tr.expand_dims(AXIS.component_dim) - .assign_coords( - { - AXIS.id_coord: (AXIS.component_dim, [new_id]), - AXIS.detect_coord: ( - AXIS.component_dim, - [new_tr[AXIS.frame_coord].max().item()], - ), - } - ) - .isel({AXIS.component_dim: 0}) - ) - - return footprint, trace - - -def _register_batch( - new_fps: xr.DataArray, new_trs: xr.DataArray -) -> tuple[xr.DataArray, xr.DataArray]: - count = new_fps.sizes[AXIS.component_dim] - new_ids = [create_id() for _ in range(count)] - - footprints = new_fps.assign_coords( - { - AXIS.id_coord: (AXIS.component_dim, new_ids), - AXIS.detect_coord: ( - AXIS.component_dim, - [new_trs[AXIS.frame_coord].max().item()] * count, - ), - } - ) - traces = new_trs.assign_coords( - { - AXIS.id_coord: (AXIS.component_dim, new_ids), - AXIS.detect_coord: ( - AXIS.component_dim, - [new_trs[AXIS.frame_coord].max().item()] * count, - ), - } - ) - - return footprints, traces - - -def _recompose( - movie: xr.DataArray, fp_coords: Coordinates, tr_coords: Coordinates -) -> tuple[xr.DataArray, xr.DataArray]: - # Reshape neighborhood to 2D matrix (time × space) - movie = movie.assign_coords({ax: movie[ax] for ax in AXIS.spatial_dims}) - shape = xr.DataArray( - np.sum(movie.transpose(AXIS.frames_dim, ...).data, axis=0) > 0, dims=AXIS.spatial_dims - ) - slice_ = movie.where(shape.as_numpy(), 0, drop=True) - R = slice_.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frames_dim) - - a, c, error = rank1nmf(R.values, np.mean(R.values, axis=1)) - - a_new, c_new = _reshape( - footprint=a, - trace=c, - fp_coords=fp_coords, - tr_coords=tr_coords, - slice_coords=slice_.coords, - ) - - factor = slice_.data.max() / c_new.data.max() - a_new = a_new / factor - c_new = c_new * factor - - return a_new, c_new - - -def _reshape( - footprint: np.ndarray, - trace: np.ndarray, - fp_coords: Coordinates, - tr_coords: Coordinates, - slice_coords: Coordinates, -) -> tuple[xr.DataArray, xr.DataArray]: - """Convert back to xarray with proper dimensions and coordinates""" - - c_new = xr.DataArray(trace.squeeze(), dims=[AXIS.frames_dim], coords=tr_coords) - - a_new = xr.DataArray( - np.zeros(tuple(fp_coords.sizes.values())), - dims=tuple(fp_coords.sizes.keys()), - coords=fp_coords, - ) - - a_new.loc[slice_coords] = xr.DataArray( - footprint.squeeze().reshape(list(slice_coords[ax].size for ax in AXIS.spatial_dims)), - dims=AXIS.spatial_dims, - coords=slice_coords, - ) - - return a_new, c_new - - -def _merge_with( - new_fp: xr.DataArray, - new_tr: xr.DataArray, - target_fps: xr.DataArray, - target_trs: xr.DataArray, - dupe_ids: Iterable[Hashable], -) -> tuple[xr.DataArray, xr.DataArray]: - target_fp = target_fps.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: dupe_ids}) - target_tr = target_trs.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: dupe_ids}) - - recreated_movie = np.matmul( - target_fp.transpose(*AXIS.spatial_dims, ...).data, - target_tr.dropna(dim=AXIS.frames_dim).data, - ) - new_movie = np.matmul( - new_fp.transpose(*AXIS.spatial_dims, ...).data, - new_tr.dropna(dim=AXIS.frames_dim).data, - ) - combined_movie = xr.DataArray( - recreated_movie + new_movie, dims=[*AXIS.spatial_dims, AXIS.frames_dim] - ) - - a_new, c_new = _recompose( - combined_movie, - new_fp.isel({AXIS.component_dim: 0}).coords, - new_tr.isel({AXIS.component_dim: 0}).coords, - ) - a_new.attrs["replaces"] = target_fp[AXIS.id_coord].values.tolist() - c_new.attrs["replaces"] = target_tr[AXIS.id_coord].values.tolist() - - return _register(a_new, c_new) - - -def _merge( - footprints: xr.DataArray, traces: xr.DataArray, merge_matrix: xr.DataArray -) -> tuple[xr.DataArray, xr.DataArray]: - num, label = connected_components(merge_matrix.data) - combined_fps = [] - combined_trs = [] - - for lbl in set(label): - group = np.where(label == lbl)[0] - fps = footprints.sel({AXIS.component_dim: group}) - trs = traces.sel({AXIS.component_dim: group}) - if len(group) > 1: - res = xr.DataArray( - np.matmul(fps.transpose(*AXIS.spatial_dims, ...).data, trs.data), - dims=[*AXIS.spatial_dims, AXIS.frames_dim], - ) - new_fp, new_tr = _recompose(res, footprints[0].coords, traces[0].coords) - else: - new_fp, new_tr = fps[0], trs[0] - combined_fps.append(new_fp) - combined_trs.append(new_tr) - - new_fps = xr.concat(combined_fps, dim=AXIS.component_dim) - new_trs = xr.concat(combined_trs, dim=AXIS.component_dim) - - return new_fps, new_trs From 3e4b85744115722aacce21b885ff2f4f579e740d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 25 Nov 2025 19:07:36 -0800 Subject: [PATCH 26/33] feat: separate out segment loop --- tests/data/pipelines/long_recording.yaml | 24 -------- tests/data/pipelines/minian.yaml | 24 -------- tests/data/pipelines/odl.yaml | 66 +++------------------- tests/data/pipelines/segment.yaml | 72 ++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 106 deletions(-) create mode 100644 tests/data/pipelines/segment.yaml diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml index 3418ef82..a5392a4e 100644 --- a/tests/data/pipelines/long_recording.yaml +++ b/tests/data/pipelines/long_recording.yaml @@ -210,27 +210,3 @@ nodes: depends: - asset: assets.footprints - curr_epoch: counter.idx - -# merge: -# type: cala.nodes.merge.merge_existing -# params: -# merge_interval: 500 -# merge_threshold: 0.95 -# smooth_kwargs: -# sigma: 2 -# depends: -# - shapes: assets.footprints -# - traces: assets.traces -# - overlaps: assets.overlaps -# - trigger: detect_update.footprints -# merge_update: -# type: cala.nodes.segment.update_assets -# depends: -# - new_footprints: merge.footprints -# - new_traces: merge.traces -# - footprints: assets.footprints -# - traces: assets.traces -# - pixel_stats: assets.pix_stats -# - component_stats: assets.comp_stats -# - overlaps: assets.overlaps -# - buffer: assets.buffer diff --git a/tests/data/pipelines/minian.yaml b/tests/data/pipelines/minian.yaml index 09aa0602..a6204089 100644 --- a/tests/data/pipelines/minian.yaml +++ b/tests/data/pipelines/minian.yaml @@ -195,27 +195,3 @@ nodes: - overlaps: assets.overlaps - buffer: assets.buffer # DETECT ENDS - -# merge: -# type: cala.nodes.merge.merge_existing -# params: -# merge_interval: 500 -# merge_threshold: 0.95 -# smooth_kwargs: -# sigma: 2 -# depends: -# - shapes: assets.footprints -# - traces: assets.traces -# - overlaps: assets.overlaps -# - trigger: detect_update.footprints -# merge_update: -# type: cala.nodes.segment.update_assets -# depends: -# - new_footprints: merge.footprints -# - new_traces: merge.traces -# - footprints: assets.footprints -# - traces: assets.traces -# - pixel_stats: assets.pix_stats -# - component_stats: assets.comp_stats -# - overlaps: assets.overlaps -# - buffer: assets.buffer diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 3362825c..eb5bed9e 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -51,7 +51,7 @@ nodes: type: cala.nodes.prep.GlowRemover depends: - frame: frame.value - size_est: + radius: type: cala.nodes.prep.SizeEst params: hardset_radius: 5 @@ -109,70 +109,20 @@ nodes: - residuals: assets.residuals # FRAME UPDATE ENDS - # DETECT BEGINS - nmf: - type: cala.nodes.segment.SliceNMF + segment: + type: tube params: - min_frames: 10 - detect_thresh: 1.0 - reprod_tol: 0.001 + tube: segment-loop depends: - - residuals: residual.movie + - residual: residual.movie - energy: residual.std - - detect_radius: size_est.radius - catalog: - type: cala.nodes.segment.Cataloger - params: - age_limit: 100 - shape_smooth_kwargs: - ksize: [ 1, 1 ] - sigmaX: 0 - trace_smooth_kwargs: - sigma: 2 - merge_threshold: 0.95 - val_threshold: 0.5 - cnt_threshold: 10 - depends: - - new_fps: nmf.new_fps - - new_trs: nmf.new_trs - - existing_fp: assets.footprints - - existing_tr: assets.traces - detect_update: - type: cala.nodes.segment.update_assets - depends: - - new_footprints: catalog.new_footprints - - new_traces: catalog.new_traces + - radius: radius.radius - footprints: assets.footprints - traces: assets.traces - - pixel_stats: assets.pix_stats - - component_stats: assets.comp_stats + - frame_sum: assets.pix_stats + - trace_sum: assets.comp_stats - overlaps: assets.overlaps - buffer: assets.buffer - # DETECT ENDS - - # merge: - # type: cala.nodes.merge.merge_existing - # params: - # merge_interval: 500 - # merge_threshold: 0.95 - # smooth_kwargs: - # sigma: 2 - # depends: - # - shapes: assets.footprints - # - traces: assets.traces - # - overlaps: assets.overlaps - # - trigger: detect_update.footprints - # merge_update: - # type: cala.nodes.segment.update_assets - # depends: - # - new_footprints: merge.footprints - # - new_traces: merge.traces - # - footprints: assets.footprints - # - traces: assets.traces - # - pixel_stats: assets.pix_stats - # - component_stats: assets.comp_stats - # - overlaps: assets.overlaps - # - buffer: assets.buffer return: type: return diff --git a/tests/data/pipelines/segment.yaml b/tests/data/pipelines/segment.yaml new file mode 100644 index 00000000..859855a9 --- /dev/null +++ b/tests/data/pipelines/segment.yaml @@ -0,0 +1,72 @@ +noob_id: segment-loop +noob_model: noob.tube.TubeSpecification +noob_version: 0.1.1.dev209+g1cafe61 +input: + residual: + scope: process + type: np.ndarray + energy: + scope: process + type: np.ndarray + radius: + scope: process + type: int + footprints: + scope: process + type: np.ndarray + traces: + scope: process + type: np.ndarray + frame_sum: + scope: process + type: np.ndarray + trace_sum: + scope: process + type: np.ndarray + overlaps: + scope: process + type: np.ndarray + buffer: + scope: process + type: np.ndarray +nodes: + nmf: + type: cala.nodes.segment.SliceNMF + params: + min_frames: 10 + detect_thresh: 1.0 + reprod_tol: 0.001 + depends: + - residuals: input.residual + - energy: input.energy + - detect_radius: input.radius + catalog: + type: cala.nodes.segment.Cataloger + params: + age_limit: 100 + shape_smooth_kwargs: + ksize: + - 1 + - 1 + sigmaX: 0 + trace_smooth_kwargs: + sigma: 2 + merge_threshold: 0.95 + val_threshold: 0.5 + cnt_threshold: 10 + depends: + - new_fps: nmf.new_fps + - new_trs: nmf.new_trs + - existing_fp: input.footprints + - existing_tr: input.traces + detect_update: + type: cala.nodes.segment.update_assets + depends: + - new_footprints: catalog.new_footprints + - new_traces: catalog.new_traces + - footprints: input.footprints + - traces: input.traces + - pixel_stats: input.frame_sum + - component_stats: input.trace_sum + - overlaps: input.overlaps + - buffer: input.buffer From fb21044740bbb569e28f4019f5e6dbdd899b0d53 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 1 Dec 2025 16:14:33 -0800 Subject: [PATCH 27/33] feat: recurse cala --- src/cala/assets.py | 26 +++++++ src/cala/nodes/omf/footprints.py | 6 +- src/cala/nodes/omf/traces.py | 60 ---------------- tests/data/pipelines/ingest.yaml | 71 +++++++++++++++++++ tests/data/pipelines/minian.yaml | 1 - tests/data/pipelines/odl.yaml | 48 +++---------- tests/{test_iter => test_omf}/__init__.py | 0 .../test_component_stats.py | 0 .../test_footprints.py | 2 +- .../{test_iter => test_omf}/test_overlaps.py | 0 .../test_pixel_stats.py | 0 .../{test_iter => test_omf}/test_residual.py | 0 tests/{test_iter => test_omf}/test_traces.py | 0 tests/test_segment/__init__.py | 0 .../test_catalog.py | 0 .../test_cleanup.py | 0 .../{test_iter => test_segment}/test_merge.py | 0 .../test_slice_nmf.py | 0 18 files changed, 109 insertions(+), 105 deletions(-) create mode 100644 tests/data/pipelines/ingest.yaml rename tests/{test_iter => test_omf}/__init__.py (100%) rename tests/{test_iter => test_omf}/test_component_stats.py (100%) rename tests/{test_iter => test_omf}/test_footprints.py (98%) rename tests/{test_iter => test_omf}/test_overlaps.py (100%) rename tests/{test_iter => test_omf}/test_pixel_stats.py (100%) rename tests/{test_iter => test_omf}/test_residual.py (100%) rename tests/{test_iter => test_omf}/test_traces.py (100%) create mode 100644 tests/test_segment/__init__.py rename tests/{test_iter => test_segment}/test_catalog.py (100%) rename tests/{test_iter => test_segment}/test_cleanup.py (100%) rename tests/{test_iter => test_segment}/test_merge.py (100%) rename tests/{test_iter => test_segment}/test_slice_nmf.py (100%) diff --git a/src/cala/assets.py b/src/cala/assets.py index 1e1f2663..5fbc1a5b 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -460,3 +460,29 @@ def from_array(cls, array: xr.DataArray, size: int) -> Self: buffer = cls(size=size) buffer.array = array return buffer + + +class Energy(Asset): + _entity: ClassVar[Entity] = PrivateAttr( + Entity( + name="energy", + dims=(Dims.width.value, Dims.height.value), + dtype=None, # np.number, # gets converted to float64 in xarray-validate + checks=[is_non_negative, has_no_nan], + ) + ) + + _mean: np.ndarray = PrivateAttr(None) + _sq_mean: np.ndarray = PrivateAttr(None) + + def update_std(self, arr: xr.DataArray) -> xr.DataArray: + """median: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=614292""" + eta = 1 / (arr[AXIS.frame_coord].item() + 1) + + if self._mean is None: + self._mean = arr.values + self._sq_mean = np.square(arr.values) + else: + self._mean += eta * (arr.values - self._mean) + self._sq_mean += eta * (np.square(arr.values) - self._sq_mean) + return xr.DataArray(np.sqrt(self._sq_mean - np.square(self._mean)), dims=arr.dims) diff --git a/src/cala/nodes/omf/footprints.py b/src/cala/nodes/omf/footprints.py index b3d74b32..a9742ba1 100644 --- a/src/cala/nodes/omf/footprints.py +++ b/src/cala/nodes/omf/footprints.py @@ -14,9 +14,7 @@ class Footprinter(BaseModel): - tol: float - max_iter: int | None = None - bep: int | None = None + max_iter: int ratio_lb: float = 0.15 _logger = init_logger(__name__) @@ -27,6 +25,8 @@ def ingest_frame( ) -> A[Footprints, Name("footprints")]: """ Update spatial footprints using sufficient statistics. + We don't use tolerance since Frobenius norm is too expensive to + calculate with the entire footprints stack. Ã[p, i] = max(Ã[p, i] + (W[p, i] - Ã[p, :]M[:, i])/M[i, i], 0) diff --git a/src/cala/nodes/omf/traces.py b/src/cala/nodes/omf/traces.py index f8bcafcc..702e0efb 100644 --- a/src/cala/nodes/omf/traces.py +++ b/src/cala/nodes/omf/traces.py @@ -85,66 +85,6 @@ def ingest_frame( return PopSnap.from_array(updated_traces) - def _update_traces( - self, A: xr.DataArray, y: xr.DataArray, c: xr.DataArray, clusters: list[np.ndarray] - ) -> xr.DataArray: - """ - Implementation of the temporal traces update algorithm. - - This function uses block coordinate descent to update temporal traces - for overlapping components together while maintaining non-negativity constraints. - - Args: - A (xr.DataArray): Spatial footprints matrix [A, b]. - Shape: (components × pixels) - y (xr.DataArray): Current data frame. - Shape: (pixels,) - c (xr.DataArray): Last value of temporal traces. (just used for shape) - Shape: (components,) - clusters (list[np.ndarray]): list of groups that each contain component indices that - have overlapping footprints. - - Returns: - xr.DataArray: Updated temporal traces satisfying non-negativity constraints. - Shape: (components,) - """ - # Step 1: Compute projection of current frame - u = (A @ y).as_numpy() - - # Step 2: Compute gram matrix of spatial components - V = (A @ A.rename({AXIS.component_dim: f"{AXIS.component_dim}'"})).as_numpy() - - # Step 3: Extract diagonal elements for normalization - V_diag = np.diag(V) - - cnt = 0 - - # Steps 4-9: Main iteration loop until convergence - while True: - c_old = c.copy() - - # Steps 6-8: Update each group using block coordinate descent - for cluster in clusters: - # Update traces for current group (division is pointwise) - - numerator = u.isel({AXIS.component_dim: cluster}) - ( - V.isel({f"{AXIS.component_dim}'": cluster}) @ c - ).rename({f"{AXIS.component_dim}'": AXIS.component_dim}) - - c.loc[{AXIS.component_dim: cluster}] = np.maximum( - c.isel({AXIS.component_dim: cluster}) + numerator / V_diag[cluster].T, 0 - ) - - cnt += 1 - maxed = self.max_iter and (cnt == self.max_iter) - - if np.linalg.norm(c - c_old) >= self.tol * np.linalg.norm(c_old) or maxed: - if maxed: - self._logger.debug(msg="max_iter reached before converging.") - return xr.DataArray( - c.values, dims=c.dims, coords=c[AXIS.component_dim].coords - ).assign_coords(y[AXIS.frames_dim].coords) - def _update_traces( y: np.ndarray, diff --git a/tests/data/pipelines/ingest.yaml b/tests/data/pipelines/ingest.yaml new file mode 100644 index 00000000..25fb9dc7 --- /dev/null +++ b/tests/data/pipelines/ingest.yaml @@ -0,0 +1,71 @@ +noob_id: ingest-loop +noob_model: noob.tube.TubeSpecification +noob_version: 0.1.1.dev209+g1cafe61 +input: + frame: + scope: process + type: np.ndarray + traces: + scope: process + type: np.ndarray + footprints: + scope: process + type: np.ndarray + overlaps: + scope: process + type: np.ndarray + pix_stats: + scope: process + type: np.ndarray + comp_stats: + scope: process + type: np.ndarray + residuals: + scope: process + type: np.ndarray + + +nodes: + trace_frame: + type: cala.nodes.omf.traces.Tracer + params: + tol: 0.001 + max_iter: 100 + depends: + - traces: input.traces + - footprints: input.footprints + - frame: input.frame + - overlaps: input.overlaps + pix_frame: + type: cala.nodes.omf.pixel_stats.ingest_frame + depends: + - pixel_stats: input.pix_stats + - frame: input.frame + - new_traces: trace_frame.latest_trace + - footprints: input.footprints + comp_frame: + type: cala.nodes.omf.component_stats.ingest_frame + depends: + - component_stats: input.comp_stats + - frame: input.frame + - new_traces: trace_frame.latest_trace + footprints_frame: + type: cala.nodes.omf.footprints.Footprinter + params: + max_iter: 5 + ratio_lb: 0.10 + depends: + - footprints: input.footprints + - pixel_stats: pix_frame.value + - component_stats: comp_frame.value + residual: + type: cala.nodes.omf.residual.Residuer + depends: + - frame: input.frame + - footprints: footprints_frame.footprints + - traces: input.traces + - residuals: input.residuals + return: + type: return + depends: + residual.std \ No newline at end of file diff --git a/tests/data/pipelines/minian.yaml b/tests/data/pipelines/minian.yaml index a6204089..285f8f61 100644 --- a/tests/data/pipelines/minian.yaml +++ b/tests/data/pipelines/minian.yaml @@ -137,7 +137,6 @@ nodes: footprints_frame: type: cala.nodes.omf.footprints.Footprinter params: - bep: 0 tol: 0.0001 max_iter: 5 ratio_lb: 0.10 diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index eb5bed9e..fad8359c 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -64,58 +64,26 @@ nodes: - frame: glow.frame #PREPROCESS ENDS - # FRAME UPDATE BEGINS - trace_frame: - type: cala.nodes.omf.traces.Tracer + ingest: + type: tube params: - tol: 0.001 - max_iter: 100 + tube: ingest-loop depends: + - frame: glow.frame - traces: assets.traces - footprints: assets.footprints - - frame: glow.frame - overlaps: assets.overlaps - pix_frame: - type: cala.nodes.omf.pixel_stats.ingest_frame - depends: - - pixel_stats: assets.pix_stats - - frame: glow.frame - - new_traces: trace_frame.latest_trace - - footprints: assets.footprints - comp_frame: - type: cala.nodes.omf.component_stats.ingest_frame - depends: - - component_stats: assets.comp_stats - - frame: glow.frame - - new_traces: trace_frame.latest_trace - footprints_frame: - type: cala.nodes.omf.footprints.Footprinter - params: - bep: 0 - tol: 0.0001 - max_iter: 5 - ratio_lb: 0.10 - depends: - - footprints: assets.footprints - - pixel_stats: pix_frame.value - - component_stats: comp_frame.value - - residual: - type: cala.nodes.omf.residual.Residuer - depends: - - frame: glow.frame - - footprints: footprints_frame.footprints - - traces: assets.traces + - pix_stats: assets.pix_stats + - comp_stats: assets.comp_stats - residuals: assets.residuals - # FRAME UPDATE ENDS segment: type: tube params: tube: segment-loop depends: - - residual: residual.movie - - energy: residual.std + - residual: assets.residuals + - energy: ingest.value - radius: radius.radius - footprints: assets.footprints - traces: assets.traces diff --git a/tests/test_iter/__init__.py b/tests/test_omf/__init__.py similarity index 100% rename from tests/test_iter/__init__.py rename to tests/test_omf/__init__.py diff --git a/tests/test_iter/test_component_stats.py b/tests/test_omf/test_component_stats.py similarity index 100% rename from tests/test_iter/test_component_stats.py rename to tests/test_omf/test_component_stats.py diff --git a/tests/test_iter/test_footprints.py b/tests/test_omf/test_footprints.py similarity index 98% rename from tests/test_iter/test_footprints.py rename to tests/test_omf/test_footprints.py index 6439563b..fd1433e6 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_omf/test_footprints.py @@ -61,7 +61,7 @@ def fpter() -> Node: NodeSpecification( id="test_footprinter", type="cala.nodes.omf.footprints.Footprinter", - params={"bep": 0, "tol": 1e-7}, + params={"max_iter": 5, "ratio_lb": 0.10}, ) ) diff --git a/tests/test_iter/test_overlaps.py b/tests/test_omf/test_overlaps.py similarity index 100% rename from tests/test_iter/test_overlaps.py rename to tests/test_omf/test_overlaps.py diff --git a/tests/test_iter/test_pixel_stats.py b/tests/test_omf/test_pixel_stats.py similarity index 100% rename from tests/test_iter/test_pixel_stats.py rename to tests/test_omf/test_pixel_stats.py diff --git a/tests/test_iter/test_residual.py b/tests/test_omf/test_residual.py similarity index 100% rename from tests/test_iter/test_residual.py rename to tests/test_omf/test_residual.py diff --git a/tests/test_iter/test_traces.py b/tests/test_omf/test_traces.py similarity index 100% rename from tests/test_iter/test_traces.py rename to tests/test_omf/test_traces.py diff --git a/tests/test_segment/__init__.py b/tests/test_segment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_iter/test_catalog.py b/tests/test_segment/test_catalog.py similarity index 100% rename from tests/test_iter/test_catalog.py rename to tests/test_segment/test_catalog.py diff --git a/tests/test_iter/test_cleanup.py b/tests/test_segment/test_cleanup.py similarity index 100% rename from tests/test_iter/test_cleanup.py rename to tests/test_segment/test_cleanup.py diff --git a/tests/test_iter/test_merge.py b/tests/test_segment/test_merge.py similarity index 100% rename from tests/test_iter/test_merge.py rename to tests/test_segment/test_merge.py diff --git a/tests/test_iter/test_slice_nmf.py b/tests/test_segment/test_slice_nmf.py similarity index 100% rename from tests/test_iter/test_slice_nmf.py rename to tests/test_segment/test_slice_nmf.py From c1b237391cd53230ad0d5107aa7382ab8fe46e55 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 1 Dec 2025 16:17:17 -0800 Subject: [PATCH 28/33] feat: scripts --- scripts/__init__.py | 0 scripts/benchmarking.py | 204 +++++++++++++++++++++++++++++++++++++++ scripts/memory_usage.py | 35 +++++++ scripts/yappi_profile.py | 13 +++ 4 files changed, 252 insertions(+) create mode 100644 scripts/__init__.py create mode 100644 scripts/benchmarking.py create mode 100644 scripts/memory_usage.py create mode 100644 scripts/yappi_profile.py diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/benchmarking.py b/scripts/benchmarking.py new file mode 100644 index 00000000..b3a3e9f3 --- /dev/null +++ b/scripts/benchmarking.py @@ -0,0 +1,204 @@ +from datetime import datetime + +import cv2 +import matplotlib +import numpy as np +import seaborn as sns +from matplotlib import pyplot as plt +from noob import SynchronousRunner, Tube + +from cala.nodes.io import stream +from cala.nodes.prep import Anchor, blur, butter, package_frame, remove_mean +from cala.testing.util import total_gradient_magnitude + +sns.set_style("whitegrid") +font = {"family": "normal", "weight": "regular", "size": 15} + +matplotlib.rc("font", **font) + + +VIDEOS = [ + "minian/msCam1.avi", + "minian/msCam2.avi", + "minian/msCam3.avi", + "minian/msCam4.avi", + "minian/msCam5.avi", + "minian/msCam6.avi", + "minian/msCam7.avi", + "minian/msCam8.avi", + "minian/msCam9.avi", + "minian/msCam10.avi", + # "long_recording/0.avi", + # "long_recording/1.avi", + # "long_recording/2.avi", + # "long_recording/3.avi", + # "long_recording/4.avi", +] + + +def preprocess(arr, idx): + frame = package_frame(arr, idx) + frame = blur(frame, method="median", kwargs={"ksize": 3}) + frame = butter(frame, {}) + return remove_mean(frame, orient="both") + + +def test_encode(): + gen = stream(VIDEOS) + + fourcc = cv2.VideoWriter_fourcc(*"FFV1") + out = cv2.VideoWriter("encode_test.avi", fourcc, 60.0, (600, 600)) + + for arr in gen: + frame_bgr = cv2.cvtColor(arr.astype(np.uint8), cv2.COLOR_GRAY2BGR) + out.write(frame_bgr) + + out.release() + + +def test_motion_movie(): + """ + For testing how well the motion correction performs with real movie + + """ + gen = stream(VIDEOS) + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter("motion_test.avi", fourcc, 60.0, (600, 1200)) + + stab = Anchor() + + for idx, arr in enumerate(gen): + frame = preprocess(arr, idx) + matched = stab.stabilize(frame) + combined = np.concat([frame.array.values, matched.array.values], axis=0) + + frame_bgr = cv2.cvtColor(combined.astype(np.uint8), cv2.COLOR_GRAY2BGR) + out.write(frame_bgr) + + out.release() + + +def test_motion_crisp_pics(): + """ + For generating a mean summary frame across raw vs. motion-corrected video. + The motion-corrected video should have a much crisper summary picture. + + """ + + gen = stream(VIDEOS) + + stab = Anchor() + raws = [] + stabs = [] + + for idx, arr in enumerate(gen): + frame = preprocess(arr, idx) + matched = stab.stabilize(frame) + raws.append(frame.array.values) + stabs.append(matched.array.values) + + raw = np.stack(raws) + stab = np.stack(stabs) + + raw_mean = np.mean(raw, axis=0) + stab_mean = np.mean(stab, axis=0) + + crisp_raw = total_gradient_magnitude(raw_mean) + crisp_stab = total_gradient_magnitude(stab_mean) + + print(f"{crisp_raw = }, {crisp_stab = }") + + mean = np.concatenate((raw_mean, stab_mean), axis=0) + plt.imsave("motion_crisp_pics.png", mean, cmap="gray") + + +def test_motion_mean_corr(): + """ + For testing how well the motion correction performs with real movie + """ + gen = stream(VIDEOS) + + stab = Anchor() + raws = [] + stabs = [] + + for idx, arr in enumerate(gen): + frame = preprocess(arr, idx) + matched = stab.stabilize(frame) + raws.append(frame.array.values) + stabs.append(matched.array.values) + + raw = np.stack(raws) + stab = np.stack(stabs) + + raw_mean = np.mean(raw[:, 20:-20, 20:-20], axis=0) + stab_mean = np.mean(stab[:, 20:-20, 20:-20], axis=0) + + raw_cms = [np.corrcoef(r[20:-20, 20:-20].flatten(), raw_mean.flatten())[0, 1] for r in raws] + stab_cms = [np.corrcoef(s[20:-20, 20:-20].flatten(), stab_mean.flatten())[0, 1] for s in stabs] + + fig, ax = plt.subplots(figsize=(24, 10)) + plt.plot(raw_cms) + plt.plot(stab_cms) + plt.legend(["raw", "stabilized"], loc="upper right") + plt.title("Mean Correlation") + plt.xlabel("frame") + plt.ylabel("correlation") + plt.tight_layout() + plt.savefig("mc.png") + + assert False + + +def test_with_movie(): + tube = Tube.from_specification("with-minian") + runner = SynchronousRunner(tube=tube) + processed_vid = runner.run() + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter("motion_test.avi", fourcc, 20.0, (100, 100)) + + for arr in processed_vid: + frame_bgr = cv2.cvtColor(arr.array.values.astype(np.uint8), cv2.COLOR_GRAY2BGR) + out.write(frame_bgr) + out.release() + + +def test_processing_speed(): + tube = Tube.from_specification("with-minian") + runner = SynchronousRunner(tube=tube) + gen = runner.iter() + frame_speed = [] + i = 0 + while True: + try: + start = datetime.now() + next(gen) + duration = datetime.now() - start + frame_speed.append(round(duration.total_seconds(), 2)) + i += 1 + except RuntimeError: + break + fig, ax = plt.subplots(figsize=(20, 4)) + ax.set_yscale("log") + plt.plot(frame_speed) + plt.xlabel("frame", fontsize=20) + plt.ylabel("time taken (s)", fontsize=20) + plt.tight_layout() + plt.savefig("frame_speed.png") + + +def test_deglow(): + + gen = stream(VIDEOS[:3]) + + stab = Anchor() + stabs = [] + + for idx, arr in enumerate(gen): + frame = preprocess(arr, idx) + matched = stab.stabilize(frame) + stabs.append(matched.array.values) + + deglowed = stabs[-1] - np.min(stabs, axis=0) + plt.imsave("deglowed.png", deglowed, cmap="gray") diff --git a/scripts/memory_usage.py b/scripts/memory_usage.py new file mode 100644 index 00000000..f2048219 --- /dev/null +++ b/scripts/memory_usage.py @@ -0,0 +1,35 @@ +import os + +import psutil +from matplotlib import pyplot as plt +from noob import Tube, SynchronousRunner + + +def main(): + process = psutil.Process(os.getpid()) + tube = Tube.from_specification("test-memory") + runner = SynchronousRunner(tube=tube) + gen = runner.iter() + ram_use_frame = [] + i = 0 + while True: + try: + next(gen) + ram_used = process.memory_info().rss / (1024 * 1024) # in MB + ram_use_frame.append(round(ram_used, 2)) + i += 1 + if i % 100 == 0: + print(f"{i} frames processed") + except RuntimeError as e: + print(e) + break + fig, ax = plt.subplots(figsize=(40, 8)) + plt.plot(ram_use_frame) + plt.xlabel("frame") + plt.ylabel("memory used (MB)") + plt.tight_layout() + plt.savefig("ram_use.svg", format="svg") + + +if __name__ == "__main__": + main() diff --git a/scripts/yappi_profile.py b/scripts/yappi_profile.py new file mode 100644 index 00000000..f7b44da3 --- /dev/null +++ b/scripts/yappi_profile.py @@ -0,0 +1,13 @@ +import yappi + +from cala.main import main + +try: + yappi.set_clock_type("WALL") + yappi.start() + main(gui=True, spec="cala-odl") + yappi.stop() +finally: + stat = yappi.get_func_stats() + ps = yappi.convert2pstats(stat) + ps.dump_stats("prof/yappi3.prof") From 0a7a4da7c5ea1db54b3b67c89a04900cca8559ad Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 1 Dec 2025 21:50:45 -0800 Subject: [PATCH 29/33] mypy --- pyproject.toml | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ebd58be2..bb68b05f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,7 @@ test = "pytest" docs = "sphinx-autobuild docs docs/_build/html" start.cmd = "npm run dev" start.working_dir = "frontend" +mypy = "mypy" [tool.pytest.ini_options] @@ -147,10 +148,10 @@ pythonpath = ["src"] addopts = [ "-ra", "-q", -# "--cov=cala", -# "--cov-append", -# "--cov-report=term-missing", -# "--cov-report=html" + # "--cov=cala", + # "--cov-append", + # "--cov-report=term-missing", + # "--cov-report=html" ] markers = [ "timeout: marks tests that need a timeout failure to prevent falling into an infinite loop" @@ -235,22 +236,6 @@ ignore = [ "F401" ] -[tool.mypy] -python_version = "3.12" -warn_return_any = true -warn_unused_configs = true -disallow_untyped_defs = true -disallow_incomplete_defs = true -check_untyped_defs = true -disallow_untyped_decorators = false -no_implicit_optional = true -warn_redundant_casts = true -warn_unused_ignores = true -warn_no_return = true -warn_unreachable = true -strict_optional = true -exclude = ["tests/.*"] - [tool.coverage.run] source = ["src/cala"] omit = [ @@ -277,3 +262,14 @@ omit = [ "**/__init__.py", "**/conftest.py", ] +[tool.mypy] +mypy_path = "$MYPY_CONFIG_FILE_DIR/src" +packages = ["cala"] +warn_redundant_casts = true +warn_unused_ignores = true +show_error_context = true +show_column_numbers = true +show_error_code_links = true +pretty = true +color_output = true +plugins = ['pydantic.mypy'] \ No newline at end of file From c8168b3513b002dd2d32e26aa0f063a03f0ae897 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 3 Dec 2025 11:13:26 -0800 Subject: [PATCH 30/33] chore: intermediate changes to assets --- src/cala/__init__.py | 2 +- src/cala/assets/__init__.py | 4 + src/cala/{ => assets}/assets.py | 67 ++++++++-------- src/cala/assets/axis.py | 66 ++++++++++++++++ src/cala/assets/mixins.py | 0 .../{models/entity.py => assets/validate.py} | 54 ++++++++++++- .../{models/access.py => assets/xr_access.py} | 0 src/cala/gui/components/counter.py | 4 +- src/cala/gui/components/encoder.py | 4 +- src/cala/gui/components/stamper.py | 4 +- src/cala/models/__init__.py | 3 - src/cala/models/axis.py | 72 ----------------- src/cala/models/checks.py | 23 ------ src/cala/nodes/io.py | 2 +- src/cala/nodes/omf/component_stats.py | 4 +- src/cala/nodes/omf/footprints.py | 4 +- src/cala/nodes/omf/overlap.py | 4 +- src/cala/nodes/omf/pixel_stats.py | 8 +- src/cala/nodes/omf/residual.py | 12 +-- src/cala/nodes/omf/traces.py | 22 +++--- src/cala/nodes/prep/background_removal.py | 2 +- src/cala/nodes/prep/denoise.py | 2 +- src/cala/nodes/prep/downsample.py | 4 +- src/cala/nodes/prep/flatten.py | 2 +- src/cala/nodes/prep/glow_removal.py | 2 +- src/cala/nodes/prep/lines.py | 4 +- src/cala/nodes/prep/motion.py | 4 +- src/cala/nodes/prep/r_estimate.py | 4 +- src/cala/nodes/prep/wrap.py | 4 +- src/cala/nodes/segment/catalog.py | 18 ++--- src/cala/nodes/segment/cleanup.py | 6 +- src/cala/nodes/segment/persist.py | 2 +- src/cala/nodes/segment/slice_nmf.py | 16 ++-- src/cala/testing/catalog_depr.py | 4 +- src/cala/testing/nodes.py | 2 +- src/cala/testing/toy.py | 18 ++--- src/cala/testing/util.py | 2 +- src/cala/util.py | 2 +- tests/test_assets.py | 78 +++++++++---------- tests/test_config.py | 2 +- tests/test_gui.py | 4 +- tests/test_omf/test_component_stats.py | 10 +-- tests/test_omf/test_footprints.py | 4 +- tests/test_omf/test_overlaps.py | 4 +- tests/test_omf/test_pixel_stats.py | 12 +-- tests/test_omf/test_residual.py | 12 +-- tests/test_omf/test_traces.py | 14 ++-- tests/test_pipeline.py | 4 +- tests/test_prep/test_denoise.py | 2 +- tests/test_prep/test_glow_removal.py | 4 +- tests/test_prep/test_motion.py | 2 +- tests/test_prep/test_r_estimate.py | 6 +- tests/test_segment/test_catalog.py | 4 +- tests/test_segment/test_cleanup.py | 4 +- tests/test_segment/test_merge.py | 2 +- tests/test_segment/test_slice_nmf.py | 8 +- 56 files changed, 325 insertions(+), 308 deletions(-) create mode 100644 src/cala/assets/__init__.py rename src/cala/{ => assets}/assets.py (86%) create mode 100644 src/cala/assets/axis.py create mode 100644 src/cala/assets/mixins.py rename src/cala/{models/entity.py => assets/validate.py} (57%) rename src/cala/{models/access.py => assets/xr_access.py} (100%) delete mode 100644 src/cala/models/__init__.py delete mode 100644 src/cala/models/axis.py delete mode 100644 src/cala/models/checks.py diff --git a/src/cala/__init__.py b/src/cala/__init__.py index 92094731..c61ea9f4 100644 --- a/src/cala/__init__.py +++ b/src/cala/__init__.py @@ -1 +1 @@ -from cala.models import access as access +from cala.assets import xr_access as access diff --git a/src/cala/assets/__init__.py b/src/cala/assets/__init__.py new file mode 100644 index 00000000..eec23d69 --- /dev/null +++ b/src/cala/assets/__init__.py @@ -0,0 +1,4 @@ +from .assets import Traces, Footprints, PixStats, CompStats, Overlaps, Buffer +from .axis import AXIS # noqa: I001 + +__all__ = [AXIS, "Traces", "Footprints", "PixStats", "CompStats", "Overlaps", "Buffer"] diff --git a/src/cala/assets.py b/src/cala/assets/assets.py similarity index 86% rename from src/cala/assets.py rename to src/cala/assets/assets.py index 5fbc1a5b..0ccce115 100644 --- a/src/cala/assets.py +++ b/src/cala/assets/assets.py @@ -9,10 +9,9 @@ from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator from sparse import COO +from cala.assets.axis import AXIS +from cala.assets.validate import Coords, Dims, Entity, Group, has_no_nan, is_non_negative from cala.config import config -from cala.models.axis import AXIS, Coords, Dims -from cala.models.checks import has_no_nan, is_non_negative -from cala.models.entity import Entity, Group from cala.util import clear_dir AssetType = TypeVar("AssetType", xr.DataArray, Path, None) @@ -187,7 +186,7 @@ def sizes(self) -> dict[str, int]: if self.zarr_path: total_size = {} for key, val in self.array_.sizes.items(): - if key == AXIS.frames_dim: + if key == AXIS.frame_dim: total_size[key] = val + self.load_zarr().sizes[key] else: total_size[key] = val @@ -198,7 +197,7 @@ def sizes(self) -> dict[str, int]: @property def array(self) -> xr.DataArray: return ( - self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + self.array_.isel({AXIS.frame_dim: slice(-self.peek_size, None)}) if self.array_ is not None else self.array_ ) @@ -213,8 +212,8 @@ def array(self, array: xr.DataArray) -> None: if self.validate_schema: array.validate.against_schema(self._entity.model) if self.zarr_path: - self.array_ = array.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) - array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.array_ = array.isel({AXIS.frame_dim: slice(-self.peek_size, None)}) + array.isel({AXIS.frame_dim: slice(None, -self.peek_size)}).to_zarr( self.zarr_path, mode="w" ) else: @@ -230,20 +229,20 @@ def append(self, array: xr.DataArray, dim: str) -> None: """ - if dim == AXIS.frames_dim: - self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) + if dim == AXIS.frame_dim: + self.array_ = xr.concat([self.array_, array], dim=AXIS.frame_dim) - if self.zarr_path and self.array_.sizes[AXIS.frames_dim] > self.flush_interval: + if self.zarr_path and self.array_.sizes[AXIS.frame_dim] > self.flush_interval: self._flush_zarr() elif dim == AXIS.component_dim: if self.zarr_path: - n_in_memory = self.array_.sizes[AXIS.frames_dim] + n_in_memory = self.array_.sizes[AXIS.frame_dim] self.array_ = xr.concat( - [self.array_, array.isel({AXIS.frames_dim: slice(-n_in_memory, None)})], + [self.array_, array.isel({AXIS.frame_dim: slice(-n_in_memory, None)})], dim=dim, ) - array.isel({AXIS.frames_dim: slice(None, -n_in_memory)}).to_zarr( + array.isel({AXIS.frame_dim: slice(None, -n_in_memory)}).to_zarr( self.zarr_path, append_dim=dim ) else: @@ -259,25 +258,25 @@ def _flush_zarr(self) -> None: Could do this much more elegantly by pre-allocating .array_ """ raw_zarr = self.load_zarr() - to_flush = self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)}) + to_flush = self.array_.isel({AXIS.frame_dim: slice(None, -self.peek_size)}) if self._deprecated: zarr_ids = raw_zarr[AXIS.id_coord].values zarr_detects = raw_zarr[AXIS.detect_coord].values intact_mask = ~np.isin(zarr_ids, self._deprecated) - n_flush = to_flush.sizes[AXIS.frames_dim] + n_flush = to_flush.sizes[AXIS.frame_dim] prealloc = xr.DataArray( np.full((raw_zarr.sizes[AXIS.component_dim], n_flush), np.nan), - dims=[AXIS.component_dim, AXIS.frames_dim], + dims=[AXIS.component_dim, AXIS.frame_dim], coords={ AXIS.id_coord: (AXIS.component_dim, zarr_ids), AXIS.detect_coord: (AXIS.component_dim, zarr_detects), }, - ).assign_coords(to_flush[AXIS.frames_dim].coords) + ).assign_coords(to_flush[AXIS.frame_dim].coords) prealloc.loc[intact_mask] = to_flush - prealloc.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim) + prealloc.to_zarr(self.zarr_path, append_dim=AXIS.frame_dim) else: - to_flush.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim) - self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + to_flush.to_zarr(self.zarr_path, append_dim=AXIS.frame_dim) + self.array_ = self.array_.isel({AXIS.frame_dim: slice(-self.peek_size, None)}) def keep(self, intact_mask: np.ndarray) -> None: if self.zarr_path: @@ -291,7 +290,7 @@ def from_array(cls, array: xr.DataArray) -> "Traces": so we don't really have to worry about specifying the parameters. """ - new_cls = cls(peek_size=array.sizes[AXIS.frames_dim]) + new_cls = cls(peek_size=array.sizes[AXIS.frame_dim]) new_cls.array = array return new_cls @@ -301,7 +300,7 @@ def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.Da zarr_ids = raw_zarr[AXIS.id_coord].values intact_mask = ~np.isin(zarr_ids, self._deprecated) - return xr.concat([raw_zarr[intact_mask], self.array_], dim=AXIS.frames_dim).compute() + return xr.concat([raw_zarr[intact_mask], self.array_], dim=AXIS.frame_dim).compute() else: return self.array_.isel(isel_filter).sel(sel_filter) @@ -380,7 +379,7 @@ class Overlaps(Asset): class Buffer(Asset): """ - Implements a fake ring buffer to avoid expensive copying that occurs with + Implements a bip buffer to avoid expensive copying that occurs with numpy concat, append, and stack. Works by preallocating a space twice the desired size. @@ -420,9 +419,9 @@ def array(self) -> xr.DataArray | None: if self.array_ is None: return None if self._full: - out = self.array_.isel({AXIS.frames_dim: slice(self._next, self._next + self.size)}) + out = self.array_.isel({AXIS.frame_dim: slice(self._next, self._next + self.size)}) else: - out = self.array_.isel({AXIS.frames_dim: slice(None, self._next)}) + out = self.array_.isel({AXIS.frame_dim: slice(None, self._next)}) # kinda expensive. maybe float is fine? return out # .assign_coords({AXIS.frame_coord: out[AXIS.frame_coord].astype(int)}) @@ -433,26 +432,26 @@ def array(self, array: xr.DataArray) -> None: """ array = ( array.volumize.dim_with_coords( - dim=AXIS.frames_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] + dim=AXIS.frame_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] ) - if AXIS.frames_dim not in array.dims - else array.isel({AXIS.frames_dim: slice(-self.size, None)}) + if AXIS.frame_dim not in array.dims + else array.isel({AXIS.frame_dim: slice(-self.size, None)}) ) fill_sizes = dict(array.sizes) - fill_sizes[AXIS.frames_dim] = self.size - array.sizes[AXIS.frames_dim] + fill_sizes[AXIS.frame_dim] = self.size - array.sizes[AXIS.frame_dim] fill = np.zeros(list(fill_sizes.values())) filler = xr.DataArray( fill, dims=array.dims, coords={ - AXIS.frame_coord: (AXIS.frames_dim, [np.nan] * (fill_sizes[AXIS.frames_dim])), - AXIS.timestamp_coord: (AXIS.frames_dim, [""] * (fill_sizes[AXIS.frames_dim])), + AXIS.frame_coord: (AXIS.frame_dim, [np.nan] * (fill_sizes[AXIS.frame_dim])), + AXIS.timestamp_coord: (AXIS.frame_dim, [""] * (fill_sizes[AXIS.frame_dim])), }, ) - buffer = xr.concat([array, filler] * 2, dim=AXIS.frames_dim) + buffer = xr.concat([array, filler] * 2, dim=AXIS.frame_dim) - self._full = array.sizes[AXIS.frames_dim] >= self.size - self._next = np.min((array.sizes[AXIS.frames_dim], self.size)) % self.size + self._full = array.sizes[AXIS.frame_dim] >= self.size + self._next = np.min((array.sizes[AXIS.frame_dim], self.size)) % self.size self.array_ = buffer @classmethod diff --git a/src/cala/assets/axis.py b/src/cala/assets/axis.py new file mode 100644 index 00000000..83891614 --- /dev/null +++ b/src/cala/assets/axis.py @@ -0,0 +1,66 @@ +from enum import StrEnum + + +class classproperty: + + def __init__(self, func): + self._func = func + + def __get__(self, obj, owner): + return self._func(owner) + + +class Dim(StrEnum): + frame = "frame" + height = "height" + width = "width" + component = "component" + """Name of the dimension representing individual components.""" + + @classproperty + def spatial(cls) -> tuple["Dim", "Dim"]: + return cls.height, cls.width + + +class Coord(StrEnum): + id = "id" + timestamp = "timestamp" + detect = "detected_on" + frame = "frame_idx" + width = "width" + height = "height" + + +class Axis: + """Mixin providing common axis-related attributes.""" + + frame_dim: str = "frame" + height_dim: str = "height" + width_dim: str = "width" + component_dim: str = "component" + + id_coord: str = "id_" + timestamp_coord: str = "timestamp" + detect_coord: str = "detected_on" + frame_coord: str = "frame_idx" + width_coord: str = "width" + height_coord: str = "height" + + @property + def spatial_dims(self) -> tuple[str, str]: + """Names of the dimensions representing 2-d spatial coordinates Default: (height, width).""" + return self.height_dim, self.width_dim + + @property + def component_rename(self) -> dict[str, str]: + return { + axis: self.duplicate(axis) + for axis in (self.component_dim, self.id_coord, self.detect_coord) + } + + @staticmethod + def duplicate(axis: str) -> str: + return f"{axis}'" + + +AXIS = Axis() diff --git a/src/cala/assets/mixins.py b/src/cala/assets/mixins.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cala/models/entity.py b/src/cala/assets/validate.py similarity index 57% rename from src/cala/models/entity.py rename to src/cala/assets/validate.py index c21cfc21..ea6c2846 100644 --- a/src/cala/models/entity.py +++ b/src/cala/assets/validate.py @@ -1,11 +1,63 @@ from collections.abc import Callable from copy import deepcopy +from enum import Enum from typing import Any +import numpy as np +import xarray as xr from pydantic import BaseModel, Field, PrivateAttr from xarray_validate import CoordsSchema, DataArraySchema, DimsSchema, DTypeSchema -from cala.models.axis import Coord, Dim, Dims +from cala.assets import AXIS + + +def is_non_negative(da: xr.DataArray) -> None: + if da.min() < 0: + raise ValueError("Array is not non-negative") + + +def is_unique(da: xr.DataArray) -> None: + elem, counts = np.unique(da, return_counts=True) + if counts.max() > 1: + raise ValueError(f"The values in DataArray are not unique : {elem[counts > 1]}") + + +def is_unit_interval(da: xr.DataArray) -> None: + if da.min() < 0 or da.max() > 1: + raise ValueError("The values in DataArray are not unit interval.") + + +def has_no_nan(da: xr.DataArray) -> None: + if np.isnan(da).any(): + raise ValueError("The DataArray has nan values.") + + +class Coord(BaseModel): + name: str + dtype: type + dim: str | None = None + checks: list[Callable] = Field(default_factory=list) + + +class Dim(BaseModel): + name: str + coords: list[Coord] = Field(default_factory=list) + + +class Coords(Enum): + id = Coord(name=AXIS.id_coord, dtype=str, checks=[is_unique]) + height = Coord(name=AXIS.height_coord, dtype=int, checks=[is_unique]) + width = Coord(name=AXIS.width_coord, dtype=int, checks=[is_unique]) + frame = Coord(name=AXIS.frame_coord, dtype=int, checks=[is_unique]) + timestamp = Coord(name=AXIS.timestamp_coord, dtype=str, checks=[is_unique]) + detected = Coord(name=AXIS.detect_coord, dtype=int, checks=[has_no_nan]) + + +class Dims(Enum): + width = Dim(name=AXIS.width_dim, coords=[Coords.width.value]) + height = Dim(name=AXIS.height_dim, coords=[Coords.height.value]) + frame = Dim(name=AXIS.frame_dim, coords=[Coords.frame.value, Coords.timestamp.value]) + component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.detected.value]) class Entity(BaseModel): diff --git a/src/cala/models/access.py b/src/cala/assets/xr_access.py similarity index 100% rename from src/cala/models/access.py rename to src/cala/assets/xr_access.py diff --git a/src/cala/gui/components/counter.py b/src/cala/gui/components/counter.py index e8b7135b..2e22b810 100644 --- a/src/cala/gui/components/counter.py +++ b/src/cala/gui/components/counter.py @@ -1,5 +1,5 @@ -from cala.assets import Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Traces def component_counter(index: int, traces: Traces) -> dict[str, int]: diff --git a/src/cala/gui/components/encoder.py b/src/cala/gui/components/encoder.py index 1b1e8539..97b50595 100644 --- a/src/cala/gui/components/encoder.py +++ b/src/cala/gui/components/encoder.py @@ -5,9 +5,9 @@ from av.video import VideoStream from noob.node import Node -from cala.assets import Frame +from cala.assets import AXIS +from cala.assets.assets import Frame from cala.config import config -from cala.models import AXIS from cala.util import clear_dir diff --git a/src/cala/gui/components/stamper.py b/src/cala/gui/components/stamper.py index d4179137..c92b5505 100644 --- a/src/cala/gui/components/stamper.py +++ b/src/cala/gui/components/stamper.py @@ -5,9 +5,9 @@ import xarray as xr from pydantic import BaseModel, ConfigDict -from cala.assets import Footprints, PopSnap +from cala.assets import AXIS +from cala.assets.assets import Footprints, PopSnap from cala.gui.components import Encoder -from cala.models import AXIS COLOR_MAP = { "red": (0, 0, 1), diff --git a/src/cala/models/__init__.py b/src/cala/models/__init__.py deleted file mode 100644 index 06c35bb2..00000000 --- a/src/cala/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .axis import AXIS - -__all__ = [AXIS] diff --git a/src/cala/models/axis.py b/src/cala/models/axis.py deleted file mode 100644 index 6430b9da..00000000 --- a/src/cala/models/axis.py +++ /dev/null @@ -1,72 +0,0 @@ -from collections.abc import Callable -from enum import Enum - -from pydantic import BaseModel, Field - -from cala.models.checks import has_no_nan, is_unique - - -class Axis: - """Mixin providing common axis-related attributes.""" - - frames_dim: str = "frame" - height_dim: str = "height" - width_dim: str = "width" - component_dim: str = "component" - """Name of the dimension representing individual components.""" - - id_coord: str = "id_" - timestamp_coord: str = "timestamp" - detect_coord: str = "detected_on" - frame_coord: str = "frame_idx" - width_coord: str = "width" - height_coord: str = "height" - - @property - def spatial_dims(self) -> tuple[str, str]: - """Names of the dimensions representing 2-d spatial coordinates Default: (height, width).""" - return self.height_dim, self.width_dim - - @property - def spatial_coords(self) -> tuple[str, str]: - """Names of the dimensions representing 2-d spatial coordinates Default: (height, width).""" - return self.height_coord, self.width_coord - - @property - def component_rename(self) -> dict[str, str]: - return { - AXIS.component_dim: f"{AXIS.component_dim}'", - AXIS.id_coord: f"{AXIS.id_coord}'", - AXIS.detect_coord: f"{AXIS.detect_coord}'", - } - - -AXIS = Axis() - - -class Coord(BaseModel): - name: str - dtype: type - dim: str | None = None - checks: list[Callable] = Field(default_factory=list) - - -class Dim(BaseModel): - name: str - coords: list[Coord] = Field(default_factory=list) - - -class Coords(Enum): - id = Coord(name=AXIS.id_coord, dtype=str, checks=[is_unique]) - height = Coord(name=AXIS.height_coord, dtype=int, checks=[is_unique]) - width = Coord(name=AXIS.width_coord, dtype=int, checks=[is_unique]) - frame = Coord(name=AXIS.frame_coord, dtype=int, checks=[is_unique]) - timestamp = Coord(name=AXIS.timestamp_coord, dtype=str, checks=[is_unique]) - detected = Coord(name=AXIS.detect_coord, dtype=int, checks=[has_no_nan]) - - -class Dims(Enum): - width = Dim(name=AXIS.width_dim, coords=[Coords.width.value]) - height = Dim(name=AXIS.height_dim, coords=[Coords.height.value]) - frame = Dim(name=AXIS.frames_dim, coords=[Coords.frame.value, Coords.timestamp.value]) - component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.detected.value]) diff --git a/src/cala/models/checks.py b/src/cala/models/checks.py deleted file mode 100644 index a4ace313..00000000 --- a/src/cala/models/checks.py +++ /dev/null @@ -1,23 +0,0 @@ -import numpy as np -import xarray as xr - - -def is_non_negative(da: xr.DataArray) -> None: - if da.min() < 0: - raise ValueError("Array is not non-negative") - - -def is_unique(da: xr.DataArray) -> None: - elem, counts = np.unique(da, return_counts=True) - if counts.max() > 1: - raise ValueError(f"The values in DataArray are not unique : {elem[counts > 1]}") - - -def is_unit_interval(da: xr.DataArray) -> None: - if da.min() < 0 or da.max() > 1: - raise ValueError("The values in DataArray are not unit interval.") - - -def has_no_nan(da: xr.DataArray) -> None: - if np.isnan(da).any(): - raise ValueError("The DataArray has nan values.") diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index 7b35cbd3..d58b2c01 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from skimage import io -from cala.assets import Asset +from cala.assets.assets import Asset from cala.config import config diff --git a/src/cala/nodes/omf/component_stats.py b/src/cala/nodes/omf/component_stats.py index 2d03117b..0c418002 100644 --- a/src/cala/nodes/omf/component_stats.py +++ b/src/cala/nodes/omf/component_stats.py @@ -1,8 +1,8 @@ import numpy as np import xarray as xr -from cala.assets import CompStats, Frame, PopSnap, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import CompStats, Frame, PopSnap, Traces def ingest_frame(component_stats: CompStats, frame: Frame, new_traces: PopSnap) -> CompStats: diff --git a/src/cala/nodes/omf/footprints.py b/src/cala/nodes/omf/footprints.py index a9742ba1..b9523791 100644 --- a/src/cala/nodes/omf/footprints.py +++ b/src/cala/nodes/omf/footprints.py @@ -7,9 +7,9 @@ from scipy.sparse import csc_matrix, vstack from sparse import COO -from cala.assets import CompStats, Footprints, PixStats +from cala.assets import AXIS +from cala.assets.assets import CompStats, Footprints, PixStats from cala.logging import init_logger -from cala.models import AXIS from cala.util import concatenate_coordinates diff --git a/src/cala/nodes/omf/overlap.py b/src/cala/nodes/omf/overlap.py index 1e056c00..55f4f191 100644 --- a/src/cala/nodes/omf/overlap.py +++ b/src/cala/nodes/omf/overlap.py @@ -2,8 +2,8 @@ import xarray as xr from sparse import COO -from cala.assets import Footprints, Overlaps -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Footprints, Overlaps from cala.util import concatenate_coordinates, sp_matmul, stack_sparse diff --git a/src/cala/nodes/omf/pixel_stats.py b/src/cala/nodes/omf/pixel_stats.py index df70cdeb..db966f53 100644 --- a/src/cala/nodes/omf/pixel_stats.py +++ b/src/cala/nodes/omf/pixel_stats.py @@ -5,8 +5,8 @@ from noob import Name from scipy.sparse import csr_matrix -from cala.assets import Buffer, Footprints, Frame, Movie, PixStats, PopSnap, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Buffer, Footprints, Frame, Movie, PixStats, PopSnap, Traces def ingest_frame( @@ -124,7 +124,7 @@ def ingest_component( def fill_buffer(buffer: Buffer, frame: Frame) -> A[Buffer, Name("buffer")]: if buffer.array is None: buffer.array = frame.array.volumize.dim_with_coords( - dim=AXIS.frames_dim, coords=[AXIS.timestamp_coord] + dim=AXIS.frame_dim, coords=[AXIS.timestamp_coord] ) return buffer @@ -168,7 +168,7 @@ def initialize( def outer_with_sparse_mask( masks: xr.DataArray, target: xr.DataArray, right: xr.DataArray, scalar: int = None ) -> np.ndarray: - n_frames = target.sizes[AXIS.frames_dim] + n_frames = target.sizes[AXIS.frame_dim] target_flat = target.data.reshape((n_frames, -1)) n_components = masks.sizes[AXIS.component_dim] A_sparse = masks.data.reshape((n_components, -1)).tocsr() diff --git a/src/cala/nodes/omf/residual.py b/src/cala/nodes/omf/residual.py index a13e651b..37ec444e 100644 --- a/src/cala/nodes/omf/residual.py +++ b/src/cala/nodes/omf/residual.py @@ -6,8 +6,8 @@ from pydantic import BaseModel, PrivateAttr from scipy.sparse import csr_matrix -from cala.assets import Buffer, Footprints, Frame, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Buffer, Footprints, Frame, Traces class Residuer(BaseModel): @@ -47,7 +47,7 @@ def update( if footprints.array is None or traces.array is None: if residuals.array is None: - residuals.array = frame.array.expand_dims(dim=AXIS.frames_dim) + residuals.array = frame.array.expand_dims(dim=AXIS.frame_dim) else: residuals.append(frame.array) std = self._update_std(frame.array) @@ -55,7 +55,7 @@ def update( return residuals, std Y = frame.array - C = traces.array.isel({AXIS.frames_dim: -1}) # (components,) + C = traces.array.isel({AXIS.frame_dim: -1}) # (components,) A = footprints.array A_pix = ( A.transpose(AXIS.component_dim, ...) @@ -66,7 +66,7 @@ def update( R_curr, flag = _find_overestimates(Y=Y, A=A_pix, C=C) if flag: C = _align_overestimates(A_pix=A_pix, C_latest=C, R_latest=R_curr) - traces.array.loc[{AXIS.frames_dim: -1}] = C + traces.array.loc[{AXIS.frame_dim: -1}] = C # if recently discovered, set to zero (or a small number). otherwise, just append preserve_area = _get_new_estimators_area(A=A, C=C) @@ -98,7 +98,7 @@ def _update_std(self, arr: xr.DataArray) -> xr.DataArray: def _init_energy(res: xr.DataArray) -> xr.DataArray: - return res.std(dim=AXIS.frames_dim) + return res.std(dim=AXIS.frame_dim) def _get_residuals(Y: xr.DataArray, A: csr_matrix, C: xr.DataArray) -> xr.DataArray: diff --git a/src/cala/nodes/omf/traces.py b/src/cala/nodes/omf/traces.py index 702e0efb..9451430f 100644 --- a/src/cala/nodes/omf/traces.py +++ b/src/cala/nodes/omf/traces.py @@ -7,9 +7,9 @@ from pydantic import BaseModel from scipy.sparse.csgraph import connected_components -from cala.assets import Footprints, Frame, Overlaps, PopSnap, Traces +from cala.assets import AXIS +from cala.assets.assets import Footprints, Frame, Overlaps, PopSnap, Traces from cala.logging import init_logger -from cala.models import AXIS from cala.util import norm, stack_sparse @@ -57,7 +57,7 @@ def ingest_frame( # Prepare inputs for the update algorithm A = stack_sparse(footprints.array, AXIS.component_dim).tocsr().T y = frame.array.data.reshape((-1,)) - c = traces.array.isel({AXIS.frames_dim: -1}).copy() + c = traces.array.isel({AXIS.frame_dim: -1}).copy() AtA = (A.T @ A).toarray() @@ -77,11 +77,11 @@ def ingest_frame( if traces.zarr_path: updated_tr = updated_traces.volumize.dim_with_coords( - dim=AXIS.frames_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] + dim=AXIS.frame_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] ) - traces.append(updated_tr, dim=AXIS.frames_dim) + traces.append(updated_tr, dim=AXIS.frame_dim) else: - traces.append(updated_traces, dim=AXIS.frames_dim) + traces.append(updated_traces, dim=AXIS.frame_dim) return PopSnap.from_array(updated_traces) @@ -153,8 +153,8 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: traces.array = c_new return traces - total_frames = traces.sizes[AXIS.frames_dim] - new_n_frames = c_new.sizes[AXIS.frames_dim] + total_frames = traces.sizes[AXIS.frame_dim] + new_n_frames = c_new.sizes[AXIS.frame_dim] merged_ids = c_new.attrs.get("replaces") if merged_ids: @@ -173,14 +173,14 @@ def _pad_history(traces: xr.DataArray, total_nframes: int, value: float = np.nan Pad unknown historical epochs with values... """ - new_nframes = traces.sizes[AXIS.frames_dim] + new_nframes = traces.sizes[AXIS.frame_dim] c_new = xr.DataArray( np.full((traces.sizes[AXIS.component_dim], total_nframes), value), - dims=[AXIS.component_dim, AXIS.frames_dim], + dims=[AXIS.component_dim, AXIS.frame_dim], coords=traces[AXIS.component_dim].coords, ) - c_new.loc[{AXIS.frames_dim: slice(total_nframes - new_nframes, None)}] = traces + c_new.loc[{AXIS.frame_dim: slice(total_nframes - new_nframes, None)}] = traces return c_new diff --git a/src/cala/nodes/prep/background_removal.py b/src/cala/nodes/prep/background_removal.py index 419c6028..f028e239 100644 --- a/src/cala/nodes/prep/background_removal.py +++ b/src/cala/nodes/prep/background_removal.py @@ -8,7 +8,7 @@ from scipy.ndimage import uniform_filter from skimage.morphology import disk -from cala.assets import Frame +from cala.assets.assets import Frame def remove_background( diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index d7e6f5c7..b950bfe9 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from skimage.restoration import calibrate_denoiser -from cala.assets import Frame +from cala.assets.assets import Frame def _bilateral(arr: np.ndarray, **kwargs: Any) -> np.ndarray: diff --git a/src/cala/nodes/prep/downsample.py b/src/cala/nodes/prep/downsample.py index 1b37f553..bc7e969c 100644 --- a/src/cala/nodes/prep/downsample.py +++ b/src/cala/nodes/prep/downsample.py @@ -3,8 +3,8 @@ import numpy as np from noob import Name -from cala.assets import Frame -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Frame from cala.nodes.prep import package_frame diff --git a/src/cala/nodes/prep/flatten.py b/src/cala/nodes/prep/flatten.py index 8db07ba9..344c99a9 100644 --- a/src/cala/nodes/prep/flatten.py +++ b/src/cala/nodes/prep/flatten.py @@ -6,7 +6,7 @@ from skimage.filters import butterworth from skimage.restoration import rolling_ball -from cala.assets import Frame +from cala.assets.assets import Frame def butter(frame: Frame, kwargs: dict[str, Any]) -> A[Frame, Name("frame")]: diff --git a/src/cala/nodes/prep/glow_removal.py b/src/cala/nodes/prep/glow_removal.py index c5cd1d8d..917a3304 100644 --- a/src/cala/nodes/prep/glow_removal.py +++ b/src/cala/nodes/prep/glow_removal.py @@ -4,7 +4,7 @@ import xarray as xr from noob import Name -from cala.assets import Frame +from cala.assets.assets import Frame class GlowRemover: diff --git a/src/cala/nodes/prep/lines.py b/src/cala/nodes/prep/lines.py index 232a2115..84a8c018 100644 --- a/src/cala/nodes/prep/lines.py +++ b/src/cala/nodes/prep/lines.py @@ -6,8 +6,8 @@ from scipy.ndimage import convolve1d from scipy.signal import firwin, welch -from cala.assets import Frame -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Frame def remove_mean(frame: Frame, orient: Literal["horiz", "vert", "both"]) -> A[Frame, Name("frame")]: diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index 40e06a23..109ba279 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -10,8 +10,8 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator from skimage.filters import difference_of_gaussians -from cala.assets import Frame -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Frame from cala.testing.util import shift_by diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py index 7124c09c..458c2b48 100644 --- a/src/cala/nodes/prep/r_estimate.py +++ b/src/cala/nodes/prep/r_estimate.py @@ -6,8 +6,8 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from skimage.feature import blob_log -from cala.assets import Frame -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Frame class SizeEst(BaseModel): diff --git a/src/cala/nodes/prep/wrap.py b/src/cala/nodes/prep/wrap.py index 17932111..20dcdede 100644 --- a/src/cala/nodes/prep/wrap.py +++ b/src/cala/nodes/prep/wrap.py @@ -7,8 +7,8 @@ import xarray as xr from noob import Name -from cala.assets import Frame -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Frame def counter(start: int = 0, limit: int = 1e7) -> A[Generator[int], Name("idx")]: diff --git a/src/cala/nodes/segment/catalog.py b/src/cala/nodes/segment/catalog.py index 7b7c35c5..381debb8 100644 --- a/src/cala/nodes/segment/catalog.py +++ b/src/cala/nodes/segment/catalog.py @@ -14,8 +14,8 @@ from skimage.measure import label from xarray import Coordinates -from cala.assets import Footprint, Footprints, Trace, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Footprint, Footprints, Trace, Traces from cala.util import combine_attr_replaces, concat_components, create_id, rank1nmf @@ -86,7 +86,7 @@ def _merge_matrix( """ overlaps = shapes_1.data @ shapes_2.T > 0 # corr is fast. (~1ms to 4ms) - corrs = xr.corr(traces_1, traces_2, dim=AXIS.frames_dim) > self.merge_threshold + corrs = xr.corr(traces_1, traces_2, dim=AXIS.frame_dim) > self.merge_threshold return xr.DataArray(overlaps * corrs.values, dims=corrs.dims, coords=corrs.coords) def _monopartite_merge_matrix(self, fps: xr.DataArray, trs: xr.DataArray) -> xr.DataArray: @@ -121,7 +121,7 @@ def _merge_candidates( trs = traces.sel({AXIS.component_dim: group}) if len(group) > 1: mov = xr.DataArray( - _create_component_movie(fps, trs), dims=[*AXIS.spatial_dims, AXIS.frames_dim] + _create_component_movie(fps, trs), dims=[*AXIS.spatial_dims, AXIS.frame_dim] ) new_fp, new_tr = self._recompose(mov, footprints[0].coords, traces[0].coords) else: @@ -210,7 +210,7 @@ def _absorb_component( combined_movie = xr.DataArray( absorber_movie.reshape(insert_movie.shape) + insert_movie, - dims=[*AXIS.spatial_dims, AXIS.frames_dim], + dims=[*AXIS.spatial_dims, AXIS.frame_dim], ) a_new, c_new = self._recompose( @@ -246,10 +246,10 @@ def _recompose( """ movie = movie.assign_coords({ax: movie[ax] for ax in AXIS.spatial_dims}) shape = xr.DataArray( - np.sum(movie.transpose(AXIS.frames_dim, ...).data, axis=0) > 0, dims=AXIS.spatial_dims + np.sum(movie.transpose(AXIS.frame_dim, ...).data, axis=0) > 0, dims=AXIS.spatial_dims ) slice_ = movie.where(shape.as_numpy(), 0, drop=True) - R = slice_.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frames_dim) + R = slice_.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frame_dim) a, c, error = rank1nmf(R.values, np.mean(R.values, axis=1)) @@ -277,7 +277,7 @@ def _reshape( ) -> tuple[xr.DataArray, xr.DataArray]: """Convert back to xarray with proper dimensions and coordinates""" - c_new = xr.DataArray(trace.squeeze(), dims=[AXIS.frames_dim], coords=tr_coords) + c_new = xr.DataArray(trace.squeeze(), dims=[AXIS.frame_dim], coords=tr_coords) a_new = xr.DataArray( np.zeros(tuple(fp_coords.sizes.values())), @@ -385,7 +385,7 @@ def _create_component_movie( trace: xr.DataArray, ) -> np.ndarray: """Movie of a single component""" - clean_trace = trace.dropna(dim=AXIS.frames_dim).data + clean_trace = trace.dropna(dim=AXIS.frame_dim).data if isinstance(footprint, csr_matrix): # Target (CSR matrix) case: Transpose the footprint matrix diff --git a/src/cala/nodes/segment/cleanup.py b/src/cala/nodes/segment/cleanup.py index 81cb5643..9daef08a 100644 --- a/src/cala/nodes/segment/cleanup.py +++ b/src/cala/nodes/segment/cleanup.py @@ -3,8 +3,8 @@ import numpy as np from noob import Name -from cala.assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces def deprecate_components( @@ -64,7 +64,7 @@ def clear_overestimates( """ if residuals.array is None: return footprints - R_min = residuals.array.isel({AXIS.frames_dim: -1}).reset_coords( + R_min = residuals.array.isel({AXIS.frame_dim: -1}).reset_coords( [AXIS.frame_coord, AXIS.timestamp_coord], drop=True ) tuned_fp = footprints.array.where(R_min > -nmf_error, 0, drop=False) diff --git a/src/cala/nodes/segment/persist.py b/src/cala/nodes/segment/persist.py index 7a707057..924bab4c 100644 --- a/src/cala/nodes/segment/persist.py +++ b/src/cala/nodes/segment/persist.py @@ -2,7 +2,7 @@ from noob import Name -from cala.assets import CompStats, Footprints, Movie, Overlaps, PixStats, Traces +from cala.assets.assets import CompStats, Footprints, Movie, Overlaps, PixStats, Traces from cala.nodes.omf.component_stats import ingest_component as update_component_stats from cala.nodes.omf.footprints import ingest_component as update_footprints from cala.nodes.omf.overlap import ingest_component as update_overlap diff --git a/src/cala/nodes/segment/slice_nmf.py b/src/cala/nodes/segment/slice_nmf.py index 837289ff..45046fb8 100644 --- a/src/cala/nodes/segment/slice_nmf.py +++ b/src/cala/nodes/segment/slice_nmf.py @@ -8,9 +8,9 @@ from noob.node import Node from pydantic import Field -from cala.assets import Buffer, Footprint, Trace +from cala.assets import AXIS +from cala.assets.assets import Buffer, Footprint, Trace from cala.logging import init_logger -from cala.models import AXIS from cala.util import rank1nmf @@ -32,7 +32,7 @@ def process( self, residuals: Buffer, energy: xr.DataArray, detect_radius: int ) -> tuple[A[list[Footprint], Name("new_fps")], A[list[Trace], Name("new_trs")]]: - if residuals.array.sizes[AXIS.frames_dim] < self.min_frames: + if residuals.array.sizes[AXIS.frame_dim] < self.min_frames: return [], [] fps = [] @@ -113,11 +113,7 @@ def _local_nmf( - Temporal component c_new (frames) """ # Reshape neighborhood to 2D matrix (time × space) - R = ( - slice_.transpose(AXIS.frames_dim, ...) - .data.reshape((slice_.sizes[AXIS.frames_dim], -1)) - .T - ) + R = slice_.transpose(AXIS.frame_dim, ...).data.reshape((slice_.sizes[AXIS.frame_dim], -1)).T mean_R = np.mean(R, axis=1) # nan_mask = np.isnan(mean_R) @@ -127,8 +123,8 @@ def _local_nmf( # Convert back to xarray with proper dimensions and coordinates c_new = xr.DataArray( c.squeeze(), - dims=[AXIS.frames_dim], - coords=slice_[AXIS.frames_dim].coords, + dims=[AXIS.frame_dim], + coords=slice_[AXIS.frame_dim].coords, ) # Create full-frame zero array with proper coordinates diff --git a/src/cala/testing/catalog_depr.py b/src/cala/testing/catalog_depr.py index ff6ebee4..7dfcd2ad 100644 --- a/src/cala/testing/catalog_depr.py +++ b/src/cala/testing/catalog_depr.py @@ -9,7 +9,7 @@ from pydantic import Field from scipy.ndimage import gaussian_filter1d -from cala.models import AXIS +from cala.assets import AXIS class CatalogerDepr(Node): @@ -51,5 +51,5 @@ def _merge_matrix( overlaps = np.matmul(fps.data, fps_base.data.T) > 0 # corr is fast. (~1ms to 4ms) - corrs = xr.corr(trs, trs_base, dim=AXIS.frames_dim) > self.merge_threshold + corrs = xr.corr(trs, trs_base, dim=AXIS.frame_dim) > self.merge_threshold return xr.DataArray(overlaps * corrs.values, dims=corrs.dims, coords=corrs.coords) diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 7c2cb874..fa5191d0 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -6,7 +6,7 @@ from noob import Name, process_method from pydantic import BaseModel, PrivateAttr, model_validator -from cala.assets import Frame +from cala.assets.assets import Frame from cala.testing.toy import FrameDims, Position, Toy diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index ec141671..46f308d6 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator from skimage.morphology import disk -from cala.assets import Footprints, Frame, Movie, Traces -from cala.models.axis import AXIS +from cala.assets.assets import Footprints, Frame, Movie, Traces +from cala.assets.axis import AXIS class FrameDims(BaseModel): @@ -128,7 +128,7 @@ def n_components(self) -> int: def _build_movie_template(self) -> xr.DataArray: return xr.DataArray( np.zeros((self.n_frames, self.frame_dims.height, self.frame_dims.width)), - dims=[AXIS.frames_dim, *AXIS.spatial_dims], + dims=[AXIS.frame_dim, *AXIS.spatial_dims], ) def _generate_footprint( @@ -171,14 +171,14 @@ def _format_trace(self, trace: np.ndarray, id_: str, detected_on: int) -> xr.Dat return ( xr.DataArray( trace, - dims=AXIS.frames_dim, + dims=AXIS.frame_dim, ) .expand_dims(AXIS.component_dim) .assign_coords( { AXIS.id_coord: (AXIS.component_dim, [id_]), AXIS.detect_coord: (AXIS.component_dim, [detected_on]), - AXIS.frame_coord: (AXIS.frames_dim, range(trace.size)), + AXIS.frame_coord: (AXIS.frame_dim, range(trace.size)), } ) ) @@ -191,7 +191,7 @@ def _build_traces(self) -> xr.DataArray: return xr.concat(traces, dim=AXIS.component_dim).assign_coords( { AXIS.timestamp_coord: ( - AXIS.frames_dim, + AXIS.frame_dim, [ (datetime.now() + i * timedelta(microseconds=20)).strftime("%H:%M:%S.%f") for i in range(self.n_frames) @@ -202,7 +202,7 @@ def _build_traces(self) -> xr.DataArray: def _build_movie(self, footprints: xr.DataArray, traces: xr.DataArray) -> xr.DataArray: movie = self._build_movie_template() - movie += (footprints @ traces).transpose(AXIS.frames_dim, *AXIS.spatial_dims).as_numpy() + movie += (footprints @ traces).transpose(AXIS.frame_dim, *AXIS.spatial_dims).as_numpy() return movie def make_movie(self) -> Movie: @@ -251,8 +251,8 @@ def traces(self) -> Traces: raise ValueError("No traces available") def movie_gen(self) -> Generator[_TFrame]: - for i in range(self._traces.sizes[AXIS.frames_dim]): - trace = self._traces.isel({AXIS.frames_dim: i}) + for i in range(self._traces.sizes[AXIS.frame_dim]): + trace = self._traces.isel({AXIS.frame_dim: i}) if not self.emit_frames: yield trace @ self._footprints else: diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 7e73069e..0e1a1690 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -5,7 +5,7 @@ import numpy as np import xarray as xr -from cala.models import AXIS +from cala.assets import AXIS _TArray = TypeVar("_TArray", xr.DataArray, np.ndarray) diff --git a/src/cala/util.py b/src/cala/util.py index 1e73f3b3..1c8523f8 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -9,7 +9,7 @@ from sparse import COO from xarray import Coordinates -from cala.models import AXIS +from cala.assets import AXIS def create_id() -> str: diff --git a/tests/test_assets.py b/tests/test_assets.py index 7b6b9b4a..0dc26145 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -11,8 +11,8 @@ import pytest import xarray as xr -from cala.assets import Buffer, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Buffer, Traces @pytest.mark.parametrize("peek_size", [30, 49, 50, 51, 70]) @@ -24,14 +24,14 @@ def test_array_assignment(tmp_path, four_connected_cells, peek_size): """ traces = four_connected_cells.traces.array - n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + n_frames = traces.sizes[AXIS.frame_dim] # 50 frames zarr_traces = Traces( zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) ) zarr_traces.array = traces - assert zarr_traces.array_.sizes[AXIS.frames_dim] == min(n_frames, peek_size) - assert zarr_traces.load_zarr().sizes[AXIS.frames_dim] == max(0, n_frames - peek_size) + assert zarr_traces.array_.sizes[AXIS.frame_dim] == min(n_frames, peek_size) + assert zarr_traces.load_zarr().sizes[AXIS.frame_dim] == max(0, n_frames - peek_size) @pytest.mark.parametrize("peek_size", [30, 50, 70]) @@ -42,14 +42,14 @@ def test_array_peek(tmp_path, four_connected_cells, peek_size): """ traces = four_connected_cells.traces.array - n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + n_frames = traces.sizes[AXIS.frame_dim] # 50 frames zarr_traces = Traces( zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) ) zarr_traces.array = traces - assert zarr_traces.array.sizes[AXIS.frames_dim] == min(peek_size, n_frames) + assert zarr_traces.array.sizes[AXIS.frame_dim] == min(peek_size, n_frames) @pytest.mark.parametrize("peek_size", [30, 50, 70]) @@ -60,7 +60,7 @@ def test_flush_zarr(four_connected_cells, tmp_path, peek_size): """ traces = four_connected_cells.traces.array - n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + n_frames = traces.sizes[AXIS.frame_dim] # 50 frames zarr_traces = Traces( zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) @@ -72,9 +72,9 @@ def test_flush_zarr(four_connected_cells, tmp_path, peek_size): zarr_traces._flush_zarr() # only peek_size left in memory - assert zarr_traces.array_.sizes[AXIS.frames_dim] == min(n_frames, peek_size) + assert zarr_traces.array_.sizes[AXIS.frame_dim] == min(n_frames, peek_size) # the rest is in zarr - assert zarr_traces.load_zarr().sizes[AXIS.frames_dim] == max(0, n_frames - peek_size) + assert zarr_traces.load_zarr().sizes[AXIS.frame_dim] == max(0, n_frames - peek_size) @pytest.mark.parametrize("peek_size, flush_interval", [(30, 70)]) @@ -85,18 +85,18 @@ def test_zarr_append_frame(four_connected_cells, tmp_path, peek_size, flush_inte """ traces = four_connected_cells.traces.array - n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + n_frames = traces.sizes[AXIS.frame_dim] # 50 frames zarr_traces = Traces(zarr_path=tmp_path, peek_size=peek_size, flush_interval=flush_interval) zarr_traces.array = traces[:, :0] # just initializing zarr # array smaller than flush_interval. does not flush. - zarr_traces.append(traces, dim=AXIS.frames_dim) - assert zarr_traces.array_.sizes[AXIS.frames_dim] == n_frames + zarr_traces.append(traces, dim=AXIS.frame_dim) + assert zarr_traces.array_.sizes[AXIS.frame_dim] == n_frames # array larger than flush_interval. flushes down to peek_size. - zarr_traces.append(traces, dim=AXIS.frames_dim) - assert zarr_traces.array_.sizes[AXIS.frames_dim] == peek_size + zarr_traces.append(traces, dim=AXIS.frame_dim) + assert zarr_traces.array_.sizes[AXIS.frame_dim] == peek_size @pytest.mark.parametrize("flush_interval", [30]) @@ -115,7 +115,7 @@ def test_zarr_append_component(four_connected_cells, tmp_path, flush_interval): assert zarr_traces.array_[AXIS.component_dim].equals(traces[AXIS.component_dim]) assert zarr_traces.load_zarr()[AXIS.component_dim].equals(traces[AXIS.component_dim]) - result = xr.concat([zarr_traces.load_zarr(), zarr_traces.array_], dim=AXIS.frames_dim).compute() + result = xr.concat([zarr_traces.load_zarr(), zarr_traces.array_], dim=AXIS.frame_dim).compute() assert result.equals(traces) @@ -129,11 +129,9 @@ def test_flush_after_deprecated(four_connected_cells, tmp_path, flush_interval) merged_ids = zarr_traces.array[AXIS.id_coord].values[0] intact_mask = ~np.isin(zarr_traces.array[AXIS.id_coord].values, merged_ids) zarr_traces.keep(intact_mask) - zarr_traces.append(traces[intact_mask], dim=AXIS.frames_dim) + zarr_traces.append(traces[intact_mask], dim=AXIS.frame_dim) - assert zarr_traces.full_array().equals( - xr.concat([traces] * 2, dim=AXIS.frames_dim)[intact_mask] - ) + assert zarr_traces.full_array().equals(xr.concat([traces] * 2, dim=AXIS.frame_dim)[intact_mask]) def test_from_array(four_connected_cells): @@ -183,51 +181,51 @@ def test_overwrite(four_connected_cells, four_separate_cells): def test_buffer_assign(four_connected_cells): movie = four_connected_cells.make_movie().array buff = Buffer(size=10) - buff.array = movie.isel({AXIS.frames_dim: -1}) - assert buff.array.equals(movie.isel({AXIS.frames_dim: [-1]})) + buff.array = movie.isel({AXIS.frame_dim: -1}) + assert buff.array.equals(movie.isel({AXIS.frame_dim: [-1]})) - buff.array = movie.isel({AXIS.frames_dim: slice(-5, None)}) - assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(-5, None)})) + buff.array = movie.isel({AXIS.frame_dim: slice(-5, None)}) + assert buff.array.equals(movie.isel({AXIS.frame_dim: slice(-5, None)})) - buff.array = movie.isel({AXIS.frames_dim: slice(-10, None)}) - assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(-10, None)})) + buff.array = movie.isel({AXIS.frame_dim: slice(-10, None)}) + assert buff.array.equals(movie.isel({AXIS.frame_dim: slice(-10, None)})) - buff.array = movie.isel({AXIS.frames_dim: slice(-15, None)}) - assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(-10, None)})) + buff.array = movie.isel({AXIS.frame_dim: slice(-15, None)}) + assert buff.array.equals(movie.isel({AXIS.frame_dim: slice(-10, None)})) def test_buffer_append(four_connected_cells): movie = four_connected_cells.make_movie().array buff = Buffer(size=10) - buff.array = movie.isel({AXIS.frames_dim: 0}) - buff.append(movie.isel({AXIS.frames_dim: 1})) - assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(0, 2)})) + buff.array = movie.isel({AXIS.frame_dim: 0}) + buff.append(movie.isel({AXIS.frame_dim: 1})) + assert buff.array.equals(movie.isel({AXIS.frame_dim: slice(0, 2)})) - buff.array = movie.isel({AXIS.frames_dim: slice(None, 9)}) - buff.append(movie.isel({AXIS.frames_dim: 9})) - assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(0, 10)})) - buff.append(movie.isel({AXIS.frames_dim: 10})) - assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(1, 11)})) + buff.array = movie.isel({AXIS.frame_dim: slice(None, 9)}) + buff.append(movie.isel({AXIS.frame_dim: 9})) + assert buff.array.equals(movie.isel({AXIS.frame_dim: slice(0, 10)})) + buff.append(movie.isel({AXIS.frame_dim: 10})) + assert buff.array.equals(movie.isel({AXIS.frame_dim: slice(1, 11)})) def test_buffer_speed(single_cell): movie = single_cell.make_movie().array - movie = xr.concat([movie, movie], dim=AXIS.frames_dim) + movie = xr.concat([movie, movie], dim=AXIS.frame_dim) buff = Buffer(size=100) buff.array = movie start = datetime.now() iter = 100 for _ in range(iter): - buff.append(movie.isel({AXIS.frames_dim: 0})) + buff.append(movie.isel({AXIS.frame_dim: 0})) _ = buff.array result = (datetime.now() - start) / iter start = datetime.now() for _ in range(iter): xr.concat( - [movie.isel({AXIS.frames_dim: slice(1, None)}), movie.isel({AXIS.frames_dim: 0})], - dim=AXIS.frames_dim, + [movie.isel({AXIS.frame_dim: slice(1, None)}), movie.isel({AXIS.frame_dim: 0})], + dim=AXIS.frame_dim, ) expected = (datetime.now() - start) / iter diff --git a/tests/test_config.py b/tests/test_config.py index 3773a0d7..a6a4549f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -27,7 +27,7 @@ def test_set_config(set_config, tmp_path): def test_config_from_environment(tmp_path, set_env): """ - Setting environmental variables should set the config, including recursive models + Setting environmental variables should set the config, including recursive assets """ override_logdir = Path(tmp_path) / "fancylogdir" diff --git a/tests/test_gui.py b/tests/test_gui.py index eb21e1e2..4c960b62 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -1,6 +1,6 @@ import pytest -from cala.assets import AXIS, PopSnap +from cala.assets.assets import AXIS, PopSnap from cala.gui.components.stamper import stamp @@ -15,7 +15,7 @@ def test_deps_spec(): def test_stamper(four_connected_cells) -> None: fp = four_connected_cells.footprints trs = four_connected_cells.traces - tr = PopSnap.from_array(trs.array.isel({AXIS.frames_dim: -1})) + tr = PopSnap.from_array(trs.array.isel({AXIS.frame_dim: -1})) frame = stamp(fp, tr, gain=1) assert len(frame.array.dims) == 3 diff --git a/tests/test_omf/test_component_stats.py b/tests/test_omf/test_component_stats.py index d32b7aab..165d4988 100644 --- a/tests/test_omf/test_component_stats.py +++ b/tests/test_omf/test_component_stats.py @@ -3,8 +3,8 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import CompStats, Frame, PopSnap, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import CompStats, Frame, PopSnap, Traces @pytest.fixture @@ -52,10 +52,10 @@ def test_ingest_frame(init, frame_update, four_separate_cells) -> None: result = frame_update.process( CompStats.from_array( - init.process(four_separate_cells.traces.array.isel({AXIS.frames_dim: slice(None, -1)})) + init.process(four_separate_cells.traces.array.isel({AXIS.frame_dim: slice(None, -1)})) ), - frame=Frame.from_array(four_separate_cells.make_movie().array.isel({AXIS.frames_dim: -1})), - new_traces=PopSnap.from_array(four_separate_cells.traces.array.isel({AXIS.frames_dim: -1})), + frame=Frame.from_array(four_separate_cells.make_movie().array.isel({AXIS.frame_dim: -1})), + new_traces=PopSnap.from_array(four_separate_cells.traces.array.isel({AXIS.frame_dim: -1})), ) expected = init.process(four_separate_cells.traces.array) diff --git a/tests/test_omf/test_footprints.py b/tests/test_omf/test_footprints.py index fd1433e6..c72d6008 100644 --- a/tests/test_omf/test_footprints.py +++ b/tests/test_omf/test_footprints.py @@ -3,8 +3,8 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import CompStats, Footprints, PixStats -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import CompStats, Footprints, PixStats from cala.nodes.omf.footprints import ingest_component from cala.testing.toy import FrameDims, Position, Toy diff --git a/tests/test_omf/test_overlaps.py b/tests/test_omf/test_overlaps.py index 9662030e..055ea272 100644 --- a/tests/test_omf/test_overlaps.py +++ b/tests/test_omf/test_overlaps.py @@ -2,8 +2,8 @@ import pytest from noob.node import Node, NodeSpecification -from cala.assets import Footprint, Footprints, Overlaps -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Footprint, Footprints, Overlaps @pytest.fixture(scope="function") diff --git a/tests/test_omf/test_pixel_stats.py b/tests/test_omf/test_pixel_stats.py index 1f84621b..9310892e 100644 --- a/tests/test_omf/test_pixel_stats.py +++ b/tests/test_omf/test_pixel_stats.py @@ -2,8 +2,8 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import Footprints, Frame, PixStats, PopSnap, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Footprints, Frame, PixStats, PopSnap, Traces @pytest.fixture(scope="function") @@ -49,15 +49,15 @@ def test_ingest_frame(init, frame_update, four_separate_cells) -> None: footprints = four_separate_cells.footprints.array pre_ingest = init.process( - traces=traces.isel({AXIS.frames_dim: slice(None, -1)}), - frames=movie.isel({AXIS.frames_dim: slice(None, -1)}), + traces=traces.isel({AXIS.frame_dim: slice(None, -1)}), + frames=movie.isel({AXIS.frame_dim: slice(None, -1)}), footprints=footprints, ) result = frame_update.process( pixel_stats=PixStats.from_array(pre_ingest), - frame=Frame.from_array(movie.isel({AXIS.frames_dim: -1})), - new_traces=PopSnap.from_array(traces.isel({AXIS.frames_dim: -1})), + frame=Frame.from_array(movie.isel({AXIS.frame_dim: -1})), + new_traces=PopSnap.from_array(traces.isel({AXIS.frame_dim: -1})), footprints=Footprints.from_array(footprints), ).array diff --git a/tests/test_omf/test_residual.py b/tests/test_omf/test_residual.py index b2f2bf0c..0142bb8e 100644 --- a/tests/test_omf/test_residual.py +++ b/tests/test_omf/test_residual.py @@ -3,8 +3,8 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import Buffer, Footprints, Frame, Traces -from cala.models.axis import AXIS +from cala.assets.assets import Buffer, Footprints, Frame, Traces +from cala.assets.axis import AXIS from cala.nodes.omf.residual import _align_overestimates, _find_unlayered_footprints from cala.testing.toy import FrameDims, Position, Toy @@ -72,14 +72,14 @@ def test_align_overestimates(single_cell) -> None: Maybe this can be absorbed straight into trace frame_ingest as a constraint. """ movie = single_cell.make_movie() - last_frame = movie.array.isel({AXIS.frames_dim: -1}) + last_frame = movie.array.isel({AXIS.frame_dim: -1}) last_res = xr.zeros_like(last_frame) # we have negative residuals last_res.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] = -1 last_res = last_res.where(single_cell.footprints.array[0].to_numpy(), 0) - last_trace = single_cell.traces.array.isel({AXIS.frames_dim: -1}) + last_trace = single_cell.traces.array.isel({AXIS.frame_dim: -1}) footprints = single_cell.footprints.array shapes_sparse = footprints.data.reshape((footprints.sizes[AXIS.component_dim], -1)).tocsr() @@ -89,7 +89,7 @@ def test_align_overestimates(single_cell) -> None: ) # adjusted to lower than last_trace - assert single_cell.traces.array.isel({AXIS.frames_dim: -2}) < adjusted_traces < last_trace + assert single_cell.traces.array.isel({AXIS.frame_dim: -2}) < adjusted_traces < last_trace def test_find_exposed_footprints(connected_cells) -> None: @@ -117,7 +117,7 @@ def test_std(init, connected_cells) -> None: traces=Traces(), ) - expected = connected_cells.make_movie().array.std(dim=AXIS.frames_dim).values + expected = connected_cells.make_movie().array.std(dim=AXIS.frame_dim).values assert np.allclose(result, expected) diff --git a/tests/test_omf/test_traces.py b/tests/test_omf/test_traces.py index d141c644..4cec1cbf 100644 --- a/tests/test_omf/test_traces.py +++ b/tests/test_omf/test_traces.py @@ -3,8 +3,8 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import Frame, Overlaps, Traces -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Frame, Overlaps, Traces @pytest.fixture @@ -40,15 +40,15 @@ def test_ingest_frame(frame_update, toy, zarr_setup, request, tmp_path) -> None: ) traces = Traces(array_=None, **zarr_setup) - traces.array = toy.traces.array.isel({AXIS.frames_dim: slice(None, -1)}) + traces.array = toy.traces.array.isel({AXIS.frame_dim: slice(None, -1)}) - frame = Frame.from_array(toy.make_movie().array.isel({AXIS.frames_dim: -1})) + frame = Frame.from_array(toy.make_movie().array.isel({AXIS.frame_dim: -1})) overlap = xray.process(overlaps=Overlaps(), footprints=toy.footprints) result = frame_update.process( traces=traces, footprints=toy.footprints, frame=frame, overlaps=overlap ).array - expected = toy.traces.array.isel({AXIS.frames_dim: -1}) + expected = toy.traces.array.isel({AXIS.frame_dim: -1}) xr.testing.assert_allclose(result, expected, atol=1e-3) @@ -86,7 +86,7 @@ def test_ingest_component(comp_update, toy, request, zarr_setup, tmp_path) -> No traces.array = toy.traces.array_.isel({AXIS.component_dim: slice(None, -1)}) new_traces = toy.traces.array_.isel( - {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-zarr_setup["peek_size"], None)} + {AXIS.component_dim: [-1], AXIS.frame_dim: slice(-zarr_setup["peek_size"], None)} ) new_traces.attrs["replaces"] = ["cell_0"] @@ -94,7 +94,7 @@ def test_ingest_component(comp_update, toy, request, zarr_setup, tmp_path) -> No expected = toy.traces.array.drop_sel({AXIS.component_dim: 0}) expected.loc[ - {AXIS.component_dim: -1, AXIS.frames_dim: slice(None, -zarr_setup["peek_size"])} + {AXIS.component_dim: -1, AXIS.frame_dim: slice(None, -zarr_setup["peek_size"])} ] = np.nan assert result.full_array().equals(expected) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d3af70e6..e57117ef 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -4,7 +4,7 @@ from noob import SynchronousRunner, Tube from noob.node import Node, NodeSpecification -from cala.models import AXIS +from cala.assets import AXIS @pytest.fixture( @@ -89,7 +89,7 @@ def test_trace_correlation(results) -> None: tr_corr = xr.corr( results["model"].traces.array, results["traces"].rename(AXIS.component_rename), - dim=AXIS.frames_dim, + dim=AXIS.frame_dim, ) for corr in tr_corr: assert np.isclose(corr.max(), 1, atol=1e-5) diff --git a/tests/test_prep/test_denoise.py b/tests/test_prep/test_denoise.py index ad9d5b24..87f8ca22 100644 --- a/tests/test_prep/test_denoise.py +++ b/tests/test_prep/test_denoise.py @@ -6,7 +6,7 @@ import pytest import xarray as xr -from cala.models import AXIS +from cala.assets import AXIS from cala.nodes.prep.denoise import blur from cala.testing.toy import FrameDims, Position, Toy diff --git a/tests/test_prep/test_glow_removal.py b/tests/test_prep/test_glow_removal.py index 777a0a41..adf39f43 100644 --- a/tests/test_prep/test_glow_removal.py +++ b/tests/test_prep/test_glow_removal.py @@ -1,6 +1,6 @@ import numpy as np -from cala.models import AXIS +from cala.assets import AXIS from cala.nodes.prep.glow_removal import GlowRemover from cala.testing.toy import FrameDims, Position, Toy @@ -20,7 +20,7 @@ def test_glow_removal(): gen = toy.movie_gen() movie = toy.make_movie() - expected_base = movie.array.min(dim=AXIS.frames_dim) + expected_base = movie.array.min(dim=AXIS.frame_dim) res = [] for frame, br in zip(iter(gen), np.array([5, 4, 3, 2, 1, 1, 1, 1, 1, 1]) * 4): diff --git a/tests/test_prep/test_motion.py b/tests/test_prep/test_motion.py index bff48895..c3391c16 100644 --- a/tests/test_prep/test_motion.py +++ b/tests/test_prep/test_motion.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from cala.models import AXIS +from cala.assets import AXIS from cala.nodes.prep import blur from cala.nodes.prep.motion import Anchor, Shift from cala.testing.toy import FrameDims, Position, Toy diff --git a/tests/test_prep/test_r_estimate.py b/tests/test_prep/test_r_estimate.py index ac567ae5..0b50a0c0 100644 --- a/tests/test_prep/test_r_estimate.py +++ b/tests/test_prep/test_r_estimate.py @@ -1,6 +1,6 @@ import pytest -from cala.models import AXIS +from cala.assets import AXIS from cala.nodes.prep import package_frame from cala.nodes.prep.r_estimate import SizeEst from cala.testing.toy import Position @@ -18,7 +18,7 @@ def test_size_estim(four_separate_cells): node = SizeEst(n_frames=1, log_kwargs=kwargs) max_proj = package_frame( - four_separate_cells.make_movie().array.max(dim=AXIS.frames_dim).values, index=1 + four_separate_cells.make_movie().array.max(dim=AXIS.frame_dim).values, index=1 ) result = node.get_median_radius(max_proj) @@ -27,7 +27,7 @@ def test_size_estim(four_separate_cells): assert len(node.sizes_) == 3 max_proj = package_frame( - four_separate_cells.make_movie().array.max(dim=AXIS.frames_dim).values, index=3 + four_separate_cells.make_movie().array.max(dim=AXIS.frame_dim).values, index=3 ) result = node.get_median_radius(max_proj) diff --git a/tests/test_segment/test_catalog.py b/tests/test_segment/test_catalog.py index 9143b9bb..e177507a 100644 --- a/tests/test_segment/test_catalog.py +++ b/tests/test_segment/test_catalog.py @@ -3,7 +3,7 @@ import xarray as xr from noob.node import NodeSpecification -from cala.assets import AXIS, Buffer, Footprints, Traces +from cala.assets.assets import AXIS, Buffer, Footprints, Traces from cala.nodes.segment import Cataloger, SliceNMF from cala.nodes.segment.catalog import _register from cala.testing.catalog_depr import CatalogerDepr @@ -180,7 +180,7 @@ def test_absorb_component(slice_nmf, cataloger, single_cell): buff = Buffer(size=100) buff.array = single_cell.make_movie().array new_component = slice_nmf.process( - buff, energy=buff.array.std(dim=AXIS.frames_dim), detect_radius=10 + buff, energy=buff.array.std(dim=AXIS.frame_dim), detect_radius=10 ) A = single_cell.footprints.array diff --git a/tests/test_segment/test_cleanup.py b/tests/test_segment/test_cleanup.py index ed01bda1..ad5ef950 100644 --- a/tests/test_segment/test_cleanup.py +++ b/tests/test_segment/test_cleanup.py @@ -1,5 +1,5 @@ -from cala.assets import Buffer -from cala.models import AXIS +from cala.assets import AXIS +from cala.assets.assets import Buffer from cala.nodes.segment.cleanup import clear_overestimates diff --git a/tests/test_segment/test_merge.py b/tests/test_segment/test_merge.py index fe23902a..b9a8e92c 100644 --- a/tests/test_segment/test_merge.py +++ b/tests/test_segment/test_merge.py @@ -3,7 +3,7 @@ # from scipy.sparse.csgraph import connected_components # # from cala.assets import Overlaps -# from cala.models import AXIS +# from cala.assets import AXIS # from cala.nodes.merge import _filter_targets, _merge_matrix, merge_existing # from cala.testing.toy import FrameDims, Position, Toy # diff --git a/tests/test_segment/test_slice_nmf.py b/tests/test_segment/test_slice_nmf.py index 10a38b55..5fc64cac 100644 --- a/tests/test_segment/test_slice_nmf.py +++ b/tests/test_segment/test_slice_nmf.py @@ -4,7 +4,7 @@ from noob.node import NodeSpecification from sklearn.decomposition import NMF -from cala.assets import AXIS, Buffer +from cala.assets.assets import AXIS, Buffer from cala.nodes.segment import SliceNMF from cala.nodes.segment.slice_nmf import rank1nmf from cala.testing.util import assert_scalar_multiple_arrays @@ -24,7 +24,7 @@ def slice_nmf(): def test_process(slice_nmf, single_cell): new_component = slice_nmf.process( residuals=Buffer.from_array(single_cell.make_movie().array, size=100), - energy=single_cell.make_movie().array.std(dim=AXIS.frames_dim), + energy=single_cell.make_movie().array.std(dim=AXIS.frame_dim), detect_radius=single_cell.cell_radii[0] * 2, ) if new_component: @@ -46,7 +46,7 @@ def test_chunks(single_cell): ) fpts, trcs = nmf.process( residuals=Buffer.from_array(single_cell.make_movie().array, size=100), - energy=single_cell.make_movie().array.std(dim=AXIS.frames_dim), + energy=single_cell.make_movie().array.std(dim=AXIS.frame_dim), detect_radius=10, ) if not fpts or not trcs: @@ -65,7 +65,7 @@ def test_chunks(single_cell): def test_rank1nmf(single_cell): Y = single_cell.make_movie().array - R = Y.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frames_dim) + R = Y.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frame_dim) R += np.random.randint(0, 2, R.shape) shape = np.mean(R.values, axis=1).shape From c5b8a2ed957a48b192f9e0cd9fc36222d104ff3f Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 3 Dec 2025 20:19:54 -0800 Subject: [PATCH 31/33] feat: clearer asset names --- src/cala/__init__.py | 2 ++ src/cala/assets/__init__.py | 2 +- src/cala/assets/assets.py | 62 ++++++++++++++++++------------------- src/cala/assets/axis.py | 33 -------------------- src/cala/assets/validate.py | 9 +++--- 5 files changed, 39 insertions(+), 69 deletions(-) diff --git a/src/cala/__init__.py b/src/cala/__init__.py index c61ea9f4..d861568d 100644 --- a/src/cala/__init__.py +++ b/src/cala/__init__.py @@ -1 +1,3 @@ from cala.assets import xr_access as access + +__all__ = ["access"] diff --git a/src/cala/assets/__init__.py b/src/cala/assets/__init__.py index eec23d69..64f87635 100644 --- a/src/cala/assets/__init__.py +++ b/src/cala/assets/__init__.py @@ -1,4 +1,4 @@ -from .assets import Traces, Footprints, PixStats, CompStats, Overlaps, Buffer from .axis import AXIS # noqa: I001 +from .assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces __all__ = [AXIS, "Traces", "Footprints", "PixStats", "CompStats", "Overlaps", "Buffer"] diff --git a/src/cala/assets/assets.py b/src/cala/assets/assets.py index 0ccce115..753c65a2 100644 --- a/src/cala/assets/assets.py +++ b/src/cala/assets/assets.py @@ -10,7 +10,7 @@ from sparse import COO from cala.assets.axis import AXIS -from cala.assets.validate import Coords, Dims, Entity, Group, has_no_nan, is_non_negative +from cala.assets.validate import Bundle, Coords, Dims, Schema, has_no_nan, is_non_negative from cala.config import config from cala.util import clear_dir @@ -18,12 +18,12 @@ class Asset(BaseModel): - validate_schema: bool = False array_: AssetType = None sparsify: ClassVar[bool] = False zarr_path: Path | None = None """relative to config.user_data_dir""" - _entity: ClassVar[Entity] + xr_schema: ClassVar[Schema] + validate_schema: bool = False model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) @@ -34,7 +34,7 @@ def array(self) -> AssetType: @array.setter def array(self, value: xr.DataArray) -> None: if self.validate_schema: - value.validate.against_schema(self._entity.model) + value.validate.against_schema(self.xr_schema.model) if self.sparsify and isinstance(value.data, np.ndarray): value.data = COO.from_numpy(value.data) self.array_ = value @@ -58,13 +58,13 @@ def __eq__(self, other: "Asset") -> bool: return self.array.equals(other.array) @classmethod - def entity(cls) -> Entity: - return cls._entity + def entity(cls) -> Schema: + return cls.xr_schema @model_validator(mode="after") def validate_array_schema(self) -> Self: if self.validate_schema and self.array_ is not None: - self.array_.validate.against_schema(self._entity.model) + self.array_.validate.against_schema(self.xr_schema.model) return self @@ -96,8 +96,8 @@ def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.Dat class Footprint(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Entity( + xr_schema: ClassVar[Schema] = PrivateAttr( + Schema( name="footprint", dims=(Dims.width.value, Dims.height.value), dtype=float, @@ -107,8 +107,8 @@ class Footprint(Asset): class Trace(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Entity( + xr_schema: ClassVar[Schema] = PrivateAttr( + Schema( name="trace", dims=(Dims.frame.value,), dtype=float, @@ -118,8 +118,8 @@ class Trace(Asset): class Frame(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Entity( + xr_schema: ClassVar[Schema] = PrivateAttr( + Schema( name="frame", dims=(Dims.width.value, Dims.height.value), dtype=None, # np.number, # gets converted to float64 in xarray-validate @@ -129,8 +129,8 @@ class Frame(Asset): class Footprints(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Group( + xr_schema: ClassVar[Schema] = PrivateAttr( + Bundle( name="footprint-group", member=Footprint.entity(), group_by=Dims.component, @@ -159,8 +159,8 @@ class Traces(Asset): added, these are added in with nan values. """ - _entity: ClassVar[Entity] = PrivateAttr( - Group( + xr_schema: ClassVar[Schema] = PrivateAttr( + Bundle( name="trace-group", member=Trace.entity(), group_by=Dims.component, @@ -306,8 +306,8 @@ def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.Da class Movie(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Group( + xr_schema: ClassVar[Schema] = PrivateAttr( + Bundle( name="movie", member=Frame.entity(), group_by=Dims.frame.value, @@ -324,8 +324,8 @@ class PopSnap(Asset): Mainly used for Traces that only has one frame. """ - _entity: ClassVar[Entity] = PrivateAttr( - Entity( + xr_schema: ClassVar[Schema] = PrivateAttr( + Schema( name="pop-snap", dims=(Dims.component.value,), dtype=float, @@ -342,8 +342,8 @@ class PopSnap(Asset): class CompStats(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Entity( + xr_schema: ClassVar[Schema] = PrivateAttr( + Schema( name="comp-stat", dims=comp_dims, dtype=float, @@ -354,8 +354,8 @@ class CompStats(Asset): class PixStats(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Entity( + xr_schema: ClassVar[Schema] = PrivateAttr( + Schema( name="pix-stat", dims=(Dims.width.value, Dims.height.value, Dims.component.value), dtype=float, @@ -366,8 +366,8 @@ class PixStats(Asset): class Overlaps(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Entity( + xr_schema: ClassVar[Schema] = PrivateAttr( + Schema( name="overlap", dims=comp_dims, dtype=bool, @@ -385,8 +385,8 @@ class Buffer(Asset): Works by preallocating a space twice the desired size. """ - _entity: ClassVar[Entity] = PrivateAttr( - Group( + xr_schema: ClassVar[Schema] = PrivateAttr( + Bundle( name="frame", member=Frame.entity(), group_by=Dims.frame.value, @@ -462,8 +462,8 @@ def from_array(cls, array: xr.DataArray, size: int) -> Self: class Energy(Asset): - _entity: ClassVar[Entity] = PrivateAttr( - Entity( + xr_schema: ClassVar[Schema] = PrivateAttr( + Schema( name="energy", dims=(Dims.width.value, Dims.height.value), dtype=None, # np.number, # gets converted to float64 in xarray-validate diff --git a/src/cala/assets/axis.py b/src/cala/assets/axis.py index 83891614..f197dad0 100644 --- a/src/cala/assets/axis.py +++ b/src/cala/assets/axis.py @@ -1,36 +1,3 @@ -from enum import StrEnum - - -class classproperty: - - def __init__(self, func): - self._func = func - - def __get__(self, obj, owner): - return self._func(owner) - - -class Dim(StrEnum): - frame = "frame" - height = "height" - width = "width" - component = "component" - """Name of the dimension representing individual components.""" - - @classproperty - def spatial(cls) -> tuple["Dim", "Dim"]: - return cls.height, cls.width - - -class Coord(StrEnum): - id = "id" - timestamp = "timestamp" - detect = "detected_on" - frame = "frame_idx" - width = "width" - height = "height" - - class Axis: """Mixin providing common axis-related attributes.""" diff --git a/src/cala/assets/validate.py b/src/cala/assets/validate.py index ea6c2846..282cee3c 100644 --- a/src/cala/assets/validate.py +++ b/src/cala/assets/validate.py @@ -60,9 +60,10 @@ class Dims(Enum): component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.detected.value]) -class Entity(BaseModel): +class Schema(BaseModel): """ - A base entity describable with an xarray dataarray. + Wrapper around xarray-schema + """ name: str @@ -110,12 +111,12 @@ def _build_coord_schema(self, coords: list[Coord]) -> CoordsSchema: return CoordsSchema(spec, allow_extra_keys=self.allow_extra_coords) -class Group(Entity): +class Bundle(Schema): """ an xarray dataarray entity that is also a group of entities. """ - member: Entity + member: Schema group_by: Dims | None = None dims: tuple[Dim, ...] = Field(default=tuple()) dtype: type = Field(default=Any) From 3fdcc9e5a5055e96c39fc9100098a48ea82eac22 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 3 Dec 2025 20:31:45 -0800 Subject: [PATCH 32/33] feat: unify axis copy --- README.md | 11 +- pyproject.toml | 2 - src/cala/assets/assets.py | 165 +++++++++++-------------- src/cala/nodes/omf/component_stats.py | 4 +- src/cala/nodes/omf/overlap.py | 4 +- src/cala/nodes/segment/catalog.py | 4 +- src/cala/testing/catalog_depr.py | 4 +- src/cala/util.py | 2 +- tests/test_omf/test_component_stats.py | 4 +- 9 files changed, 89 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index d7b0e37c..68388206 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,9 @@ > [!CAUTION] > **NOT READY FOR USE: In active development in version alpha. Beta release scheduled by the end of 2025.** -Cala is a neural endoscope image processing tool designed for neuroscience research, with a focus on long-term massive recordings. It features a no-code approach through configuration files, making it accessible to researchers of all programming backgrounds. +Cala is a neural endoscope image processing tool designed for neuroscience research, with a focus on long-term massive +recordings. It features a no-code approach through configuration files, making it accessible to researchers of all +programming backgrounds. ## Requirements @@ -20,11 +22,10 @@ Cala is a neural endoscope image processing tool designed for neuroscience resea ## Architecture -Schematics of the architecture can be found [here](https://lucid.app/documents/embedded/808097f9-bf66-4ea8-9df0-e957e6bd0931). - - -3. **API Reference**: Available on [Read the Docs](https://cala.readthedocs.io/en/latest/) +Schematics of the architecture can be +found [here](https://lucid.app/documents/embedded/808097f9-bf66-4ea8-9df0-e957e6bd0931). +1. **API Reference**: Available on [Read the Docs](https://cala.readthedocs.io/en/latest/) ## Contact diff --git a/pyproject.toml b/pyproject.toml index bb68b05f..cefed70d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,8 +208,6 @@ select = [ "S608", "S701", ] ignore = [ - # needing to annotate `self` is ridiculous - "ANN101", #"special" methods like `__init__` don't need to be annotated "ANN204", # any types are semantically valid actually sometimes diff --git a/src/cala/assets/assets.py b/src/cala/assets/assets.py index 753c65a2..2c2167e5 100644 --- a/src/cala/assets/assets.py +++ b/src/cala/assets/assets.py @@ -96,48 +96,41 @@ def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.Dat class Footprint(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Schema( - name="footprint", - dims=(Dims.width.value, Dims.height.value), - dtype=float, - checks=[is_non_negative, has_no_nan], - ) + xr_schema: ClassVar[Schema] = Schema( + name="footprint", + dims=(Dims.width.value, Dims.height.value), + dtype=float, + checks=[is_non_negative, has_no_nan], ) class Trace(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Schema( - name="trace", - dims=(Dims.frame.value,), - dtype=float, - checks=[is_non_negative], - ) + xr_schema: ClassVar[Schema] = Schema( + name="trace", + dims=(Dims.frame.value,), + dtype=float, + checks=[is_non_negative], ) class Frame(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Schema( - name="frame", - dims=(Dims.width.value, Dims.height.value), - dtype=None, # np.number, # gets converted to float64 in xarray-validate - checks=[is_non_negative, has_no_nan], - ) + xr_schema: ClassVar[Schema] = Schema( + name="frame", + dims=(Dims.width.value, Dims.height.value), + dtype=None, # np.number, # gets converted to float64 in xarray-validate + checks=[is_non_negative, has_no_nan], ) class Footprints(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Bundle( - name="footprint-group", - member=Footprint.entity(), - group_by=Dims.component, - checks=[is_non_negative, has_no_nan], - allow_extra_coords=False, - ) + xr_schema: ClassVar[Schema] = Bundle( + name="footprint-group", + member=Footprint.entity(), + group_by=Dims.component, + checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) + sparsify = True @@ -159,14 +152,12 @@ class Traces(Asset): added, these are added in with nan values. """ - xr_schema: ClassVar[Schema] = PrivateAttr( - Bundle( - name="trace-group", - member=Trace.entity(), - group_by=Dims.component, - checks=[is_non_negative], - allow_extra_coords=False, - ) + xr_schema: ClassVar[Schema] = Bundle( + name="trace-group", + member=Trace.entity(), + group_by=Dims.component, + checks=[is_non_negative], + allow_extra_coords=False, ) @model_validator(mode="after") @@ -306,14 +297,12 @@ def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.Da class Movie(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Bundle( - name="movie", - member=Frame.entity(), - group_by=Dims.frame.value, - checks=[is_non_negative, has_no_nan], - allow_extra_coords=False, - ) + xr_schema: ClassVar[Schema] = Bundle( + name="movie", + member=Frame.entity(), + group_by=Dims.frame.value, + checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) @@ -324,56 +313,48 @@ class PopSnap(Asset): Mainly used for Traces that only has one frame. """ - xr_schema: ClassVar[Schema] = PrivateAttr( - Schema( - name="pop-snap", - dims=(Dims.component.value,), - dtype=float, - coords=[Coords.frame.value, Coords.timestamp.value], - checks=[is_non_negative, has_no_nan], - ) + xr_schema: ClassVar[Schema] = Schema( + name="pop-snap", + dims=(Dims.component.value,), + dtype=float, + coords=[Coords.frame.value, Coords.timestamp.value], + checks=[is_non_negative, has_no_nan], ) comp_dims = (Dims.component.value, deepcopy(Dims.component.value)) -comp_dims[1].name += "'" +comp_dims[1].name = AXIS.duplicate(comp_dims[1].name) for coord in comp_dims[1].coords: - coord.name += "'" + coord.name = AXIS.duplicate(coord.name) class CompStats(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Schema( - name="comp-stat", - dims=comp_dims, - dtype=float, - checks=[is_non_negative, has_no_nan], - allow_extra_coords=False, - ) + xr_schema: ClassVar[Schema] = Schema( + name="comp-stat", + dims=comp_dims, + dtype=float, + checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) class PixStats(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Schema( - name="pix-stat", - dims=(Dims.width.value, Dims.height.value, Dims.component.value), - dtype=float, - checks=[is_non_negative, has_no_nan], - allow_extra_coords=False, - ) + xr_schema: ClassVar[Schema] = Schema( + name="pix-stat", + dims=(Dims.width.value, Dims.height.value, Dims.component.value), + dtype=float, + checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) class Overlaps(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Schema( - name="overlap", - dims=comp_dims, - dtype=bool, - checks=[has_no_nan], - allow_extra_coords=False, - ) + xr_schema: ClassVar[Schema] = Schema( + name="overlap", + dims=comp_dims, + dtype=bool, + checks=[has_no_nan], + allow_extra_coords=False, ) @@ -385,14 +366,12 @@ class Buffer(Asset): Works by preallocating a space twice the desired size. """ - xr_schema: ClassVar[Schema] = PrivateAttr( - Bundle( - name="frame", - member=Frame.entity(), - group_by=Dims.frame.value, - checks=[is_non_negative, has_no_nan], - allow_extra_coords=False, - ) + xr_schema: ClassVar[Schema] = Bundle( + name="frame", + member=Frame.entity(), + group_by=Dims.frame.value, + checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) validate_schema: bool = False @@ -462,13 +441,11 @@ def from_array(cls, array: xr.DataArray, size: int) -> Self: class Energy(Asset): - xr_schema: ClassVar[Schema] = PrivateAttr( - Schema( - name="energy", - dims=(Dims.width.value, Dims.height.value), - dtype=None, # np.number, # gets converted to float64 in xarray-validate - checks=[is_non_negative, has_no_nan], - ) + xr_schema: ClassVar[Schema] = Schema( + name="energy", + dims=(Dims.width.value, Dims.height.value), + dtype=None, # np.number, # gets converted to float64 in xarray-validate + checks=[is_non_negative, has_no_nan], ) _mean: np.ndarray = PrivateAttr(None) diff --git a/src/cala/nodes/omf/component_stats.py b/src/cala/nodes/omf/component_stats.py index 0c418002..5cb9abb5 100644 --- a/src/cala/nodes/omf/component_stats.py +++ b/src/cala/nodes/omf/component_stats.py @@ -102,7 +102,9 @@ def ingest_component(component_stats: CompStats, traces: Traces, new_traces: Tra # Bottom block: [cross_corr.T, auto_corr] bottom_block = xr.concat([bottom_left_corr, auto_corr], dim=AXIS.component_dim) # Combine blocks - component_stats.array = xr.concat([top_block, bottom_block], dim=f"{AXIS.component_dim}'") + component_stats.array = xr.concat( + [top_block, bottom_block], dim=AXIS.duplicate(AXIS.component_dim) + ) return component_stats diff --git a/src/cala/nodes/omf/overlap.py b/src/cala/nodes/omf/overlap.py index 55f4f191..6e25f69e 100644 --- a/src/cala/nodes/omf/overlap.py +++ b/src/cala/nodes/omf/overlap.py @@ -82,9 +82,9 @@ def overlap_format(array: COO, V_comp: xr.DataArray, a_new_comp: xr.DataArray) - ) return xr.DataArray( array, - dims=(AXIS.component_dim, f"{AXIS.component_dim}'"), + dims=(AXIS.component_dim, AXIS.duplicate(AXIS.component_dim)), coords={k: (AXIS.component_dim, v) for k, v in prim_coords.items()}, - ).assign_coords({k: (f"{AXIS.component_dim}'", v) for k, v in seco_coords.items()}) + ).assign_coords({k: (AXIS.duplicate(AXIS.component_dim), v) for k, v in seco_coords.items()}) def assemble_sparse_bool( diff --git a/src/cala/nodes/segment/catalog.py b/src/cala/nodes/segment/catalog.py index 381debb8..9f093b08 100644 --- a/src/cala/nodes/segment/catalog.py +++ b/src/cala/nodes/segment/catalog.py @@ -100,7 +100,7 @@ def _monopartite_merge_matrix(self, fps: xr.DataArray, trs: xr.DataArray) -> xr. fps2 = fps.data smooth_trs = _smooth_traces(trs, self.trace_smooth_kwargs) - trs2 = smooth_trs.rename({AXIS.component_dim: f"{AXIS.component_dim}'"}) + trs2 = smooth_trs.rename({AXIS.component_dim: AXIS.duplicate(AXIS.component_dim)}) return self._merge_matrix(fps, smooth_trs, fps2, trs2) def _merge_candidates( @@ -317,7 +317,7 @@ def _gather_discrete( """ if merge_groups is not None: discrete_idx = merge_groups.where( - merge_groups.sum(f"{AXIS.component_dim}'") == 0, drop=True + merge_groups.sum(AXIS.duplicate(AXIS.component_dim)) == 0, drop=True )[AXIS.component_dim].values else: discrete_idx = np.arange(fps.sizes[AXIS.component_dim]) diff --git a/src/cala/testing/catalog_depr.py b/src/cala/testing/catalog_depr.py index 7dfcd2ad..5abaf884 100644 --- a/src/cala/testing/catalog_depr.py +++ b/src/cala/testing/catalog_depr.py @@ -36,8 +36,8 @@ def _merge_matrix( ) if fps_base is None: - fps_base = fps.rename({AXIS.component_dim: f"{AXIS.component_dim}'"}) - trs_base = trs.rename({AXIS.component_dim: f"{AXIS.component_dim}'"}) + fps_base = fps.rename({AXIS.component_dim: AXIS.duplicate(AXIS.component_dim)}) + trs_base = trs.rename({AXIS.component_dim: AXIS.duplicate(AXIS.component_dim)}) else: fps_base = fps_base.stack(pixels=AXIS.spatial_dims).rename(AXIS.component_rename) trs_base = xr.DataArray( diff --git a/src/cala/util.py b/src/cala/util.py index 1c8523f8..7b3ce70e 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -53,7 +53,7 @@ def sp_matmul( val = ll @ rr.T return xr.DataArray( - COO.from_scipy_sparse(val), dims=[dim, f"{dim}'"], coords=left[dim].coords + COO.from_scipy_sparse(val), dims=[dim, AXIS.duplicate(dim)], coords=left[dim].coords ).assign_coords(right[dim].rename(rename_map).coords) diff --git a/tests/test_omf/test_component_stats.py b/tests/test_omf/test_component_stats.py index 165d4988..638d4b8a 100644 --- a/tests/test_omf/test_component_stats.py +++ b/tests/test_omf/test_component_stats.py @@ -29,8 +29,8 @@ def test_init(init, four_separate_cells) -> None: assert ( result.set_xindex(AXIS.id_coord) .sel({AXIS.id_coord: id1}) - .set_xindex(f"{AXIS.id_coord}'") - .sel({f"{AXIS.id_coord}'": id2}) + .set_xindex(AXIS.duplicate(AXIS.id_coord)) + .sel({AXIS.duplicate(AXIS.id_coord): id2}) .item() == (trace1 @ trace2).item() / four_separate_cells.n_frames ) From 0c104b2c3fcd17283d67846e27b6bba7aef62718 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 3 Dec 2025 20:49:13 -0800 Subject: [PATCH 33/33] feat: import cleanup --- src/cala/assets/__init__.py | 29 +++++++++++++++++++++-- src/cala/gui/components/counter.py | 3 +-- src/cala/gui/components/encoder.py | 3 +-- src/cala/gui/components/stamper.py | 3 +-- src/cala/nodes/omf/component_stats.py | 3 +-- src/cala/nodes/omf/footprints.py | 3 +-- src/cala/nodes/omf/overlap.py | 3 +-- src/cala/nodes/omf/pixel_stats.py | 3 +-- src/cala/nodes/omf/residual.py | 3 +-- src/cala/nodes/omf/traces.py | 3 +-- src/cala/nodes/prep/background_removal.py | 2 +- src/cala/nodes/prep/denoise.py | 2 +- src/cala/nodes/prep/downsample.py | 3 +-- src/cala/nodes/prep/flatten.py | 2 +- src/cala/nodes/prep/glow_removal.py | 2 +- src/cala/nodes/prep/lines.py | 3 +-- src/cala/nodes/prep/motion.py | 3 +-- src/cala/nodes/prep/r_estimate.py | 3 +-- src/cala/nodes/prep/wrap.py | 3 +-- src/cala/nodes/segment/catalog.py | 3 +-- src/cala/nodes/segment/cleanup.py | 3 +-- src/cala/nodes/segment/persist.py | 2 +- src/cala/nodes/segment/slice_nmf.py | 3 +-- src/cala/testing/nodes.py | 2 +- src/cala/testing/toy.py | 3 +-- tests/test_assets.py | 3 +-- tests/test_gui.py | 2 +- tests/test_omf/test_component_stats.py | 3 +-- tests/test_omf/test_footprints.py | 3 +-- tests/test_omf/test_overlaps.py | 3 +-- tests/test_omf/test_pixel_stats.py | 3 +-- tests/test_omf/test_residual.py | 3 +-- tests/test_omf/test_traces.py | 3 +-- tests/test_segment/test_catalog.py | 2 +- tests/test_segment/test_cleanup.py | 3 +-- tests/test_segment/test_slice_nmf.py | 2 +- 36 files changed, 62 insertions(+), 63 deletions(-) diff --git a/src/cala/assets/__init__.py b/src/cala/assets/__init__.py index 64f87635..398ba95f 100644 --- a/src/cala/assets/__init__.py +++ b/src/cala/assets/__init__.py @@ -1,4 +1,29 @@ from .axis import AXIS # noqa: I001 -from .assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces +from .assets import ( + Buffer, + CompStats, + Footprints, + Overlaps, + PixStats, + Traces, + Trace, + Footprint, + Frame, + PopSnap, + Movie, +) -__all__ = [AXIS, "Traces", "Footprints", "PixStats", "CompStats", "Overlaps", "Buffer"] +__all__ = [ + "AXIS", + "Traces", + "Footprints", + "PixStats", + "CompStats", + "Overlaps", + "Buffer", + "Trace", + "Footprint", + "Frame", + "PopSnap", + "Movie", +] diff --git a/src/cala/gui/components/counter.py b/src/cala/gui/components/counter.py index 2e22b810..3e862aa0 100644 --- a/src/cala/gui/components/counter.py +++ b/src/cala/gui/components/counter.py @@ -1,5 +1,4 @@ -from cala.assets import AXIS -from cala.assets.assets import Traces +from cala.assets import AXIS, Traces def component_counter(index: int, traces: Traces) -> dict[str, int]: diff --git a/src/cala/gui/components/encoder.py b/src/cala/gui/components/encoder.py index 97b50595..5a07b9cb 100644 --- a/src/cala/gui/components/encoder.py +++ b/src/cala/gui/components/encoder.py @@ -5,8 +5,7 @@ from av.video import VideoStream from noob.node import Node -from cala.assets import AXIS -from cala.assets.assets import Frame +from cala.assets import AXIS, Frame from cala.config import config from cala.util import clear_dir diff --git a/src/cala/gui/components/stamper.py b/src/cala/gui/components/stamper.py index c92b5505..27fa8b18 100644 --- a/src/cala/gui/components/stamper.py +++ b/src/cala/gui/components/stamper.py @@ -5,8 +5,7 @@ import xarray as xr from pydantic import BaseModel, ConfigDict -from cala.assets import AXIS -from cala.assets.assets import Footprints, PopSnap +from cala.assets import AXIS, Footprints, PopSnap from cala.gui.components import Encoder COLOR_MAP = { diff --git a/src/cala/nodes/omf/component_stats.py b/src/cala/nodes/omf/component_stats.py index 5cb9abb5..3c60b0b5 100644 --- a/src/cala/nodes/omf/component_stats.py +++ b/src/cala/nodes/omf/component_stats.py @@ -1,8 +1,7 @@ import numpy as np import xarray as xr -from cala.assets import AXIS -from cala.assets.assets import CompStats, Frame, PopSnap, Traces +from cala.assets import AXIS, CompStats, Frame, PopSnap, Traces def ingest_frame(component_stats: CompStats, frame: Frame, new_traces: PopSnap) -> CompStats: diff --git a/src/cala/nodes/omf/footprints.py b/src/cala/nodes/omf/footprints.py index b9523791..baa497d5 100644 --- a/src/cala/nodes/omf/footprints.py +++ b/src/cala/nodes/omf/footprints.py @@ -7,8 +7,7 @@ from scipy.sparse import csc_matrix, vstack from sparse import COO -from cala.assets import AXIS -from cala.assets.assets import CompStats, Footprints, PixStats +from cala.assets import AXIS, CompStats, Footprints, PixStats from cala.logging import init_logger from cala.util import concatenate_coordinates diff --git a/src/cala/nodes/omf/overlap.py b/src/cala/nodes/omf/overlap.py index 6e25f69e..f52c1f0c 100644 --- a/src/cala/nodes/omf/overlap.py +++ b/src/cala/nodes/omf/overlap.py @@ -2,8 +2,7 @@ import xarray as xr from sparse import COO -from cala.assets import AXIS -from cala.assets.assets import Footprints, Overlaps +from cala.assets import AXIS, Footprints, Overlaps from cala.util import concatenate_coordinates, sp_matmul, stack_sparse diff --git a/src/cala/nodes/omf/pixel_stats.py b/src/cala/nodes/omf/pixel_stats.py index db966f53..5435927b 100644 --- a/src/cala/nodes/omf/pixel_stats.py +++ b/src/cala/nodes/omf/pixel_stats.py @@ -5,8 +5,7 @@ from noob import Name from scipy.sparse import csr_matrix -from cala.assets import AXIS -from cala.assets.assets import Buffer, Footprints, Frame, Movie, PixStats, PopSnap, Traces +from cala.assets import AXIS, Buffer, Footprints, Frame, Movie, PixStats, PopSnap, Traces def ingest_frame( diff --git a/src/cala/nodes/omf/residual.py b/src/cala/nodes/omf/residual.py index 37ec444e..529eaa12 100644 --- a/src/cala/nodes/omf/residual.py +++ b/src/cala/nodes/omf/residual.py @@ -6,8 +6,7 @@ from pydantic import BaseModel, PrivateAttr from scipy.sparse import csr_matrix -from cala.assets import AXIS -from cala.assets.assets import Buffer, Footprints, Frame, Traces +from cala.assets import AXIS, Buffer, Footprints, Frame, Traces class Residuer(BaseModel): diff --git a/src/cala/nodes/omf/traces.py b/src/cala/nodes/omf/traces.py index 9451430f..bd3c00ea 100644 --- a/src/cala/nodes/omf/traces.py +++ b/src/cala/nodes/omf/traces.py @@ -7,8 +7,7 @@ from pydantic import BaseModel from scipy.sparse.csgraph import connected_components -from cala.assets import AXIS -from cala.assets.assets import Footprints, Frame, Overlaps, PopSnap, Traces +from cala.assets import AXIS, Footprints, Frame, Overlaps, PopSnap, Traces from cala.logging import init_logger from cala.util import norm, stack_sparse diff --git a/src/cala/nodes/prep/background_removal.py b/src/cala/nodes/prep/background_removal.py index f028e239..419c6028 100644 --- a/src/cala/nodes/prep/background_removal.py +++ b/src/cala/nodes/prep/background_removal.py @@ -8,7 +8,7 @@ from scipy.ndimage import uniform_filter from skimage.morphology import disk -from cala.assets.assets import Frame +from cala.assets import Frame def remove_background( diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index b950bfe9..d7e6f5c7 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from skimage.restoration import calibrate_denoiser -from cala.assets.assets import Frame +from cala.assets import Frame def _bilateral(arr: np.ndarray, **kwargs: Any) -> np.ndarray: diff --git a/src/cala/nodes/prep/downsample.py b/src/cala/nodes/prep/downsample.py index bc7e969c..6efbb8c8 100644 --- a/src/cala/nodes/prep/downsample.py +++ b/src/cala/nodes/prep/downsample.py @@ -3,8 +3,7 @@ import numpy as np from noob import Name -from cala.assets import AXIS -from cala.assets.assets import Frame +from cala.assets import AXIS, Frame from cala.nodes.prep import package_frame diff --git a/src/cala/nodes/prep/flatten.py b/src/cala/nodes/prep/flatten.py index 344c99a9..8db07ba9 100644 --- a/src/cala/nodes/prep/flatten.py +++ b/src/cala/nodes/prep/flatten.py @@ -6,7 +6,7 @@ from skimage.filters import butterworth from skimage.restoration import rolling_ball -from cala.assets.assets import Frame +from cala.assets import Frame def butter(frame: Frame, kwargs: dict[str, Any]) -> A[Frame, Name("frame")]: diff --git a/src/cala/nodes/prep/glow_removal.py b/src/cala/nodes/prep/glow_removal.py index 917a3304..c5cd1d8d 100644 --- a/src/cala/nodes/prep/glow_removal.py +++ b/src/cala/nodes/prep/glow_removal.py @@ -4,7 +4,7 @@ import xarray as xr from noob import Name -from cala.assets.assets import Frame +from cala.assets import Frame class GlowRemover: diff --git a/src/cala/nodes/prep/lines.py b/src/cala/nodes/prep/lines.py index 84a8c018..f9cda48d 100644 --- a/src/cala/nodes/prep/lines.py +++ b/src/cala/nodes/prep/lines.py @@ -6,8 +6,7 @@ from scipy.ndimage import convolve1d from scipy.signal import firwin, welch -from cala.assets import AXIS -from cala.assets.assets import Frame +from cala.assets import AXIS, Frame def remove_mean(frame: Frame, orient: Literal["horiz", "vert", "both"]) -> A[Frame, Name("frame")]: diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index 109ba279..fd3e9b7d 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -10,8 +10,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator from skimage.filters import difference_of_gaussians -from cala.assets import AXIS -from cala.assets.assets import Frame +from cala.assets import AXIS, Frame from cala.testing.util import shift_by diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py index 458c2b48..d9b42c8c 100644 --- a/src/cala/nodes/prep/r_estimate.py +++ b/src/cala/nodes/prep/r_estimate.py @@ -6,8 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from skimage.feature import blob_log -from cala.assets import AXIS -from cala.assets.assets import Frame +from cala.assets import AXIS, Frame class SizeEst(BaseModel): diff --git a/src/cala/nodes/prep/wrap.py b/src/cala/nodes/prep/wrap.py index 20dcdede..3d423c28 100644 --- a/src/cala/nodes/prep/wrap.py +++ b/src/cala/nodes/prep/wrap.py @@ -7,8 +7,7 @@ import xarray as xr from noob import Name -from cala.assets import AXIS -from cala.assets.assets import Frame +from cala.assets import AXIS, Frame def counter(start: int = 0, limit: int = 1e7) -> A[Generator[int], Name("idx")]: diff --git a/src/cala/nodes/segment/catalog.py b/src/cala/nodes/segment/catalog.py index 9f093b08..a6ad2a98 100644 --- a/src/cala/nodes/segment/catalog.py +++ b/src/cala/nodes/segment/catalog.py @@ -14,8 +14,7 @@ from skimage.measure import label from xarray import Coordinates -from cala.assets import AXIS -from cala.assets.assets import Footprint, Footprints, Trace, Traces +from cala.assets import AXIS, Footprint, Footprints, Trace, Traces from cala.util import combine_attr_replaces, concat_components, create_id, rank1nmf diff --git a/src/cala/nodes/segment/cleanup.py b/src/cala/nodes/segment/cleanup.py index 9daef08a..7a32495a 100644 --- a/src/cala/nodes/segment/cleanup.py +++ b/src/cala/nodes/segment/cleanup.py @@ -3,8 +3,7 @@ import numpy as np from noob import Name -from cala.assets import AXIS -from cala.assets.assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces +from cala.assets import AXIS, Buffer, CompStats, Footprints, Overlaps, PixStats, Traces def deprecate_components( diff --git a/src/cala/nodes/segment/persist.py b/src/cala/nodes/segment/persist.py index 924bab4c..7a707057 100644 --- a/src/cala/nodes/segment/persist.py +++ b/src/cala/nodes/segment/persist.py @@ -2,7 +2,7 @@ from noob import Name -from cala.assets.assets import CompStats, Footprints, Movie, Overlaps, PixStats, Traces +from cala.assets import CompStats, Footprints, Movie, Overlaps, PixStats, Traces from cala.nodes.omf.component_stats import ingest_component as update_component_stats from cala.nodes.omf.footprints import ingest_component as update_footprints from cala.nodes.omf.overlap import ingest_component as update_overlap diff --git a/src/cala/nodes/segment/slice_nmf.py b/src/cala/nodes/segment/slice_nmf.py index 45046fb8..c09f67b5 100644 --- a/src/cala/nodes/segment/slice_nmf.py +++ b/src/cala/nodes/segment/slice_nmf.py @@ -8,8 +8,7 @@ from noob.node import Node from pydantic import Field -from cala.assets import AXIS -from cala.assets.assets import Buffer, Footprint, Trace +from cala.assets import AXIS, Buffer, Footprint, Trace from cala.logging import init_logger from cala.util import rank1nmf diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index fa5191d0..7c2cb874 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -6,7 +6,7 @@ from noob import Name, process_method from pydantic import BaseModel, PrivateAttr, model_validator -from cala.assets.assets import Frame +from cala.assets import Frame from cala.testing.toy import FrameDims, Position, Toy diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index 46f308d6..d074adb6 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -7,8 +7,7 @@ from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator from skimage.morphology import disk -from cala.assets.assets import Footprints, Frame, Movie, Traces -from cala.assets.axis import AXIS +from cala.assets import AXIS, Footprints, Frame, Movie, Traces class FrameDims(BaseModel): diff --git a/tests/test_assets.py b/tests/test_assets.py index 0dc26145..141520cc 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -11,8 +11,7 @@ import pytest import xarray as xr -from cala.assets import AXIS -from cala.assets.assets import Buffer, Traces +from cala.assets import AXIS, Buffer, Traces @pytest.mark.parametrize("peek_size", [30, 49, 50, 51, 70]) diff --git a/tests/test_gui.py b/tests/test_gui.py index 4c960b62..d70adf66 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -1,6 +1,6 @@ import pytest -from cala.assets.assets import AXIS, PopSnap +from cala.assets import AXIS, PopSnap from cala.gui.components.stamper import stamp diff --git a/tests/test_omf/test_component_stats.py b/tests/test_omf/test_component_stats.py index 638d4b8a..fda463a0 100644 --- a/tests/test_omf/test_component_stats.py +++ b/tests/test_omf/test_component_stats.py @@ -3,8 +3,7 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import AXIS -from cala.assets.assets import CompStats, Frame, PopSnap, Traces +from cala.assets import AXIS, CompStats, Frame, PopSnap, Traces @pytest.fixture diff --git a/tests/test_omf/test_footprints.py b/tests/test_omf/test_footprints.py index c72d6008..12140924 100644 --- a/tests/test_omf/test_footprints.py +++ b/tests/test_omf/test_footprints.py @@ -3,8 +3,7 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import AXIS -from cala.assets.assets import CompStats, Footprints, PixStats +from cala.assets import AXIS, CompStats, Footprints, PixStats from cala.nodes.omf.footprints import ingest_component from cala.testing.toy import FrameDims, Position, Toy diff --git a/tests/test_omf/test_overlaps.py b/tests/test_omf/test_overlaps.py index 055ea272..0b6ebfc8 100644 --- a/tests/test_omf/test_overlaps.py +++ b/tests/test_omf/test_overlaps.py @@ -2,8 +2,7 @@ import pytest from noob.node import Node, NodeSpecification -from cala.assets import AXIS -from cala.assets.assets import Footprint, Footprints, Overlaps +from cala.assets import AXIS, Footprint, Footprints, Overlaps @pytest.fixture(scope="function") diff --git a/tests/test_omf/test_pixel_stats.py b/tests/test_omf/test_pixel_stats.py index 9310892e..73131b57 100644 --- a/tests/test_omf/test_pixel_stats.py +++ b/tests/test_omf/test_pixel_stats.py @@ -2,8 +2,7 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import AXIS -from cala.assets.assets import Footprints, Frame, PixStats, PopSnap, Traces +from cala.assets import AXIS, Footprints, Frame, PixStats, PopSnap, Traces @pytest.fixture(scope="function") diff --git a/tests/test_omf/test_residual.py b/tests/test_omf/test_residual.py index 0142bb8e..392fee06 100644 --- a/tests/test_omf/test_residual.py +++ b/tests/test_omf/test_residual.py @@ -3,8 +3,7 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets.assets import Buffer, Footprints, Frame, Traces -from cala.assets.axis import AXIS +from cala.assets import AXIS, Buffer, Footprints, Frame, Traces from cala.nodes.omf.residual import _align_overestimates, _find_unlayered_footprints from cala.testing.toy import FrameDims, Position, Toy diff --git a/tests/test_omf/test_traces.py b/tests/test_omf/test_traces.py index 4cec1cbf..2cd3c523 100644 --- a/tests/test_omf/test_traces.py +++ b/tests/test_omf/test_traces.py @@ -3,8 +3,7 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import AXIS -from cala.assets.assets import Frame, Overlaps, Traces +from cala.assets import AXIS, Frame, Overlaps, Traces @pytest.fixture diff --git a/tests/test_segment/test_catalog.py b/tests/test_segment/test_catalog.py index e177507a..41eec2a7 100644 --- a/tests/test_segment/test_catalog.py +++ b/tests/test_segment/test_catalog.py @@ -3,7 +3,7 @@ import xarray as xr from noob.node import NodeSpecification -from cala.assets.assets import AXIS, Buffer, Footprints, Traces +from cala.assets import AXIS, Buffer, Footprints, Traces from cala.nodes.segment import Cataloger, SliceNMF from cala.nodes.segment.catalog import _register from cala.testing.catalog_depr import CatalogerDepr diff --git a/tests/test_segment/test_cleanup.py b/tests/test_segment/test_cleanup.py index ad5ef950..18b3f076 100644 --- a/tests/test_segment/test_cleanup.py +++ b/tests/test_segment/test_cleanup.py @@ -1,5 +1,4 @@ -from cala.assets import AXIS -from cala.assets.assets import Buffer +from cala.assets import AXIS, Buffer from cala.nodes.segment.cleanup import clear_overestimates diff --git a/tests/test_segment/test_slice_nmf.py b/tests/test_segment/test_slice_nmf.py index 5fc64cac..94a09cd5 100644 --- a/tests/test_segment/test_slice_nmf.py +++ b/tests/test_segment/test_slice_nmf.py @@ -4,7 +4,7 @@ from noob.node import NodeSpecification from sklearn.decomposition import NMF -from cala.assets.assets import AXIS, Buffer +from cala.assets import AXIS, Buffer from cala.nodes.segment import SliceNMF from cala.nodes.segment.slice_nmf import rank1nmf from cala.testing.util import assert_scalar_multiple_arrays