diff --git a/pyproject.toml b/pyproject.toml index bcf70f22..4535426d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "graphviz", "plopp>=24.7.0", "sciline>=24.06.0", - "scipp>=23.8.0", + "scipp>=25.3.0", "scippnexus>=23.12.0", "pooch>=1.5", "pandas>=2.1.2", @@ -49,6 +49,7 @@ dynamic = ["version"] [project.scripts] essnmx_reduce_mcstas = "ess.nmx.mcstas.executables:main" +essnmx-reduce = "ess.nmx.executables:main" [project.optional-dependencies] test = [ diff --git a/src/ess/nmx/_executable_helper.py b/src/ess/nmx/_executable_helper.py new file mode 100644 index 00000000..bf8aec32 --- /dev/null +++ b/src/ess/nmx/_executable_helper.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import argparse +import logging +import sys + +from .types import Compression + + +def build_reduction_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Command line arguments for the NMX reduction. " + "It assumes 14 Hz pulse speed." + ) + input_arg_group = parser.add_argument_group("Input Options") + input_arg_group.add_argument( + "--input_file", type=str, help="Path to the input file", required=True + ) + input_arg_group.add_argument( + "--nbins", + type=int, + default=50, + help="Number of TOF bins", + ) + input_arg_group.add_argument( + "--detector_ids", + type=int, + nargs="+", + default=[0, 1, 2], + help="Detector indices to process", + ) + + output_arg_group = parser.add_argument_group("Output Options") + output_arg_group.add_argument( + "--output_file", + type=str, + default="scipp_output.h5", + help="Path to the output file", + ) + output_arg_group.add_argument( + "--compression", + type=str, + default=Compression.BITSHUFFLE_LZ4.name, + choices=[compression_key.name for compression_key in Compression], + help="Compress option of reduced output file. Default: BITSHUFFLE_LZ4", + ) + output_arg_group.add_argument( + "--verbose", "-v", action="store_true", help="Increase output verbosity" + ) + + return parser + + +def build_logger(args: argparse.Namespace) -> logging.Logger: + logger = logging.getLogger(__name__) + if args.verbose: + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler(sys.stdout)) + return logger diff --git a/src/ess/nmx/data/__init__.py b/src/ess/nmx/data/__init__.py index f8803dbb..36919908 100644 --- a/src/ess/nmx/data/__init__.py +++ b/src/ess/nmx/data/__init__.py @@ -21,6 +21,7 @@ def _make_pooch() -> pooch.Pooch: "small_mcstas_3_sample.h5": "md5:2afaac205d13ee857ee5364e3f1957a7", "mtz_samples.tar.gz": "md5:bed1eaf604bbe8725c1f6a20ca79fcc0", "mtz_random_samples.tar.gz": "md5:c8259ae2e605560ab88959e7109613b6", + "small_nmx_nexus.hdf": "md5:42cffb85e4ce7c1aaa5f7e81469b865e", }, ) @@ -89,3 +90,9 @@ def get_small_random_mtz_samples() -> list[pathlib.Path]: pathlib.Path(file_path) for file_path in _pooch.fetch("mtz_random_samples.tar.gz", processor=Untar()) ] + + +def get_small_nmx_nexus() -> str: + """Return the path to a small NMX NeXus file.""" + + return _pooch.fetch("small_nmx_nexus.hdf") diff --git a/src/ess/nmx/executables.py b/src/ess/nmx/executables.py new file mode 100644 index 00000000..61d8e634 --- /dev/null +++ b/src/ess/nmx/executables.py @@ -0,0 +1,402 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import argparse +import logging +import pathlib +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal + +import scipp as sc +import scippnexus as snx + +from .nexus import ( + _compute_positions, + _export_detector_metadata_as_nxlauetof, + _export_reduced_data_as_nxlauetof, + _export_static_metadata_as_nxlauetof, +) +from .streaming import _validate_chunk_size +from .types import Compression, NMXDetectorMetadata, NMXExperimentMetadata + + +def _retrieve_source_position(file: snx.File) -> sc.Variable: + da = file['entry/instrument/source'][()] + return _compute_positions(da, auto_fix_transformations=True)['position'] + + +def _retrieve_sample_position(file: snx.File) -> sc.Variable: + da = file['entry/sample'][()] + return _compute_positions(da, auto_fix_transformations=True)['position'] + + +def _decide_fast_axis(da: sc.DataArray) -> str: + x_slice = da['x_pixel_offset', 0].coords['detector_number'] + y_slice = da['y_pixel_offset', 0].coords['detector_number'] + + if (x_slice.max() < y_slice.max()).value: + return 'y' + elif (x_slice.max() > y_slice.max()).value: + return 'x' + else: + raise ValueError( + "Cannot decide fast axis based on pixel offsets. " + "Please specify the fast axis explicitly." + ) + + +def _decide_step(offsets: sc.Variable) -> sc.Variable: + """Decide the step size based on the offsets assuming at least 2 values.""" + sorted_offsets = sc.sort(offsets, key=offsets.dim, order='ascending') + return sorted_offsets[1] - sorted_offsets[0] + + +@dataclass +class DetectorDesc: + """Detector information extracted from McStas instrument xml description.""" + + name: str + id_start: int # 'idstart' + num_x: int # 'xpixels' + num_y: int # 'ypixels' + step_x: sc.Variable # 'xstep' + step_y: sc.Variable # 'ystep' + start_x: float # 'xstart' + start_y: float # 'ystart' + position: sc.Variable # 'x', 'y', 'z' + # Calculated fields + rotation_matrix: sc.Variable + fast_axis_name: str + slow_axis_name: str + fast_axis: sc.Variable + slow_axis: sc.Variable + + +def build_detector_desc( + name: str, dg: sc.DataGroup, *, fast_axis: Literal['x', 'y'] | None = None +) -> DetectorDesc: + da: sc.DataArray = dg['data'] + _fast_axis = fast_axis if fast_axis is not None else _decide_fast_axis(da) + transformation_matrix = dg['transform_matrix'] + t_unit = transformation_matrix.unit + fast_axis_vector = ( + sc.vector([1, 0, 0], unit=t_unit) + if _fast_axis == 'x' + else sc.vector([0, 1, 0], unit=t_unit) + ) + slow_axis_vector = ( + sc.vector([0, 1, 0], unit=t_unit) + if _fast_axis == 'x' + else sc.vector([1, 0, 0], unit=t_unit) + ) + return DetectorDesc( + name=name, + id_start=da.coords['detector_number'].min().value, + num_x=da.sizes['x_pixel_offset'], + num_y=da.sizes['y_pixel_offset'], + start_x=da.coords['x_pixel_offset'].min().value, + start_y=da.coords['y_pixel_offset'].min().value, + position=dg['position'], + rotation_matrix=dg['transform_matrix'], + fast_axis_name=_fast_axis, + slow_axis_name='x' if _fast_axis == 'y' else 'y', + fast_axis=fast_axis_vector, + slow_axis=slow_axis_vector, + step_x=_decide_step(da.coords['x_pixel_offset']), + step_y=_decide_step(da.coords['y_pixel_offset']), + ) + + +def calculate_number_of_chunks(detector_gr: snx.Group, *, chunk_size: int = 0) -> int: + _validate_chunk_size(chunk_size) + event_time_zero_size = detector_gr.sizes['event_time_zero'] + if chunk_size == -1: + return 1 # Read all at once + else: + return event_time_zero_size // chunk_size + int( + event_time_zero_size % chunk_size != 0 + ) + + +def build_toa_bin_edges( + *, + min_toa: sc.Variable | int = 0, + max_toa: sc.Variable | int = int((1 / 14) * 1_000), # Default for ESS NMX + toa_bin_edges: sc.Variable | int = 250, +) -> sc.Variable: + if isinstance(toa_bin_edges, sc.Variable): + return toa_bin_edges + elif isinstance(toa_bin_edges, int): + min_toa = sc.scalar(min_toa, unit='ms') if isinstance(min_toa, int) else min_toa + max_toa = sc.scalar(max_toa, unit='ms') if isinstance(max_toa, int) else max_toa + return sc.linspace( + dim='event_time_offset', + start=min_toa.value, + stop=max_toa.to(unit=min_toa.unit).value, + unit=min_toa.unit, + num=toa_bin_edges + 1, + ) + + +def reduction( + *, + input_file: pathlib.Path, + output_file: pathlib.Path, + chunk_size: int = 1_000, + detector_ids: list[int | str], + compression: Compression = Compression.BITSHUFFLE_LZ4, + logger: logging.Logger | None = None, + min_toa: sc.Variable | int = 0, + max_toa: sc.Variable | int = int((1 / 14) * 1_000), # Default for ESS NMX + toa_bin_edges: sc.Variable | int = 250, + fast_axis: Literal['x', 'y'] | None = None, # 'x', 'y', or None to auto-detect + display: Callable | None = None, # For Jupyter notebook display +) -> sc.DataGroup: + """Reduce NMX data from a Nexus file and export to NXLauetof(ESS NMX specific) file. + + This workflow is written as a flatten function without using sciline Pipeline. + It is because the first part of NMX reduction only requires + a few steps of processing and it is overkill to use a Pipeline or GenericWorkflow. + + We also do not apply frame unwrapping or pulse skipping here, + as it is not expected from NMX experiments. + + Frame unwrapping may be applied later on the result of this function if needed + however, then the whole range of `event_time_offset` should have been histogrammed + so that the unwrapping can be applied. + i.e. `min_toa` should be 0 and `max_toa` should be 1/14 seconds + for 14 Hz pulse frequency. + TODO: Implement tof/wavelength workflow for NMX. + + Parameters + ---------- + input_file: + Path to the input Nexus file containing NMX data. + output_file: + Path to the output file where reduced data will be saved. + chunk_size: + Number of pulses to process in each chunk. If <= 0, all data is processed + at once. It represents the number of event_time_zero entries to read at once. + detector_ids: + List of detector IDs (as integers or names) to process. + compression: + If True, the output data will be compressed. + logger: + Logger to use for logging messages. If None, a default logger is created. + min_toa: + Minimum time of arrival (TOA) in milliseconds. Default is 0 ms. + max_toa: + Maximum time of arrival (TOA) in milliseconds. Default is 1/14 seconds, + typical for ESS NMX. + toa_bin_edges: + Number of time of arrival (TOA) bin edges or a scipp Variable defining the + edges. Default is 250 edges. + display: + Callable for displaying messages, useful in Jupyter notebooks. If None, + defaults to logger.info. + + Returns + ------- + sc.DataGroup: + A DataGroup containing the reduced data for each selected detector. + + """ + import scippnexus as snx + + if logger is None: + logger = logging.getLogger(__name__) + if display is None: + display = logger.info + + toa_bin_edges = build_toa_bin_edges( + min_toa=min_toa, max_toa=max_toa, toa_bin_edges=toa_bin_edges + ) + with snx.File(input_file) as f: + intrument_group = f['entry/instrument'] + dets = intrument_group[snx.NXdetector] + detector_group_keys = list(dets.keys()) + display(f"Found NXdetectors: {detector_group_keys}") + detector_id_map = { + det_name: dets[det_name] + for i, det_name in enumerate(detector_group_keys) + if i in detector_ids or det_name in detector_ids + } + if len(detector_id_map) != len(detector_ids): + raise ValueError( + f"Requested detector ids {detector_ids} not found in the file.\n" + f"Found {detector_group_keys}\n" + f"Try using integer indices instead of names." + ) + display(f"Selected detectors: {list(detector_id_map.keys())}") + source_position = _retrieve_source_position(f) + sample_position = _retrieve_sample_position(f) + experiment_metadata = NMXExperimentMetadata( + sc.DataGroup( + { + # Placeholder for crystal rotation + 'crystal_rotation': sc.vector([0, 0, 0], unit='deg'), + 'sample_position': sample_position, + 'source_position': source_position, + 'sample_name': sc.scalar(f['entry/sample/name'][()]), + } + ) + ) + display(experiment_metadata) + display("Experiment metadata component:") + for name, component in experiment_metadata.items(): + display(f"{name}: {component}") + + _export_static_metadata_as_nxlauetof( + experiment_metadata=experiment_metadata, + output_file=output_file, + ) + detector_grs = {} + for det_name, det_group in detector_id_map.items(): + display(f"Processing {det_name}") + if chunk_size <= 0: + dg = det_group[()] + else: + # Slice the first chunk for metadata extraction + dg = det_group['event_time_zero', 0:chunk_size] + + display("Computing detector positions...") + display(dg := _compute_positions(dg, auto_fix_transformations=True)) + detector = build_detector_desc(det_name, dg, fast_axis=fast_axis) + detector_meta = sc.DataGroup( + { + 'fast_axis': detector.fast_axis, + 'slow_axis': detector.slow_axis, + 'origin_position': sc.vector([0, 0, 0], unit='m'), + 'position': detector.position, + 'detector_shape': sc.scalar( + ( + dg['data'].sizes['x_pixel_offset'], + dg['data'].sizes['y_pixel_offset'], + ) + ), + 'x_pixel_size': detector.step_x, + 'y_pixel_size': detector.step_y, + 'detector_name': sc.scalar(detector.name), + } + ) + _export_detector_metadata_as_nxlauetof( + NMXDetectorMetadata(detector_meta), output_file=output_file + ) + + da: sc.DataArray = dg['data'] + event_time_offset_unit = da.bins.coords['event_time_offset'].bins.unit + display("Event time offset unit: %s", event_time_offset_unit) + toa_bin_edges = toa_bin_edges.to(unit=event_time_offset_unit, copy=False) + if chunk_size <= 0: + counts = da.hist(event_time_offset=toa_bin_edges).rename_dims( + x_pixel_offset='x', y_pixel_offset='y', event_time_offset='t' + ) + counts.coords['t'] = counts.coords['event_time_offset'] + + else: + num_chunks = calculate_number_of_chunks( + det_group, chunk_size=chunk_size + ) + display(f"Number of chunks: {num_chunks}") + counts = da.hist(event_time_offset=toa_bin_edges).rename_dims( + x_pixel_offset='x', y_pixel_offset='y', event_time_offset='t' + ) + counts.coords['t'] = counts.coords['event_time_offset'] + for chunk_index in range(1, num_chunks): + cur_chunk = det_group[ + 'event_time_zero', + chunk_index * chunk_size : (chunk_index + 1) * chunk_size, + ] + display(f"Processing chunk {chunk_index + 1} of {num_chunks}") + cur_chunk = _compute_positions( + cur_chunk, auto_fix_transformations=True + ) + cur_counts = ( + cur_chunk['data'] + .hist(event_time_offset=toa_bin_edges) + .rename_dims( + x_pixel_offset='x', + y_pixel_offset='y', + event_time_offset='t', + ) + ) + cur_counts.coords['t'] = cur_counts.coords['event_time_offset'] + counts += cur_counts + display("Accumulated counts:") + display(counts.sum().data) + + dg = sc.DataGroup( + counts=counts, + detector_shape=detector_meta['detector_shape'], + detector_name=detector_meta['detector_name'], + ) + display("Final data group:") + display(dg) + display("Saving reduced data to Nexus file...") + _export_reduced_data_as_nxlauetof( + dg, + output_file=output_file, + compress_counts=(compression == Compression.NONE), + ) + detector_grs[det_name] = dg + + display("Reduction completed successfully.") + return sc.DataGroup(detector_grs) + + +def _add_ess_reduction_args(arg: argparse.ArgumentParser) -> None: + argument_group = arg.add_argument_group("ESS Reduction Options") + argument_group.add_argument( + "--chunk_size", + type=int, + default=-1, + help="Chunk size for processing (number of pulses per chunk).", + ) + argument_group.add_argument( + "--min-toa", + type=int, + default=0, + help="Minimum time of arrival (TOA) in ms.", + ) + argument_group.add_argument( + "--max-toa", + type=int, + default=int((1 / 14) * 1_000), + help="Maximum time of arrival (TOA) in ms.", + ) + argument_group.add_argument( + "--fast-axis", + type=str, + choices=['x', 'y', None], + default=None, + help="Specify the fast axis of the detector. If None, it will be determined " + "automatically based on the pixel offsets.", + ) + + +def main() -> None: + from ._executable_helper import build_logger, build_reduction_arg_parser + + parser = build_reduction_arg_parser() + _add_ess_reduction_args(parser) + args = parser.parse_args() + + input_file = pathlib.Path(args.input_file).resolve() + output_file = pathlib.Path(args.output_file).resolve() + logger = build_logger(args) + + logger.info("Input file: %s", input_file) + logger.info("Output file: %s", output_file) + + reduction( + input_file=input_file, + output_file=output_file, + chunk_size=args.chunk_size, + detector_ids=args.detector_ids, + compression=Compression[args.compression], + toa_bin_edges=args.nbins, + min_toa=sc.scalar(args.min_toa, unit='ms'), + max_toa=sc.scalar(args.max_toa, unit='ms'), + fast_axis=args.fast_axis, + logger=logger, + ) diff --git a/src/ess/nmx/mcstas/executables.py b/src/ess/nmx/mcstas/executables.py index 47abe419..14ec7f68 100644 --- a/src/ess/nmx/mcstas/executables.py +++ b/src/ess/nmx/mcstas/executables.py @@ -3,7 +3,6 @@ import argparse import logging import pathlib -import sys from collections.abc import Callable from functools import partial @@ -24,6 +23,7 @@ ) from ..streaming import calculate_number_of_chunks from ..types import ( + Compression, DetectorIndex, DetectorName, FilePath, @@ -139,10 +139,10 @@ def reduction( input_file: pathlib.Path, output_file: pathlib.Path, chunk_size: int = 10_000_000, - nbins: int = 51, + nbins: int = 50, max_counts: int | None = None, detector_ids: list[int | str], - compression: bool = True, + compression: Compression = Compression.BITSHUFFLE_LZ4, wf: sl.Pipeline | None = None, logger: logging.Logger | None = None, toa_min_max_prob: tuple[float] | None = None, @@ -161,7 +161,10 @@ def reduction( logger.info("Metadata retrieved: %s", data_metadata) toa_bin_edges = sc.linspace( - dim='t', start=data_metadata.min_toa, stop=data_metadata.max_toa, num=nbins + dim='t', + start=data_metadata.min_toa, + stop=data_metadata.max_toa, + num=nbins + 1, ) scale_factor = mcstas_weight_to_probability_scalefactor( max_counts=wf.compute(MaximumCounts), @@ -173,7 +176,7 @@ def reduction( toa_min = sc.scalar(toa_min_max_prob[0], unit='s') toa_max = sc.scalar(toa_min_max_prob[1], unit='s') prob_max = sc.scalar(toa_min_max_prob[2]) - toa_bin_edges = sc.linspace(dim='t', start=toa_min, stop=toa_max, num=nbins) + toa_bin_edges = sc.linspace(dim='t', start=toa_min, stop=toa_max, num=nbins + 1) scale_factor = mcstas_weight_to_probability_scalefactor( max_counts=wf.compute(MaximumCounts), max_probability=prob_max, @@ -253,69 +256,44 @@ def reduction( result_list.append(result) if logger is not None: logger.info("Appending reduced data into the output file %s", output_file) + _export_reduced_data_as_nxlauetof( - result, output_file=output_file, compress_counts=compression + result, + output_file=output_file, + compress_counts=(compression == Compression.NONE), ) from ess.nmx.reduction import merge_panels return merge_panels(*result_list) -def main() -> None: - parser = argparse.ArgumentParser(description="McStas Data Reduction.") - parser.add_argument( - "--input_file", type=str, help="Path to the input file", required=True - ) - parser.add_argument( - "--output_file", - type=str, - default="scipp_output.h5", - help="Path to the output file", - ) - parser.add_argument( - "--verbose", action="store_true", help="Increase output verbosity" - ) - parser.add_argument( - "--chunk_size", - type=int, - default=10_000_000, - help="Chunk size for processing", - ) - parser.add_argument( - "--nbins", - type=int, - default=51, - help="Number of TOF bins", - ) - parser.add_argument( +def _add_mcstas_args(parser: argparse.ArgumentParser) -> None: + mcstas_arg_group = parser.add_argument_group("McStas Data Reduction Options") + mcstas_arg_group.add_argument( "--max_counts", type=int, default=None, help="Maximum Counts", ) - parser.add_argument( - "--detector_ids", + mcstas_arg_group.add_argument( + "--chunk_size", type=int, - nargs="+", - default=[0, 1, 2], - help="Detector indices to process", - ) - parser.add_argument( - "--compression", - type=bool, - default=True, - help="Compress reduced output with bitshuffle/lz4", + default=10_000_000, + help="Chunk size for processing (number of events per chunk)", ) + +def main() -> None: + from .._executable_helper import build_logger, build_reduction_arg_parser + + parser = build_reduction_arg_parser() + _add_mcstas_args(parser) args = parser.parse_args() input_file = pathlib.Path(args.input_file).resolve() output_file = pathlib.Path(args.output_file).resolve() - logger = logging.getLogger(__name__) - if args.verbose: - logger.setLevel(logging.INFO) - logger.addHandler(logging.StreamHandler(sys.stdout)) + logger = build_logger(args) wf = McStasWorkflow() reduction( @@ -325,7 +303,7 @@ def main() -> None: nbins=args.nbins, max_counts=args.max_counts, detector_ids=args.detector_ids, - compression=args.compression, + compression=Compression[args.compression], logger=logger, wf=wf, ) diff --git a/src/ess/nmx/nexus.py b/src/ess/nmx/nexus.py index 42a11156..64f31b36 100644 --- a/src/ess/nmx/nexus.py +++ b/src/ess/nmx/nexus.py @@ -22,6 +22,73 @@ ) +def _fallback_compute_positions(dg: sc.DataGroup) -> sc.DataGroup: + import warnings + + import scippnexus as snx + + warnings.warn( + "Using fallback compute_positions due to empty log entries. " + "This may lead to incorrect results. Please check the data carefully." + "The fallback will replace empty logs with a scalar value of zero.", + UserWarning, + stacklevel=2, + ) + + empty_transformations = [ + transformation + for transformation in dg['depends_on'].transformations.values() + if 'time' in transformation.value.dims + and transformation.sizes['time'] == 0 # empty log + ] + for transformation in empty_transformations: + orig_value = transformation.value + orig_value = sc.scalar(0, unit=orig_value.unit, dtype=orig_value.dtype) + transformation.value = orig_value + return snx.compute_positions(dg, store_transform='transform_matrix') + + +def _compute_positions( + dg: sc.DataGroup, auto_fix_transformations: bool = False +) -> sc.DataGroup: + """Compute positions of the data group from transformations. + + Wraps the `scippnexus.compute_positions` function + and provides a fallback for cases where the transformations + contain empty logs. + + Parameters + ---------- + dg: + Data group containing the transformations and data. + auto_fix_transformations: + If `True`, it will attempt to fix empty transformations. + It will replace them with a scalar value of zero. + It is because adding a time dimension will make it not possible + to compute positions of children due to time-dependent transformations. + + Returns + ------- + : + Data group with computed positions. + + Warnings + -------- + If `auto_fix_transformations` is `True`, it will warn about the fallback + being used due to empty logs or scalar transformations. + This is because the fallback may lead to incorrect results. + + """ + import scippnexus as snx + + try: + return snx.compute_positions(dg, store_transform='transform_matrix') + except ValueError as e: + if auto_fix_transformations: + return _fallback_compute_positions(dg) + raise e + + def _create_dataset_from_string(*, root_entry: h5py.Group, name: str, var: str) -> None: root_entry.create_dataset(name, dtype=h5py.string_dtype(), data=var) @@ -428,6 +495,16 @@ def _export_detector_metadata_as_nxlauetof( _add_lauetof_detector_group(detector_metadata, nx_instrument) +def _extract_counts(dg: sc.DataGroup) -> sc.Variable: + counts: sc.DataArray = dg['counts'].data + if 'id' in counts.dims: + num_x, num_y = dg["detector_shape"].value + return sc.fold(counts, dim='id', sizes={'x': num_x, 'y': num_y}) + else: + # If there is no 'id' dimension, we assume it is already in the correct shape + return counts + + def _export_reduced_data_as_nxlauetof( dg: NMXReducedDataGroup, output_file: str | pathlib.Path | io.BytesIO, @@ -471,9 +548,7 @@ def _export_reduced_data_as_nxlauetof( data_dset = _create_compressed_dataset( name="data", root_entry=nx_detector, - var=sc.fold( - dg['counts'].data, dim='id', sizes={'x': num_x, 'y': num_y} - ), + var=_extract_counts(dg), chunks=(num_x, num_y, 1), dtype=np.uint, ) @@ -481,9 +556,7 @@ def _export_reduced_data_as_nxlauetof( data_dset = _create_dataset_from_var( name="data", root_entry=nx_detector, - var=sc.fold( - dg['counts'].data, dim='id', sizes={'x': num_x, 'y': num_y} - ), + var=_extract_counts(dg), dtype=np.uint, ) data_dset.attrs["signal"] = 1 diff --git a/src/ess/nmx/types.py b/src/ess/nmx/types.py index 0d629021..356117e0 100644 --- a/src/ess/nmx/types.py +++ b/src/ess/nmx/types.py @@ -1,3 +1,4 @@ +import enum from dataclasses import dataclass from typing import Any, NewType @@ -76,3 +77,13 @@ class NMXRawDataMetadata: max_probability: MaximumProbability min_toa: MinimumTimeOfArrival max_toa: MaximumTimeOfArrival + + +class Compression(enum.Enum): + """Compression type of the output file. + + These options are written as enum for future extensibility. + """ + + NONE = 0 + BITSHUFFLE_LZ4 = 1 diff --git a/tests/executable_test.py b/tests/executable_test.py new file mode 100644 index 00000000..f96a56f9 --- /dev/null +++ b/tests/executable_test.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) + +import pathlib +import subprocess + +import pytest +import scipp as sc +import scippnexus as snx +from scipp.testing import assert_allclose + + +@pytest.fixture(scope="session") +def small_nmx_nexus_path(): + """Fixture to provide the path to the small NMX NeXus file.""" + from ess.nmx.data import get_small_nmx_nexus + + return get_small_nmx_nexus() + + +def _check_output_file( + output_file_path: pathlib.Path, expected_toa_output: sc.Variable +): + detector_names = [f'detector_panel_{i}' for i in range(3)] + with snx.File(output_file_path, 'r') as f: + # Test + for name in detector_names: + det_gr = f[f'entry/instrument/{name}'] + assert det_gr is not None + toa_edges = det_gr['time_of_flight'][()] + assert_allclose(toa_edges, expected_toa_output) + + +def test_executable_runs(small_nmx_nexus_path, tmp_path: pathlib.Path): + """Test that the executable runs and returns the expected output.""" + output_file = tmp_path / "output.h5" + assert not output_file.exists() + + nbins = 20 # Small number of bins for testing. + # The output has 1280x1280 pixels per detector per time bin. + expected_toa_bins = sc.linspace( + dim='dim_0', + start=2, # Unrealistic number for testing + stop=int((1 / 15) * 1_000), # Unrealistic number for testing + num=nbins + 1, + unit='ms', + ) + expected_toa_output = sc.midpoints(expected_toa_bins, dim='dim_0').to(unit='ns') + + commands = ( + 'essnmx-reduce', + '--input_file', + small_nmx_nexus_path, + '--nbins', + str(nbins), + '--output_file', + output_file.as_posix(), + '--min-toa', + str(int(expected_toa_bins.min().value)), + '--max-toa', + str(int(expected_toa_bins.max().value)), + ) + # Validate that all commands are strings and contain no unsafe characters + result = subprocess.run( # noqa: S603 - We are not accepting arbitrary input here. + commands, + text=True, + capture_output=True, + check=False, + ) + assert result.returncode == 0 + assert output_file.exists() + _check_output_file(output_file, expected_toa_output=expected_toa_output)