diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b1171cc..59ac6b0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: name: pdm-format entry: pdm format language: system - types: [python] + types: [ python ] pass_filenames: false always_run: true # - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/pdm.lock b/pdm.lock index a3d86080..16127fbd 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "gui", "tests"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:68f2cd4f887a77af253a18c01590866b2fd460982b397836133b8fa09d808e17" +content_hash = "sha256:d02754f2363ec34c50db0fe7d4a94b017e21816863471a002e007ed0fd85e128" [[metadata.targets]] requires_python = ">=3.11" @@ -1715,6 +1715,16 @@ files = [ {file = "myst_parser-4.0.1.tar.gz", hash = "sha256:5cfea715e4f3574138aecbf7d54132296bfd72bb614d31168f48c477a830a7c4"}, ] +[[package]] +name = "natsort" +version = "8.4.0" +requires_python = ">=3.7" +summary = "Simple yet flexible natural sorting in Python." +files = [ + {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, + {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, +] + [[package]] name = "nbclient" version = "0.10.2" @@ -1811,10 +1821,11 @@ files = [ [[package]] name = "noob" -version = "0.1.1.dev208" +version = "0.1.1.dev209" requires_python = ">=3.11" git = "https://github.com/miniscope/noob.git" -revision = "6f9512e57e10e142335ccddf64a599d85edc73d6" +ref = "scheduler-optimize" +revision = "1cafe617372795c5bfe9c6953416b309572e1676" summary = "Default template for PDM package" dependencies = [ "PyYAML>=6.0.2", diff --git a/pyproject.toml b/pyproject.toml index dbdf603d..ebd58be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ dependencies = [ "pyyaml>=6.0.2", "typer>=0.15.3", "xarray-validate>=0.0.2", - "noob @ git+https://github.com/miniscope/noob.git", + "noob @ git+https://github.com/miniscope/noob.git@scheduler-optimize", + "natsort>=8.4.0", ] keywords = [ "pipeline", diff --git a/src/cala/assets.py b/src/cala/assets.py index 74092b21..1e1f2663 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -2,7 +2,7 @@ import shutil from copy import deepcopy from pathlib import Path -from typing import Any, ClassVar, Self, TypeVar +from typing import ClassVar, Self, TypeVar import numpy as np import xarray as xr @@ -22,6 +22,8 @@ class Asset(BaseModel): validate_schema: bool = False array_: AssetType = None sparsify: ClassVar[bool] = False + zarr_path: Path | None = None + """relative to config.user_data_dir""" _entity: ClassVar[Entity] model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) @@ -46,6 +48,12 @@ def from_array(cls, array: xr.DataArray) -> Self: def reset(self) -> None: self.array_ = None + if self.zarr_path: + path = Path(self.zarr_path) + try: + shutil.rmtree(path) + except FileNotFoundError: + contextlib.suppress(FileNotFoundError) def __eq__(self, other: "Asset") -> bool: return self.array.equals(other.array) @@ -61,6 +69,32 @@ def validate_array_schema(self) -> Self: return self + @field_validator("zarr_path", mode="after") + @classmethod + def validate_zarr_path(cls, value: Path | None) -> Path | None: + if value is None: + return value + zarr_dir = (config.user_dir / value).resolve() + zarr_dir.mkdir(parents=True, exist_ok=True) + clear_dir(zarr_dir) + return zarr_dir + + def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: + da = ( + xr.open_zarr(self.zarr_path) + .isel(isel_filter) + .sel(sel_filter) + .to_dataarray() + .drop_vars(["variable"]) + .isel(variable=0) + ) + return da.assign_coords( + { + AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str), + AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str), + } + ) + class Footprint(Asset): _entity: ClassVar[Entity] = PrivateAttr( @@ -109,104 +143,167 @@ class Footprints(Asset): class Traces(Asset): - zarr_path: Path | None = None - """relative to config.user_data_dir""" - peek_size: int | None = None + peek_size: int = None + """How many epochs to return when called.""" + flush_interval: int | None = None + """How many epochs to wait until next flush""" + + _deprecated: list[str] = PrivateAttr(default_factory=list) """ - Traces(array=array, path=path) -> saves to zarr (should be set in this asset, and leave - untouched in nodes.) - Traces.array -> loads from zarr + Deprecated, or replaced component idx. + Since zarr does not support efficiently removing rows and columns, + there's no easy way to remove a column when a component has been + removed or replaced. Instead, we "mask" it with this "deprecated" + flag. + + When arrays are called, these are filtered out. When new epochs are + added, these are added in with nan values. """ - @property - def array(self) -> xr.DataArray: - peek_filter = {AXIS.frames_dim: slice(-self.peek_size, None)} if self.peek_size else None - return self.full_array(isel_filter=peek_filter) + _entity: ClassVar[Entity] = PrivateAttr( + Group( + name="trace-group", + member=Trace.entity(), + group_by=Dims.component, + checks=[is_non_negative], + allow_extra_coords=False, + ) + ) - @array.setter - def array(self, array: xr.DataArray) -> None: - if self.zarr_path: - if self.validate_schema: - array.validate.against_schema(self._entity.model) - array.to_zarr(self.zarr_path, mode="w") # need to make sure it can overwrite - else: - self.array_ = array + @model_validator(mode="after") + def flush_conditions(self) -> Self: + assert (self.flush_interval and self.zarr_path) or ( + not self.flush_interval and not self.zarr_path + ), "zarr_path and flush_interval should either be both provided or neither." + if self.flush_interval: + assert self.flush_interval > self.peek_size, ( + f"flush_interval must be larger than peek_size. " + f"Provided: {self.flush_interval = }, {self.peek_size = }" + ) + return self - def reset(self) -> None: - self.array_ = None + @property + def sizes(self) -> dict[str, int]: if self.zarr_path: - path = Path(self.zarr_path) - try: - shutil.rmtree(path) - except FileNotFoundError: - contextlib.suppress(FileNotFoundError) + total_size = {} + for key, val in self.array_.sizes.items(): + if key == AXIS.frames_dim: + total_size[key] = val + self.load_zarr().sizes[key] + else: + total_size[key] = val + return total_size + else: + return self.array_.sizes - def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: - if self.zarr_path: - try: - return self.load_zarr(isel_filter=isel_filter, sel_filter=sel_filter).compute() - except FileNotFoundError: - pass + @property + def array(self) -> xr.DataArray: return ( - self.array_.isel(isel_filter).sel(sel_filter) + self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) if self.array_ is not None else self.array_ ) - def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: - da = ( - xr.open_zarr(self.zarr_path) - .isel(isel_filter) - .sel(sel_filter) - .to_dataarray() - .drop_vars(["variable"]) - .isel(variable=0) - ) - return da.assign_coords( - { - AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str), - AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str), - } - ) + @array.setter + def array(self, array: xr.DataArray) -> None: + """ + In case zarr_path is defined, if array is larger than peek_size, + the epochs older than -peek_size gets flushed to zarr array. - def update(self, array: xr.DataArray, **kwargs: Any) -> None: + """ if self.validate_schema: array.validate.against_schema(self._entity.model) - array.to_zarr(self.zarr_path, **kwargs) + if self.zarr_path: + self.array_ = array.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr( + self.zarr_path, mode="w" + ) + else: + self.array_ = array + + def append(self, array: xr.DataArray, dim: str) -> None: + """ + Since we cannot simply append to zarr array in memory using xarray syntax, + we provide a convenience method for appending to zarr array and in-memory array + in a streamlined manner. + + Incoming arrays have to be 2-dimensional. + + """ + + if dim == AXIS.frames_dim: + self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim) + + if self.zarr_path and self.array_.sizes[AXIS.frames_dim] > self.flush_interval: + self._flush_zarr() + + elif dim == AXIS.component_dim: + if self.zarr_path: + n_in_memory = self.array_.sizes[AXIS.frames_dim] + self.array_ = xr.concat( + [self.array_, array.isel({AXIS.frames_dim: slice(-n_in_memory, None)})], + dim=dim, + ) + array.isel({AXIS.frames_dim: slice(None, -n_in_memory)}).to_zarr( + self.zarr_path, append_dim=dim + ) + else: + self.array_ = xr.concat([self.array_, array], dim=dim, combine_attrs="drop") + + def _flush_zarr(self) -> None: + """ + Flushes traces older than peek_size to zarr array. + Needs to append nans to deprecated components, since they get deleted + in in-memory array, but persist in zarr array. + + + Could do this much more elegantly by pre-allocating .array_ + """ + raw_zarr = self.load_zarr() + to_flush = self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)}) + if self._deprecated: + zarr_ids = raw_zarr[AXIS.id_coord].values + zarr_detects = raw_zarr[AXIS.detect_coord].values + intact_mask = ~np.isin(zarr_ids, self._deprecated) + n_flush = to_flush.sizes[AXIS.frames_dim] + prealloc = xr.DataArray( + np.full((raw_zarr.sizes[AXIS.component_dim], n_flush), np.nan), + dims=[AXIS.component_dim, AXIS.frames_dim], + coords={ + AXIS.id_coord: (AXIS.component_dim, zarr_ids), + AXIS.detect_coord: (AXIS.component_dim, zarr_detects), + }, + ).assign_coords(to_flush[AXIS.frames_dim].coords) + prealloc.loc[intact_mask] = to_flush + prealloc.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim) + else: + to_flush.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim) + self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + + def keep(self, intact_mask: np.ndarray) -> None: + if self.zarr_path: + self._deprecated.extend(self.array_[AXIS.id_coord].values[~intact_mask]) + self.array_ = self.array_[intact_mask] @classmethod - def from_array( - cls, array: xr.DataArray, zarr_path: Path | str | None = None, peek_size: int | None = None - ) -> "Traces": - new_cls = cls(zarr_path=zarr_path, peek_size=peek_size) + def from_array(cls, array: xr.DataArray) -> "Traces": + """ + This is only really used for typing / auto-validation purposes, + so we don't really have to worry about specifying the parameters. + + """ + new_cls = cls(peek_size=array.sizes[AXIS.frames_dim]) new_cls.array = array return new_cls - @field_validator("zarr_path", mode="after") - @classmethod - def validate_zarr_path(cls, value: Path | None) -> Path | None: - if value is None: - return value - zarr_dir = (config.user_dir / value).resolve() - zarr_dir.mkdir(parents=True, exist_ok=True) - clear_dir(zarr_dir) - return zarr_dir - - @model_validator(mode="after") - def check_zarr_setting(self) -> "Traces": + def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray: if self.zarr_path: - assert self.peek_size, "peek_size must be set for zarr." - return self + raw_zarr = self.load_zarr(isel_filter, sel_filter) + zarr_ids = raw_zarr[AXIS.id_coord].values + intact_mask = ~np.isin(zarr_ids, self._deprecated) - _entity: ClassVar[Entity] = PrivateAttr( - Group( - name="trace-group", - member=Trace.entity(), - group_by=Dims.component, - checks=[is_non_negative], - allow_extra_coords=False, - ) - ) + return xr.concat([raw_zarr[intact_mask], self.array_], dim=AXIS.frames_dim).compute() + else: + return self.array_.isel(isel_filter).sel(sel_filter) class Movie(Asset): diff --git a/src/cala/gui/templates/index.html b/src/cala/gui/templates/index.html index 35d9531c..18ed7b2d 100644 --- a/src/cala/gui/templates/index.html +++ b/src/cala/gui/templates/index.html @@ -5,6 +5,7 @@ + diff --git a/src/cala/main.py b/src/cala/main.py index cd7004b7..7038d2bd 100644 --- a/src/cala/main.py +++ b/src/cala/main.py @@ -1,3 +1,4 @@ +from datetime import datetime from pathlib import Path from typing import Annotated @@ -14,7 +15,7 @@ try: app = get_app() -except TypeError as e: +except (TypeError, RuntimeError) as e: logger.warning(f"Failed to load gui app: {e}") @@ -23,6 +24,11 @@ def main(spec: str, gui: Annotated[bool, typer.Option()] = False) -> None: if gui: uvicorn.run("cala.main:app", reload=False, reload_dirs=[Path(__file__).parent]) else: + start = datetime.now() tube = Tube.from_specification(spec) runner = SynchronousRunner(tube=tube) - runner.run() + try: + runner.run() + finally: + end = datetime.now() + print(f"Finished in {end - start}") diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 7319946d..fe52da99 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -169,19 +169,3 @@ def _filter_redundant( keep_ids.append(a[AXIS.id_coord].item()) return keep_ids - - -def merge_components( - footprints: Footprints, - traces: Traces, -) -> A[Footprints, Name("footprints")]: - """ - Merge existing components - - 1. dilate footprints (to consider adjacent components) - 2. find overlaps - 3. send to cataloger?? - 4. then send to all component ingestion - - nvm, let's just do it in cataloger. - """ diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index a658e3a4..7b35cbd3 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -1,12 +1,15 @@ from abc import abstractmethod from collections.abc import Generator +from glob import glob from pathlib import Path -from typing import Protocol +from typing import Literal, Protocol import cv2 +from natsort import natsorted from numpy.typing import NDArray from skimage import io +from cala.assets import Asset from cala.config import config @@ -80,18 +83,30 @@ def stream_videos(video_paths: list[Path]) -> Generator[NDArray]: current_stream.close() -def stream(files: list[str | Path]) -> Generator[NDArray, None, None]: +def stream( + files: list[str | Path] = None, + subdir: str | Path = None, + extension: Literal[".avi"] = None, + prefix: str | None = None, +) -> Generator[NDArray, None, None]: """ Create a video stream from the provided video files. Args: files: List of file paths + subdir: Directory path. Ignored if files is populated. + extension: File extension. Ignored if files is populated. + prefix: File prefix. Ignored if files is populated. Returns: Stream: A stream that yields video frames """ - file_paths = [Path(f) if isinstance(f, str) else f for f in files] - suffix = {path.suffix.lower() for path in file_paths} + if files: + file_paths = [Path(f) if isinstance(f, str) else f for f in files] + suffix = {path.suffix.lower() for path in file_paths} + else: + files = natsort_paths(subdir, extension, prefix) + suffix = {extension} image_format = {".tif", ".tiff"} video_format = {".mp4", ".avi", ".webm"} @@ -102,3 +117,21 @@ def stream(files: list[str | Path]) -> Generator[NDArray, None, None]: yield from stream_images(files) else: raise ValueError(f"Unsupported file format: {suffix}") + + +def save_asset(asset: Asset, target_epoch: int, curr_epoch: int, path: str | Path) -> Asset: + if target_epoch == curr_epoch: + zarr_dir = config.user_dir + try: + asset.full_array().to_zarr(zarr_dir / path, mode="w") # for Traces + except AttributeError: + asset.array.as_numpy().to_zarr(zarr_dir / path, mode="w") + return asset + + +def natsort_paths( + subdir: str | Path, extension: Literal[".avi"], prefix: str | None = None +) -> list[str]: + prefix = prefix or "" + video_dir = config.video_dir / subdir + return natsorted(glob(f"{str(video_dir)}/{prefix}*{extension}")) diff --git a/src/cala/nodes/merge.py b/src/cala/nodes/merge.py index ff8adf48..737e0f25 100644 --- a/src/cala/nodes/merge.py +++ b/src/cala/nodes/merge.py @@ -8,7 +8,7 @@ from cala.assets import Footprints, Overlaps, Traces from cala.models import AXIS -from cala.nodes.detect.catalog import _recompose, _register +from cala.nodes.segment.catalog import _recompose, _register from cala.util import combine_attr_replaces diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index f0753b87..ae7dafff 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,11 +1,12 @@ +from .wrap import counter, package_frame # noqa: I001 from .background_removal import remove_background from .denoise import Restore, blur +from .downsample import downsample from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean from .motion import Anchor from .r_estimate import SizeEst -from .wrap import counter, package_frame __all__ = [ "blur", @@ -19,4 +20,5 @@ "Restore", "package_frame", "counter", + "downsample", ] diff --git a/src/cala/nodes/prep/downsample.py b/src/cala/nodes/prep/downsample.py new file mode 100644 index 00000000..1b37f553 --- /dev/null +++ b/src/cala/nodes/prep/downsample.py @@ -0,0 +1,30 @@ +from typing import Annotated as A + +import numpy as np +from noob import Name + +from cala.assets import Frame +from cala.models import AXIS +from cala.nodes.prep import package_frame + + +def downsample( + frames: list[Frame], x_range: tuple[int, int], y_range: tuple[int, int], t_downsample: int = 1 +) -> A[Frame, Name("frame")]: + """ + Downsampling in time and cropping in space. Must be followed by gather node, and + t_downsample has to be same as gather's parameter n value. + + :param frames: + :param x_range: + :param y_range: + :param t_downsample: + :return: + """ + arrays = [] + for frame in frames: + arrays.append(frame.array[x_range[0] : x_range[1], y_range[0] : y_range[1]]) + + return package_frame( + np.mean(arrays, axis=0), arrays[-1][AXIS.frame_coord].item() // t_downsample + ) diff --git a/src/cala/nodes/detect/__init__.py b/src/cala/nodes/segment/__init__.py similarity index 100% rename from src/cala/nodes/detect/__init__.py rename to src/cala/nodes/segment/__init__.py diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/segment/catalog.py similarity index 100% rename from src/cala/nodes/detect/catalog.py rename to src/cala/nodes/segment/catalog.py diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/segment/slice_nmf.py similarity index 100% rename from src/cala/nodes/detect/slice_nmf.py rename to src/cala/nodes/segment/slice_nmf.py diff --git a/src/cala/nodes/detect/update.py b/src/cala/nodes/segment/update.py similarity index 100% rename from src/cala/nodes/detect/update.py rename to src/cala/nodes/segment/update.py diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index 520886ce..f8bcafcc 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -79,9 +79,9 @@ def ingest_frame( updated_tr = updated_traces.volumize.dim_with_coords( dim=AXIS.frames_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] ) - traces.update(updated_tr, append_dim=AXIS.frames_dim) + traces.append(updated_tr, dim=AXIS.frames_dim) else: - traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) + traces.append(updated_traces, dim=AXIS.frames_dim) return PopSnap.from_array(updated_traces) @@ -204,44 +204,43 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: :param new_traces: Can be either a newly registered trace or an updated existing one. """ - c = traces.full_array() - c_det = new_traces.array + c_new = new_traces.array - if c_det is None: + if c_new is None: return traces - if c is None: - traces.array = c_det + if traces.array is None: + traces.array = c_new return traces - if c.sizes[AXIS.frames_dim] > c_det.sizes[AXIS.frames_dim]: - # if newly detected cells are truncated, pad with np.nans - c_new = xr.DataArray( - np.full((c_det.sizes[AXIS.component_dim], c.sizes[AXIS.frames_dim]), np.nan), - dims=[AXIS.component_dim, AXIS.frames_dim], - coords=c.isel({AXIS.component_dim: 0}).coords, - ) - c_new[AXIS.id_coord] = c_det[AXIS.id_coord] - c_new[AXIS.detect_coord] = c_det[AXIS.detect_coord] - - c_new.loc[ - {AXIS.frames_dim: slice(c.sizes[AXIS.frames_dim] - c_det.sizes[AXIS.frames_dim], None)} - ] = c_det - else: - c_new = c_det + total_frames = traces.sizes[AXIS.frames_dim] + new_n_frames = c_new.sizes[AXIS.frames_dim] - merged_ids = c_det.attrs.get("replaces") + merged_ids = c_new.attrs.get("replaces") if merged_ids: - if traces.zarr_path: - invalid = c[AXIS.id_coord].isin(merged_ids) - traces.array = c.where(~invalid.compute(), drop=True).compute() - else: - intact_mask = ~np.isin(c[AXIS.id_coord].values, merged_ids) - c = c[intact_mask] + intact_mask = ~np.isin(traces.array[AXIS.id_coord].values, merged_ids) + traces.keep(intact_mask) - if traces.zarr_path: - traces.update(c_new, append_dim=AXIS.component_dim) - else: - traces.array = xr.concat([c, c_new], dim=AXIS.component_dim, combine_attrs="drop") + c_pad = _pad_history(c_new, total_frames, np.nan) if total_frames > new_n_frames else c_new + + traces.append(c_pad, dim=AXIS.component_dim) return traces + + +def _pad_history(traces: xr.DataArray, total_nframes: int, value: float = np.nan) -> xr.DataArray: + """ + Pad unknown historical epochs with values... + + """ + new_nframes = traces.sizes[AXIS.frames_dim] + + c_new = xr.DataArray( + np.full((traces.sizes[AXIS.component_dim], total_nframes), value), + dims=[AXIS.component_dim, AXIS.frames_dim], + coords=traces[AXIS.component_dim].coords, + ) + + c_new.loc[{AXIS.frames_dim: slice(total_nframes - new_nframes, None)}] = traces + + return c_new diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 02f87b37..7e73069e 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -110,3 +110,26 @@ def expand_boundary(footprints: xr.DataArray) -> xr.DataArray: vectorize=True, dask="parallelized", ) + + +def total_gradient_magnitude(image: np.ndarray) -> float: + """ + Calculates the total gradient magnitude c(I) = || |nabla I| ||_F for a 2D array I. + + This function computes the Frobenius norm of the pixel-wise gradient magnitude map, + which is a common measure of total image variation or signal energy. + + Args: + image (np.ndarray): The 2D input array (e.g., an image). + + Returns: + float: The Frobenius norm of the gradient magnitude map. + """ + if image.ndim != 2: + raise ValueError("Input array I must be 2-dimensional.") + + grad_y, grad_x = np.gradient(image) + grad_magnitude_map = np.sqrt(grad_x**2 + grad_y**2) + c_I = np.linalg.norm(grad_magnitude_map, ord="fro") + + return c_I diff --git a/tests/data/pipelines/long_recording.yaml b/tests/data/pipelines/long_recording.yaml new file mode 100644 index 00000000..fd7ffe66 --- /dev/null +++ b/tests/data/pipelines/long_recording.yaml @@ -0,0 +1,236 @@ +noob_id: long-recording +noob_model: noob.tube.TubeSpecification +noob_version: 0.1.1.dev118+g64d81b7 + +assets: + buffer: + type: cala.assets.Buffer + params: + size: 100 + scope: runner + footprints: + type: cala.assets.Footprints + params: + sparsify: True + scope: runner + traces: + type: cala.assets.Traces + scope: runner + params: + zarr_path: traces/ + peek_size: 100 + flush_interval: 1000 + pix_stats: + type: cala.assets.PixStats + scope: runner + comp_stats: + type: cala.assets.CompStats + scope: runner + overlaps: + type: cala.assets.Overlaps + scope: runner + residuals: + type: cala.assets.Buffer + params: + size: 100 + scope: runner + + +nodes: + source: + type: cala.nodes.io.stream + params: + subdir: long_recording + extension: .avi + counter: + type: cala.nodes.prep.counter + frame: + type: cala.nodes.prep.package_frame + depends: + - frame: source.value + - index: counter.idx + + #PREPROCESS BEGINS + hotpix: + type: cala.nodes.prep.blur + params: + method: median + kwargs: + ksize: 3 + depends: + - frame: frame.value + flatten: + type: cala.nodes.prep.butter + params: + kwargs: + cutoff_frequency_ratio: 0.010 + depends: + - frame: hotpix.frame + lines: + type: cala.nodes.prep.remove_mean + params: + orient: both + depends: + - frame: flatten.frame + motion: + type: cala.nodes.prep.Anchor + depends: + - frame: lines.frame + glow: + type: cala.nodes.prep.GlowRemover + depends: + - frame: motion.frame + gather: + type: gather + params: + n: 4 + depends: + - glow.frame + downsample: # 20fps -> 5fps + type: cala.nodes.prep.downsample + params: + x_range: [ 250, 350 ] + y_range: [ 250, 350 ] + t_downsample: 4 + depends: + - frames: gather.value + size_est: + type: cala.nodes.prep.SizeEst + params: + hardset_radius: 8 + depends: + - frame: downsample.frame + cache: + type: cala.nodes.buffer.fill_buffer + depends: + - buffer: assets.buffer + - frame: downsample.frame + #PREPROCESS ENDS + + # FRAME UPDATE BEGINS + trace_frame: + type: cala.nodes.traces.Tracer + params: + tol: 0.001 + max_iter: 100 + depends: + - traces: assets.traces + - footprints: assets.footprints + - frame: downsample.frame + - overlaps: assets.overlaps + pix_frame: + type: cala.nodes.pixel_stats.ingest_frame + depends: + - pixel_stats: assets.pix_stats + - frame: downsample.frame + - new_traces: trace_frame.latest_trace + - footprints: assets.footprints + comp_frame: + type: cala.nodes.component_stats.ingest_frame + depends: + - component_stats: assets.comp_stats + - frame: downsample.frame + - new_traces: trace_frame.latest_trace + footprints_frame: + type: cala.nodes.footprints.Footprinter + params: + bep: 0 + tol: 0.0001 + max_iter: 5 + ratio_lb: 0.10 + depends: + - footprints: assets.footprints + - pixel_stats: pix_frame.value + - component_stats: comp_frame.value + + residual: + type: cala.nodes.residual.Residuer + depends: + - frame: downsample.frame + - footprints: footprints_frame.footprints + - traces: assets.traces + - residuals: assets.residuals + # FRAME UPDATE ENDS + + # DETECT BEGINS + nmf: + type: cala.nodes.segment.SliceNMF + params: + min_frames: 100 + detect_thresh: 8.0 + reprod_tol: 0.005 + depends: + - residuals: residual.movie + - energy: residual.std + - detect_radius: size_est.radius + catalog: + type: cala.nodes.segment.Cataloger + params: + age_limit: 100 + shape_smooth_kwargs: + ksize: [ 5, 5 ] + sigmaX: 0 + trace_smooth_kwargs: + sigma: 2 + merge_threshold: 0.9 + val_threshold: 0.5 + cnt_threshold: 5 + depends: + - new_fps: nmf.new_fps + - new_trs: nmf.new_trs + - existing_fp: assets.footprints + - existing_tr: assets.traces + detect_update: + type: cala.nodes.segment.update_assets + depends: + - new_footprints: catalog.new_footprints + - new_traces: catalog.new_traces + - footprints: assets.footprints + - traces: assets.traces + - pixel_stats: assets.pix_stats + - component_stats: assets.comp_stats + - overlaps: assets.overlaps + - buffer: assets.buffer + # DETECT ENDS + + save_trace: + type: cala.nodes.io.save_asset + params: + target_epoch: 855999 # formula: 1000 * n_chunks - 1 + path: traces_fin/ + depends: + - asset: assets.traces + - curr_epoch: counter.idx + + save_shape: + type: cala.nodes.io.save_asset + params: + target_epoch: 855999 # formula: 1000 * n_chunks - 1 + path: footprints_fin/ + depends: + - asset: assets.footprints + - curr_epoch: counter.idx + +# merge: +# type: cala.nodes.merge.merge_existing +# params: +# merge_interval: 500 +# merge_threshold: 0.95 +# smooth_kwargs: +# sigma: 2 +# depends: +# - shapes: assets.footprints +# - traces: assets.traces +# - overlaps: assets.overlaps +# - trigger: detect_update.footprints +# merge_update: +# type: cala.nodes.segment.update_assets +# depends: +# - new_footprints: merge.footprints +# - new_traces: merge.traces +# - footprints: assets.footprints +# - traces: assets.traces +# - pixel_stats: assets.pix_stats +# - component_stats: assets.comp_stats +# - overlaps: assets.overlaps +# - buffer: assets.buffer diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/minian.yaml similarity index 95% rename from tests/data/pipelines/with_src.yaml rename to tests/data/pipelines/minian.yaml index 3595697f..48d1a779 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/minian.yaml @@ -1,4 +1,4 @@ -noob_id: cala-with-movie +noob_id: with-minian noob_model: noob.tube.TubeSpecification noob_version: 0.1.1.dev118+g64d81b7 @@ -17,8 +17,9 @@ assets: type: cala.assets.Traces scope: runner params: - # zarr_path: traces/ + zarr_path: traces/ peek_size: 100 + flush_interval: 1000 pix_stats: type: cala.assets.PixStats scope: runner @@ -156,7 +157,7 @@ nodes: # DETECT BEGINS nmf: - type: cala.nodes.detect.SliceNMF + type: cala.nodes.segment.SliceNMF params: min_frames: 100 detect_thresh: 3.0 @@ -166,7 +167,7 @@ nodes: - energy: residual.std - detect_radius: size_est.radius catalog: - type: cala.nodes.detect.Cataloger + type: cala.nodes.segment.Cataloger params: age_limit: 100 shape_smooth_kwargs: @@ -183,7 +184,7 @@ nodes: - existing_fp: assets.footprints - existing_tr: assets.traces detect_update: - type: cala.nodes.detect.update_assets + type: cala.nodes.segment.update_assets depends: - new_footprints: catalog.new_footprints - new_traces: catalog.new_traces @@ -208,7 +209,7 @@ nodes: # - overlaps: assets.overlaps # - trigger: detect_update.footprints # merge_update: -# type: cala.nodes.detect.update_assets +# type: cala.nodes.segment.update_assets # depends: # - new_footprints: merge.footprints # - new_traces: merge.traces diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index b980e57a..79308e3b 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -15,6 +15,10 @@ assets: scope: runner traces: type: cala.assets.Traces + params: + zarr_path: traces/ + peek_size: 100 + flush_interval: 200 scope: runner pix_stats: type: cala.assets.PixStats @@ -107,7 +111,7 @@ nodes: # DETECT BEGINS nmf: - type: cala.nodes.detect.SliceNMF + type: cala.nodes.segment.SliceNMF params: min_frames: 10 detect_thresh: 1.0 @@ -117,7 +121,7 @@ nodes: - energy: residual.std - detect_radius: size_est.radius catalog: - type: cala.nodes.detect.Cataloger + type: cala.nodes.segment.Cataloger params: age_limit: 100 shape_smooth_kwargs: @@ -134,7 +138,7 @@ nodes: - existing_fp: assets.footprints - existing_tr: assets.traces detect_update: - type: cala.nodes.detect.update_assets + type: cala.nodes.segment.update_assets depends: - new_footprints: catalog.new_footprints - new_traces: catalog.new_traces @@ -159,7 +163,7 @@ nodes: # - overlaps: assets.overlaps # - trigger: detect_update.footprints # merge_update: - # type: cala.nodes.detect.update_assets + # type: cala.nodes.segment.update_assets # depends: # - new_footprints: merge.footprints # - new_traces: merge.traces diff --git a/tests/test_assets.py b/tests/test_assets.py index 2897eb77..7b6b9b4a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -1,7 +1,13 @@ -import os +""" +Since assets without zarr operations is very straightforward, +this test file is mainly focused on testing zarr-integrated assets, +with the exception of Buffer. + +""" + from datetime import datetime -from pathlib import Path +import numpy as np import pytest import xarray as xr @@ -9,71 +15,160 @@ from cala.models import AXIS -@pytest.fixture -def path() -> Path: - return Path("assets") - +@pytest.mark.parametrize("peek_size", [30, 49, 50, 51, 70]) +def test_array_assignment(tmp_path, four_connected_cells, peek_size): + """ + 1. array should get assigned to memory if smaller than or equal to peek_size + 2. array should get split to memory and zarr_path at peek_size if + array is larger than peek_size -def test_assign_zarr(path, four_connected_cells): - zarr_traces = Traces(zarr_path=path, peek_size=100) + """ traces = four_connected_cells.traces.array + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) zarr_traces.array = traces - print(os.listdir(zarr_traces.zarr_path)) - assert zarr_traces.array_ is None # not in memory - assert zarr_traces.array.equals(traces) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == min(n_frames, peek_size) + assert zarr_traces.load_zarr().sizes[AXIS.frames_dim] == max(0, n_frames - peek_size) + +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_array_peek(tmp_path, four_connected_cells, peek_size): + """ + .array property returns correctly for peek_size smaller, equal, larger + than the saved array. -def test_from_array(four_connected_cells, path): + """ traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array(traces, path, peek_size=four_connected_cells.n_frames) - assert zarr_traces.array_ is None - assert zarr_traces.array.equals(traces) + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) + zarr_traces.array = traces + + assert zarr_traces.array.sizes[AXIS.frames_dim] == min(peek_size, n_frames) + +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_flush_zarr(four_connected_cells, tmp_path, peek_size): + """ + _flush_zarr method flushes epochs older than peek_size to zarr_path, + and leaves only the newest epochs in memory. -@pytest.mark.parametrize("peek_shift", [-1, 0, 1]) -def test_peek(four_connected_cells, path, peek_shift): + """ traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array( - traces, path, peek_size=four_connected_cells.n_frames + peek_shift + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) ) - if peek_shift >= 0: - assert zarr_traces.array.equals(traces) - else: - with pytest.raises(AssertionError): - assert zarr_traces.array.equals(traces) + # need to initialize zarr file first to append with _flush_zarr + zarr_traces.array = traces[:, :0] + zarr_traces.array_ = traces # add 50 frames + + zarr_traces._flush_zarr() + # only peek_size left in memory + assert zarr_traces.array_.sizes[AXIS.frames_dim] == min(n_frames, peek_size) + # the rest is in zarr + assert zarr_traces.load_zarr().sizes[AXIS.frames_dim] == max(0, n_frames - peek_size) -def test_ingest_frame(path, four_connected_cells): + +@pytest.mark.parametrize("peek_size, flush_interval", [(30, 70)]) +def test_zarr_append_frame(four_connected_cells, tmp_path, peek_size, flush_interval): + """ + Test that when in-memory array size hits flush_interval, old epochs + get flushed. + + """ traces = four_connected_cells.traces.array - old_traces = traces.isel({AXIS.frames_dim: slice(None, -1)}) - zarr_traces = Traces.from_array(old_traces, path, peek_size=four_connected_cells.n_frames) - new_traces = four_connected_cells.traces.array.isel({AXIS.frames_dim: [-1]}) + n_frames = traces.sizes[AXIS.frames_dim] # 50 frames - zarr_traces.update(new_traces, append_dim=AXIS.frames_dim) - # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.frames_dim) + zarr_traces = Traces(zarr_path=tmp_path, peek_size=peek_size, flush_interval=flush_interval) + zarr_traces.array = traces[:, :0] # just initializing zarr + + # array smaller than flush_interval. does not flush. + zarr_traces.append(traces, dim=AXIS.frames_dim) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == n_frames + + # array larger than flush_interval. flushes down to peek_size. + zarr_traces.append(traces, dim=AXIS.frames_dim) + assert zarr_traces.array_.sizes[AXIS.frames_dim] == peek_size - assert zarr_traces.array.equals(traces) +@pytest.mark.parametrize("flush_interval", [30]) +def test_zarr_append_component(four_connected_cells, tmp_path, flush_interval): + """ + Test that when adding components, + (a) it gets appropriately divided between in-memory and zarr arrays. + (b) in-memory and zarr arrays can be concatenated together afterward. -def test_ingest_component(four_connected_cells, path): + """ traces = four_connected_cells.traces.array - old_traces = traces.isel({AXIS.component_dim: slice(None, -1)}) - zarr_traces = Traces.from_array(old_traces, path, peek_size=four_connected_cells.n_frames) - new_traces = four_connected_cells.traces.array.isel({AXIS.component_dim: [-1]}) - zarr_traces.update(new_traces, append_dim=AXIS.component_dim) - # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.component_dim) + zarr_traces = Traces(zarr_path=tmp_path, peek_size=20, flush_interval=flush_interval) + zarr_traces.array = traces[:-1, :] # forgot the last component! Also, got flushed. + zarr_traces.append(traces[-1:, :], dim=AXIS.component_dim) # appendee needs to be 2D + assert zarr_traces.array_[AXIS.component_dim].equals(traces[AXIS.component_dim]) + assert zarr_traces.load_zarr()[AXIS.component_dim].equals(traces[AXIS.component_dim]) + result = xr.concat([zarr_traces.load_zarr(), zarr_traces.array_], dim=AXIS.frames_dim).compute() + assert result.equals(traces) + + +@pytest.mark.parametrize("flush_interval", [30]) +def test_flush_after_deprecated(four_connected_cells, tmp_path, flush_interval) -> None: + traces = four_connected_cells.traces.array + peek_size = 20 + zarr_traces = Traces(zarr_path=tmp_path, peek_size=peek_size, flush_interval=flush_interval) + zarr_traces.array = traces + + merged_ids = zarr_traces.array[AXIS.id_coord].values[0] + intact_mask = ~np.isin(zarr_traces.array[AXIS.id_coord].values, merged_ids) + zarr_traces.keep(intact_mask) + zarr_traces.append(traces[intact_mask], dim=AXIS.frames_dim) + + assert zarr_traces.full_array().equals( + xr.concat([traces] * 2, dim=AXIS.frames_dim)[intact_mask] + ) + + +def test_from_array(four_connected_cells): + """ + .from_array method can correctly reproduce the array with .array + + """ + traces = four_connected_cells.traces.array + zarr_traces = Traces.from_array(traces) assert zarr_traces.array.equals(traces) -def test_overwrite(four_connected_cells, four_separate_cells, path): - conn_traces = four_connected_cells.traces.array - zarr_traces = Traces.from_array(conn_traces, path, peek_size=four_connected_cells.n_frames) +@pytest.mark.parametrize("peek_size", [30, 50, 70]) +def test_sizes(four_connected_cells, tmp_path, peek_size): + """ + The sizes property of the asset combines the sizes + of the in-memory array and the zarr array. + + """ + traces = four_connected_cells.traces.array + + zarr_traces = Traces( + zarr_path=tmp_path, peek_size=peek_size, flush_interval=max(1000, peek_size) + ) + zarr_traces.array = traces + + assert zarr_traces.sizes == traces.sizes + + +@pytest.mark.xfail +def test_overwrite(four_connected_cells, four_separate_cells): + """ + test that zarr array can get overwritten. - sep_traces = four_separate_cells.traces.array - zarr_traces.array = sep_traces - assert zarr_traces.array.equals(sep_traces) + """ # two cases of init: diff --git a/tests/test_iter/test_catalog.py b/tests/test_iter/test_catalog.py index b8d045b1..9143b9bb 100644 --- a/tests/test_iter/test_catalog.py +++ b/tests/test_iter/test_catalog.py @@ -4,8 +4,8 @@ from noob.node import NodeSpecification from cala.assets import AXIS, Buffer, Footprints, Traces -from cala.nodes.detect import Cataloger, SliceNMF -from cala.nodes.detect.catalog import _register +from cala.nodes.segment import Cataloger, SliceNMF +from cala.nodes.segment.catalog import _register from cala.testing.catalog_depr import CatalogerDepr from cala.testing.util import expand_boundary, split_footprint @@ -15,7 +15,7 @@ def slice_nmf(): return SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -26,7 +26,7 @@ def cataloger(): return Cataloger.from_specification( spec=NodeSpecification( id="test", - type="cala.nodes.detect.Cataloger", + type="cala.nodes.segment.Cataloger", params={ "age_limit": 100, "trace_smooth_kwargs": {"sigma": 2}, diff --git a/tests/test_iter/test_slice_nmf.py b/tests/test_iter/test_slice_nmf.py index 5f3f773b..10a38b55 100644 --- a/tests/test_iter/test_slice_nmf.py +++ b/tests/test_iter/test_slice_nmf.py @@ -5,8 +5,8 @@ from sklearn.decomposition import NMF from cala.assets import AXIS, Buffer -from cala.nodes.detect import SliceNMF -from cala.nodes.detect.slice_nmf import rank1nmf +from cala.nodes.segment import SliceNMF +from cala.nodes.segment.slice_nmf import rank1nmf from cala.testing.util import assert_scalar_multiple_arrays @@ -15,7 +15,7 @@ def slice_nmf(): return SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -30,7 +30,7 @@ def test_process(slice_nmf, single_cell): if new_component: new_fp, new_tr = new_component else: - raise AssertionError("Failed to detect a new component") + raise AssertionError("Failed to segment a new component") for new, old in zip([new_fp[0], new_tr[0]], [single_cell.footprints, single_cell.traces]): assert_scalar_multiple_arrays(new.array.as_numpy(), old.array.as_numpy()) @@ -40,7 +40,7 @@ def test_chunks(single_cell): nmf = SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", - type="cala.nodes.detect.SliceNMF", + type="cala.nodes.segment.SliceNMF", params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -50,7 +50,7 @@ def test_chunks(single_cell): detect_radius=10, ) if not fpts or not trcs: - raise AssertionError("Failed to detect a new component") + raise AssertionError("Failed to segment a new component") factors = [trc.array.data.max() for trc in trcs] fpt_arr = xr.concat([f.array * m for f, m in zip(fpts, factors)], dim="component") diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index bd5057b6..a84c1aa5 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -18,15 +18,29 @@ def frame_update() -> Node: ) +@pytest.mark.parametrize( + "zarr_setup", + [ + {"zarr_path": None, "peek_size": 50, "flush_interval": None}, + {"zarr_path": "tmp", "peek_size": 30, "flush_interval": 40}, + ], +) @pytest.mark.parametrize("toy", ["four_separate_cells", "four_connected_cells"]) -def test_update_traces(frame_update, toy, request) -> None: +def test_ingest_frame(frame_update, toy, zarr_setup, request, tmp_path) -> None: + """ + Frame ingestion step adds new traces to the Traces instance + that matches true brightness, overlapping and non-overlapping alike. + + """ toy = request.getfixturevalue(toy) + zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None xray = Node.from_specification( spec=NodeSpecification(id="test", type="cala.nodes.overlap.initialize") ) - traces = Traces.from_array(toy.traces.array.isel({AXIS.frames_dim: slice(None, -1)})) + traces = Traces(array_=None, **zarr_setup) + traces.array = toy.traces.array.isel({AXIS.frames_dim: slice(None, -1)}) frame = Frame.from_array(toy.make_movie().array.isel({AXIS.frames_dim: -1})) overlap = xray.process(overlaps=Overlaps(), footprints=toy.footprints) @@ -46,21 +60,41 @@ def comp_update() -> Node: ) +@pytest.mark.parametrize( + "zarr_setup", + [ + {"zarr_path": None, "peek_size": 40, "flush_interval": None}, + {"zarr_path": "tmp", "peek_size": 30, "flush_interval": 40}, + ], +) @pytest.mark.parametrize("toy", ["four_separate_cells"]) -def test_ingest_component(comp_update, toy, request) -> None: +def test_ingest_component(comp_update, toy, request, zarr_setup, tmp_path) -> None: + """ + *New component is always the same length as peek_size. + + I can add components that... + 1. is single and new + 2. is single and replacing + 3. is multiple + 4. is multiple and replacing + + """ toy = request.getfixturevalue(toy) + zarr_setup["zarr_path"] = tmp_path if zarr_setup["zarr_path"] else None - traces = Traces.from_array(toy.traces.array.isel({AXIS.component_dim: slice(None, -1)})) + traces = Traces(array_=None, **zarr_setup) + traces.array = toy.traces.array_.isel({AXIS.component_dim: slice(None, -1)}) - new_traces = toy.traces.array.isel( - {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-40, None)} + new_traces = toy.traces.array_.isel( + {AXIS.component_dim: [-1], AXIS.frames_dim: slice(-zarr_setup["peek_size"], None)} ) new_traces.attrs["replaces"] = ["cell_0"] - result = comp_update.process(traces, Traces.from_array(new_traces)) expected = toy.traces.array.drop_sel({AXIS.component_dim: 0}) - expected.loc[{AXIS.component_dim: -1, AXIS.frames_dim: slice(None, 10)}] = np.nan + expected.loc[ + {AXIS.component_dim: -1, AXIS.frames_dim: slice(None, -zarr_setup["peek_size"])} + ] = np.nan - assert result.array.equals(expected) + assert result.full_array().equals(expected) diff --git a/tests/test_prep/test_motion.py b/tests/test_prep/test_motion.py index 6d47a7ea..bff48895 100644 --- a/tests/test_prep/test_motion.py +++ b/tests/test_prep/test_motion.py @@ -64,37 +64,3 @@ def test_motion_estimation(params) -> None: # Allow 1 pixel absolute tolerance np.testing.assert_allclose(estimate, expected[1:], atol=1.0) - - -# def test_motion_with_movie(): -# """ -# For testing how well the motion correction performs with real movie -# """ -# gen = stream( -# [ -# "cala/msCam1.avi", -# "cala/msCam2.avi", -# "cala/msCam3.avi", -# "cala/msCam4.avi", -# "cala/msCam5.avi", -# ] -# ) -# -# fourcc = cv2.VideoWriter_fourcc(*"mp4v") -# out = cv2.VideoWriter("mc_test.avi", fourcc, 24.0, (752, 960)) -# -# stab = Anchor() -# -# for idx, arr in enumerate(gen): -# frame = package_frame(arr, idx) -# frame = blur(frame, method="median", kwargs={"ksize": 3}) -# frame = butter(frame, {}) -# frame = remove_mean(frame, orient="both") -# matched = stab.stabilize(frame) -# -# combined = np.concat([frame.array.values, matched.array.values], axis=0) -# -# frame_bgr = cv2.cvtColor(combined.astype(np.uint8), cv2.COLOR_GRAY2BGR) -# out.write(frame_bgr) -# -# out.release()