Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion zetta_utils/layer/volumetric/cloudvol/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_cv_cache: cachetools.LRUCache = cachetools.LRUCache(maxsize=2048)
_cv_cached: Dict[str, set] = {}

IN_MEM_CACHE_NUM_BYTES_PER_CV = 128 * 1024 ** 2
IN_MEM_CACHE_NUM_BYTES_PER_CV = 1 * 1024 ** 3


def _serialize_kwargs(kwargs: Dict[str, Any]) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import time
from typing import Any, Generic, Sequence, cast

Expand All @@ -10,7 +11,9 @@

from zetta_utils import builder, log, mazepa
from zetta_utils.geometry import Vec3D
from zetta_utils.layer.layer_base import Layer
from zetta_utils.layer.volumetric import VolumetricIndex, VolumetricLayer
from zetta_utils.mazepa import semaphore

from ..operation_protocols import StackableVolumetricOpProtocol

Expand Down Expand Up @@ -55,6 +58,31 @@ def get_input_resolution(self, dst_resolution: Vec3D) -> Vec3D:
def with_added_crop_pad(self, crop_pad: Vec3D[int]) -> StackedVolumetricOperations[P]:
return attrs.evolve(self, base_op=self.base_op.with_added_crop_pad(crop_pad))

def _prefetch_region(
self,
indices: Sequence[VolumetricIndex],
**kwargs: Any,
) -> None:
"""
Prefetch the supremum bounding box of all indices from source layers.

Issues a single large read per source layer to warm the CloudVolume cache,
so that subsequent per-chunk reads are cache hits instead of remote fetches.
"""
sup_idx = indices[0]
for idx in indices[1:]:
sup_idx = sup_idx.supremum(idx)

sup_idx_input = copy.deepcopy(sup_idx)
sup_idx_input.resolution = self.base_op.get_input_resolution(sup_idx.resolution)
input_crop_pad = getattr(self.base_op, "input_crop_pad", (0, 0, 0))
sup_idx_input = sup_idx_input.padded(input_crop_pad)

with semaphore("read"):
for v in kwargs.values():
if isinstance(v, Layer):
v.read_with_procs(sup_idx_input)

