diff --git a/src/cedalion/imagereco/forward_model.py b/src/cedalion/imagereco/forward_model.py index 80fc952d..5f7a5ce9 100644 --- a/src/cedalion/imagereco/forward_model.py +++ b/src/cedalion/imagereco/forward_model.py @@ -1124,13 +1124,16 @@ def compute_stacked_sensitivity(sensitivity: xr.DataArray): def apply_inv_sensitivity( - od: cdt.NDTimeSeries, inv_sens: xr.DataArray + od: cdt.NDTimeSeries, inv_sens: xr.DataArray, chunk: bool = True, ) -> tuple[xr.DataArray, xr.DataArray]: """Apply the inverted sensitivity matrix to optical density data. Args: od: time series of optical density data inv_sens: the inverted sensitivity matrix + chunk: optional piecewise matrix multiplication. + default True, gets active if more than 1000 time samples are to be converted. + False force-skips chunking. Returns: Two DataArrays for the brain and scalp with the reconcstructed time series per @@ -1142,7 +1145,19 @@ def apply_inv_sensitivity( od_stacked = od.stack({"flat_channel": ["wavelength", "channel"]}) od_stacked = od_stacked.pint.dequantify() - delta_conc = inv_sens @ od_stacked + # for image recon we have time-series data either with "time" or "reltime" dimension + sample_dim = next((d for d in ["time","reltime"] if d in od_stacked.dims), None) + + # if od_stacked has more than 1000 time points, chunk it + if (od_stacked.sizes["time"] > 1000) and chunk: + delta_conc = xrutils.chunked_eff_xr_matmult( + od_stacked, + inv_sens, + contract_dim="flat_channel", + sample_dim = sample_dim, + chunksize=1000) + else: + delta_conc = inv_sens @ od_stacked # Construct a multiindex for dimension flat_vertex from chromo and vertex. # Afterwards use this multiindex to unstack flat_vertex. The resulting array diff --git a/src/cedalion/xrutils.py b/src/cedalion/xrutils.py index 8b06183b..8e37ab66 100644 --- a/src/cedalion/xrutils.py +++ b/src/cedalion/xrutils.py @@ -5,6 +5,9 @@ import numpy as np import pint import xarray as xr +import os +import tempfile +import shutil def pinv(array: xr.DataArray) -> xr.DataArray: @@ -231,3 +234,99 @@ def unit_stripping_is_error(is_error : bool = True): if f[0] =="error" and f[2] == pint.errors.UnitStrippedWarning: del warnings.filters[i] break + + +def chunked_eff_xr_matmult( + A: xr.DataArray, + B: xr.DataArray, + contract_dim: str, + sample_dim: str, + chunksize: int = 5000, + tmpdir: str | None = None +) -> xr.DataArray: + """Performs a large matrix multiplication of A and B, chunking A along `sample_dim`; to avoid memory issues, streams each chunk to disk, and then rebuilds a full DataArray. + + Args: + A: DataArray to multiply (dims include `contract_dim` and `sample_dim` among others) + B: DataArray defining the mat-mul (dims include `contract_dim` and others) + contract_dim: name of the dimension to contract (e.g. "flat_channel") + sample_dim: name of the dimension along which to chunk (e.g. "time") + chunksize: max size of each chunk along dimension `sample_dim` + tmpdir: optional path to temp directory (auto‐created and removed if None) + + Returns: + A new DataArray of containing the result of the matrix multiplication over `contract_dim`, + with coords, dims, and attrs preserved. Should yield the same result as `xr.dot(A, B, dims=[contract_dim])` + but at increased speed and with a much lower memory footprint. + + Initial Contributors: + - Alexander von Lühmann | vonluehmann@tu-berlin.de | 2025 + """ + # Total samples & number of chunks + N = A.sizes[sample_dim] + n_chunks = int(np.ceil(N / chunksize)) + + # Build a “shell” result for metadata by doing the dot on the first sample + A0 = A.isel({sample_dim: slice(0, 1)}) + Xres = xr.dot(B, A0, dims=[contract_dim]) + + # Prepare for raw numpy multiply + dims_B_not = [d for d in B.dims if d != contract_dim] + dims_A_not = [d for d in A.dims if d != contract_dim] + B_mat = B.transpose(*dims_B_not, contract_dim).values + A2 = A.transpose(contract_dim, *dims_A_not) + + # Create Temp directory + cleanup = False + if tmpdir is None: + tmpdir = tempfile.mkdtemp() + cleanup = True + else: + os.makedirs(tmpdir, exist_ok=True) + + print(f"Large Matrix Multiplication: Processing {n_chunks} chunks...") + + # Stream‐compute each chunk + file_paths = [] + for i in range(n_chunks): + start = i * chunksize + stop = min((i + 1) * chunksize, N) + A_chunk = A2.isel({sample_dim: slice(start, stop)}) + C_chunk = B_mat.dot(A_chunk.values) # raw (out_dim, chunk_len, ...) + fn = os.path.join(tmpdir, f"chunk_{i:04d}.npy") + np.save(fn, C_chunk) + file_paths.append(fn) + del A_chunk, C_chunk + print(f"Chunk {i+1}/{n_chunks} done.") + + # Read back & concatenate along the sample axis + arrs = [np.load(fp) for fp in sorted(file_paths)] + axis = Xres.get_axis_num(sample_dim) + full_arr = np.concatenate(arrs, axis=axis) + + if cleanup: + shutil.rmtree(tmpdir) + + # create set of coordinates + coords = { + name: coord + for name, coord in Xres.coords.items() + if sample_dim not in coord.dims + } + sample_coords = { + name: coord + for name, coord in A.coords.items() + if sample_dim in coord.dims + } + coords.update(sample_coords) + + # rebuild the DataArray using the Xres metadata + result = xr.DataArray( + data = full_arr, + dims = Xres.dims, + coords = coords, + attrs = Xres.attrs + ) + # add time coords + result.assign_coords() + return result