Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
61fcf32
format: frontend css
raymondwjang Nov 10, 2025
6fdcb71
feat: add downsample
raymondwjang Nov 11, 2025
c0440d4
tests: benchmarking moved out of tests
raymondwjang Nov 11, 2025
fbc4065
debug: retain older traces than peek_size
raymondwjang Nov 11, 2025
eb1ba05
feat: asset save / natsorting videos
raymondwjang Nov 11, 2025
72bebad
feat: long_recording.yaml
raymondwjang Nov 11, 2025
f74853e
init_order
raymondwjang Nov 11, 2025
dcc48b2
format: ruff
raymondwjang Nov 11, 2025
fa6301a
debug: handle sparse in zarr
raymondwjang Nov 11, 2025
352ff51
stop fucking switching the import order
raymondwjang Nov 11, 2025
1e1827c
feat: trace interval flushing to zarr functionality
raymondwjang Nov 11, 2025
b4eb5ad
feat: trace mostly supporting new grammar (except zarr component update)
raymondwjang Nov 11, 2025
31d3638
feat: implement zarr and in-memory caching in traces
raymondwjang Nov 11, 2025
c13a3b6
format: ruff
raymondwjang Nov 11, 2025
acd7cb4
debug: new component concat
raymondwjang Nov 11, 2025
f2c0c0b
debug: new component concat
raymondwjang Nov 11, 2025
c4b8b7b
debug: update noob for gather compatibility
raymondwjang Nov 11, 2025
2c0e10c
debug: asset zarr saving
raymondwjang Nov 11, 2025
f06ce84
test: motion correction crisp score
raymondwjang Nov 12, 2025
4f9b08e
debug: allow non-zarr
raymondwjang Nov 12, 2025
39bf41d
feat: continue for gui failure
raymondwjang Nov 12, 2025
17b2ccf
format: ruff
raymondwjang Nov 21, 2025
0724bd1
rename: detect to segment
raymondwjang Nov 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
name: pdm-format
entry: pdm format
language: system
types: [python]
types: [ python ]
pass_filenames: false
always_run: true
# - repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
17 changes: 14 additions & 3 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ dependencies = [
"pyyaml>=6.0.2",
"typer>=0.15.3",
"xarray-validate>=0.0.2",
"noob @ git+https://github.com/miniscope/noob.git",
"noob @ git+https://github.com/miniscope/noob.git@scheduler-optimize",
"natsort>=8.4.0",
]
keywords = [
"pipeline",
Expand Down
251 changes: 174 additions & 77 deletions src/cala/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from copy import deepcopy
from pathlib import Path
from typing import Any, ClassVar, Self, TypeVar
from typing import ClassVar, Self, TypeVar

import numpy as np
import xarray as xr
Expand All @@ -22,6 +22,8 @@ class Asset(BaseModel):
validate_schema: bool = False
array_: AssetType = None
sparsify: ClassVar[bool] = False
zarr_path: Path | None = None
"""relative to config.user_data_dir"""
_entity: ClassVar[Entity]

model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)
Expand All @@ -46,6 +48,12 @@ def from_array(cls, array: xr.DataArray) -> Self:

def reset(self) -> None:
self.array_ = None
if self.zarr_path:
path = Path(self.zarr_path)
try:
shutil.rmtree(path)
except FileNotFoundError:
contextlib.suppress(FileNotFoundError)

def __eq__(self, other: "Asset") -> bool:
return self.array.equals(other.array)
Expand All @@ -61,6 +69,32 @@ def validate_array_schema(self) -> Self:

return self

@field_validator("zarr_path", mode="after")
@classmethod
def validate_zarr_path(cls, value: Path | None) -> Path | None:
if value is None:
return value
zarr_dir = (config.user_dir / value).resolve()
zarr_dir.mkdir(parents=True, exist_ok=True)
clear_dir(zarr_dir)
return zarr_dir

def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray:
da = (
xr.open_zarr(self.zarr_path)
.isel(isel_filter)
.sel(sel_filter)
.to_dataarray()
.drop_vars(["variable"])
.isel(variable=0)
)
return da.assign_coords(
{
AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str),
AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str),
}
)