def __call__( # pylint: disable=keyword-arg-before-vararg,too-many-branches
self,
indices: Sequence[VolumetricIndex],
Expand Down Expand Up @@ -85,9 +113,16 @@ def __call__( # pylint: disable=keyword-arg-before-vararg,too-many-branches
f"length of indices ({len(indices)})"
)

# Read all data
# Prefetch the entire region to warm CloudVolume cache
prefetch_start = time.time()
self._prefetch_region(indices, **kwargs)
prefetch_time = time.time() - prefetch_start

# Read all data (should be cache hits after prefetch)
read_start = time.time()
data_list = [self.base_op.read(idx, *args, **kwargs) for idx in indices]
data_list = [
self.base_op.read(idx, *args, use_semaphore=False, **kwargs) for idx in indices
]
read_time = time.time() - read_start

# Stack tensors by key
Expand Down Expand Up @@ -119,23 +154,25 @@ def __call__( # pylint: disable=keyword-arg-before-vararg,too-many-branches

# Unstack and write results
write_start = time.time()
for i, (idx, dst) in enumerate(zip(indices, dsts_list)):
result: Any
if isinstance(batched_result, torch.Tensor):
result = batched_result[i]
elif isinstance(batched_result, np.ndarray):
result = batched_result[i]
else:
raise TypeError(
f"Function returned unsupported type: {type(batched_result)}. "
f"Only torch.Tensor and np.ndarray are supported."
)

self.base_op.write(idx, dst, result, *args, **kwargs)
with semaphore("write"):
for i, (idx, dst) in enumerate(zip(indices, dsts_list)):
result: Any
if isinstance(batched_result, torch.Tensor):
result = batched_result[i]
elif isinstance(batched_result, np.ndarray):
result = batched_result[i]
else:
raise TypeError(
f"Function returned unsupported type: {type(batched_result)}. "
f"Only torch.Tensor and np.ndarray are supported."
)

self.base_op.write(idx, dst, result, *args, use_semaphore=False, **kwargs)
write_time = time.time() - write_start

total_time = read_time + process_time + write_time
total_time = prefetch_time + read_time + process_time + write_time
logger.info(
f"StackedVolumetricOperations: Total time for {len(indices)} chunks: {total_time:.2f}s"
f" (read: {read_time:.2f}s, process: {process_time:.2f}s, write: {write_time:.2f}s)"
f" (prefetch: {prefetch_time:.2f}s, read: {read_time:.2f}s,"
f" process: {process_time:.2f}s, write: {write_time:.2f}s)"
)
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,7 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg
Iteratively build the hierarchy of schemas
"""
for level in range(1, num_levels):
is_top_level = level == num_levels - 1
flow_schema = VolumetricApplyFlowSchema(
op=DelegatedSubchunkedOperation( # type:ignore #readability over typing here
flow_schema,
Expand All @@ -1314,7 +1315,7 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg
roi_crop_pad=roi_crop_pads[-level - 1],
processing_blend_pad=processing_blend_pads[-level - 1],
processing_blend_mode=processing_blend_modes[-level - 1],
processing_gap=processing_gap if level == num_levels - 1 else None,
processing_gap=processing_gap if is_top_level else None,
intermediaries_dir=_path_join_if_not_none(
level_intermediaries_dirs[-level - 1], f"chunks_level_{level}"
),
Expand Down
110 changes: 79 additions & 31 deletions zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import time
from contextlib import nullcontext

import itertools
import multiprocessing
from abc import ABC
Expand Down Expand Up @@ -56,9 +59,11 @@ def __call__(
dst: VolumetricBasedLayerProtocol,
idx: VolumetricIndex,
) -> None:
with semaphore("read"):
read_ctx = nullcontext() if src.backend.is_local else semaphore("read")
with read_ctx:
data = src[idx]
with semaphore("write"):
write_ctx = nullcontext() if dst.backend.is_local else semaphore("write")
with write_ctx:
dst[idx] = data


Expand Down Expand Up @@ -96,6 +101,7 @@ def __call__(
with suppress_type_checks():
if len(src_layers) == 0:
return
reduce_start = time.time()
res = np.zeros(
(dst.backend.num_channels, *red_idx.shape),
dtype=dst.backend.dtype,
Expand All @@ -105,17 +111,29 @@ def __call__(
for src_idx, layer in zip(src_idxs, src_layers):
intscn, subidx = src_idx.get_intersection_and_subindex(red_idx)
subidx_channels = (slice(0, res.shape[0]), *subidx)
with semaphore("read"):
read_ctx = nullcontext() if layer.backend.is_local else semaphore("read")
with read_ctx:
res[subidx_channels] = np.maximum(res[subidx_channels], layer[intscn])
else:
for src_idx, layer in zip(src_idxs, src_layers):
intscn, subidx = src_idx.get_intersection_and_subindex(red_idx)
subidx_channels = (slice(0, res.shape[0]), *subidx)
with semaphore("read"):
read_ctx = nullcontext() if layer.backend.is_local else semaphore("read")
with read_ctx:
res[subidx_channels] = layer[intscn]
reduce_time = time.time() - reduce_start
if np.any(res):
with semaphore("write"):
write_start = time.time()
write_ctx = nullcontext() if dst.backend.is_local else semaphore("write")
with write_ctx:
dst[red_idx] = res
write_time = time.time() - write_start
else:
write_time = 0.0
logger.info(
f"ReduceNaive: {len(src_layers)} sources, "
f"reduce: {reduce_time:.3f}s, write: {write_time:.3f}s"
)


def is_floating_point_dtype(dtype: np.dtype) -> bool:
Expand Down Expand Up @@ -145,6 +163,7 @@ def __call__(
with suppress_type_checks():
if len(src_layers) == 0:
return
reduce_start = time.time()
if not is_floating_point_dtype(dst.backend.dtype) and processing_blend_pad != Vec3D[
int
](0, 0, 0):
Expand All @@ -168,7 +187,8 @@ def __call__(
)
intscn, subidx = src_idx.get_intersection_and_subindex(red_idx)
subidx_channels = [slice(0, res.shape[0])] + list(subidx)
with semaphore("read"):
read_ctx = nullcontext() if layer.backend.is_local else semaphore("read")
with read_ctx:
if not is_floating_point_dtype(dst.backend.dtype):
# Temporarily convert integer cutout to float for rounding
res[tuple(subidx_channels)] = (
Expand All @@ -186,11 +206,22 @@ def __call__(
for src_idx, layer in zip(src_idxs, src_layers):
intscn, subidx = src_idx.get_intersection_and_subindex(red_idx)
subidx_channels = [slice(0, res.shape[0])] + list(subidx)
with semaphore("read"):
read_ctx = nullcontext() if layer.backend.is_local else semaphore("read")
with read_ctx:
res.numpy()[tuple(subidx_channels)] = layer[intscn]
reduce_time = time.time() - reduce_start
if res.any():
with semaphore("write"):
write_start = time.time()
write_ctx = nullcontext() if dst.backend.is_local else semaphore("write")
with write_ctx:
dst[red_idx] = res
write_time = time.time() - write_start
else:
write_time = 0.0
logger.info(
f"ReduceByWeightedSum: {len(src_layers)} sources, "
f"reduce: {reduce_time:.3f}s, write: {write_time:.3f}s"
)


@cachetools.cached(_weights_cache)
Expand Down Expand Up @@ -490,6 +521,7 @@ def _get_temp_dst(
self.dst_resolution,
),
enforce_chunk_aligned_writes=False,
overwrite_partial_chunks=True,
allow_cache=allow_cache,
use_compression=False,
)
Expand Down Expand Up @@ -889,30 +921,46 @@ def _flow_with_checkerboarding( # pylint: disable=too-many-locals
)
e.args = (error_str + e.args[0],)
raise e
if not self.max_reduction_chunk_size_final >= dst.backend.get_chunk_size(
self.dst_resolution
):
raise ValueError(
"`max_reduction_chunk_size` (which defaults to `processing_chunk_size` when"
" not specified)` must be at least as large as the `dst` layer's"
f" chunk size; received {self.max_reduction_chunk_size_final}, which is"
f" smaller than {dst.backend.get_chunk_size(self.dst_resolution)}"
)
reduction_chunker = VolumetricIndexChunker(
chunk_size=dst.backend.get_chunk_size(self.dst_resolution),
resolution=self.dst_resolution,
max_superchunk_size=self.max_reduction_chunk_size_final,
)
logger.debug(
f"Breaking {idx} into reduction chunks with checkerboarding"
f" with {reduction_chunker}. Processing chunks will use the padded index"
f" {idx.padded(self.roi_crop_pad)} and be chunked with {self.processing_chunker}."
)
stride_start_offset = dst.backend.get_voxel_offset(self.dst_resolution)
red_chunks = reduction_chunker(idx, mode="exact", stride_start_offset=stride_start_offset)
red_shape = reduction_chunker.get_shape(
idx, mode="exact", stride_start_offset=stride_start_offset
reduce_whole_roi = (
not dst.backend.enforce_chunk_aligned_writes
and dst.backend.overwrite_partial_chunks
and idx.shape <= self.max_reduction_chunk_size_final
)
if reduce_whole_roi:
red_chunks = [idx]
red_shape = Vec3D[int](1, 1, 1)
logger.debug(
f"Reducing entire ROI {idx} as a single chunk. Processing chunks will"
f" use the padded index {idx.padded(self.roi_crop_pad)} and be chunked"
f" with {self.processing_chunker}."
)
else:
if not self.max_reduction_chunk_size_final >= dst.backend.get_chunk_size(
self.dst_resolution
):
raise ValueError(
"`max_reduction_chunk_size` (which defaults to `processing_chunk_size` when"
" not specified)` must be at least as large as the `dst` layer's"
f" chunk size; received {self.max_reduction_chunk_size_final}, which is"
f" smaller than {dst.backend.get_chunk_size(self.dst_resolution)}"
)
reduction_chunker = VolumetricIndexChunker(
chunk_size=dst.backend.get_chunk_size(self.dst_resolution),
resolution=self.dst_resolution,
max_superchunk_size=self.max_reduction_chunk_size_final,
)
logger.debug(
f"Breaking {idx} into reduction chunks with checkerboarding"
f" with {reduction_chunker}. Processing chunks will use the padded index"
f" {idx.padded(self.roi_crop_pad)} and be chunked with {self.processing_chunker}."
)
stride_start_offset = dst.backend.get_voxel_offset(self.dst_resolution)
red_chunks = reduction_chunker(
idx, mode="exact", stride_start_offset=stride_start_offset
)
red_shape = reduction_chunker.get_shape(
idx, mode="exact", stride_start_offset=stride_start_offset
)
(
tasks,
red_chunks_task_idxs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def read( # pylint: disable=keyword-arg-before-vararg
self,
idx: VolumetricIndex,
*args: P.args,
use_semaphore: bool = True,
**kwargs: P.kwargs,
) -> dict[str, Any]:
"""Read all source data for this operation."""
Expand All @@ -88,17 +89,19 @@ def read( # pylint: disable=keyword-arg-before-vararg
idx_input.resolution = self.get_input_resolution(idx.resolution)
idx_input_padded = idx_input.padded(Vec3D[int](*self.input_crop_pad))

with semaphore("read"):
task_kwargs = _process_callable_kwargs(idx_input_padded, kwargs)
if use_semaphore:
with semaphore("read"):
return _process_callable_kwargs(idx_input_padded, kwargs)
return _process_callable_kwargs(idx_input_padded, kwargs)

return task_kwargs

def write( # pylint: disable=keyword-arg-before-vararg,unused-argument
self,
idx: VolumetricIndex,
dst: VolumetricLayer,
tensor: Any,
*args: P.args,
use_semaphore: bool = True,
**kwargs: P.kwargs,
) -> None:
"""Write tensor data to destination."""
Expand All @@ -114,7 +117,10 @@ def write( # pylint: disable=keyword-arg-before-vararg,unused-argument
else:
dst_with_crop = dst

with semaphore("write"):
if use_semaphore:
with semaphore("write"):
dst_with_crop[idx] = tensor
else:
dst_with_crop[idx] = tensor

def processing_fn(self, **kwargs: Any) -> Any:
Expand Down
2 changes: 2 additions & 0 deletions zetta_utils/mazepa_layer_processing/operation_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def read(
self,
idx: VolumetricIndex,
*args: P.args,
use_semaphore: bool = True,
**kwargs: P.kwargs,
) -> dict[str, Tensor]:
"""Read all source data for this operation and return as named tensors."""
Expand All @@ -157,6 +158,7 @@ def write(
dst: DstLayerT_contra,
tensor: Tensor,
*args: P.args,
use_semaphore: bool = True,
**kwargs: P.kwargs,
) -> None:
"""Write tensor data to destination."""
Loading