diff --git a/zetta_utils/layer/volumetric/cloudvol/backend.py b/zetta_utils/layer/volumetric/cloudvol/backend.py index 155d1e975..25f60d505 100644 --- a/zetta_utils/layer/volumetric/cloudvol/backend.py +++ b/zetta_utils/layer/volumetric/cloudvol/backend.py @@ -24,7 +24,7 @@ _cv_cache: cachetools.LRUCache = cachetools.LRUCache(maxsize=2048) _cv_cached: Dict[str, set] = {} -IN_MEM_CACHE_NUM_BYTES_PER_CV = 128 * 1024 ** 2 +IN_MEM_CACHE_NUM_BYTES_PER_CV = 1 * 1024 ** 3 def _serialize_kwargs(kwargs: Dict[str, Any]) -> str: diff --git a/zetta_utils/mazepa_layer_processing/common/stacked_volumetric_operations.py b/zetta_utils/mazepa_layer_processing/common/stacked_volumetric_operations.py index 7a432f684..5ed56381e 100644 --- a/zetta_utils/mazepa_layer_processing/common/stacked_volumetric_operations.py +++ b/zetta_utils/mazepa_layer_processing/common/stacked_volumetric_operations.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import time from typing import Any, Generic, Sequence, cast @@ -10,7 +11,9 @@ from zetta_utils import builder, log, mazepa from zetta_utils.geometry import Vec3D +from zetta_utils.layer.layer_base import Layer from zetta_utils.layer.volumetric import VolumetricIndex, VolumetricLayer +from zetta_utils.mazepa import semaphore from ..operation_protocols import StackableVolumetricOpProtocol @@ -55,6 +58,31 @@ def get_input_resolution(self, dst_resolution: Vec3D) -> Vec3D: def with_added_crop_pad(self, crop_pad: Vec3D[int]) -> StackedVolumetricOperations[P]: return attrs.evolve(self, base_op=self.base_op.with_added_crop_pad(crop_pad)) + def _prefetch_region( + self, + indices: Sequence[VolumetricIndex], + **kwargs: Any, + ) -> None: + """ + Prefetch the supremum bounding box of all indices from source layers. + + Issues a single large read per source layer to warm the CloudVolume cache, + so that subsequent per-chunk reads are cache hits instead of remote fetches. + """ + sup_idx = indices[0] + for idx in indices[1:]: + sup_idx = sup_idx.supremum(idx) + + sup_idx_input = copy.deepcopy(sup_idx) + sup_idx_input.resolution = self.base_op.get_input_resolution(sup_idx.resolution) + input_crop_pad = getattr(self.base_op, "input_crop_pad", (0, 0, 0)) + sup_idx_input = sup_idx_input.padded(input_crop_pad) + + with semaphore("read"): + for v in kwargs.values(): + if isinstance(v, Layer): + v.read_with_procs(sup_idx_input) + def __call__( # pylint: disable=keyword-arg-before-vararg,too-many-branches self, indices: Sequence[VolumetricIndex], @@ -85,9 +113,16 @@ def __call__( # pylint: disable=keyword-arg-before-vararg,too-many-branches f"length of indices ({len(indices)})" ) - # Read all data + # Prefetch the entire region to warm CloudVolume cache + prefetch_start = time.time() + self._prefetch_region(indices, **kwargs) + prefetch_time = time.time() - prefetch_start + + # Read all data (should be cache hits after prefetch) read_start = time.time() - data_list = [self.base_op.read(idx, *args, **kwargs) for idx in indices] + data_list = [ + self.base_op.read(idx, *args, use_semaphore=False, **kwargs) for idx in indices + ] read_time = time.time() - read_start # Stack tensors by key @@ -119,23 +154,25 @@ def __call__( # pylint: disable=keyword-arg-before-vararg,too-many-branches # Unstack and write results write_start = time.time() - for i, (idx, dst) in enumerate(zip(indices, dsts_list)): - result: Any - if isinstance(batched_result, torch.Tensor): - result = batched_result[i] - elif isinstance(batched_result, np.ndarray): - result = batched_result[i] - else: - raise TypeError( - f"Function returned unsupported type: {type(batched_result)}. " - f"Only torch.Tensor and np.ndarray are supported." - ) - - self.base_op.write(idx, dst, result, *args, **kwargs) + with semaphore("write"): + for i, (idx, dst) in enumerate(zip(indices, dsts_list)): + result: Any + if isinstance(batched_result, torch.Tensor): + result = batched_result[i] + elif isinstance(batched_result, np.ndarray): + result = batched_result[i] + else: + raise TypeError( + f"Function returned unsupported type: {type(batched_result)}. " + f"Only torch.Tensor and np.ndarray are supported." + ) + + self.base_op.write(idx, dst, result, *args, use_semaphore=False, **kwargs) write_time = time.time() - write_start - total_time = read_time + process_time + write_time + total_time = prefetch_time + read_time + process_time + write_time logger.info( f"StackedVolumetricOperations: Total time for {len(indices)} chunks: {total_time:.2f}s" - f" (read: {read_time:.2f}s, process: {process_time:.2f}s, write: {write_time:.2f}s)" + f" (prefetch: {prefetch_time:.2f}s, read: {read_time:.2f}s," + f" process: {process_time:.2f}s, write: {write_time:.2f}s)" ) diff --git a/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py index 9009d1a2b..154322556 100644 --- a/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py @@ -1302,6 +1302,7 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg Iteratively build the hierarchy of schemas """ for level in range(1, num_levels): + is_top_level = level == num_levels - 1 flow_schema = VolumetricApplyFlowSchema( op=DelegatedSubchunkedOperation( # type:ignore #readability over typing here flow_schema, @@ -1314,7 +1315,7 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg roi_crop_pad=roi_crop_pads[-level - 1], processing_blend_pad=processing_blend_pads[-level - 1], processing_blend_mode=processing_blend_modes[-level - 1], - processing_gap=processing_gap if level == num_levels - 1 else None, + processing_gap=processing_gap if is_top_level else None, intermediaries_dir=_path_join_if_not_none( level_intermediaries_dirs[-level - 1], f"chunks_level_{level}" ), diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index 1957d1cf7..e55920b74 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -1,5 +1,8 @@ from __future__ import annotations +import time +from contextlib import nullcontext + import itertools import multiprocessing from abc import ABC @@ -56,9 +59,11 @@ def __call__( dst: VolumetricBasedLayerProtocol, idx: VolumetricIndex, ) -> None: - with semaphore("read"): + read_ctx = nullcontext() if src.backend.is_local else semaphore("read") + with read_ctx: data = src[idx] - with semaphore("write"): + write_ctx = nullcontext() if dst.backend.is_local else semaphore("write") + with write_ctx: dst[idx] = data @@ -96,6 +101,7 @@ def __call__( with suppress_type_checks(): if len(src_layers) == 0: return + reduce_start = time.time() res = np.zeros( (dst.backend.num_channels, *red_idx.shape), dtype=dst.backend.dtype, @@ -105,17 +111,29 @@ def __call__( for src_idx, layer in zip(src_idxs, src_layers): intscn, subidx = src_idx.get_intersection_and_subindex(red_idx) subidx_channels = (slice(0, res.shape[0]), *subidx) - with semaphore("read"): + read_ctx = nullcontext() if layer.backend.is_local else semaphore("read") + with read_ctx: res[subidx_channels] = np.maximum(res[subidx_channels], layer[intscn]) else: for src_idx, layer in zip(src_idxs, src_layers): intscn, subidx = src_idx.get_intersection_and_subindex(red_idx) subidx_channels = (slice(0, res.shape[0]), *subidx) - with semaphore("read"): + read_ctx = nullcontext() if layer.backend.is_local else semaphore("read") + with read_ctx: res[subidx_channels] = layer[intscn] + reduce_time = time.time() - reduce_start if np.any(res): - with semaphore("write"): + write_start = time.time() + write_ctx = nullcontext() if dst.backend.is_local else semaphore("write") + with write_ctx: dst[red_idx] = res + write_time = time.time() - write_start + else: + write_time = 0.0 + logger.info( + f"ReduceNaive: {len(src_layers)} sources, " + f"reduce: {reduce_time:.3f}s, write: {write_time:.3f}s" + ) def is_floating_point_dtype(dtype: np.dtype) -> bool: @@ -145,6 +163,7 @@ def __call__( with suppress_type_checks(): if len(src_layers) == 0: return + reduce_start = time.time() if not is_floating_point_dtype(dst.backend.dtype) and processing_blend_pad != Vec3D[ int ](0, 0, 0): @@ -168,7 +187,8 @@ def __call__( ) intscn, subidx = src_idx.get_intersection_and_subindex(red_idx) subidx_channels = [slice(0, res.shape[0])] + list(subidx) - with semaphore("read"): + read_ctx = nullcontext() if layer.backend.is_local else semaphore("read") + with read_ctx: if not is_floating_point_dtype(dst.backend.dtype): # Temporarily convert integer cutout to float for rounding res[tuple(subidx_channels)] = ( @@ -186,11 +206,22 @@ def __call__( for src_idx, layer in zip(src_idxs, src_layers): intscn, subidx = src_idx.get_intersection_and_subindex(red_idx) subidx_channels = [slice(0, res.shape[0])] + list(subidx) - with semaphore("read"): + read_ctx = nullcontext() if layer.backend.is_local else semaphore("read") + with read_ctx: res.numpy()[tuple(subidx_channels)] = layer[intscn] + reduce_time = time.time() - reduce_start if res.any(): - with semaphore("write"): + write_start = time.time() + write_ctx = nullcontext() if dst.backend.is_local else semaphore("write") + with write_ctx: dst[red_idx] = res + write_time = time.time() - write_start + else: + write_time = 0.0 + logger.info( + f"ReduceByWeightedSum: {len(src_layers)} sources, " + f"reduce: {reduce_time:.3f}s, write: {write_time:.3f}s" + ) @cachetools.cached(_weights_cache) @@ -490,6 +521,7 @@ def _get_temp_dst( self.dst_resolution, ), enforce_chunk_aligned_writes=False, + overwrite_partial_chunks=True, allow_cache=allow_cache, use_compression=False, ) @@ -889,30 +921,46 @@ def _flow_with_checkerboarding( # pylint: disable=too-many-locals ) e.args = (error_str + e.args[0],) raise e - if not self.max_reduction_chunk_size_final >= dst.backend.get_chunk_size( - self.dst_resolution - ): - raise ValueError( - "`max_reduction_chunk_size` (which defaults to `processing_chunk_size` when" - " not specified)` must be at least as large as the `dst` layer's" - f" chunk size; received {self.max_reduction_chunk_size_final}, which is" - f" smaller than {dst.backend.get_chunk_size(self.dst_resolution)}" - ) - reduction_chunker = VolumetricIndexChunker( - chunk_size=dst.backend.get_chunk_size(self.dst_resolution), - resolution=self.dst_resolution, - max_superchunk_size=self.max_reduction_chunk_size_final, - ) - logger.debug( - f"Breaking {idx} into reduction chunks with checkerboarding" - f" with {reduction_chunker}. Processing chunks will use the padded index" - f" {idx.padded(self.roi_crop_pad)} and be chunked with {self.processing_chunker}." - ) - stride_start_offset = dst.backend.get_voxel_offset(self.dst_resolution) - red_chunks = reduction_chunker(idx, mode="exact", stride_start_offset=stride_start_offset) - red_shape = reduction_chunker.get_shape( - idx, mode="exact", stride_start_offset=stride_start_offset + reduce_whole_roi = ( + not dst.backend.enforce_chunk_aligned_writes + and dst.backend.overwrite_partial_chunks + and idx.shape <= self.max_reduction_chunk_size_final ) + if reduce_whole_roi: + red_chunks = [idx] + red_shape = Vec3D[int](1, 1, 1) + logger.debug( + f"Reducing entire ROI {idx} as a single chunk. Processing chunks will" + f" use the padded index {idx.padded(self.roi_crop_pad)} and be chunked" + f" with {self.processing_chunker}." + ) + else: + if not self.max_reduction_chunk_size_final >= dst.backend.get_chunk_size( + self.dst_resolution + ): + raise ValueError( + "`max_reduction_chunk_size` (which defaults to `processing_chunk_size` when" + " not specified)` must be at least as large as the `dst` layer's" + f" chunk size; received {self.max_reduction_chunk_size_final}, which is" + f" smaller than {dst.backend.get_chunk_size(self.dst_resolution)}" + ) + reduction_chunker = VolumetricIndexChunker( + chunk_size=dst.backend.get_chunk_size(self.dst_resolution), + resolution=self.dst_resolution, + max_superchunk_size=self.max_reduction_chunk_size_final, + ) + logger.debug( + f"Breaking {idx} into reduction chunks with checkerboarding" + f" with {reduction_chunker}. Processing chunks will use the padded index" + f" {idx.padded(self.roi_crop_pad)} and be chunked with {self.processing_chunker}." + ) + stride_start_offset = dst.backend.get_voxel_offset(self.dst_resolution) + red_chunks = reduction_chunker( + idx, mode="exact", stride_start_offset=stride_start_offset + ) + red_shape = reduction_chunker.get_shape( + idx, mode="exact", stride_start_offset=stride_start_offset + ) ( tasks, red_chunks_task_idxs, diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py b/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py index 3fd9dc731..8271ffd5e 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py @@ -80,6 +80,7 @@ def read( # pylint: disable=keyword-arg-before-vararg self, idx: VolumetricIndex, *args: P.args, + use_semaphore: bool = True, **kwargs: P.kwargs, ) -> dict[str, Any]: """Read all source data for this operation.""" @@ -88,10 +89,11 @@ def read( # pylint: disable=keyword-arg-before-vararg idx_input.resolution = self.get_input_resolution(idx.resolution) idx_input_padded = idx_input.padded(Vec3D[int](*self.input_crop_pad)) - with semaphore("read"): - task_kwargs = _process_callable_kwargs(idx_input_padded, kwargs) + if use_semaphore: + with semaphore("read"): + return _process_callable_kwargs(idx_input_padded, kwargs) + return _process_callable_kwargs(idx_input_padded, kwargs) - return task_kwargs def write( # pylint: disable=keyword-arg-before-vararg,unused-argument self, @@ -99,6 +101,7 @@ def write( # pylint: disable=keyword-arg-before-vararg,unused-argument dst: VolumetricLayer, tensor: Any, *args: P.args, + use_semaphore: bool = True, **kwargs: P.kwargs, ) -> None: """Write tensor data to destination.""" @@ -114,7 +117,10 @@ def write( # pylint: disable=keyword-arg-before-vararg,unused-argument else: dst_with_crop = dst - with semaphore("write"): + if use_semaphore: + with semaphore("write"): + dst_with_crop[idx] = tensor + else: dst_with_crop[idx] = tensor def processing_fn(self, **kwargs: Any) -> Any: diff --git a/zetta_utils/mazepa_layer_processing/operation_protocols.py b/zetta_utils/mazepa_layer_processing/operation_protocols.py index a9fcf1267..ab3627a11 100644 --- a/zetta_utils/mazepa_layer_processing/operation_protocols.py +++ b/zetta_utils/mazepa_layer_processing/operation_protocols.py @@ -147,6 +147,7 @@ def read( self, idx: VolumetricIndex, *args: P.args, + use_semaphore: bool = True, **kwargs: P.kwargs, ) -> dict[str, Tensor]: """Read all source data for this operation and return as named tensors.""" @@ -157,6 +158,7 @@ def write( dst: DstLayerT_contra, tensor: Tensor, *args: P.args, + use_semaphore: bool = True, **kwargs: P.kwargs, ) -> None: """Write tensor data to destination."""