class Footprint(Asset):
_entity: ClassVar[Entity] = PrivateAttr(
Expand Down Expand Up @@ -109,104 +143,167 @@ class Footprints(Asset):


class Traces(Asset):
zarr_path: Path | None = None
"""relative to config.user_data_dir"""
peek_size: int | None = None
peek_size: int = None
"""How many epochs to return when called."""
flush_interval: int | None = None
"""How many epochs to wait until next flush"""

_deprecated: list[str] = PrivateAttr(default_factory=list)
"""
Traces(array=array, path=path) -> saves to zarr (should be set in this asset, and leave
untouched in nodes.)
Traces.array -> loads from zarr
Deprecated, or replaced component idx.
Since zarr does not support efficiently removing rows and columns,
there's no easy way to remove a column when a component has been
removed or replaced. Instead, we "mask" it with this "deprecated"
flag.

When arrays are called, these are filtered out. When new epochs are
added, these are added in with nan values.
"""

@property
def array(self) -> xr.DataArray:
peek_filter = {AXIS.frames_dim: slice(-self.peek_size, None)} if self.peek_size else None
return self.full_array(isel_filter=peek_filter)
_entity: ClassVar[Entity] = PrivateAttr(
Group(
name="trace-group",
member=Trace.entity(),
group_by=Dims.component,
checks=[is_non_negative],
allow_extra_coords=False,
)
)

@array.setter
def array(self, array: xr.DataArray) -> None:
if self.zarr_path:
if self.validate_schema:
array.validate.against_schema(self._entity.model)
array.to_zarr(self.zarr_path, mode="w") # need to make sure it can overwrite
else:
self.array_ = array
@model_validator(mode="after")
def flush_conditions(self) -> Self:
assert (self.flush_interval and self.zarr_path) or (
not self.flush_interval and not self.zarr_path
), "zarr_path and flush_interval should either be both provided or neither."
if self.flush_interval:
assert self.flush_interval > self.peek_size, (
f"flush_interval must be larger than peek_size. "
f"Provided: {self.flush_interval = }, {self.peek_size = }"
)
return self

def reset(self) -> None:
self.array_ = None
@property
def sizes(self) -> dict[str, int]:
if self.zarr_path:
path = Path(self.zarr_path)
try:
shutil.rmtree(path)
except FileNotFoundError:
contextlib.suppress(FileNotFoundError)
total_size = {}
for key, val in self.array_.sizes.items():
if key == AXIS.frames_dim:
total_size[key] = val + self.load_zarr().sizes[key]
else:
total_size[key] = val
return total_size
else:
return self.array_.sizes

def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray:
if self.zarr_path:
try:
return self.load_zarr(isel_filter=isel_filter, sel_filter=sel_filter).compute()
except FileNotFoundError:
pass
@property
def array(self) -> xr.DataArray:
return (
self.array_.isel(isel_filter).sel(sel_filter)
self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)})
if self.array_ is not None
else self.array_
)

def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray:
da = (
xr.open_zarr(self.zarr_path)
.isel(isel_filter)
.sel(sel_filter)
.to_dataarray()
.drop_vars(["variable"])
.isel(variable=0)
)
return da.assign_coords(
{
AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str),
AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str),
}
)
@array.setter
def array(self, array: xr.DataArray) -> None:
"""
In case zarr_path is defined, if array is larger than peek_size,
the epochs older than -peek_size gets flushed to zarr array.

def update(self, array: xr.DataArray, **kwargs: Any) -> None:
"""
if self.validate_schema:
array.validate.against_schema(self._entity.model)
array.to_zarr(self.zarr_path, **kwargs)
if self.zarr_path:
self.array_ = array.isel({AXIS.frames_dim: slice(-self.peek_size, None)})
array.isel({AXIS.frames_dim: slice(None, -self.peek_size)}).to_zarr(
self.zarr_path, mode="w"
)
else:
self.array_ = array

def append(self, array: xr.DataArray, dim: str) -> None:
"""
Since we cannot simply append to zarr array in memory using xarray syntax,
we provide a convenience method for appending to zarr array and in-memory array
in a streamlined manner.

Incoming arrays have to be 2-dimensional.

"""

if dim == AXIS.frames_dim:
self.array_ = xr.concat([self.array_, array], dim=AXIS.frames_dim)

if self.zarr_path and self.array_.sizes[AXIS.frames_dim] > self.flush_interval:
self._flush_zarr()

elif dim == AXIS.component_dim:
if self.zarr_path:
n_in_memory = self.array_.sizes[AXIS.frames_dim]
self.array_ = xr.concat(
[self.array_, array.isel({AXIS.frames_dim: slice(-n_in_memory, None)})],
dim=dim,
)
array.isel({AXIS.frames_dim: slice(None, -n_in_memory)}).to_zarr(
self.zarr_path, append_dim=dim
)
else:
self.array_ = xr.concat([self.array_, array], dim=dim, combine_attrs="drop")

def _flush_zarr(self) -> None:
"""
Flushes traces older than peek_size to zarr array.
Needs to append nans to deprecated components, since they get deleted
in in-memory array, but persist in zarr array.


Could do this much more elegantly by pre-allocating .array_
"""
raw_zarr = self.load_zarr()
to_flush = self.array_.isel({AXIS.frames_dim: slice(None, -self.peek_size)})
if self._deprecated:
zarr_ids = raw_zarr[AXIS.id_coord].values
zarr_detects = raw_zarr[AXIS.detect_coord].values
intact_mask = ~np.isin(zarr_ids, self._deprecated)
n_flush = to_flush.sizes[AXIS.frames_dim]
prealloc = xr.DataArray(
np.full((raw_zarr.sizes[AXIS.component_dim], n_flush), np.nan),
dims=[AXIS.component_dim, AXIS.frames_dim],
coords={
AXIS.id_coord: (AXIS.component_dim, zarr_ids),
AXIS.detect_coord: (AXIS.component_dim, zarr_detects),
},
).assign_coords(to_flush[AXIS.frames_dim].coords)
prealloc.loc[intact_mask] = to_flush
prealloc.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim)
else:
to_flush.to_zarr(self.zarr_path, append_dim=AXIS.frames_dim)
self.array_ = self.array_.isel({AXIS.frames_dim: slice(-self.peek_size, None)})

def keep(self, intact_mask: np.ndarray) -> None:
if self.zarr_path:
self._deprecated.extend(self.array_[AXIS.id_coord].values[~intact_mask])
self.array_ = self.array_[intact_mask]

@classmethod
def from_array(
cls, array: xr.DataArray, zarr_path: Path | str | None = None, peek_size: int | None = None
) -> "Traces":
new_cls = cls(zarr_path=zarr_path, peek_size=peek_size)
def from_array(cls, array: xr.DataArray) -> "Traces":
"""
This is only really used for typing / auto-validation purposes,
so we don't really have to worry about specifying the parameters.

"""
new_cls = cls(peek_size=array.sizes[AXIS.frames_dim])
new_cls.array = array
return new_cls

@field_validator("zarr_path", mode="after")
@classmethod
def validate_zarr_path(cls, value: Path | None) -> Path | None:
if value is None:
return value
zarr_dir = (config.user_dir / value).resolve()
zarr_dir.mkdir(parents=True, exist_ok=True)
clear_dir(zarr_dir)
return zarr_dir

@model_validator(mode="after")
def check_zarr_setting(self) -> "Traces":
def full_array(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.DataArray:
if self.zarr_path:
assert self.peek_size, "peek_size must be set for zarr."
return self
raw_zarr = self.load_zarr(isel_filter, sel_filter)
zarr_ids = raw_zarr[AXIS.id_coord].values
intact_mask = ~np.isin(zarr_ids, self._deprecated)

_entity: ClassVar[Entity] = PrivateAttr(
Group(
name="trace-group",
member=Trace.entity(),
group_by=Dims.component,
checks=[is_non_negative],
allow_extra_coords=False,
)
)
return xr.concat([raw_zarr[intact_mask], self.array_], dim=AXIS.frames_dim).compute()
else:
return self.array_.isel(isel_filter).sel(sel_filter)


class Movie(Asset):
Expand Down
1 change: 1 addition & 0 deletions src/cala/gui/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<link rel="icon"
href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🎯</text></svg>">
<link href="../static/main.css" rel="stylesheet"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/water.css@2/out/water.css">
<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
</head>
<body>
Expand Down
Loading
Loading