From b99d64f83299a7b8f65efefa573d9af9e9ce0dc2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 17 Oct 2025 12:43:52 -0400 Subject: [PATCH 1/7] start neo support --- pynapple/io/neo.py | 373 +++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 374 insertions(+) create mode 100644 pynapple/io/neo.py diff --git a/pynapple/io/neo.py b/pynapple/io/neo.py new file mode 100644 index 000000000..feee0de3d --- /dev/null +++ b/pynapple/io/neo.py @@ -0,0 +1,373 @@ +import neo +import numpy as np +import pynapple as nap +from neo.io.proxyobjects import AnalogSignalProxy, SpikeTrainProxy +from neo.core.spiketrainlist import SpikeTrainList +import pathlib + + +class NEOSignalInterface: + + def __init__(self, signal, block, time_support, sig_num=None): + self.time_support = time_support + if isinstance(signal, (neo.AnalogSignal, AnalogSignalProxy)): + self.is_analog = True + self.nap_type = self._get_meta_analog(signal) + elif isinstance(signal, (neo.SpikeTrain, SpikeTrainProxy)): + self.nap_type = nap.Ts + self.is_analog = False + elif isinstance(signal, (list, SpikeTrainList)): + self.nap_type = nap.TsGroup + self.is_analog = False + else: + raise TypeError(f"signal type {type(signal)} not recognized.") + self._block = block + self._sig_num = sig_num + + if self.is_analog: + self.dt = (1 / signal.sampling_rate).rescale("s").magnitude + self.shape = signal.shape + if not issubclass(self.nap_type, nap.TsGroup): + self.start_time = signal.t_start.rescale("s").magnitude + self.end_time = signal.t_stop.rescale("s").magnitude + else: + self.start_time = [s.t_start.rescale("s").magnitude for s in signal] + self.end_time = [s.t_stop.rescale("s").magnitude for s in signal] + + @staticmethod + def _get_meta_analog(signal): + if len(signal.shape) == 1: + nap_type = nap.Tsd + elif len(signal.shape) == 2: + nap_type = nap.TsdFrame + else: + nap_type = nap.TsdTensor + return nap_type + + def __getitem__(self, item): + if isinstance(item, slice): + return self._get_from_slice(item) + raise ValueError(f"Cannot get item {item}.") + + def get(self, start, stop): + """Get data between start and stop times.""" + if self.is_analog: + return self._get_analog(start, stop) + elif issubclass(self.nap_type, nap.Ts): + return self._get_ts(self._sig_num, start, stop) + else: # TsGroup + return self._get_tsgroup(start, stop) + + def restrict(self, epoch): + """Restrict data to epochs.""" + if self.is_analog: + return self._restrict_analog(epoch) + elif issubclass(self.nap_type, nap.Ts): + return self._restrict_ts(epoch) + else: # TsGroup + return self._restrict_tsgroup(epoch) + + def _get_from_slice(self, slc): + start = slc.start if slc.start is not None else 0 + stop = slc.stop + step = slc.step if slc.step is not None else 1 + + if self.is_analog: + if stop is None: + stop = sum(s.analogsignals[self._sig_num].shape[0] for s in self._block.segments) + return self._slice_segment_analog(start, stop, step) + elif issubclass(self.nap_type, nap.Ts): + if stop is None: + stop = sum(len(s.spiketrains[self._sig_num]) for s in self._block.segments) + return self._slice_segment_ts(start, stop, step) + else: + raise ValueError("Cannot slice a TsGroup.") + + def _instantiate_nap(self, time, data, time_support): + return self.nap_type( + t=time, + d=data, + time_support=time_support, + ) + + def _concatenate_array(self, time_list, data_list): + if len(data_list) == 0: + return np.array([]), np.array([]).reshape((0, *self.shape[1:]) if len(self.shape) > 1 else (0, 1)) + else: + return np.concatenate(time_list), np.concatenate(data_list, axis=0) + + # ========== Analog Signal Methods ========== + + def _get_analog(self, start, stop, return_array=False): + """Get analog signal between start and stop times.""" + data = [] + time = [] + + for i, seg in enumerate(self._block.segments): + signal = seg.analogsignals[self._sig_num] + + # Get segment boundaries + seg_start = self.time_support.start[i] + seg_stop = self.time_support.end[i] + + # Check if requested time overlaps with this segment + if start >= seg_stop or stop <= seg_start: + continue # No overlap, skip this segment + + # Clip to segment bounds + chunk_start = max(start, seg_start) + chunk_stop = min(stop, seg_stop) + + chunk = signal.time_slice(chunk_start, chunk_stop) + + if chunk.shape[0] > 0: # Has data + data.append(chunk.magnitude) + time.append(chunk.times.rescale("s").magnitude) + + time, data = self._concatenate_array(time, data) + if not return_array: + return self._instantiate_nap(time, data, time_support=self.time_support) + else: + return time, data + + def _restrict_analog(self, epoch): + """Restrict analog signal to epochs.""" + time = [] + data = [] + + for start, end in epoch.values: + time_ep, data_ep = self._get_analog(start, end, return_array=True) + time.append(time_ep) + data.append(data_ep) + + time, data = self._concatenate_array(time, data) + return self._instantiate_nap(time, data, self.time_support).restrict(epoch) + + def _slice_segment_analog(self, start_idx, stop_idx, step): + """Load by exact indices from each segment.""" + data = [] + time = [] + + for i, seg in enumerate(self._block.segments): + signal = seg.analogsignals[self._sig_num] + + # Segment boundaries from time_support (already in seconds) + seg_start_time = self.time_support.start[i] + seg_end_time = self.time_support.end[i] + seg_duration = seg_end_time - seg_start_time + seg_n_samples = signal.shape[0] + + # Actual dt for this segment + dt = seg_duration / seg_n_samples + + # Clip indices to segment bounds + seg_start_idx = max(0, start_idx) + seg_stop_idx = min(seg_n_samples, stop_idx) + + if seg_start_idx >= seg_stop_idx: + continue # No overlap with this segment + + # Load full segment and slice exactly + try: + signal_loaded = signal.load() + chunk = signal_loaded[seg_start_idx:seg_stop_idx:step] + + except MemoryError: + # Fallback: use time_slice + chunk_start_time = seg_start_time + seg_start_idx * dt + chunk_stop_time = seg_start_time + seg_stop_idx * dt + chunk = signal.time_slice(chunk_start_time, chunk_stop_time) + + if step != 1: + chunk = chunk[::step] + + data.append(chunk.magnitude) + time.append(chunk.times.rescale("s").magnitude) + + time, data = self._concatenate_array(time, data) + return self._instantiate_nap(time, data, time_support=self.time_support) + + # ========== Spike Train (Ts) Methods ========== + + def _get_ts(self, unit_idx, start, stop, return_array=False): + """Get spike times for a unit within time range.""" + spikes = [] + + for i, seg in enumerate(self._block.segments): + spiketrain = seg.spiketrains[unit_idx] + + # Get segment boundaries + seg_start = self.time_support.start[i] + seg_stop = self.time_support.end[i] + + # Check if requested time overlaps with this segment + if start >= seg_stop or stop <= seg_start: + continue # No overlap + + # Clip to segment bounds + chunk_start = max(start, seg_start) + chunk_stop = min(stop, seg_stop) + + chunk = spiketrain.time_slice(chunk_start, chunk_stop) + + if len(chunk) > 0: # Has spikes + spikes.append(chunk.times.rescale("s").magnitude) + + spike_times = np.concatenate(spikes) if spikes else np.array([]) + + if return_array: + return spike_times + else: + return nap.Ts(t=spike_times, time_support=self.time_support) + + def _restrict_ts(self, epoch): + """Restrict spike train to epochs.""" + spikes = [] + + for start, end in epoch.values: + spike_times = self._get_ts(self._sig_num, start, end, return_array=True) + if len(spike_times) > 0: + spikes.append(spike_times) + + spike_times = np.concatenate(spikes) if spikes else np.array([]) + return nap.Ts(t=spike_times, time_support=self.time_support).restrict(epoch) + + def _slice_segment_ts(self, start_idx, stop_idx, step): + """Slice spike trains by spike index.""" + spikes = [] + + for i, seg in enumerate(self._block.segments): + spiketrain = seg.spiketrains[self._sig_num] + + # Get number of spikes in this segment + n_spikes = len(spiketrain) + + # Clip indices to segment bounds + seg_start_idx = max(0, start_idx) + seg_stop_idx = min(n_spikes, stop_idx) + + if seg_start_idx >= seg_stop_idx: + continue # No overlap + + # Load and slice by spike index + spiketrain_loaded = spiketrain.load() if hasattr(spiketrain, 'load') else spiketrain + chunk = spiketrain_loaded[seg_start_idx:seg_stop_idx:step] + + spikes.append(chunk.times.rescale("s").magnitude) + + return nap.Ts( + t=np.concatenate(spikes) if spikes else np.array([]), + time_support=self.time_support + ) + + # ========== TsGroup Methods ========== + + def _get_tsgroup(self, start, stop): + """Get TsGroup (all units) within time range.""" + n_units = len(self._block.segments[0].spiketrains) + ts_dict = {} + + for unit_idx in range(n_units): + spike_times = self._get_ts(unit_idx, start, stop, return_array=True) + ts_dict[unit_idx] = nap.Ts(t=spike_times, time_support=self.time_support) + + return nap.TsGroup(ts_dict, time_support=self.time_support) + + def _restrict_tsgroup(self, epoch): + """Restrict TsGroup to epochs.""" + n_units = len(self._block.segments[0].spiketrains) + ts_dict = {} + + for unit_idx in range(n_units): + spikes = [] + for start, end in epoch.values: + spike_times = self._get_ts(unit_idx, start, end, return_array=True) + if len(spike_times) > 0: + spikes.append(spike_times) + + spike_times = np.concatenate(spikes) if spikes else np.array([]) + ts_dict[unit_idx] = nap.Ts(t=spike_times, time_support=self.time_support) + + return nap.TsGroup(ts_dict, time_support=self.time_support).restrict(epoch) + + +class NEOExperimentInterface: + def __init__(self, reader, lazy=False): + # block, aka experiments (contains multiple segments, aka trials) + self._reader = reader + self._lazy = lazy + self.experiment = self._collect_time_series_info() + self._reader = reader + + def _collect_time_series_info(self): + blocks = self._reader.read(lazy=self._lazy) + + experiments = {} + for i, block in enumerate(blocks): + name = f"block {i}" + if block.name: + name += ": " + block.name + experiments[name] = {} + # loop once to get the time support + starts, ends = np.empty(len(block.segments)), np.empty(len(block.segments)) + for trial_num, segment in enumerate(block.segments): + starts[trial_num] = segment.t_start.rescale("s").magnitude + ends[trial_num] = segment.t_stop.rescale("s").magnitude + + iset = nap.IntervalSet(starts, ends) + for trial_num, segment in enumerate(block.segments): + # segment may contain epoch (potentially overlapping) + # with fields: times, durations, labels. We may add them to metadata. + + # tsd/tsdFrame/TsdTensor + for signal_num, signal in enumerate(segment.analogsignals): + if signal.name: + signame = f" {signal_num}: " + signal.name + else: + signame = f" {signal_num}" + signal_interface = NEOSignalInterface(signal, block, iset, sig_num=signal_num) + signame = signal_interface.nap_type.__name__ + signame + experiments[name][signame] = signal_interface + + if len(segment.spiketrains) == 1: + signal = segment.spiketrains[0] + signal_interface = NEOSignalInterface(signal, block, iset, sig_num=0) + signame = f"Ts" + ": " + signal.name if signal.name else "Ts" + experiments[name][signame] = signal_interface + else: + signame = f"TsGroup" + experiments[name][signame] = NEOSignalInterface(segment.spiketrains, block, iset) + return experiments + + def __getitem__(self, item): + if isinstance(item, str): + return self.experiment[item] + else: + res = self.experiment + for it in item: + res = res[it] + return res + + def keys(self): + return [(k, k2) for k in self.experiment.keys() for k2 in self.experiment[k]] + + +def load_experiment(path: str | pathlib.Path, lazy: bool = True) -> NEOExperimentInterface: + """ + Load a neural recording experiment. + + Parameters + ---------- + path : str or Path + Path to the recording file + lazy : bool, default True + Whether to lazy load the data + + Returns + ------- + NEOExperimentInterface + """ + path = pathlib.Path(path) + reader = neo.io.get_io(path) + + return NEOExperimentInterface(reader, lazy=lazy) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1470f95e2..38e13f95d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "h5py", "rich", "xarray>=2023.1.0", + "neo", ] requires-python = ">=3.8" From ed7dcee7c3fbe2aa398003730ba91adb1eab658f Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 8 Jan 2026 17:07:00 -0500 Subject: [PATCH 2/7] NeoReader --- pynapple/io/__init__.py | 8 + pynapple/io/neo.py | 1230 ++++++++++++++++++++++++++++++++++++--- tests/test_neo.py | 563 ++++++++++++++++++ 3 files changed, 1725 insertions(+), 76 deletions(-) create mode 100644 tests/test_neo.py diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index f4eb2a70b..f7991d9ca 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -8,3 +8,11 @@ load_folder, load_session, ) +from .neo import ( + NeoReader, + load_file as load_neo_file, + # to_neo_analogsignal, + # to_neo_spiketrain, + # to_neo_epoch, + # to_neo_event, +) diff --git a/pynapple/io/neo.py b/pynapple/io/neo.py index feee0de3d..b6f52063d 100644 --- a/pynapple/io/neo.py +++ b/pynapple/io/neo.py @@ -1,71 +1,531 @@ -import neo +""" +Pynapple interface for Neo (neural electrophysiology objects). + +Neo is a Python package for working with electrophysiology data in Python, +supporting many file formats through a unified API. + +Data are lazy-loaded by default. The interface behaves like a dictionary. + +For more information on Neo, see: https://neo.readthedocs.io/ + +Neo to Pynapple Object Conversion +--------------------------------- +The following Neo objects are converted to their pynapple equivalents: + +- neo.AnalogSignal -> Tsd, TsdFrame, or TsdTensor (depending on shape) +- neo.IrregularlySampledSignal -> Tsd, TsdFrame, or TsdTensor (depending on shape) +- neo.SpikeTrain -> Ts +- neo.SpikeTrain (list) -> TsGroup +- neo.SpikeTrainList -> TsGroup +- neo.Epoch -> IntervalSet +- neo.Event -> Ts +""" + +import warnings +from collections import UserDict +from pathlib import Path +from typing import Union, Optional, Dict, Any, List + import numpy as np -import pynapple as nap -from neo.io.proxyobjects import AnalogSignalProxy, SpikeTrainProxy -from neo.core.spiketrainlist import SpikeTrainList -import pathlib +try: + import neo + from neo.io.proxyobjects import ( + AnalogSignalProxy, + SpikeTrainProxy, + EpochProxy, + EventProxy, + ) + from neo.core.spiketrainlist import SpikeTrainList + + HAS_NEO = True +except ImportError: + HAS_NEO = False + +try: + from tabulate import tabulate + + HAS_TABULATE = True +except ImportError: + HAS_TABULATE = False + +from .. import core as nap + + +def _check_neo_installed(): + """Check if neo is installed and raise ImportError if not.""" + if not HAS_NEO: + raise ImportError( + "Neo is required for this functionality. " + "Install it with: pip install neo" + ) + + +def _rescale_to_seconds(quantity): + """Convert a neo quantity to seconds. + + Parameters + ---------- + quantity : neo.core.baseneo.BaseNeo or quantities.Quantity + A quantity with time units + + Returns + ------- + float + Value in seconds + """ + return float(quantity.rescale("s").magnitude) + + +def _get_signal_type(signal) -> type: + """Determine the appropriate pynapple type for a Neo signal. + + Parameters + ---------- + signal : neo.AnalogSignal or neo.IrregularlySampledSignal + The Neo signal object + + Returns + ------- + type + The pynapple type (Tsd, TsdFrame, or TsdTensor) + """ + if len(signal.shape) == 1: + return nap.Tsd + elif len(signal.shape) == 2: + if signal.shape[1] == 1: + return nap.Tsd + return nap.TsdFrame + else: + return nap.TsdTensor + + +def _extract_annotations(obj) -> Dict[str, Any]: + """Extract annotations from a Neo object. + + Parameters + ---------- + obj : neo.core.baseneo.BaseNeo + Any Neo object with annotations + + Returns + ------- + dict + Dictionary of annotations + """ + annotations = {} + if hasattr(obj, "annotations") and obj.annotations: + annotations.update(obj.annotations) + if hasattr(obj, "name") and obj.name: + annotations["neo_name"] = obj.name + if hasattr(obj, "description") and obj.description: + annotations["neo_description"] = obj.description + return annotations + + +def _extract_array_annotations(obj) -> Dict[str, np.ndarray]: + """Extract array annotations from a Neo object. + + Parameters + ---------- + obj : neo.core.baseneo.BaseNeo + Any Neo object with array_annotations + + Returns + ------- + dict + Dictionary of array annotations + """ + if hasattr(obj, "array_annotations") and obj.array_annotations: + return dict(obj.array_annotations) + return {} + + +# ============================================================================= +# Conversion functions: Neo -> Pynapple +# ============================================================================= + + +def _make_intervalset_from_epoch(epoch, time_support: Optional[nap.IntervalSet] = None) -> nap.IntervalSet: + """Convert a Neo Epoch to a pynapple IntervalSet. + + Parameters + ---------- + epoch : neo.Epoch or neo.io.proxyobjects.EpochProxy + Neo Epoch object + time_support : IntervalSet, optional + Time support for the IntervalSet + + Returns + ------- + IntervalSet + Pynapple IntervalSet + """ + if hasattr(epoch, "load"): + epoch = epoch.load() + + times = epoch.times.rescale("s").magnitude + durations = epoch.durations.rescale("s").magnitude + + starts = times + ends = times + durations + + # Extract labels as metadata if available + metadata = {} + if hasattr(epoch, "labels") and len(epoch.labels) > 0: + metadata["label"] = np.array(epoch.labels) + + # Add any other annotations + annotations = _extract_annotations(epoch) + + iset = nap.IntervalSet(start=starts, end=ends, metadata=metadata) + + return iset + + +def _make_ts_from_event(event, time_support: Optional[nap.IntervalSet] = None) -> nap.Ts: + """Convert a Neo Event to a pynapple Ts. + + Parameters + ---------- + event : neo.Event or neo.io.proxyobjects.EventProxy + Neo Event object + time_support : IntervalSet, optional + Time support for the Ts + + Returns + ------- + Ts + Pynapple Ts object + """ + if hasattr(event, "load"): + event = event.load() + + times = event.times.rescale("s").magnitude + + return nap.Ts(t=times, time_support=time_support) + + +def _make_tsd_from_analog( + signal, + time_support: Optional[nap.IntervalSet] = None, + column_names: Optional[List[str]] = None, +) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: + """Convert a Neo AnalogSignal to a pynapple Tsd/TsdFrame/TsdTensor. + + Parameters + ---------- + signal : neo.AnalogSignal or AnalogSignalProxy + Neo analog signal + time_support : IntervalSet, optional + Time support + column_names : list of str, optional + Column names for TsdFrame + + Returns + ------- + Tsd, TsdFrame, or TsdTensor + Appropriate pynapple time series object + """ + if hasattr(signal, "load"): + signal = signal.load() + + times = signal.times.rescale("s").magnitude + data = signal.magnitude + + nap_type = _get_signal_type(signal) + + if nap_type == nap.Tsd: + if len(data.shape) == 2: + data = data.squeeze() + return nap.Tsd(t=times, d=data, time_support=time_support) + elif nap_type == nap.TsdFrame: + if column_names is None: + # Try to get channel names from annotations + if hasattr(signal, "array_annotations"): + channel_names = signal.array_annotations.get("channel_names", None) + if channel_names is not None: + column_names = list(channel_names) + return nap.TsdFrame(t=times, d=data, columns=column_names, time_support=time_support) + else: + return nap.TsdTensor(t=times, d=data, time_support=time_support) + + +def _make_tsd_from_irregular( + signal, + time_support: Optional[nap.IntervalSet] = None, + column_names: Optional[List[str]] = None, +) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: + """Convert a Neo IrregularlySampledSignal to a pynapple Tsd/TsdFrame/TsdTensor. + + Parameters + ---------- + signal : neo.IrregularlySampledSignal + Neo irregularly sampled signal + time_support : IntervalSet, optional + Time support + column_names : list of str, optional + Column names for TsdFrame + + Returns + ------- + Tsd, TsdFrame, or TsdTensor + Appropriate pynapple time series object + """ + if hasattr(signal, "load"): + signal = signal.load() + + times = signal.times.rescale("s").magnitude + data = signal.magnitude + + nap_type = _get_signal_type(signal) + + if nap_type == nap.Tsd: + if len(data.shape) == 2: + data = data.squeeze() + return nap.Tsd(t=times, d=data, time_support=time_support) + elif nap_type == nap.TsdFrame: + return nap.TsdFrame(t=times, d=data, columns=column_names, time_support=time_support) + else: + return nap.TsdTensor(t=times, d=data, time_support=time_support) + + +def _make_ts_from_spiketrain( + spiketrain, time_support: Optional[nap.IntervalSet] = None +) -> nap.Ts: + """Convert a Neo SpikeTrain to a pynapple Ts. + + Parameters + ---------- + spiketrain : neo.SpikeTrain or SpikeTrainProxy + Neo spike train + time_support : IntervalSet, optional + Time support + + Returns + ------- + Ts + Pynapple Ts object + """ + if hasattr(spiketrain, "load"): + spiketrain = spiketrain.load() + + times = spiketrain.times.rescale("s").magnitude + + return nap.Ts(t=times, time_support=time_support) + + +def _make_tsgroup_from_spiketrains( + spiketrains: Union[list, "SpikeTrainList"], + time_support: Optional[nap.IntervalSet] = None, +) -> nap.TsGroup: + """Convert a list of Neo SpikeTrains to a pynapple TsGroup. + + Parameters + ---------- + spiketrains : list of neo.SpikeTrain or SpikeTrainList + List of spike trains + time_support : IntervalSet, optional + Time support + + Returns + ------- + TsGroup + Pynapple TsGroup + """ + ts_dict = {} + metadata = {} + + for i, st in enumerate(spiketrains): + if hasattr(st, "load"): + st = st.load() + + times = st.times.rescale("s").magnitude + ts_dict[i] = nap.Ts(t=times, time_support=time_support) + + # Collect metadata from annotations + for key, value in _extract_annotations(st).items(): + if key not in metadata: + metadata[key] = [] + metadata[key].append(value) + + # Convert metadata lists to arrays + meta_arrays = {} + for key, values in metadata.items(): + try: + meta_arrays[key] = np.array(values) + except (ValueError, TypeError): + # Skip metadata that can't be converted to array + pass + + return nap.TsGroup(ts_dict, time_support=time_support, **meta_arrays) + + + +# ============================================================================= +# Signal Interface for lazy loading +# ============================================================================= -class NEOSignalInterface: + +class NeoSignalInterface: + """Interface for lazy-loading Neo signals into pynapple objects. + + This class provides lazy access to Neo signals, loading data only when + requested via `get()` or `restrict()` methods. + + Parameters + ---------- + signal : neo signal object + A Neo signal (AnalogSignal, SpikeTrain, etc.) + block : neo.Block + The parent block containing the signal + time_support : IntervalSet + Time support for the data + sig_num : int, optional + Index of the signal within the segment + + Attributes + ---------- + nap_type : type + The pynapple type this signal will be converted to + is_analog : bool + Whether this is an analog signal + dt : float + Sampling interval (for analog signals) + shape : tuple + Shape of the data + start_time : float or list + Start time(s) + end_time : float or list + End time(s) + """ def __init__(self, signal, block, time_support, sig_num=None): self.time_support = time_support + self._block = block + self._sig_num = sig_num + + # Determine signal type and pynapple mapping if isinstance(signal, (neo.AnalogSignal, AnalogSignalProxy)): self.is_analog = True - self.nap_type = self._get_meta_analog(signal) + self.nap_type = _get_signal_type(signal) + self._signal_type = "analog" + elif hasattr(neo, "IrregularlySampledSignal") and isinstance( + signal, neo.IrregularlySampledSignal + ): + self.is_analog = False # Irregularly sampled + self.nap_type = _get_signal_type(signal) + self._signal_type = "irregular" elif isinstance(signal, (neo.SpikeTrain, SpikeTrainProxy)): self.nap_type = nap.Ts self.is_analog = False + self._signal_type = "spiketrain" elif isinstance(signal, (list, SpikeTrainList)): self.nap_type = nap.TsGroup self.is_analog = False + self._signal_type = "tsgroup" + elif isinstance(signal, (neo.Epoch,)) or (hasattr(neo.io, "proxyobjects") and isinstance(signal, EpochProxy)): + self.nap_type = nap.IntervalSet + self.is_analog = False + self._signal_type = "epoch" + elif isinstance(signal, (neo.Event,)) or (hasattr(neo.io, "proxyobjects") and isinstance(signal, EventProxy)): + self.nap_type = nap.Ts + self.is_analog = False + self._signal_type = "event" else: - raise TypeError(f"signal type {type(signal)} not recognized.") - self._block = block - self._sig_num = sig_num + raise TypeError(f"Signal type {type(signal)} not recognized.") + # Store timing info if self.is_analog: self.dt = (1 / signal.sampling_rate).rescale("s").magnitude self.shape = signal.shape - if not issubclass(self.nap_type, nap.TsGroup): - self.start_time = signal.t_start.rescale("s").magnitude - self.end_time = signal.t_stop.rescale("s").magnitude - else: - self.start_time = [s.t_start.rescale("s").magnitude for s in signal] - self.end_time = [s.t_stop.rescale("s").magnitude for s in signal] - - @staticmethod - def _get_meta_analog(signal): - if len(signal.shape) == 1: - nap_type = nap.Tsd - elif len(signal.shape) == 2: - nap_type = nap.TsdFrame + elif self._signal_type == "irregular": + self.shape = signal.shape + + if self._signal_type not in ("tsgroup",): + self.start_time = _rescale_to_seconds(signal.t_start) + self.end_time = _rescale_to_seconds(signal.t_stop) else: - nap_type = nap.TsdTensor - return nap_type + self.start_time = [_rescale_to_seconds(s.t_start) for s in signal] + self.end_time = [_rescale_to_seconds(s.t_stop) for s in signal] + + def __repr__(self): + return f"" def __getitem__(self, item): if isinstance(item, slice): return self._get_from_slice(item) raise ValueError(f"Cannot get item {item}.") - def get(self, start, stop): - """Get data between start and stop times.""" + def get(self, start: float, stop: float): + """Get data between start and stop times. + + Parameters + ---------- + start : float + Start time in seconds + stop : float + Stop time in seconds + + Returns + ------- + pynapple object + Data restricted to the time range + """ if self.is_analog: return self._get_analog(start, stop) - elif issubclass(self.nap_type, nap.Ts): + elif self._signal_type == "irregular": + return self._get_irregular(start, stop) + elif self._signal_type == "spiketrain": return self._get_ts(self._sig_num, start, stop) - else: # TsGroup + elif self._signal_type == "tsgroup": return self._get_tsgroup(start, stop) - - def restrict(self, epoch): - """Restrict data to epochs.""" + elif self._signal_type == "epoch": + return self._get_epoch(start, stop) + elif self._signal_type == "event": + return self._get_event(start, stop) + + def load(self): + """Load all data. + + Returns + ------- + pynapple object + The fully loaded data + """ + start = float(self.time_support.start[0]) + end = float(self.time_support.end[-1]) + return self.get(start, end) + + def restrict(self, epoch: nap.IntervalSet): + """Restrict data to epochs. + + Parameters + ---------- + epoch : IntervalSet + Epochs to restrict to + + Returns + ------- + pynapple object + Data restricted to the epochs + """ if self.is_analog: return self._restrict_analog(epoch) - elif issubclass(self.nap_type, nap.Ts): + elif self._signal_type == "irregular": + return self._restrict_irregular(epoch) + elif self._signal_type == "spiketrain": return self._restrict_ts(epoch) - else: # TsGroup + elif self._signal_type == "tsgroup": return self._restrict_tsgroup(epoch) + elif self._signal_type == "epoch": + return self._get_epoch( + float(epoch.start[0]), float(epoch.end[-1]) + ).restrict(epoch) + elif self._signal_type == "event": + return self._get_event( + float(epoch.start[0]), float(epoch.end[-1]) + ).restrict(epoch) def _get_from_slice(self, slc): start = slc.start if slc.start is not None else 0 @@ -74,14 +534,19 @@ def _get_from_slice(self, slc): if self.is_analog: if stop is None: - stop = sum(s.analogsignals[self._sig_num].shape[0] for s in self._block.segments) + stop = sum( + s.analogsignals[self._sig_num].shape[0] + for s in self._block.segments + ) return self._slice_segment_analog(start, stop, step) - elif issubclass(self.nap_type, nap.Ts): + elif self._signal_type == "spiketrain": if stop is None: - stop = sum(len(s.spiketrains[self._sig_num]) for s in self._block.segments) + stop = sum( + len(s.spiketrains[self._sig_num]) for s in self._block.segments + ) return self._slice_segment_ts(start, stop, step) else: - raise ValueError("Cannot slice a TsGroup.") + raise ValueError(f"Cannot slice a {self._signal_type}.") def _instantiate_nap(self, time, data, time_support): return self.nap_type( @@ -92,7 +557,10 @@ def _instantiate_nap(self, time, data, time_support): def _concatenate_array(self, time_list, data_list): if len(data_list) == 0: - return np.array([]), np.array([]).reshape((0, *self.shape[1:]) if len(self.shape) > 1 else (0, 1)) + shape = getattr(self, "shape", (0, 1)) + return np.array([]), np.array([]).reshape( + (0, *shape[1:]) if len(shape) > 1 else (0,) + ) else: return np.concatenate(time_list), np.concatenate(data_list, axis=0) @@ -106,21 +574,18 @@ def _get_analog(self, start, stop, return_array=False): for i, seg in enumerate(self._block.segments): signal = seg.analogsignals[self._sig_num] - # Get segment boundaries seg_start = self.time_support.start[i] seg_stop = self.time_support.end[i] - # Check if requested time overlaps with this segment if start >= seg_stop or stop <= seg_start: - continue # No overlap, skip this segment + continue - # Clip to segment bounds chunk_start = max(start, seg_start) chunk_stop = min(stop, seg_stop) chunk = signal.time_slice(chunk_start, chunk_stop) - if chunk.shape[0] > 0: # Has data + if chunk.shape[0] > 0: data.append(chunk.magnitude) time.append(chunk.times.rescale("s").magnitude) @@ -151,33 +616,26 @@ def _slice_segment_analog(self, start_idx, stop_idx, step): for i, seg in enumerate(self._block.segments): signal = seg.analogsignals[self._sig_num] - # Segment boundaries from time_support (already in seconds) seg_start_time = self.time_support.start[i] seg_end_time = self.time_support.end[i] seg_duration = seg_end_time - seg_start_time seg_n_samples = signal.shape[0] - # Actual dt for this segment dt = seg_duration / seg_n_samples - # Clip indices to segment bounds seg_start_idx = max(0, start_idx) seg_stop_idx = min(seg_n_samples, stop_idx) if seg_start_idx >= seg_stop_idx: - continue # No overlap with this segment + continue - # Load full segment and slice exactly try: signal_loaded = signal.load() chunk = signal_loaded[seg_start_idx:seg_stop_idx:step] - except MemoryError: - # Fallback: use time_slice chunk_start_time = seg_start_time + seg_start_idx * dt chunk_stop_time = seg_start_time + seg_stop_idx * dt chunk = signal.time_slice(chunk_start_time, chunk_stop_time) - if step != 1: chunk = chunk[::step] @@ -187,6 +645,82 @@ def _slice_segment_analog(self, start_idx, stop_idx, step): time, data = self._concatenate_array(time, data) return self._instantiate_nap(time, data, time_support=self.time_support) + # ========== Irregularly Sampled Signal Methods ========== + + def _get_irregular(self, start, stop, return_array=False): + """Get irregularly sampled signal between start and stop times.""" + data = [] + time = [] + + for i, seg in enumerate(self._block.segments): + signal = seg.irregularlysampledsignals[self._sig_num] + + seg_start = self.time_support.start[i] + seg_stop = self.time_support.end[i] + + if start >= seg_stop or stop <= seg_start: + continue + + chunk_start = max(start, seg_start) + chunk_stop = min(stop, seg_stop) + + chunk = signal.time_slice(chunk_start, chunk_stop) + + if chunk.shape[0] > 0: + data.append(chunk.magnitude) + time.append(chunk.times.rescale("s").magnitude) + + if len(time) == 0: + time = np.array([]) + data = np.array([]) + else: + time = np.concatenate(time) + data = np.concatenate(data, axis=0) + + if not return_array: + if data.ndim == 1 or (data.ndim == 2 and data.shape[1] == 1): + if data.ndim == 2: + data = data.squeeze() + return nap.Tsd(t=time, d=data, time_support=self.time_support) + elif data.ndim == 2: + return nap.TsdFrame(t=time, d=data, time_support=self.time_support) + else: + return nap.TsdTensor(t=time, d=data, time_support=self.time_support) + else: + return time, data + + def _restrict_irregular(self, epoch): + """Restrict irregularly sampled signal to epochs.""" + time = [] + data = [] + + for start, end in epoch.values: + time_ep, data_ep = self._get_irregular(start, end, return_array=True) + if len(time_ep) > 0: + time.append(time_ep) + data.append(data_ep) + + if len(time) == 0: + return nap.Tsd(t=np.array([]), d=np.array([]), time_support=epoch) + + time = np.concatenate(time) + data = np.concatenate(data, axis=0) + + if data.ndim == 1 or (data.ndim == 2 and data.shape[1] == 1): + if data.ndim == 2: + data = data.squeeze() + return nap.Tsd(t=time, d=data, time_support=self.time_support).restrict( + epoch + ) + elif data.ndim == 2: + return nap.TsdFrame( + t=time, d=data, time_support=self.time_support + ).restrict(epoch) + else: + return nap.TsdTensor( + t=time, d=data, time_support=self.time_support + ).restrict(epoch) + # ========== Spike Train (Ts) Methods ========== def _get_ts(self, unit_idx, start, stop, return_array=False): @@ -196,21 +730,18 @@ def _get_ts(self, unit_idx, start, stop, return_array=False): for i, seg in enumerate(self._block.segments): spiketrain = seg.spiketrains[unit_idx] - # Get segment boundaries seg_start = self.time_support.start[i] seg_stop = self.time_support.end[i] - # Check if requested time overlaps with this segment if start >= seg_stop or stop <= seg_start: - continue # No overlap + continue - # Clip to segment bounds chunk_start = max(start, seg_start) chunk_stop = min(stop, seg_stop) chunk = spiketrain.time_slice(chunk_start, chunk_stop) - if len(chunk) > 0: # Has spikes + if len(chunk) > 0: spikes.append(chunk.times.rescale("s").magnitude) spike_times = np.concatenate(spikes) if spikes else np.array([]) @@ -239,25 +770,24 @@ def _slice_segment_ts(self, start_idx, stop_idx, step): for i, seg in enumerate(self._block.segments): spiketrain = seg.spiketrains[self._sig_num] - # Get number of spikes in this segment n_spikes = len(spiketrain) - # Clip indices to segment bounds seg_start_idx = max(0, start_idx) seg_stop_idx = min(n_spikes, stop_idx) if seg_start_idx >= seg_stop_idx: - continue # No overlap + continue - # Load and slice by spike index - spiketrain_loaded = spiketrain.load() if hasattr(spiketrain, 'load') else spiketrain + spiketrain_loaded = ( + spiketrain.load() if hasattr(spiketrain, "load") else spiketrain + ) chunk = spiketrain_loaded[seg_start_idx:seg_stop_idx:step] spikes.append(chunk.times.rescale("s").magnitude) return nap.Ts( t=np.concatenate(spikes) if spikes else np.array([]), - time_support=self.time_support + time_support=self.time_support, ) # ========== TsGroup Methods ========== @@ -290,14 +820,317 @@ def _restrict_tsgroup(self, epoch): return nap.TsGroup(ts_dict, time_support=self.time_support).restrict(epoch) + # ========== Epoch Methods ========== + + def _get_epoch(self, start, stop): + """Get epochs within time range.""" + all_starts = [] + all_ends = [] + all_labels = [] + + for i, seg in enumerate(self._block.segments): + for epoch in seg.epochs: + if hasattr(epoch, "load"): + epoch = epoch.load() + + times = epoch.times.rescale("s").magnitude + durations = epoch.durations.rescale("s").magnitude + + for t, d, lbl in zip(times, durations, epoch.labels): + ep_start = t + ep_end = t + d + + # Check overlap with requested range + if ep_end > start and ep_start < stop: + all_starts.append(max(ep_start, start)) + all_ends.append(min(ep_end, stop)) + all_labels.append(lbl) + + if len(all_starts) == 0: + return nap.IntervalSet(start=[], end=[]) + + return nap.IntervalSet( + start=np.array(all_starts), + end=np.array(all_ends), + metadata={"label": np.array(all_labels)} if all_labels else None, + ) + + # ========== Event Methods ========== + + def _get_event(self, start, stop): + """Get events within time range.""" + all_times = [] + + for i, seg in enumerate(self._block.segments): + for event in seg.events: + if hasattr(event, "load"): + event = event.load() + + times = event.times.rescale("s").magnitude + + mask = (times >= start) & (times <= stop) + all_times.extend(times[mask]) + + return nap.Ts(t=np.array(all_times), time_support=self.time_support) + + +# ============================================================================= +# Main Interface Class +# ============================================================================= + + +class NeoReader(UserDict): + """Class for reading Neo-compatible files. + + This class provides a dictionary-like interface to Neo files, with + lazy-loading support. It automatically detects the appropriate IO + based on the file extension. + + Parameters + ---------- + file : str or Path + Path to the file to load + lazy : bool, default True + Whether to use lazy loading + + Examples + -------- + >>> import pynapple as nap + >>> data = nap.io.NeoReader("my_file.plx") + >>> print(data) + my_file + +---------------------+----------+ + | Key | Type | + +=====================+==========+ + | TsGroup | TsGroup | + | Tsd 0: LFP | Tsd | + +---------------------+----------+ + + >>> spikes = data["TsGroup"] + >>> lfp = data["Tsd 0: LFP"] + """ + + def __init__(self, file: Union[str, Path], lazy: bool = True): + _check_neo_installed() + + self.path = Path(file) + if not self.path.exists(): + raise FileNotFoundError(f"File not found: {file}") + + self.name = self.path.stem + self._lazy = lazy + + # Get appropriate IO + self._reader = neo.io.get_io(str(self.path)) + + # Read blocks + self._blocks = self._reader.read(lazy=lazy) + + # Build data dictionary + self.data = {} + self._data_info = {} # Store type info for display + self._interfaces = {} # Store NeoSignalInterface objects + + self._collect_data() + + UserDict.__init__(self, self.data) + + def _collect_data(self): + """Collect all data from Neo blocks into the dictionary.""" + for block_idx, block in enumerate(self._blocks): + block_prefix = "" if len(self._blocks) == 1 else f"block{block_idx}/" + + # Build time support from segments + starts = np.array( + [_rescale_to_seconds(seg.t_start) for seg in block.segments] + ) + ends = np.array( + [_rescale_to_seconds(seg.t_stop) for seg in block.segments] + ) + time_support = nap.IntervalSet(starts, ends) + + # Process first segment to get signal info + # (assuming consistent structure across segments) + if len(block.segments) > 0: + seg = block.segments[0] + + # Analog signals + for sig_idx, signal in enumerate(seg.analogsignals): + nap_type = _get_signal_type(signal) + name = signal.name if signal.name else f"signal{sig_idx}" + key = f"{block_prefix}{nap_type.__name__} {sig_idx}: {name}" + + interface = NeoSignalInterface( + signal, block, time_support, sig_num=sig_idx + ) + self._interfaces[key] = interface + self.data[key] = {"type": nap_type.__name__, "interface": interface} + self._data_info[key] = nap_type.__name__ + + # Irregularly sampled signals + for sig_idx, signal in enumerate(seg.irregularlysampledsignals): + nap_type = _get_signal_type(signal) + name = signal.name if signal.name else f"irregular{sig_idx}" + key = f"{block_prefix}{nap_type.__name__} (irregular) {sig_idx}: {name}" + + interface = NeoSignalInterface( + signal, block, time_support, sig_num=sig_idx + ) + self._interfaces[key] = interface + self.data[key] = {"type": nap_type.__name__, "interface": interface} + self._data_info[key] = nap_type.__name__ + + # Spike trains + if len(seg.spiketrains) == 1: + st = seg.spiketrains[0] + name = st.name if st.name else "spikes" + key = f"{block_prefix}Ts: {name}" + + interface = NeoSignalInterface( + st, block, time_support, sig_num=0 + ) + self._interfaces[key] = interface + self.data[key] = {"type": "Ts", "interface": interface} + self._data_info[key] = "Ts" + elif len(seg.spiketrains) > 1: + key = f"{block_prefix}TsGroup" + + interface = NeoSignalInterface( + seg.spiketrains, block, time_support + ) + self._interfaces[key] = interface + self.data[key] = {"type": "TsGroup", "interface": interface} + self._data_info[key] = "TsGroup" + + # Epochs + for ep_idx, epoch in enumerate(seg.epochs): + name = epoch.name if hasattr(epoch, "name") and epoch.name else f"epoch{ep_idx}" + key = f"{block_prefix}IntervalSet {ep_idx}: {name}" + + interface = NeoSignalInterface( + epoch, block, time_support, sig_num=ep_idx + ) + self._interfaces[key] = interface + self.data[key] = {"type": "IntervalSet", "interface": interface} + self._data_info[key] = "IntervalSet" + + # Events + for ev_idx, event in enumerate(seg.events): + name = event.name if hasattr(event, "name") and event.name else f"event{ev_idx}" + key = f"{block_prefix}Ts (event) {ev_idx}: {name}" + + interface = NeoSignalInterface( + event, block, time_support, sig_num=ev_idx + ) + self._interfaces[key] = interface + self.data[key] = {"type": "Ts", "interface": interface} + self._data_info[key] = "Ts" + + def __str__(self): + """String representation showing available data.""" + title = self.name + view = [[k, self._data_info[k]] for k in self.data.keys()] + headers = ["Key", "Type"] + + if HAS_TABULATE: + return title + "\n" + tabulate(view, headers=headers, tablefmt="mixed_outline") + else: + # Simple fallback without tabulate + lines = [title, "-" * len(title)] + for k, t in view: + lines.append(f" {k}: {t}") + return "\n".join(lines) + + def __repr__(self): + return self.__str__() + + def __getitem__(self, key: str): + """Get data by key, loading if necessary. + + Parameters + ---------- + key : str + Key for the data item + + Returns + ------- + pynapple object + The requested data (Ts, Tsd, TsdFrame, TsGroup, IntervalSet, etc.) + """ + if key not in self.data: + raise KeyError(f"Key '{key}' not found. Available keys: {list(self.data.keys())}") + + item = self.data[key] + + # If already loaded, return it + if not isinstance(item, dict): + return item + + # Load via interface + interface = item["interface"] + loaded_data = interface.load() + + # Cache the loaded data + self.data[key] = loaded_data + + return loaded_data + + def keys(self): + """Return available data keys.""" + return list(self.data.keys()) + + def items(self): + """Return key-value pairs (loads data on access).""" + return [(k, self[k]) for k in self.keys()] + + def values(self): + """Return all values (loads all data).""" + return [self[k] for k in self.keys()] + + def get_time_support(self) -> nap.IntervalSet: + """Get the time support from the first interface. + + Returns + ------- + IntervalSet + Time support covering all segments + """ + if self._interfaces: + return list(self._interfaces.values())[0].time_support + return nap.IntervalSet(start=0, end=0) + + def close(self): + """Close the underlying Neo reader if it supports closing.""" + if hasattr(self._reader, "close"): + self._reader.close() + + +# ============================================================================= +# Legacy Interface (for backward compatibility) +# ============================================================================= + + +class NEOSignalInterface(NeoSignalInterface): + """Legacy alias for NeoSignalInterface.""" + pass + class NEOExperimentInterface: + """Legacy interface for Neo experiments. + + .. deprecated:: + Use :class:`NeoReader` instead. + """ + def __init__(self, reader, lazy=False): - # block, aka experiments (contains multiple segments, aka trials) + warnings.warn( + "NEOExperimentInterface is deprecated. Use NeoReader instead.", + DeprecationWarning, + stacklevel=2, + ) self._reader = reader self._lazy = lazy self.experiment = self._collect_time_series_info() - self._reader = reader def _collect_time_series_info(self): blocks = self._reader.read(lazy=self._lazy) @@ -308,35 +1141,41 @@ def _collect_time_series_info(self): if block.name: name += ": " + block.name experiments[name] = {} - # loop once to get the time support + starts, ends = np.empty(len(block.segments)), np.empty(len(block.segments)) for trial_num, segment in enumerate(block.segments): starts[trial_num] = segment.t_start.rescale("s").magnitude ends[trial_num] = segment.t_stop.rescale("s").magnitude iset = nap.IntervalSet(starts, ends) - for trial_num, segment in enumerate(block.segments): - # segment may contain epoch (potentially overlapping) - # with fields: times, durations, labels. We may add them to metadata. - # tsd/tsdFrame/TsdTensor + for trial_num, segment in enumerate(block.segments): + # Analog signals for signal_num, signal in enumerate(segment.analogsignals): if signal.name: signame = f" {signal_num}: " + signal.name else: signame = f" {signal_num}" - signal_interface = NEOSignalInterface(signal, block, iset, sig_num=signal_num) + signal_interface = NeoSignalInterface( + signal, block, iset, sig_num=signal_num + ) signame = signal_interface.nap_type.__name__ + signame experiments[name][signame] = signal_interface + # Spike trains if len(segment.spiketrains) == 1: signal = segment.spiketrains[0] - signal_interface = NEOSignalInterface(signal, block, iset, sig_num=0) + signal_interface = NeoSignalInterface( + signal, block, iset, sig_num=0 + ) signame = f"Ts" + ": " + signal.name if signal.name else "Ts" experiments[name][signame] = signal_interface else: signame = f"TsGroup" - experiments[name][signame] = NEOSignalInterface(segment.spiketrains, block, iset) + experiments[name][signame] = NeoSignalInterface( + segment.spiketrains, block, iset + ) + return experiments def __getitem__(self, item): @@ -352,9 +1191,63 @@ def keys(self): return [(k, k2) for k in self.experiment.keys() for k2 in self.experiment[k]] -def load_experiment(path: str | pathlib.Path, lazy: bool = True) -> NEOExperimentInterface: +def load_file(path: Union[str, Path], lazy: bool = True) -> NeoReader: + """Load a neural recording file using Neo. + + This function automatically detects the file format and uses the + appropriate Neo IO to load the data. + + Parameters + ---------- + path : str or Path + Path to the recording file + lazy : bool, default True + Whether to use lazy loading (recommended for large files) + + Returns + ------- + NeoReader + Interface to the loaded data + + Examples + -------- + >>> import pynapple as nap + >>> data = nap.io.neo.load_file("recording.plx") + >>> print(data) + recording + +---------------------+----------+ + | Key | Type | + +=====================+==========+ + | TsGroup | TsGroup | + | Tsd 0: LFP | Tsd | + +---------------------+----------+ + + >>> spikes = data["TsGroup"] + + See Also + -------- + NeoReader : Class for Neo file interface + + Notes + ----- + Supported formats depend on your Neo installation. Common formats include: + - Plexon (.plx, .pl2) + - Blackrock (.nev, .ns*) + - Spike2 (.smr) + - Neuralynx (.ncs, .nse, .ntt) + - OpenEphys + - Intan (.rhd, .rhs) + - And many more (see Neo documentation) """ - Load a neural recording experiment. + return NeoReader(path, lazy=lazy) + + +# Legacy alias +def load_experiment(path: Union[str, Path], lazy: bool = True) -> NEOExperimentInterface: + """Load a neural recording experiment. + + .. deprecated:: + Use :func:`load_file` instead. Parameters ---------- @@ -367,7 +1260,192 @@ def load_experiment(path: str | pathlib.Path, lazy: bool = True) -> NEOExperimen ------- NEOExperimentInterface """ + import pathlib + path = pathlib.Path(path) reader = neo.io.get_io(path) - return NEOExperimentInterface(reader, lazy=lazy) \ No newline at end of file + return NEOExperimentInterface(reader, lazy=lazy) + + +# +# # ============================================================================= +# # Conversion functions: Pynapple -> Neo +# # ============================================================================= +# +# +# def to_neo_analogsignal( +# tsd: Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor], +# units: str = "dimensionless", +# **kwargs, +# ) -> "neo.AnalogSignal": +# """Convert a pynapple Tsd/TsdFrame/TsdTensor to a Neo AnalogSignal. +# +# Parameters +# ---------- +# tsd : Tsd, TsdFrame, or TsdTensor +# Pynapple time series object +# units : str, default "dimensionless" +# Units for the signal (e.g., "mV", "uV") +# **kwargs +# Additional arguments passed to neo.AnalogSignal +# +# Returns +# ------- +# neo.AnalogSignal +# Neo analog signal object +# """ +# _check_neo_installed() +# import quantities as pq +# +# times = tsd.times() +# data = tsd.values +# +# # Ensure 2D for AnalogSignal +# if data.ndim == 1: +# data = data.reshape(-1, 1) +# +# # Calculate sampling rate from timestamps +# if len(times) > 1: +# dt = np.median(np.diff(times)) +# sampling_rate = 1.0 / dt +# else: +# sampling_rate = 1.0 # Default if only one sample +# +# signal = neo.AnalogSignal( +# data, +# units=units, +# sampling_rate=sampling_rate * pq.Hz, +# t_start=times[0] * pq.s, +# **kwargs, +# ) +# +# return signal +# +# +# def to_neo_spiketrain( +# ts: nap.Ts, +# t_stop: Optional[float] = None, +# units: str = "s", +# **kwargs, +# ) -> "neo.SpikeTrain": +# """Convert a pynapple Ts to a Neo SpikeTrain. +# +# Parameters +# ---------- +# ts : Ts +# Pynapple Ts object +# t_stop : float, optional +# Stop time for the spike train. If None, uses the end of time_support +# units : str, default "s" +# Time units +# **kwargs +# Additional arguments passed to neo.SpikeTrain +# +# Returns +# ------- +# neo.SpikeTrain +# Neo spike train object +# """ +# _check_neo_installed() +# import quantities as pq +# +# times = ts.times() +# +# if t_stop is None: +# t_stop = float(ts.time_support.end[-1]) +# +# t_start = float(ts.time_support.start[0]) if len(times) == 0 else min(times[0], float(ts.time_support.start[0])) +# +# spiketrain = neo.SpikeTrain( +# times, +# units=units, +# t_start=t_start * pq.s, +# t_stop=t_stop * pq.s, +# **kwargs, +# ) +# +# return spiketrain +# +# +# def to_neo_epoch( +# iset: nap.IntervalSet, +# labels: Optional[np.ndarray] = None, +# **kwargs, +# ) -> "neo.Epoch": +# """Convert a pynapple IntervalSet to a Neo Epoch. +# +# Parameters +# ---------- +# iset : IntervalSet +# Pynapple IntervalSet +# labels : array-like, optional +# Labels for each epoch. If None, uses integers. +# **kwargs +# Additional arguments passed to neo.Epoch +# +# Returns +# ------- +# neo.Epoch +# Neo epoch object +# """ +# _check_neo_installed() +# import quantities as pq +# +# starts = iset.start +# ends = iset.end +# durations = ends - starts +# +# if labels is None: +# # Check if there's a 'label' column in metadata +# if hasattr(iset, "label"): +# labels = iset.label +# else: +# labels = np.arange(len(starts)).astype(str) +# +# epoch = neo.Epoch( +# times=starts * pq.s, +# durations=durations * pq.s, +# labels=labels, +# **kwargs, +# ) +# +# return epoch +# +# +# def to_neo_event( +# ts: nap.Ts, +# labels: Optional[np.ndarray] = None, +# **kwargs, +# ) -> "neo.Event": +# """Convert a pynapple Ts to a Neo Event. +# +# Parameters +# ---------- +# ts : Ts +# Pynapple Ts object +# labels : array-like, optional +# Labels for each event. If None, uses integers. +# **kwargs +# Additional arguments passed to neo.Event +# +# Returns +# ------- +# neo.Event +# Neo event object +# """ +# _check_neo_installed() +# import quantities as pq +# +# times = ts.times() +# +# if labels is None: +# labels = np.arange(len(times)).astype(str) +# +# event = neo.Event( +# times=times * pq.s, +# labels=labels, +# **kwargs, +# ) +# +# return event diff --git a/tests/test_neo.py b/tests/test_neo.py new file mode 100644 index 000000000..685b0ddee --- /dev/null +++ b/tests/test_neo.py @@ -0,0 +1,563 @@ +# -*- coding: utf-8 -*- +"""Tests of Neo interface for `pynapple` package.""" + +import warnings + +import numpy as np +import pytest + +import pynapple as nap + +# Check if neo is installed +try: + import neo + import quantities as pq + + HAS_NEO = True +except ImportError: + HAS_NEO = False + +pytestmark = pytest.mark.skipif(not HAS_NEO, reason="Neo is not installed") + + +# ============================================================================= +# Helper functions to create mock Neo objects +# ============================================================================= + + +def create_mock_analog_signal(n_samples=100, n_channels=3, sampling_rate=1000.0, t_start=0.0): + """Create a mock Neo AnalogSignal.""" + data = np.random.randn(n_samples, n_channels) + signal = neo.AnalogSignal( + data, + units="mV", + sampling_rate=sampling_rate * pq.Hz, + t_start=t_start * pq.s, + ) + return signal + + +def create_mock_spiketrain(n_spikes=50, t_start=0.0, t_stop=10.0): + """Create a mock Neo SpikeTrain.""" + spike_times = np.sort(np.random.uniform(t_start, t_stop, n_spikes)) + spiketrain = neo.SpikeTrain( + spike_times, + units="s", + t_start=t_start * pq.s, + t_stop=t_stop * pq.s, + ) + return spiketrain + + +def create_mock_epoch(n_epochs=5, t_start=0.0, max_duration=2.0): + """Create a mock Neo Epoch.""" + times = np.sort(np.random.uniform(t_start, 10.0, n_epochs)) + durations = np.random.uniform(0.1, max_duration, n_epochs) + labels = np.array([f"epoch_{i}" for i in range(n_epochs)]) + epoch = neo.Epoch( + times=times * pq.s, + durations=durations * pq.s, + labels=labels, + ) + return epoch + + +def create_mock_event(n_events=10, t_start=0.0, t_stop=10.0): + """Create a mock Neo Event.""" + times = np.sort(np.random.uniform(t_start, t_stop, n_events)) + labels = np.array([f"event_{i}" for i in range(n_events)]) + event = neo.Event( + times=times * pq.s, + labels=labels, + ) + return event + + +def create_mock_irregular_signal(n_samples=50, n_channels=2, t_start=0.0, t_stop=10.0): + """Create a mock Neo IrregularlySampledSignal.""" + times = np.sort(np.random.uniform(t_start, t_stop, n_samples)) + data = np.random.randn(n_samples, n_channels) + signal = neo.IrregularlySampledSignal( + times=times * pq.s, + signal=data, + units="mV", + time_units="s", + ) + return signal + + +def create_mock_block_with_segments(n_segments=2, n_analog=2, n_spiketrains=3): + """Create a mock Neo Block with multiple segments.""" + block = neo.Block(name="test_block") + + for seg_idx in range(n_segments): + seg = neo.Segment(name=f"segment_{seg_idx}") + + # Add analog signals + for sig_idx in range(n_analog): + signal = create_mock_analog_signal( + n_samples=100, + n_channels=3, + sampling_rate=1000.0, + t_start=seg_idx * 10.0, + ) + signal.name = f"analog_{sig_idx}" + seg.analogsignals.append(signal) + + # Add spike trains + for st_idx in range(n_spiketrains): + spiketrain = create_mock_spiketrain( + n_spikes=20, + t_start=seg_idx * 10.0, + t_stop=(seg_idx + 1) * 10.0, + ) + spiketrain.name = f"unit_{st_idx}" + seg.spiketrains.append(spiketrain) + + # Add an epoch + epoch = create_mock_epoch(n_epochs=3, t_start=seg_idx * 10.0) + epoch.name = "behavioral_states" + seg.epochs.append(epoch) + + # Add an event + event = create_mock_event( + n_events=5, + t_start=seg_idx * 10.0, + t_stop=(seg_idx + 1) * 10.0, + ) + event.name = "stimuli" + seg.events.append(event) + + # Set segment times + seg.t_start = seg_idx * 10.0 * pq.s + seg.t_stop = (seg_idx + 1) * 10.0 * pq.s + + block.segments.append(seg) + + return block + + +# ============================================================================= +# Tests for conversion functions: Neo -> Pynapple +# ============================================================================= + + +class TestNeoToPynapple: + """Test Neo to Pynapple conversion functions.""" + + def test_analog_to_tsd(self): + """Test AnalogSignal to Tsd conversion.""" + from pynapple.io.neo import _make_tsd_from_analog + + # Single channel -> Tsd + signal = create_mock_analog_signal(n_samples=100, n_channels=1) + tsd = _make_tsd_from_analog(signal) + assert isinstance(tsd, nap.Tsd) + assert len(tsd) == 100 + + def test_analog_to_tsdframe(self): + """Test AnalogSignal to TsdFrame conversion.""" + from pynapple.io.neo import _make_tsd_from_analog + + # Multi-channel -> TsdFrame + signal = create_mock_analog_signal(n_samples=100, n_channels=3) + tsdframe = _make_tsd_from_analog(signal) + assert isinstance(tsdframe, nap.TsdFrame) + assert len(tsdframe) == 100 + assert tsdframe.shape[1] == 3 + + def test_spiketrain_to_ts(self): + """Test SpikeTrain to Ts conversion.""" + from pynapple.io.neo import _make_ts_from_spiketrain + + spiketrain = create_mock_spiketrain(n_spikes=50) + ts = _make_ts_from_spiketrain(spiketrain) + assert isinstance(ts, nap.Ts) + assert len(ts) == 50 + + def test_spiketrains_to_tsgroup(self): + """Test multiple SpikeTrains to TsGroup conversion.""" + from pynapple.io.neo import _make_tsgroup_from_spiketrains + + spiketrains = [create_mock_spiketrain(n_spikes=30) for _ in range(5)] + tsgroup = _make_tsgroup_from_spiketrains(spiketrains) + assert isinstance(tsgroup, nap.TsGroup) + assert len(tsgroup) == 5 + + def test_epoch_to_intervalset(self): + """Test Epoch to IntervalSet conversion.""" + from pynapple.io.neo import _make_intervalset_from_epoch + + epoch = create_mock_epoch(n_epochs=5) + iset = _make_intervalset_from_epoch(epoch) + assert isinstance(iset, nap.IntervalSet) + assert len(iset) == 5 + + def test_event_to_ts(self): + """Test Event to Ts conversion.""" + from pynapple.io.neo import _make_ts_from_event + + event = create_mock_event(n_events=10) + ts = _make_ts_from_event(event) + assert isinstance(ts, nap.Ts) + assert len(ts) == 10 + + def test_irregular_signal_to_tsd(self): + """Test IrregularlySampledSignal to Tsd conversion.""" + from pynapple.io.neo import _make_tsd_from_irregular + + signal = create_mock_irregular_signal(n_samples=50, n_channels=1) + tsd = _make_tsd_from_irregular(signal) + assert isinstance(tsd, nap.Tsd) + assert len(tsd) == 50 + + def test_irregular_signal_to_tsdframe(self): + """Test IrregularlySampledSignal to TsdFrame conversion.""" + from pynapple.io.neo import _make_tsd_from_irregular + + signal = create_mock_irregular_signal(n_samples=50, n_channels=3) + tsdframe = _make_tsd_from_irregular(signal) + assert isinstance(tsdframe, nap.TsdFrame) + assert len(tsdframe) == 50 + assert tsdframe.shape[1] == 3 + + +# ============================================================================= +# Tests for conversion functions: Pynapple -> Neo +# ============================================================================= + + +class TestPynappleToNeo: + """Test Pynapple to Neo conversion functions.""" + + def test_tsd_to_analog(self): + """Test Tsd to AnalogSignal conversion.""" + from pynapple.io.neo import to_neo_analogsignal + + tsd = nap.Tsd(t=np.arange(100) / 1000.0, d=np.random.randn(100)) + signal = to_neo_analogsignal(tsd, units="mV") + + assert isinstance(signal, neo.AnalogSignal) + assert signal.shape[0] == 100 + assert signal.shape[1] == 1 # Tsd is converted to 2D + + def test_tsdframe_to_analog(self): + """Test TsdFrame to AnalogSignal conversion.""" + from pynapple.io.neo import to_neo_analogsignal + + tsdframe = nap.TsdFrame( + t=np.arange(100) / 1000.0, d=np.random.randn(100, 3) + ) + signal = to_neo_analogsignal(tsdframe, units="uV") + + assert isinstance(signal, neo.AnalogSignal) + assert signal.shape == (100, 3) + + def test_ts_to_spiketrain(self): + """Test Ts to SpikeTrain conversion.""" + from pynapple.io.neo import to_neo_spiketrain + + ts = nap.Ts(t=np.sort(np.random.uniform(0, 10, 50))) + spiketrain = to_neo_spiketrain(ts, t_stop=10.0) + + assert isinstance(spiketrain, neo.SpikeTrain) + assert len(spiketrain) == 50 + + def test_intervalset_to_epoch(self): + """Test IntervalSet to Epoch conversion.""" + from pynapple.io.neo import to_neo_epoch + + iset = nap.IntervalSet(start=[0, 5, 10], end=[2, 7, 12]) + epoch = to_neo_epoch(iset) + + assert isinstance(epoch, neo.Epoch) + assert len(epoch) == 3 + np.testing.assert_array_almost_equal( + epoch.times.rescale("s").magnitude, np.array([0, 5, 10]) + ) + np.testing.assert_array_almost_equal( + epoch.durations.rescale("s").magnitude, np.array([2, 2, 2]) + ) + + def test_ts_to_event(self): + """Test Ts to Event conversion.""" + from pynapple.io.neo import to_neo_event + + ts = nap.Ts(t=np.array([1.0, 2.5, 5.0, 7.5])) + event = to_neo_event(ts) + + assert isinstance(event, neo.Event) + assert len(event) == 4 + np.testing.assert_array_almost_equal( + event.times.rescale("s").magnitude, np.array([1.0, 2.5, 5.0, 7.5]) + ) + + +# ============================================================================= +# Tests for NeoSignalInterface +# ============================================================================= + + +class TestNeoSignalInterface: + """Test NeoSignalInterface class.""" + + def test_analog_interface_init(self): + """Test initialization with AnalogSignal.""" + from pynapple.io.neo import NeoSignalInterface + + block = create_mock_block_with_segments(n_segments=1, n_analog=1, n_spiketrains=0) + signal = block.segments[0].analogsignals[0] + time_support = nap.IntervalSet(start=0, end=10) + + interface = NeoSignalInterface(signal, block, time_support, sig_num=0) + + assert interface.is_analog is True + assert interface.nap_type == nap.TsdFrame + assert interface._signal_type == "analog" + + def test_spiketrain_interface_init(self): + """Test initialization with SpikeTrain.""" + from pynapple.io.neo import NeoSignalInterface + + block = create_mock_block_with_segments(n_segments=1, n_analog=0, n_spiketrains=1) + spiketrain = block.segments[0].spiketrains[0] + time_support = nap.IntervalSet(start=0, end=10) + + interface = NeoSignalInterface(spiketrain, block, time_support, sig_num=0) + + assert interface.is_analog is False + assert interface.nap_type == nap.Ts + assert interface._signal_type == "spiketrain" + + def test_tsgroup_interface_init(self): + """Test initialization with multiple SpikeTrains.""" + from pynapple.io.neo import NeoSignalInterface + + block = create_mock_block_with_segments(n_segments=1, n_analog=0, n_spiketrains=3) + spiketrains = block.segments[0].spiketrains + time_support = nap.IntervalSet(start=0, end=10) + + interface = NeoSignalInterface(spiketrains, block, time_support) + + assert interface.is_analog is False + assert interface.nap_type == nap.TsGroup + assert interface._signal_type == "tsgroup" + + def test_interface_load(self): + """Test loading data through interface.""" + from pynapple.io.neo import NeoSignalInterface + + block = create_mock_block_with_segments(n_segments=1, n_analog=1, n_spiketrains=0) + signal = block.segments[0].analogsignals[0] + time_support = nap.IntervalSet(start=0, end=0.1) + + interface = NeoSignalInterface(signal, block, time_support, sig_num=0) + loaded = interface.load() + + assert isinstance(loaded, nap.TsdFrame) + + def test_interface_get_time_range(self): + """Test getting data for a time range.""" + from pynapple.io.neo import NeoSignalInterface + + block = create_mock_block_with_segments(n_segments=1, n_analog=1, n_spiketrains=0) + signal = block.segments[0].analogsignals[0] + time_support = nap.IntervalSet(start=0, end=0.1) + + interface = NeoSignalInterface(signal, block, time_support, sig_num=0) + data = interface.get(0.0, 0.05) + + assert isinstance(data, nap.TsdFrame) + + def test_interface_restrict(self): + """Test restricting data to epochs.""" + from pynapple.io.neo import NeoSignalInterface + + block = create_mock_block_with_segments(n_segments=1, n_analog=1, n_spiketrains=0) + signal = block.segments[0].analogsignals[0] + time_support = nap.IntervalSet(start=0, end=0.1) + + interface = NeoSignalInterface(signal, block, time_support, sig_num=0) + epoch = nap.IntervalSet(start=[0.01], end=[0.03]) + data = interface.restrict(epoch) + + assert isinstance(data, nap.TsdFrame) + + +# ============================================================================= +# Tests for legacy interface +# ============================================================================= + + +class TestLegacyInterface: + """Test legacy NEOExperimentInterface for backward compatibility.""" + + def test_legacy_deprecation_warning(self): + """Test that legacy interface raises deprecation warning.""" + from pynapple.io.neo import NEOExperimentInterface + + # Create a simple mock reader + class MockReader: + def read(self, lazy=False): + return [create_mock_block_with_segments(n_segments=1)] + + with pytest.warns(DeprecationWarning, match="NEOExperimentInterface is deprecated"): + NEOExperimentInterface(MockReader(), lazy=False) + + +# ============================================================================= +# Tests for helper functions +# ============================================================================= + + +class TestHelperFunctions: + """Test helper functions.""" + + def test_rescale_to_seconds(self): + """Test rescaling quantities to seconds.""" + from pynapple.io.neo import _rescale_to_seconds + + # Test milliseconds + value_ms = 1000 * pq.ms + assert _rescale_to_seconds(value_ms) == 1.0 + + # Test seconds + value_s = 5.0 * pq.s + assert _rescale_to_seconds(value_s) == 5.0 + + def test_get_signal_type(self): + """Test signal type detection.""" + from pynapple.io.neo import _get_signal_type + + # 1D signal -> Tsd + signal_1d = neo.AnalogSignal( + np.random.randn(100, 1), units="mV", sampling_rate=1000 * pq.Hz + ) + assert _get_signal_type(signal_1d) == nap.Tsd + + # 2D signal -> TsdFrame + signal_2d = neo.AnalogSignal( + np.random.randn(100, 3), units="mV", sampling_rate=1000 * pq.Hz + ) + assert _get_signal_type(signal_2d) == nap.TsdFrame + + def test_extract_annotations(self): + """Test annotation extraction.""" + from pynapple.io.neo import _extract_annotations + + signal = neo.AnalogSignal( + np.random.randn(100, 1), + units="mV", + sampling_rate=1000 * pq.Hz, + name="test_signal", + description="A test signal", + custom_annotation="custom_value", + ) + + annotations = _extract_annotations(signal) + + assert "neo_name" in annotations + assert annotations["neo_name"] == "test_signal" + assert "neo_description" in annotations + assert annotations["neo_description"] == "A test signal" + assert "custom_annotation" in annotations + assert annotations["custom_annotation"] == "custom_value" + + +# ============================================================================= +# Tests for round-trip conversion +# ============================================================================= + + +class TestRoundTrip: + """Test round-trip conversion (pynapple -> neo -> pynapple).""" + + def test_tsd_roundtrip(self): + """Test Tsd round-trip conversion.""" + from pynapple.io.neo import to_neo_analogsignal, _make_tsd_from_analog + + original = nap.Tsd(t=np.arange(100) / 1000.0, d=np.random.randn(100)) + + # Convert to Neo + neo_signal = to_neo_analogsignal(original) + + # Convert back to pynapple + recovered = _make_tsd_from_analog(neo_signal) + + assert isinstance(recovered, nap.Tsd) + np.testing.assert_array_almost_equal(original.values, recovered.values.flatten()) + + def test_intervalset_roundtrip(self): + """Test IntervalSet round-trip conversion.""" + from pynapple.io.neo import to_neo_epoch, _make_intervalset_from_epoch + + original = nap.IntervalSet(start=[1.0, 5.0, 10.0], end=[2.0, 7.0, 12.0]) + + # Convert to Neo + neo_epoch = to_neo_epoch(original) + + # Convert back to pynapple + recovered = _make_intervalset_from_epoch(neo_epoch) + + assert isinstance(recovered, nap.IntervalSet) + np.testing.assert_array_almost_equal(original.start, recovered.start) + np.testing.assert_array_almost_equal(original.end, recovered.end) + + def test_ts_roundtrip(self): + """Test Ts round-trip conversion via SpikeTrain.""" + from pynapple.io.neo import to_neo_spiketrain, _make_ts_from_spiketrain + + spike_times = np.sort(np.random.uniform(0, 10, 50)) + original = nap.Ts(t=spike_times) + + # Convert to Neo + neo_spiketrain = to_neo_spiketrain(original, t_stop=10.0) + + # Convert back to pynapple + recovered = _make_ts_from_spiketrain(neo_spiketrain) + + assert isinstance(recovered, nap.Ts) + np.testing.assert_array_almost_equal(original.times(), recovered.times()) + + +# ============================================================================= +# Integration tests +# ============================================================================= + + +class TestIntegration: + """Integration tests for the full workflow.""" + + def test_block_with_all_data_types(self): + """Test processing a block with all data types.""" + block = create_mock_block_with_segments(n_segments=2, n_analog=2, n_spiketrains=3) + + # Verify block structure + assert len(block.segments) == 2 + + for seg in block.segments: + assert len(seg.analogsignals) == 2 + assert len(seg.spiketrains) == 3 + assert len(seg.epochs) == 1 + assert len(seg.events) == 1 + + def test_multi_segment_time_support(self): + """Test that time support correctly spans multiple segments.""" + from pynapple.io.neo import NeoSignalInterface + + block = create_mock_block_with_segments(n_segments=3, n_analog=1, n_spiketrains=0) + + # Build time support from segments + starts = np.array([seg.t_start.rescale("s").magnitude for seg in block.segments]) + ends = np.array([seg.t_stop.rescale("s").magnitude for seg in block.segments]) + time_support = nap.IntervalSet(starts, ends) + + assert len(time_support) == 3 + + interface = NeoSignalInterface( + block.segments[0].analogsignals[0], block, time_support, sig_num=0 + ) + + # Load all data + loaded = interface.load() + assert isinstance(loaded, nap.TsdFrame) \ No newline at end of file From f30807453cef310a6004ab74b7c007c36fef49e7 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Sat, 10 Jan 2026 17:57:36 -0500 Subject: [PATCH 3/7] Update --- doc/api.rst | 12 + pynapple/io/__init__.py | 2 +- pynapple/io/neo.py | 1313 ++++++++++++++++++++++++--------------- 3 files changed, 829 insertions(+), 498 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 3f3b46ce8..5d7af0b0b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -70,6 +70,18 @@ Input-Ouput NWBFile +.. currentmodule:: pynapple.io + +.. rubric:: Neo Reader / LFP + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :recursive: + + neo + + .. currentmodule:: pynapple.io .. rubric:: Numpy files diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index f7991d9ca..73bcd7c98 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -10,7 +10,7 @@ ) from .neo import ( NeoReader, - load_file as load_neo_file, + # load_file as load_neo_file, # to_neo_analogsignal, # to_neo_spiketrain, # to_neo_epoch, diff --git a/pynapple/io/neo.py b/pynapple/io/neo.py index b6f52063d..6bef4de23 100644 --- a/pynapple/io/neo.py +++ b/pynapple/io/neo.py @@ -4,7 +4,7 @@ Neo is a Python package for working with electrophysiology data in Python, supporting many file formats through a unified API. -Data are lazy-loaded by default. The interface behaves like a dictionary. +The interface behaves like a dictionary. For more information on Neo, see: https://neo.readthedocs.io/ @@ -12,13 +12,16 @@ --------------------------------- The following Neo objects are converted to their pynapple equivalents: -- neo.AnalogSignal -> Tsd, TsdFrame, or TsdTensor (depending on shape) -- neo.IrregularlySampledSignal -> Tsd, TsdFrame, or TsdTensor (depending on shape) +- 'neo.AnalogSignal' -> 'Tsd', `TsdFrame`, or `TsdTensor` (depending on shape) [lazy-loaded] +- neo.IrregularlySampledSignal -> Tsd, TsdFrame, or TsdTensor (depending on shape) [lazy-loaded] - neo.SpikeTrain -> Ts - neo.SpikeTrain (list) -> TsGroup - neo.SpikeTrainList -> TsGroup - neo.Epoch -> IntervalSet - neo.Event -> Ts + +Note: All data types support lazy loading. Data is only loaded when accessed +via __getitem__ (e.g., data["TsGroup"]). """ import warnings @@ -93,8 +96,6 @@ def _get_signal_type(signal) -> type: if len(signal.shape) == 1: return nap.Tsd elif len(signal.shape) == 2: - if signal.shape[1] == 1: - return nap.Tsd return nap.TsdFrame else: return nap.TsdTensor @@ -183,111 +184,225 @@ def _make_intervalset_from_epoch(epoch, time_support: Optional[nap.IntervalSet] return iset -def _make_ts_from_event(event, time_support: Optional[nap.IntervalSet] = None) -> nap.Ts: - """Convert a Neo Event to a pynapple Ts. +def _make_intervalset_from_epoch_multiseg( + block, ep_idx: int, time_support: Optional[nap.IntervalSet] = None +) -> nap.IntervalSet: + """Convert Neo Epochs from multiple segments to a pynapple IntervalSet. Parameters ---------- - event : neo.Event or neo.io.proxyobjects.EventProxy - Neo Event object + block : neo.Block + The Neo block containing the segments + ep_idx : int + Index of the epoch in each segment time_support : IntervalSet, optional - Time support for the Ts + Time support for the IntervalSet Returns ------- - Ts - Pynapple Ts object + IntervalSet + Pynapple IntervalSet """ - if hasattr(event, "load"): - event = event.load() + all_starts = [] + all_ends = [] + all_labels = [] - times = event.times.rescale("s").magnitude + for seg in block.segments: + if ep_idx >= len(seg.epochs): + continue - return nap.Ts(t=times, time_support=time_support) + epoch = seg.epochs[ep_idx] + if hasattr(epoch, "load"): + epoch = epoch.load() + times = epoch.times.rescale("s").magnitude + durations = epoch.durations.rescale("s").magnitude -def _make_tsd_from_analog( - signal, - time_support: Optional[nap.IntervalSet] = None, - column_names: Optional[List[str]] = None, -) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: - """Convert a Neo AnalogSignal to a pynapple Tsd/TsdFrame/TsdTensor. + all_starts.extend(times) + all_ends.extend(times + durations) + + if hasattr(epoch, "labels") and len(epoch.labels) > 0: + all_labels.extend(epoch.labels) + + if len(all_starts) == 0: + return nap.IntervalSet(start=[], end=[]) + + metadata = {} + if all_labels: + metadata["label"] = np.array(all_labels) + + return nap.IntervalSet( + start=np.array(all_starts), + end=np.array(all_ends), + metadata=metadata if metadata else None, + ) + + +def _make_ts_from_event_multiseg( + block, ev_idx: int, time_support: Optional[nap.IntervalSet] = None +) -> nap.Ts: + """Convert Neo Events from multiple segments to a pynapple Ts. Parameters ---------- - signal : neo.AnalogSignal or AnalogSignalProxy - Neo analog signal + block : neo.Block + The Neo block containing the segments + ev_idx : int + Index of the event in each segment time_support : IntervalSet, optional Time support - column_names : list of str, optional - Column names for TsdFrame Returns ------- - Tsd, TsdFrame, or TsdTensor - Appropriate pynapple time series object + Ts + Pynapple Ts object """ - if hasattr(signal, "load"): - signal = signal.load() + all_times = [] - times = signal.times.rescale("s").magnitude - data = signal.magnitude + for seg in block.segments: + if ev_idx >= len(seg.events): + continue - nap_type = _get_signal_type(signal) + event = seg.events[ev_idx] + if hasattr(event, "load"): + event = event.load() - if nap_type == nap.Tsd: - if len(data.shape) == 2: - data = data.squeeze() - return nap.Tsd(t=times, d=data, time_support=time_support) - elif nap_type == nap.TsdFrame: - if column_names is None: - # Try to get channel names from annotations - if hasattr(signal, "array_annotations"): - channel_names = signal.array_annotations.get("channel_names", None) - if channel_names is not None: - column_names = list(channel_names) - return nap.TsdFrame(t=times, d=data, columns=column_names, time_support=time_support) - else: - return nap.TsdTensor(t=times, d=data, time_support=time_support) + times = event.times.rescale("s").magnitude + all_times.extend(times) + return nap.Ts(t=np.array(all_times), time_support=time_support) -def _make_tsd_from_irregular( - signal, - time_support: Optional[nap.IntervalSet] = None, - column_names: Optional[List[str]] = None, -) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: - """Convert a Neo IrregularlySampledSignal to a pynapple Tsd/TsdFrame/TsdTensor. + +def _make_ts_from_event(event, time_support: Optional[nap.IntervalSet] = None) -> nap.Ts: + """Convert a Neo Event to a pynapple Ts. Parameters ---------- - signal : neo.IrregularlySampledSignal - Neo irregularly sampled signal + event : neo.Event or neo.io.proxyobjects.EventProxy + Neo Event object time_support : IntervalSet, optional - Time support - column_names : list of str, optional - Column names for TsdFrame + Time support for the Ts + + Returns + ------- + Ts + Pynapple Ts object + """ + if hasattr(event, "load"): + event = event.load() + + times = event.times.rescale("s").magnitude + + return nap.Ts(t=times, time_support=time_support) + +def _make_tsd_from_interface(interface) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: + """Convert a NeoSignalInterface to a pynapple Tsd/TsdFrame/TsdTensor. + + Parameters + ---------- + interface : NeoSignalInterface + The NeoSignalInterface object Returns ------- Tsd, TsdFrame, or TsdTensor Appropriate pynapple time series object """ - if hasattr(signal, "load"): - signal = signal.load() - - times = signal.times.rescale("s").magnitude - data = signal.magnitude + times = interface.times + data = interface - nap_type = _get_signal_type(signal) + nap_type = interface.nap_type if nap_type == nap.Tsd: - if len(data.shape) == 2: - data = data.squeeze() - return nap.Tsd(t=times, d=data, time_support=time_support) + return nap.Tsd(t=times, d=data, time_support=interface.time_support) elif nap_type == nap.TsdFrame: - return nap.TsdFrame(t=times, d=data, columns=column_names, time_support=time_support) + return nap.TsdFrame(t=times, d=data, time_support=interface.time_support, load_array=False) else: - return nap.TsdTensor(t=times, d=data, time_support=time_support) + return nap.TsdTensor(t=times, d=data, time_support=interface.time_support) + +# +# def _make_tsd_from_analog( +# signal, +# time_support: Optional[nap.IntervalSet] = None, +# column_names: Optional[List[str]] = None, +# ) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: +# """Convert a Neo AnalogSignal to a pynapple Tsd/TsdFrame/TsdTensor. +# +# Parameters +# ---------- +# signal : neo.AnalogSignal or AnalogSignalProxy +# Neo analog signal +# time_support : IntervalSet, optional +# Time support +# column_names : list of str, optional +# Column names for TsdFrame +# +# Returns +# ------- +# Tsd, TsdFrame, or TsdTensor +# Appropriate pynapple time series object +# """ +# if hasattr(signal, "load"): +# signal = signal.load() +# +# times = signal.times.rescale("s").magnitude +# data = signal.magnitude +# +# nap_type = _get_signal_type(signal) +# +# if nap_type == nap.Tsd: +# if len(data.shape) == 2: +# data = data.squeeze() +# return nap.Tsd(t=times, d=data, time_support=time_support) +# elif nap_type == nap.TsdFrame: +# if column_names is None: +# # Try to get channel names from annotations +# if hasattr(signal, "array_annotations"): +# channel_names = signal.array_annotations.get("channel_names", None) +# if channel_names is not None: +# column_names = list(channel_names) +# return nap.TsdFrame(t=times, d=data, columns=column_names, time_support=time_support) +# else: +# return nap.TsdTensor(t=times, d=data, time_support=time_support) + + +# def _make_tsd_from_irregular( +# signal, +# time_support: Optional[nap.IntervalSet] = None, +# column_names: Optional[List[str]] = None, +# ) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: +# """Convert a Neo IrregularlySampledSignal to a pynapple Tsd/TsdFrame/TsdTensor. +# +# Parameters +# ---------- +# signal : neo.IrregularlySampledSignal +# Neo irregularly sampled signal +# time_support : IntervalSet, optional +# Time support +# column_names : list of str, optional +# Column names for TsdFrame +# +# Returns +# ------- +# Tsd, TsdFrame, or TsdTensor +# Appropriate pynapple time series object +# """ +# if hasattr(signal, "load"): +# signal = signal.load() +# +# times = signal.times.rescale("s").magnitude +# data = signal.magnitude +# +# nap_type = _get_signal_type(signal) +# +# if nap_type == nap.Tsd: +# if len(data.shape) == 2: +# data = data.squeeze() +# return nap.Tsd(t=times, d=data, time_support=time_support) +# elif nap_type == nap.TsdFrame: +# return nap.TsdFrame(t=times, d=data, columns=column_names, time_support=time_support) +# else: +# return nap.TsdTensor(t=times, d=data, time_support=time_support) def _make_ts_from_spiketrain( @@ -315,6 +430,39 @@ def _make_ts_from_spiketrain( return nap.Ts(t=times, time_support=time_support) +def _make_ts_from_spiketrain_multiseg( + block, unit_idx: int, time_support: Optional[nap.IntervalSet] = None +) -> nap.Ts: + """Convert a Neo SpikeTrain from multiple segments to a pynapple Ts. + + Parameters + ---------- + block : neo.Block + The Neo block containing the segments + unit_idx : int + Index of the spike train in each segment + time_support : IntervalSet, optional + Time support + + Returns + ------- + Ts + Pynapple Ts object + """ + all_times = [] + + for seg in block.segments: + spiketrain = seg.spiketrains[unit_idx] + if hasattr(spiketrain, "load"): + spiketrain = spiketrain.load() + + times = spiketrain.times.rescale("s").magnitude + all_times.append(times) + + spike_times = np.concatenate(all_times) if all_times else np.array([]) + return nap.Ts(t=spike_times, time_support=time_support) + + def _make_tsgroup_from_spiketrains( spiketrains: Union[list, "SpikeTrainList"], time_support: Optional[nap.IntervalSet] = None, @@ -361,6 +509,66 @@ def _make_tsgroup_from_spiketrains( return nap.TsGroup(ts_dict, time_support=time_support, **meta_arrays) +def _make_tsgroup_from_spiketrains_multiseg( + all_spiketrains: List[list], + time_support: Optional[nap.IntervalSet] = None, +) -> nap.TsGroup: + """Convert spike trains from multiple segments to a pynapple TsGroup. + + This function concatenates spike times across segments for each unit. + + Parameters + ---------- + all_spiketrains : list of lists + List of spike train lists, one per segment. Each inner list contains + the spike trains for that segment. + time_support : IntervalSet, optional + Time support + + Returns + ------- + TsGroup + Pynapple TsGroup + """ + if len(all_spiketrains) == 0: + return nap.TsGroup({}, time_support=time_support) + + n_units = len(all_spiketrains[0]) + ts_dict = {} + metadata = {} + + for unit_idx in range(n_units): + all_times = [] + + for seg_spiketrains in all_spiketrains: + st = seg_spiketrains[unit_idx] + if hasattr(st, "load"): + st = st.load() + + times = st.times.rescale("s").magnitude + all_times.append(times) + + # Collect metadata from first segment only + if seg_spiketrains is all_spiketrains[0]: + for key, value in _extract_annotations(st).items(): + if key not in metadata: + metadata[key] = [] + metadata[key].append(value) + + spike_times = np.concatenate(all_times) if all_times else np.array([]) + ts_dict[unit_idx] = nap.Ts(t=spike_times, time_support=time_support) + + # Convert metadata lists to arrays + meta_arrays = {} + for key, values in metadata.items(): + try: + meta_arrays[key] = np.array(values) + except (ValueError, TypeError): + # Skip metadata that can't be converted to array + pass + + return nap.TsGroup(ts_dict, time_support=time_support, **meta_arrays) + # ============================================================================= # Signal Interface for lazy loading @@ -368,15 +576,21 @@ def _make_tsgroup_from_spiketrains( class NeoSignalInterface: - """Interface for lazy-loading Neo signals into pynapple objects. + """Interface for lazy-loading Neo analog signals into pynapple objects. - This class provides lazy access to Neo signals, loading data only when - requested via `get()` or `restrict()` methods. + This class provides lazy access to Neo analog signals (AnalogSignal, + IrregularlySampledSignal), loading data only when requested. It acts as + a pseudo memory-mapped array that can be passed directly to Tsd, TsdFrame, + or TsdTensor initialization with `load_array=False`. + + The interface is array-like (has shape, dtype, ndim, supports indexing + and iteration) so it can be used as a drop-in replacement for numpy arrays + in pynapple time series constructors. Parameters ---------- signal : neo signal object - A Neo signal (AnalogSignal, SpikeTrain, etc.) + A Neo analog signal (AnalogSignal or IrregularlySampledSignal) block : neo.Block The parent block containing the signal time_support : IntervalSet @@ -387,20 +601,34 @@ class NeoSignalInterface: Attributes ---------- nap_type : type - The pynapple type this signal will be converted to + The pynapple type this signal will be converted to (Tsd, TsdFrame, or TsdTensor) is_analog : bool - Whether this is an analog signal + Whether this is a regularly sampled analog signal dt : float Sampling interval (for analog signals) shape : tuple - Shape of the data - start_time : float or list - Start time(s) - end_time : float or list - End time(s) + Shape of the data (total samples across all segments, channels, ...) + dtype : numpy.dtype + Data type of the signal + ndim : int + Number of dimensions + times : numpy.ndarray + Pre-loaded timestamps for all segments (in seconds) + start_time : float + Start time + end_time : float + End time + + Examples + -------- + >>> interface = NeoSignalInterface(signal, block, time_support, sig_num=0) + >>> # Use as array-like for lazy loading + >>> tsd = nap.Tsd(t=interface.times, d=interface, load_array=False) + >>> # Data is only loaded when accessed + >>> chunk = tsd[0:1000] # Loads only first 1000 samples """ - def __init__(self, signal, block, time_support, sig_num=None): + def __init__(self, signal, block, time_support=None, sig_num=0): self.time_support = time_support self._block = block self._sig_num = sig_num @@ -416,46 +644,256 @@ def __init__(self, signal, block, time_support, sig_num=None): self.is_analog = False # Irregularly sampled self.nap_type = _get_signal_type(signal) self._signal_type = "irregular" - elif isinstance(signal, (neo.SpikeTrain, SpikeTrainProxy)): - self.nap_type = nap.Ts - self.is_analog = False - self._signal_type = "spiketrain" - elif isinstance(signal, (list, SpikeTrainList)): - self.nap_type = nap.TsGroup - self.is_analog = False - self._signal_type = "tsgroup" - elif isinstance(signal, (neo.Epoch,)) or (hasattr(neo.io, "proxyobjects") and isinstance(signal, EpochProxy)): - self.nap_type = nap.IntervalSet - self.is_analog = False - self._signal_type = "epoch" - elif isinstance(signal, (neo.Event,)) or (hasattr(neo.io, "proxyobjects") and isinstance(signal, EventProxy)): - self.nap_type = nap.Ts - self.is_analog = False - self._signal_type = "event" else: raise TypeError(f"Signal type {type(signal)} not recognized.") + # Store dtype from signal + self.dtype = signal.dtype + + # Build segment info and compute total shape across all segments + self._segment_offsets = [] # Cumulative sample counts per segment + self._segment_n_samples = [] # Number of samples per segment + self._times_list = [] # Pre-load timestamps per segment (small memory footprint) + + total_samples = 0 + for seg in block.segments: + if self.is_analog: + seg_signal = seg.analogsignals[sig_num] + else: + seg_signal = seg.irregularlysampledsignals[sig_num] + + n_samples = seg_signal.shape[0] + self._segment_offsets.append(total_samples) + self._segment_n_samples.append(n_samples) + total_samples += n_samples + + # Pre-load timestamps (much smaller than data) + if hasattr(seg_signal, "times"): + self._times_list.append(seg_signal.times.rescale("s").magnitude) + else: + self._times_list.append( + np.linspace( + _rescale_to_seconds(seg_signal.t_start), + _rescale_to_seconds(seg_signal.t_stop), + n_samples, + endpoint=False, + ) + ) + + self._segment_offsets = np.array(self._segment_offsets) + self._segment_n_samples = np.array(self._segment_n_samples) + + # Concatenate all timestamps + if self._times_list: + self._times = np.concatenate(self._times_list) + else: + self._times = np.array([]) + + # Compute total shape (first dimension is total samples) + if len(signal.shape) == 1: + self.shape = (total_samples,) + else: + self.shape = (total_samples,) + signal.shape[1:] + # Store timing info if self.is_analog: self.dt = (1 / signal.sampling_rate).rescale("s").magnitude - self.shape = signal.shape - elif self._signal_type == "irregular": - self.shape = signal.shape - if self._signal_type not in ("tsgroup",): - self.start_time = _rescale_to_seconds(signal.t_start) - self.end_time = _rescale_to_seconds(signal.t_stop) - else: - self.start_time = [_rescale_to_seconds(s.t_start) for s in signal] - self.end_time = [_rescale_to_seconds(s.t_stop) for s in signal] + self.start_time = _rescale_to_seconds(signal.t_start) + self.end_time = _rescale_to_seconds(signal.t_stop) def __repr__(self): - return f"" + return f"" + + @property + def ndim(self): + """Number of dimensions.""" + return len(self.shape) + + @property + def times(self): + """Pre-loaded timestamps for all segments (in seconds).""" + return self._times + + def __len__(self): + """Return the number of samples (first dimension of shape).""" + return self.shape[0] + + def __iter__(self): + """Iterate over the first axis, loading data lazily.""" + for i in range(len(self)): + yield self[i] + + def _find_segment_for_index(self, idx): + """Find which segment contains the given global index. + + Returns + ------- + seg_idx : int + Index of the segment + local_idx : int + Index within that segment + """ + if idx < 0: + idx = len(self) + idx + if idx < 0 or idx >= len(self): + raise IndexError(f"Index {idx} out of bounds for size {len(self)}") + + # Find segment using binary search on offsets + seg_idx = np.searchsorted(self._segment_offsets, idx, side='right') - 1 + local_idx = idx - self._segment_offsets[seg_idx] + return seg_idx, local_idx + + def _load_data_range(self, start_idx, stop_idx, step=1): + """Load data for a range of global indices. + + Parameters + ---------- + start_idx : int + Start index (inclusive) + stop_idx : int + Stop index (exclusive) + step : int + Step size + + Returns + ------- + numpy.ndarray + The loaded data + """ + if start_idx >= stop_idx: + # Return empty array with correct shape + if len(self.shape) == 1: + return np.array([], dtype=self.dtype) + else: + return np.empty((0,) + self.shape[1:], dtype=self.dtype) + + data_chunks = [] + + for seg_idx, seg in enumerate(self._block.segments): + seg_start = self._segment_offsets[seg_idx] + seg_end = seg_start + self._segment_n_samples[seg_idx] + + # Check if this segment overlaps with requested range + if stop_idx <= seg_start or start_idx >= seg_end: + continue + + # Calculate local indices within this segment + local_start = max(0, start_idx - seg_start) + local_stop = min(self._segment_n_samples[seg_idx], stop_idx - seg_start) + + # Load data from this segment + if self.is_analog: + signal = seg.analogsignals[self._sig_num] + else: + signal = seg.irregularlysampledsignals[self._sig_num] + + # Try to load with indexing, fall back to time slicing + try: + if hasattr(signal, 'load'): + loaded = signal.load() + chunk = loaded[local_start:local_stop].magnitude + else: + chunk = signal[local_start:local_stop].magnitude + except (MemoryError, AttributeError): + # Fall back to time slicing + t_start = self._times_list[seg_idx][local_start] + t_stop = self._times_list[seg_idx][min(local_stop, len(self._times_list[seg_idx]) - 1)] + chunk = signal.time_slice(t_start, t_stop).magnitude + + data_chunks.append(chunk) + + if not data_chunks: + if len(self.shape) == 1: + return np.array([], dtype=self.dtype) + else: + return np.empty((0,) + self.shape[1:], dtype=self.dtype) + + result = np.concatenate(data_chunks, axis=0) + + # Apply step if needed + if step != 1: + result = result[::step] + + return result def __getitem__(self, item): + """Get data by index, loading lazily from Neo signals. + + Supports integer indexing, slicing, and tuple indexing for + multi-dimensional access. + + Parameters + ---------- + item : int, slice, or tuple + Index specification + + Returns + ------- + numpy.ndarray or scalar + The requested data + """ + # Handle integer indexing + if isinstance(item, (int, np.integer)): + seg_idx, local_idx = self._find_segment_for_index(item) + + if self.is_analog: + signal = self._block.segments[seg_idx].analogsignals[self._sig_num] + else: + signal = self._block.segments[seg_idx].irregularlysampledsignals[self._sig_num] + + try: + if hasattr(signal, 'load'): + loaded = signal.load() + return loaded[local_idx].magnitude + else: + return signal[local_idx].magnitude + except (MemoryError, AttributeError): + # Fall back to time slicing for a single point + t = self._times_list[seg_idx][local_idx] + return signal.time_slice(t, t).magnitude[0] + + # Handle slice indexing if isinstance(item, slice): - return self._get_from_slice(item) - raise ValueError(f"Cannot get item {item}.") + start = item.start if item.start is not None else 0 + stop = item.stop if item.stop is not None else len(self) + step = item.step if item.step is not None else 1 + + # Handle negative indices + if start < 0: + start = len(self) + start + if stop < 0: + stop = len(self) + stop + + return self._load_data_range(start, stop, step) + + # Handle tuple indexing (e.g., interface[0:100, 0] for specific channel) + if isinstance(item, tuple): + # First index is for time dimension + time_idx = item[0] + rest = item[1:] + + # Get data for time dimension + data = self[time_idx] + + # Apply remaining indices + if rest: + data = data[(slice(None),) + rest] if isinstance(time_idx, slice) else data[rest] + + return data + + # Handle numpy array or list indexing + if isinstance(item, (np.ndarray, list)): + indices = np.asarray(item) + if indices.dtype == bool: + # Boolean indexing + indices = np.where(indices)[0] + + # Load each index and stack + result = np.stack([self[int(i)] for i in indices]) + return result + + raise TypeError(f"Invalid index type: {type(item)}") def get(self, start: float, stop: float): """Get data between start and stop times. @@ -476,14 +914,6 @@ def get(self, start: float, stop: float): return self._get_analog(start, stop) elif self._signal_type == "irregular": return self._get_irregular(start, stop) - elif self._signal_type == "spiketrain": - return self._get_ts(self._sig_num, start, stop) - elif self._signal_type == "tsgroup": - return self._get_tsgroup(start, stop) - elif self._signal_type == "epoch": - return self._get_epoch(start, stop) - elif self._signal_type == "event": - return self._get_event(start, stop) def load(self): """Load all data. @@ -512,41 +942,8 @@ def restrict(self, epoch: nap.IntervalSet): """ if self.is_analog: return self._restrict_analog(epoch) - elif self._signal_type == "irregular": - return self._restrict_irregular(epoch) - elif self._signal_type == "spiketrain": - return self._restrict_ts(epoch) - elif self._signal_type == "tsgroup": - return self._restrict_tsgroup(epoch) - elif self._signal_type == "epoch": - return self._get_epoch( - float(epoch.start[0]), float(epoch.end[-1]) - ).restrict(epoch) - elif self._signal_type == "event": - return self._get_event( - float(epoch.start[0]), float(epoch.end[-1]) - ).restrict(epoch) - - def _get_from_slice(self, slc): - start = slc.start if slc.start is not None else 0 - stop = slc.stop - step = slc.step if slc.step is not None else 1 - - if self.is_analog: - if stop is None: - stop = sum( - s.analogsignals[self._sig_num].shape[0] - for s in self._block.segments - ) - return self._slice_segment_analog(start, stop, step) - elif self._signal_type == "spiketrain": - if stop is None: - stop = sum( - len(s.spiketrains[self._sig_num]) for s in self._block.segments - ) - return self._slice_segment_ts(start, stop, step) else: - raise ValueError(f"Cannot slice a {self._signal_type}.") + return self._restrict_irregular(epoch) def _instantiate_nap(self, time, data, time_support): return self.nap_type( @@ -721,158 +1118,6 @@ def _restrict_irregular(self, epoch): t=time, d=data, time_support=self.time_support ).restrict(epoch) - # ========== Spike Train (Ts) Methods ========== - - def _get_ts(self, unit_idx, start, stop, return_array=False): - """Get spike times for a unit within time range.""" - spikes = [] - - for i, seg in enumerate(self._block.segments): - spiketrain = seg.spiketrains[unit_idx] - - seg_start = self.time_support.start[i] - seg_stop = self.time_support.end[i] - - if start >= seg_stop or stop <= seg_start: - continue - - chunk_start = max(start, seg_start) - chunk_stop = min(stop, seg_stop) - - chunk = spiketrain.time_slice(chunk_start, chunk_stop) - - if len(chunk) > 0: - spikes.append(chunk.times.rescale("s").magnitude) - - spike_times = np.concatenate(spikes) if spikes else np.array([]) - - if return_array: - return spike_times - else: - return nap.Ts(t=spike_times, time_support=self.time_support) - - def _restrict_ts(self, epoch): - """Restrict spike train to epochs.""" - spikes = [] - - for start, end in epoch.values: - spike_times = self._get_ts(self._sig_num, start, end, return_array=True) - if len(spike_times) > 0: - spikes.append(spike_times) - - spike_times = np.concatenate(spikes) if spikes else np.array([]) - return nap.Ts(t=spike_times, time_support=self.time_support).restrict(epoch) - - def _slice_segment_ts(self, start_idx, stop_idx, step): - """Slice spike trains by spike index.""" - spikes = [] - - for i, seg in enumerate(self._block.segments): - spiketrain = seg.spiketrains[self._sig_num] - - n_spikes = len(spiketrain) - - seg_start_idx = max(0, start_idx) - seg_stop_idx = min(n_spikes, stop_idx) - - if seg_start_idx >= seg_stop_idx: - continue - - spiketrain_loaded = ( - spiketrain.load() if hasattr(spiketrain, "load") else spiketrain - ) - chunk = spiketrain_loaded[seg_start_idx:seg_stop_idx:step] - - spikes.append(chunk.times.rescale("s").magnitude) - - return nap.Ts( - t=np.concatenate(spikes) if spikes else np.array([]), - time_support=self.time_support, - ) - - # ========== TsGroup Methods ========== - - def _get_tsgroup(self, start, stop): - """Get TsGroup (all units) within time range.""" - n_units = len(self._block.segments[0].spiketrains) - ts_dict = {} - - for unit_idx in range(n_units): - spike_times = self._get_ts(unit_idx, start, stop, return_array=True) - ts_dict[unit_idx] = nap.Ts(t=spike_times, time_support=self.time_support) - - return nap.TsGroup(ts_dict, time_support=self.time_support) - - def _restrict_tsgroup(self, epoch): - """Restrict TsGroup to epochs.""" - n_units = len(self._block.segments[0].spiketrains) - ts_dict = {} - - for unit_idx in range(n_units): - spikes = [] - for start, end in epoch.values: - spike_times = self._get_ts(unit_idx, start, end, return_array=True) - if len(spike_times) > 0: - spikes.append(spike_times) - - spike_times = np.concatenate(spikes) if spikes else np.array([]) - ts_dict[unit_idx] = nap.Ts(t=spike_times, time_support=self.time_support) - - return nap.TsGroup(ts_dict, time_support=self.time_support).restrict(epoch) - - # ========== Epoch Methods ========== - - def _get_epoch(self, start, stop): - """Get epochs within time range.""" - all_starts = [] - all_ends = [] - all_labels = [] - - for i, seg in enumerate(self._block.segments): - for epoch in seg.epochs: - if hasattr(epoch, "load"): - epoch = epoch.load() - - times = epoch.times.rescale("s").magnitude - durations = epoch.durations.rescale("s").magnitude - - for t, d, lbl in zip(times, durations, epoch.labels): - ep_start = t - ep_end = t + d - - # Check overlap with requested range - if ep_end > start and ep_start < stop: - all_starts.append(max(ep_start, start)) - all_ends.append(min(ep_end, stop)) - all_labels.append(lbl) - - if len(all_starts) == 0: - return nap.IntervalSet(start=[], end=[]) - - return nap.IntervalSet( - start=np.array(all_starts), - end=np.array(all_ends), - metadata={"label": np.array(all_labels)} if all_labels else None, - ) - - # ========== Event Methods ========== - - def _get_event(self, start, stop): - """Get events within time range.""" - all_times = [] - - for i, seg in enumerate(self._block.segments): - for event in seg.events: - if hasattr(event, "load"): - event = event.load() - - times = event.times.rescale("s").magnitude - - mask = (times >= start) & (times <= stop) - all_times.extend(times[mask]) - - return nap.Ts(t=np.array(all_times), time_support=self.time_support) - # ============================================================================= # Main Interface Class @@ -954,17 +1199,23 @@ def _collect_data(self): if len(block.segments) > 0: seg = block.segments[0] - # Analog signals + # Analog signals - deferred loading via NeoSignalInterface for sig_idx, signal in enumerate(seg.analogsignals): nap_type = _get_signal_type(signal) name = signal.name if signal.name else f"signal{sig_idx}" key = f"{block_prefix}{nap_type.__name__} {sig_idx}: {name}" - interface = NeoSignalInterface( - signal, block, time_support, sig_num=sig_idx - ) - self._interfaces[key] = interface - self.data[key] = {"type": nap_type.__name__, "interface": interface} + # interface = NeoSignalInterface( + # signal, block, time_support, sig_num=sig_idx + # ) + # self._interfaces[key] = interface + self.data[key] = { + "type": nap_type.__name__, + "loader": "analogsignal", + "block": block, + "sig_num": sig_idx, + "time_support": time_support, + } self._data_info[key] = nap_type.__name__ # Irregularly sampled signals @@ -973,57 +1224,74 @@ def _collect_data(self): name = signal.name if signal.name else f"irregular{sig_idx}" key = f"{block_prefix}{nap_type.__name__} (irregular) {sig_idx}: {name}" - interface = NeoSignalInterface( - signal, block, time_support, sig_num=sig_idx - ) - self._interfaces[key] = interface - self.data[key] = {"type": nap_type.__name__, "interface": interface} + # interface = NeoSignalInterface( + # signal, block, time_support, sig_num=sig_idx + # ) + # self._interfaces[key] = interface + self.data[key] = { + "type": nap_type.__name__, + "loader": "irregularsignal", + "block": block, + "sig_num": sig_idx, + "time_support": time_support, + } self._data_info[key] = nap_type.__name__ - # Spike trains + # Spike trains - deferred loading if len(seg.spiketrains) == 1: st = seg.spiketrains[0] name = st.name if st.name else "spikes" key = f"{block_prefix}Ts: {name}" - interface = NeoSignalInterface( - st, block, time_support, sig_num=0 - ) - self._interfaces[key] = interface - self.data[key] = {"type": "Ts", "interface": interface} + # Store info for deferred loading + self.data[key] = { + "type": "Ts", + "loader": "spiketrain", + "block": block, + "unit_idx": 0, + "time_support": time_support, + } self._data_info[key] = "Ts" elif len(seg.spiketrains) > 1: key = f"{block_prefix}TsGroup" - interface = NeoSignalInterface( - seg.spiketrains, block, time_support - ) - self._interfaces[key] = interface - self.data[key] = {"type": "TsGroup", "interface": interface} + # Store info for deferred loading + self.data[key] = { + "type": "TsGroup", + "loader": "tsgroup", + "block": block, + "time_support": time_support, + } self._data_info[key] = "TsGroup" - # Epochs + # Epochs - deferred loading for ep_idx, epoch in enumerate(seg.epochs): name = epoch.name if hasattr(epoch, "name") and epoch.name else f"epoch{ep_idx}" key = f"{block_prefix}IntervalSet {ep_idx}: {name}" - interface = NeoSignalInterface( - epoch, block, time_support, sig_num=ep_idx - ) - self._interfaces[key] = interface - self.data[key] = {"type": "IntervalSet", "interface": interface} + # Store info for deferred loading + self.data[key] = { + "type": "IntervalSet", + "loader": "epoch", + "block": block, + "ep_idx": ep_idx, + "time_support": time_support, + } self._data_info[key] = "IntervalSet" - # Events + # Events - deferred loading for ev_idx, event in enumerate(seg.events): name = event.name if hasattr(event, "name") and event.name else f"event{ev_idx}" key = f"{block_prefix}Ts (event) {ev_idx}: {name}" - interface = NeoSignalInterface( - event, block, time_support, sig_num=ev_idx - ) - self._interfaces[key] = interface - self.data[key] = {"type": "Ts", "interface": interface} + # Store info for deferred loading + self.data[key] = { + "type": "Ts", + "loader": "event", + "block": block, + "ev_idx": ev_idx, + "time_support": time_support, + } self._data_info[key] = "Ts" def __str__(self): @@ -1055,7 +1323,7 @@ def __getitem__(self, key: str): Returns ------- pynapple object - The requested data (Ts, Tsd, TsdFrame, TsGroup, IntervalSet, etc.) + The requested data (Ts, Tsd, TsdFrame, TsdTensor, TsGroup, IntervalSet) """ if key not in self.data: raise KeyError(f"Key '{key}' not found. Available keys: {list(self.data.keys())}") @@ -1066,9 +1334,51 @@ def __getitem__(self, key: str): if not isinstance(item, dict): return item - # Load via interface - interface = item["interface"] - loaded_data = interface.load() + # Load based on loader type + loader = item.get("loader") + + if loader == "spiketrain": + # Load single spike train from all segments + loaded_data = _make_ts_from_spiketrain_multiseg( + item["block"], + unit_idx=item["unit_idx"], + time_support=item["time_support"], + ) + elif loader == "tsgroup": + # Load TsGroup from all segments + all_spiketrains = [s.spiketrains for s in item["block"].segments] + loaded_data = _make_tsgroup_from_spiketrains_multiseg( + all_spiketrains, + time_support=item["time_support"], + ) + elif loader == "epoch": + # Load IntervalSet from all segments + loaded_data = _make_intervalset_from_epoch_multiseg( + item["block"], + ep_idx=item["ep_idx"], + time_support=item["time_support"], + ) + elif loader == "event": + # Load Ts (event) from all segments + loaded_data = _make_ts_from_event_multiseg( + item["block"], + ev_idx=item["ev_idx"], + time_support=item["time_support"], + ) + elif loader in ["analogsignal", "irregularsignal"]: + # Load via NeoSignalInterface (deferred loading) + interface = NeoSignalInterface( + signal=item["block"].segments[0].analogsignals[item["sig_num"]] + if loader == "analogsignal" + else item["block"].segments[0].irregularlysampledsignals[item["sig_num"]], + block=item["block"], + time_support=item["time_support"], + sig_num=item["sig_num"], + ) + loaded_data = _make_tsd_from_interface(interface) + + else: + raise ValueError(f"Unknown loader type for key '{key}'") # Cache the loaded data self.data[key] = loaded_data @@ -1105,169 +1415,178 @@ def close(self): self._reader.close() -# ============================================================================= -# Legacy Interface (for backward compatibility) -# ============================================================================= -class NEOSignalInterface(NeoSignalInterface): - """Legacy alias for NeoSignalInterface.""" - pass -class NEOExperimentInterface: - """Legacy interface for Neo experiments. - .. deprecated:: - Use :class:`NeoReader` instead. - """ - - def __init__(self, reader, lazy=False): - warnings.warn( - "NEOExperimentInterface is deprecated. Use NeoReader instead.", - DeprecationWarning, - stacklevel=2, - ) - self._reader = reader - self._lazy = lazy - self.experiment = self._collect_time_series_info() - - def _collect_time_series_info(self): - blocks = self._reader.read(lazy=self._lazy) - - experiments = {} - for i, block in enumerate(blocks): - name = f"block {i}" - if block.name: - name += ": " + block.name - experiments[name] = {} - - starts, ends = np.empty(len(block.segments)), np.empty(len(block.segments)) - for trial_num, segment in enumerate(block.segments): - starts[trial_num] = segment.t_start.rescale("s").magnitude - ends[trial_num] = segment.t_stop.rescale("s").magnitude - - iset = nap.IntervalSet(starts, ends) - - for trial_num, segment in enumerate(block.segments): - # Analog signals - for signal_num, signal in enumerate(segment.analogsignals): - if signal.name: - signame = f" {signal_num}: " + signal.name - else: - signame = f" {signal_num}" - signal_interface = NeoSignalInterface( - signal, block, iset, sig_num=signal_num - ) - signame = signal_interface.nap_type.__name__ + signame - experiments[name][signame] = signal_interface - - # Spike trains - if len(segment.spiketrains) == 1: - signal = segment.spiketrains[0] - signal_interface = NeoSignalInterface( - signal, block, iset, sig_num=0 - ) - signame = f"Ts" + ": " + signal.name if signal.name else "Ts" - experiments[name][signame] = signal_interface - else: - signame = f"TsGroup" - experiments[name][signame] = NeoSignalInterface( - segment.spiketrains, block, iset - ) - - return experiments - - def __getitem__(self, item): - if isinstance(item, str): - return self.experiment[item] - else: - res = self.experiment - for it in item: - res = res[it] - return res - - def keys(self): - return [(k, k2) for k in self.experiment.keys() for k2 in self.experiment[k]] - - -def load_file(path: Union[str, Path], lazy: bool = True) -> NeoReader: - """Load a neural recording file using Neo. - - This function automatically detects the file format and uses the - appropriate Neo IO to load the data. - - Parameters - ---------- - path : str or Path - Path to the recording file - lazy : bool, default True - Whether to use lazy loading (recommended for large files) - - Returns - ------- - NeoReader - Interface to the loaded data - - Examples - -------- - >>> import pynapple as nap - >>> data = nap.io.neo.load_file("recording.plx") - >>> print(data) - recording - +---------------------+----------+ - | Key | Type | - +=====================+==========+ - | TsGroup | TsGroup | - | Tsd 0: LFP | Tsd | - +---------------------+----------+ - - >>> spikes = data["TsGroup"] - - See Also - -------- - NeoReader : Class for Neo file interface - - Notes - ----- - Supported formats depend on your Neo installation. Common formats include: - - Plexon (.plx, .pl2) - - Blackrock (.nev, .ns*) - - Spike2 (.smr) - - Neuralynx (.ncs, .nse, .ntt) - - OpenEphys - - Intan (.rhd, .rhs) - - And many more (see Neo documentation) - """ - return NeoReader(path, lazy=lazy) - - -# Legacy alias -def load_experiment(path: Union[str, Path], lazy: bool = True) -> NEOExperimentInterface: - """Load a neural recording experiment. - - .. deprecated:: - Use :func:`load_file` instead. - - Parameters - ---------- - path : str or Path - Path to the recording file - lazy : bool, default True - Whether to lazy load the data - - Returns - ------- - NEOExperimentInterface - """ - import pathlib - - path = pathlib.Path(path) - reader = neo.io.get_io(path) - - return NEOExperimentInterface(reader, lazy=lazy) +# +# +# # ============================================================================= +# # Legacy Interface (for backward compatibility) +# # ============================================================================= +# +# +# class NEOSignalInterface(NeoSignalInterface): +# """Legacy alias for NeoSignalInterface.""" +# pass +# +# +# class NEOExperimentInterface: +# """Legacy interface for Neo experiments. +# +# .. deprecated:: +# Use :class:`NeoReader` instead. +# """ +# +# def __init__(self, reader, lazy=False): +# warnings.warn( +# "NEOExperimentInterface is deprecated. Use NeoReader instead.", +# DeprecationWarning, +# stacklevel=2, +# ) +# self._reader = reader +# self._lazy = lazy +# self.experiment = self._collect_time_series_info() +# +# def _collect_time_series_info(self): +# blocks = self._reader.read(lazy=self._lazy) +# +# experiments = {} +# for i, block in enumerate(blocks): +# name = f"block {i}" +# if block.name: +# name += ": " + block.name +# experiments[name] = {} +# +# starts, ends = np.empty(len(block.segments)), np.empty(len(block.segments)) +# for trial_num, segment in enumerate(block.segments): +# starts[trial_num] = segment.t_start.rescale("s").magnitude +# ends[trial_num] = segment.t_stop.rescale("s").magnitude +# +# iset = nap.IntervalSet(starts, ends) +# +# for trial_num, segment in enumerate(block.segments): +# # Analog signals +# for signal_num, signal in enumerate(segment.analogsignals): +# if signal.name: +# signame = f" {signal_num}: " + signal.name +# else: +# signame = f" {signal_num}" +# signal_interface = NeoSignalInterface( +# signal, block, iset, sig_num=signal_num +# ) +# signame = signal_interface.nap_type.__name__ + signame +# experiments[name][signame] = signal_interface +# +# # Spike trains +# if len(segment.spiketrains) == 1: +# signal = segment.spiketrains[0] +# signal_interface = NeoSignalInterface( +# signal, block, iset, sig_num=0 +# ) +# signame = f"Ts" + ": " + signal.name if signal.name else "Ts" +# experiments[name][signame] = signal_interface +# else: +# signame = f"TsGroup" +# experiments[name][signame] = NeoSignalInterface( +# segment.spiketrains, block, iset +# ) +# +# return experiments +# +# def __getitem__(self, item): +# if isinstance(item, str): +# return self.experiment[item] +# else: +# res = self.experiment +# for it in item: +# res = res[it] +# return res +# +# def keys(self): +# return [(k, k2) for k in self.experiment.keys() for k2 in self.experiment[k]] +# +# +# def load_file(path: Union[str, Path], lazy: bool = True) -> NeoReader: +# """Load a neural recording file using Neo. +# +# This function automatically detects the file format and uses the +# appropriate Neo IO to load the data. +# +# Parameters +# ---------- +# path : str or Path +# Path to the recording file +# lazy : bool, default True +# Whether to use lazy loading (recommended for large files) +# +# Returns +# ------- +# NeoReader +# Interface to the loaded data +# +# Examples +# -------- +# >>> import pynapple as nap +# >>> data = nap.io.neo.load_file("recording.plx") +# >>> print(data) +# recording +# +---------------------+----------+ +# | Key | Type | +# +=====================+==========+ +# | TsGroup | TsGroup | +# | Tsd 0: LFP | Tsd | +# +---------------------+----------+ +# +# >>> spikes = data["TsGroup"] +# +# See Also +# -------- +# NeoReader : Class for Neo file interface +# +# Notes +# ----- +# Supported formats depend on your Neo installation. Common formats include: +# - Plexon (.plx, .pl2) +# - Blackrock (.nev, .ns*) +# - Spike2 (.smr) +# - Neuralynx (.ncs, .nse, .ntt) +# - OpenEphys +# - Intan (.rhd, .rhs) +# - And many more (see Neo documentation) +# """ +# return NeoReader(path, lazy=lazy) +# +# +# # Legacy alias +# def load_experiment(path: Union[str, Path], lazy: bool = True) -> NEOExperimentInterface: +# """Load a neural recording experiment. +# +# .. deprecated:: +# Use :func:`load_file` instead. +# +# Parameters +# ---------- +# path : str or Path +# Path to the recording file +# lazy : bool, default True +# Whether to lazy load the data +# +# Returns +# ------- +# NEOExperimentInterface +# """ +# import pathlib +# +# path = pathlib.Path(path) +# reader = neo.io.get_io(path) +# +# return NEOExperimentInterface(reader, lazy=lazy) +# +# # # # ============================================================================= # # Conversion functions: Pynapple -> Neo From 870e6d26d639c6cc428cfe57f23d6a4c7ab85c9d Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 12 Jan 2026 17:21:48 -0500 Subject: [PATCH 4/7] update neo --- doc/user_guide/02_input_output.md | 38 +++ pynapple/io/__init__.py | 7 +- pynapple/io/{neo.py => interface_neo.py} | 380 ++--------------------- tests/test_neo.py | 54 ++-- 4 files changed, 87 insertions(+), 392 deletions(-) rename pynapple/io/{neo.py => interface_neo.py} (78%) diff --git a/doc/user_guide/02_input_output.md b/doc/user_guide/02_input_output.md index 9588ff9bb..00aeabb73 100644 --- a/doc/user_guide/02_input_output.md +++ b/doc/user_guide/02_input_output.md @@ -133,6 +133,44 @@ z = data['z'] print(type(z.d)) ``` +## LFP loading & NEO compatibility + +Raw LFP data can be loaded with pynapple through the NEO library. +Internally, pynapple uses the NEO raw IO classes to read the data and convert them to one of the pynapple time series object. +This is done through the class [`nap.NeoReader`](pynapple.io.neo.NeoReader). + +See here the [list of supported formats](https://neo.readthedocs.io/en/stable/rawiolist.html)). + + +Here is a minimal working example : + +```{code-cell} ipython3 +:tags: [hide-cell] +import urllib +distantfile = "https://web.gin.g-node.org/NeuralEnsemble/ephy_testing_data/raw/master/plexon/File_plexon_3.plx" +urllib.request.urlretrieve(distantfile, "File_plexon_3.plx") +``` + +```{code-cell} ipython3 +data = nap.NeoReader("File_plexon_3.plx") +print(data) +``` + +```{code-cell} ipython3 +lfp = data['TsdFrame 0: V'] +spikes = data['TsGroup'] +``` + +```{code-cell} ipython3 +fig = pyplot.figure() +ax1 = fig.add_subplot(2, 1, 1) +ax2 = fig.add_subplot(2, 1, 2) +ax1.plot(lfp.t, lfp.d[:]) +colors = pyplot.cm.jet(np.linspace(0, 1, len(spikes))) +ax2.eventplot([spikes[i].t for i in spikes.keys()], colors=colors) +plt.show() +``` + ## Saving as NPZ Pynapple objects have [`save`](pynapple.Tsd.save) methods to save them as npz files. diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index 73bcd7c98..cf0b5da6b 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -8,11 +8,6 @@ load_folder, load_session, ) -from .neo import ( +from .interface_neo import ( NeoReader, - # load_file as load_neo_file, - # to_neo_analogsignal, - # to_neo_spiketrain, - # to_neo_epoch, - # to_neo_event, ) diff --git a/pynapple/io/neo.py b/pynapple/io/interface_neo.py similarity index 78% rename from pynapple/io/neo.py rename to pynapple/io/interface_neo.py index 6bef4de23..0f04de260 100644 --- a/pynapple/io/neo.py +++ b/pynapple/io/interface_neo.py @@ -295,116 +295,6 @@ def _make_ts_from_event(event, time_support: Optional[nap.IntervalSet] = None) - return nap.Ts(t=times, time_support=time_support) -def _make_tsd_from_interface(interface) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: - """Convert a NeoSignalInterface to a pynapple Tsd/TsdFrame/TsdTensor. - - Parameters - ---------- - interface : NeoSignalInterface - The NeoSignalInterface object - - Returns - ------- - Tsd, TsdFrame, or TsdTensor - Appropriate pynapple time series object - """ - times = interface.times - data = interface - - nap_type = interface.nap_type - - if nap_type == nap.Tsd: - return nap.Tsd(t=times, d=data, time_support=interface.time_support) - elif nap_type == nap.TsdFrame: - return nap.TsdFrame(t=times, d=data, time_support=interface.time_support, load_array=False) - else: - return nap.TsdTensor(t=times, d=data, time_support=interface.time_support) - -# -# def _make_tsd_from_analog( -# signal, -# time_support: Optional[nap.IntervalSet] = None, -# column_names: Optional[List[str]] = None, -# ) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: -# """Convert a Neo AnalogSignal to a pynapple Tsd/TsdFrame/TsdTensor. -# -# Parameters -# ---------- -# signal : neo.AnalogSignal or AnalogSignalProxy -# Neo analog signal -# time_support : IntervalSet, optional -# Time support -# column_names : list of str, optional -# Column names for TsdFrame -# -# Returns -# ------- -# Tsd, TsdFrame, or TsdTensor -# Appropriate pynapple time series object -# """ -# if hasattr(signal, "load"): -# signal = signal.load() -# -# times = signal.times.rescale("s").magnitude -# data = signal.magnitude -# -# nap_type = _get_signal_type(signal) -# -# if nap_type == nap.Tsd: -# if len(data.shape) == 2: -# data = data.squeeze() -# return nap.Tsd(t=times, d=data, time_support=time_support) -# elif nap_type == nap.TsdFrame: -# if column_names is None: -# # Try to get channel names from annotations -# if hasattr(signal, "array_annotations"): -# channel_names = signal.array_annotations.get("channel_names", None) -# if channel_names is not None: -# column_names = list(channel_names) -# return nap.TsdFrame(t=times, d=data, columns=column_names, time_support=time_support) -# else: -# return nap.TsdTensor(t=times, d=data, time_support=time_support) - - -# def _make_tsd_from_irregular( -# signal, -# time_support: Optional[nap.IntervalSet] = None, -# column_names: Optional[List[str]] = None, -# ) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: -# """Convert a Neo IrregularlySampledSignal to a pynapple Tsd/TsdFrame/TsdTensor. -# -# Parameters -# ---------- -# signal : neo.IrregularlySampledSignal -# Neo irregularly sampled signal -# time_support : IntervalSet, optional -# Time support -# column_names : list of str, optional -# Column names for TsdFrame -# -# Returns -# ------- -# Tsd, TsdFrame, or TsdTensor -# Appropriate pynapple time series object -# """ -# if hasattr(signal, "load"): -# signal = signal.load() -# -# times = signal.times.rescale("s").magnitude -# data = signal.magnitude -# -# nap_type = _get_signal_type(signal) -# -# if nap_type == nap.Tsd: -# if len(data.shape) == 2: -# data = data.squeeze() -# return nap.Tsd(t=times, d=data, time_support=time_support) -# elif nap_type == nap.TsdFrame: -# return nap.TsdFrame(t=times, d=data, columns=column_names, time_support=time_support) -# else: -# return nap.TsdTensor(t=times, d=data, time_support=time_support) - - def _make_ts_from_spiketrain( spiketrain, time_support: Optional[nap.IntervalSet] = None ) -> nap.Ts: @@ -569,6 +459,24 @@ def _make_tsgroup_from_spiketrains_multiseg( return nap.TsGroup(ts_dict, time_support=time_support, **meta_arrays) +def _make_tsd_from_interface(interface) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: + """Convert a NeoSignalInterface to a pynapple Tsd/TsdFrame/TsdTensor. + + Parameters + ---------- + interface : NeoSignalInterface + The NeoSignalInterface object + + Returns + ------- + Tsd, TsdFrame, or TsdTensor + Appropriate pynapple time series object + """ + nap_type = interface.nap_type + + # return nap_type(t=times, d=data, time_support=interface.time_support, load_array=False) + return nap_type(t=interface.times, d=interface, load_array=False) + # ============================================================================= # Signal Interface for lazy loading @@ -691,9 +599,9 @@ def __init__(self, signal, block, time_support=None, sig_num=0): # Compute total shape (first dimension is total samples) if len(signal.shape) == 1: - self.shape = (total_samples,) + self.shape = (int(total_samples),) else: - self.shape = (total_samples,) + signal.shape[1:] + self.shape = (int(total_samples),) + signal.shape[1:] # Store timing info if self.is_analog: @@ -835,23 +743,7 @@ def __getitem__(self, item): """ # Handle integer indexing if isinstance(item, (int, np.integer)): - seg_idx, local_idx = self._find_segment_for_index(item) - - if self.is_analog: - signal = self._block.segments[seg_idx].analogsignals[self._sig_num] - else: - signal = self._block.segments[seg_idx].irregularlysampledsignals[self._sig_num] - - try: - if hasattr(signal, 'load'): - loaded = signal.load() - return loaded[local_idx].magnitude - else: - return signal[local_idx].magnitude - except (MemoryError, AttributeError): - # Fall back to time slicing for a single point - t = self._times_list[seg_idx][local_idx] - return signal.time_slice(t, t).magnitude[0] + return self._load_data_range(item, item + 1)[0] # Handle slice indexing if isinstance(item, slice): @@ -895,228 +787,6 @@ def __getitem__(self, item): raise TypeError(f"Invalid index type: {type(item)}") - def get(self, start: float, stop: float): - """Get data between start and stop times. - - Parameters - ---------- - start : float - Start time in seconds - stop : float - Stop time in seconds - - Returns - ------- - pynapple object - Data restricted to the time range - """ - if self.is_analog: - return self._get_analog(start, stop) - elif self._signal_type == "irregular": - return self._get_irregular(start, stop) - - def load(self): - """Load all data. - - Returns - ------- - pynapple object - The fully loaded data - """ - start = float(self.time_support.start[0]) - end = float(self.time_support.end[-1]) - return self.get(start, end) - - def restrict(self, epoch: nap.IntervalSet): - """Restrict data to epochs. - - Parameters - ---------- - epoch : IntervalSet - Epochs to restrict to - - Returns - ------- - pynapple object - Data restricted to the epochs - """ - if self.is_analog: - return self._restrict_analog(epoch) - else: - return self._restrict_irregular(epoch) - - def _instantiate_nap(self, time, data, time_support): - return self.nap_type( - t=time, - d=data, - time_support=time_support, - ) - - def _concatenate_array(self, time_list, data_list): - if len(data_list) == 0: - shape = getattr(self, "shape", (0, 1)) - return np.array([]), np.array([]).reshape( - (0, *shape[1:]) if len(shape) > 1 else (0,) - ) - else: - return np.concatenate(time_list), np.concatenate(data_list, axis=0) - - # ========== Analog Signal Methods ========== - - def _get_analog(self, start, stop, return_array=False): - """Get analog signal between start and stop times.""" - data = [] - time = [] - - for i, seg in enumerate(self._block.segments): - signal = seg.analogsignals[self._sig_num] - - seg_start = self.time_support.start[i] - seg_stop = self.time_support.end[i] - - if start >= seg_stop or stop <= seg_start: - continue - - chunk_start = max(start, seg_start) - chunk_stop = min(stop, seg_stop) - - chunk = signal.time_slice(chunk_start, chunk_stop) - - if chunk.shape[0] > 0: - data.append(chunk.magnitude) - time.append(chunk.times.rescale("s").magnitude) - - time, data = self._concatenate_array(time, data) - if not return_array: - return self._instantiate_nap(time, data, time_support=self.time_support) - else: - return time, data - - def _restrict_analog(self, epoch): - """Restrict analog signal to epochs.""" - time = [] - data = [] - - for start, end in epoch.values: - time_ep, data_ep = self._get_analog(start, end, return_array=True) - time.append(time_ep) - data.append(data_ep) - - time, data = self._concatenate_array(time, data) - return self._instantiate_nap(time, data, self.time_support).restrict(epoch) - - def _slice_segment_analog(self, start_idx, stop_idx, step): - """Load by exact indices from each segment.""" - data = [] - time = [] - - for i, seg in enumerate(self._block.segments): - signal = seg.analogsignals[self._sig_num] - - seg_start_time = self.time_support.start[i] - seg_end_time = self.time_support.end[i] - seg_duration = seg_end_time - seg_start_time - seg_n_samples = signal.shape[0] - - dt = seg_duration / seg_n_samples - - seg_start_idx = max(0, start_idx) - seg_stop_idx = min(seg_n_samples, stop_idx) - - if seg_start_idx >= seg_stop_idx: - continue - - try: - signal_loaded = signal.load() - chunk = signal_loaded[seg_start_idx:seg_stop_idx:step] - except MemoryError: - chunk_start_time = seg_start_time + seg_start_idx * dt - chunk_stop_time = seg_start_time + seg_stop_idx * dt - chunk = signal.time_slice(chunk_start_time, chunk_stop_time) - if step != 1: - chunk = chunk[::step] - - data.append(chunk.magnitude) - time.append(chunk.times.rescale("s").magnitude) - - time, data = self._concatenate_array(time, data) - return self._instantiate_nap(time, data, time_support=self.time_support) - - # ========== Irregularly Sampled Signal Methods ========== - - def _get_irregular(self, start, stop, return_array=False): - """Get irregularly sampled signal between start and stop times.""" - data = [] - time = [] - - for i, seg in enumerate(self._block.segments): - signal = seg.irregularlysampledsignals[self._sig_num] - - seg_start = self.time_support.start[i] - seg_stop = self.time_support.end[i] - - if start >= seg_stop or stop <= seg_start: - continue - - chunk_start = max(start, seg_start) - chunk_stop = min(stop, seg_stop) - - chunk = signal.time_slice(chunk_start, chunk_stop) - - if chunk.shape[0] > 0: - data.append(chunk.magnitude) - time.append(chunk.times.rescale("s").magnitude) - - if len(time) == 0: - time = np.array([]) - data = np.array([]) - else: - time = np.concatenate(time) - data = np.concatenate(data, axis=0) - - if not return_array: - if data.ndim == 1 or (data.ndim == 2 and data.shape[1] == 1): - if data.ndim == 2: - data = data.squeeze() - return nap.Tsd(t=time, d=data, time_support=self.time_support) - elif data.ndim == 2: - return nap.TsdFrame(t=time, d=data, time_support=self.time_support) - else: - return nap.TsdTensor(t=time, d=data, time_support=self.time_support) - else: - return time, data - - def _restrict_irregular(self, epoch): - """Restrict irregularly sampled signal to epochs.""" - time = [] - data = [] - - for start, end in epoch.values: - time_ep, data_ep = self._get_irregular(start, end, return_array=True) - if len(time_ep) > 0: - time.append(time_ep) - data.append(data_ep) - - if len(time) == 0: - return nap.Tsd(t=np.array([]), d=np.array([]), time_support=epoch) - - time = np.concatenate(time) - data = np.concatenate(data, axis=0) - - if data.ndim == 1 or (data.ndim == 2 and data.shape[1] == 1): - if data.ndim == 2: - data = data.squeeze() - return nap.Tsd(t=time, d=data, time_support=self.time_support).restrict( - epoch - ) - elif data.ndim == 2: - return nap.TsdFrame( - t=time, d=data, time_support=self.time_support - ).restrict(epoch) - else: - return nap.TsdTensor( - t=time, d=data, time_support=self.time_support - ).restrict(epoch) # ============================================================================= @@ -1205,10 +875,6 @@ def _collect_data(self): name = signal.name if signal.name else f"signal{sig_idx}" key = f"{block_prefix}{nap_type.__name__} {sig_idx}: {name}" - # interface = NeoSignalInterface( - # signal, block, time_support, sig_num=sig_idx - # ) - # self._interfaces[key] = interface self.data[key] = { "type": nap_type.__name__, "loader": "analogsignal", @@ -1224,10 +890,6 @@ def _collect_data(self): name = signal.name if signal.name else f"irregular{sig_idx}" key = f"{block_prefix}{nap_type.__name__} (irregular) {sig_idx}: {name}" - # interface = NeoSignalInterface( - # signal, block, time_support, sig_num=sig_idx - # ) - # self._interfaces[key] = interface self.data[key] = { "type": nap_type.__name__, "loader": "irregularsignal", diff --git a/tests/test_neo.py b/tests/test_neo.py index 685b0ddee..12debdff3 100644 --- a/tests/test_neo.py +++ b/tests/test_neo.py @@ -147,7 +147,7 @@ class TestNeoToPynapple: def test_analog_to_tsd(self): """Test AnalogSignal to Tsd conversion.""" - from pynapple.io.neo import _make_tsd_from_analog + from pynapple.io.interface_neo import _make_tsd_from_analog # Single channel -> Tsd signal = create_mock_analog_signal(n_samples=100, n_channels=1) @@ -157,7 +157,7 @@ def test_analog_to_tsd(self): def test_analog_to_tsdframe(self): """Test AnalogSignal to TsdFrame conversion.""" - from pynapple.io.neo import _make_tsd_from_analog + from pynapple.io.interface_neo import _make_tsd_from_analog # Multi-channel -> TsdFrame signal = create_mock_analog_signal(n_samples=100, n_channels=3) @@ -168,7 +168,7 @@ def test_analog_to_tsdframe(self): def test_spiketrain_to_ts(self): """Test SpikeTrain to Ts conversion.""" - from pynapple.io.neo import _make_ts_from_spiketrain + from pynapple.io.interface_neo import _make_ts_from_spiketrain spiketrain = create_mock_spiketrain(n_spikes=50) ts = _make_ts_from_spiketrain(spiketrain) @@ -177,7 +177,7 @@ def test_spiketrain_to_ts(self): def test_spiketrains_to_tsgroup(self): """Test multiple SpikeTrains to TsGroup conversion.""" - from pynapple.io.neo import _make_tsgroup_from_spiketrains + from pynapple.io.interface_neo import _make_tsgroup_from_spiketrains spiketrains = [create_mock_spiketrain(n_spikes=30) for _ in range(5)] tsgroup = _make_tsgroup_from_spiketrains(spiketrains) @@ -186,7 +186,7 @@ def test_spiketrains_to_tsgroup(self): def test_epoch_to_intervalset(self): """Test Epoch to IntervalSet conversion.""" - from pynapple.io.neo import _make_intervalset_from_epoch + from pynapple.io.interface_neo import _make_intervalset_from_epoch epoch = create_mock_epoch(n_epochs=5) iset = _make_intervalset_from_epoch(epoch) @@ -195,7 +195,7 @@ def test_epoch_to_intervalset(self): def test_event_to_ts(self): """Test Event to Ts conversion.""" - from pynapple.io.neo import _make_ts_from_event + from pynapple.io.interface_neo import _make_ts_from_event event = create_mock_event(n_events=10) ts = _make_ts_from_event(event) @@ -204,7 +204,7 @@ def test_event_to_ts(self): def test_irregular_signal_to_tsd(self): """Test IrregularlySampledSignal to Tsd conversion.""" - from pynapple.io.neo import _make_tsd_from_irregular + from pynapple.io.interface_neo import _make_tsd_from_irregular signal = create_mock_irregular_signal(n_samples=50, n_channels=1) tsd = _make_tsd_from_irregular(signal) @@ -213,7 +213,7 @@ def test_irregular_signal_to_tsd(self): def test_irregular_signal_to_tsdframe(self): """Test IrregularlySampledSignal to TsdFrame conversion.""" - from pynapple.io.neo import _make_tsd_from_irregular + from pynapple.io.interface_neo import _make_tsd_from_irregular signal = create_mock_irregular_signal(n_samples=50, n_channels=3) tsdframe = _make_tsd_from_irregular(signal) @@ -232,7 +232,7 @@ class TestPynappleToNeo: def test_tsd_to_analog(self): """Test Tsd to AnalogSignal conversion.""" - from pynapple.io.neo import to_neo_analogsignal + from pynapple.io.interface_neo import to_neo_analogsignal tsd = nap.Tsd(t=np.arange(100) / 1000.0, d=np.random.randn(100)) signal = to_neo_analogsignal(tsd, units="mV") @@ -243,7 +243,7 @@ def test_tsd_to_analog(self): def test_tsdframe_to_analog(self): """Test TsdFrame to AnalogSignal conversion.""" - from pynapple.io.neo import to_neo_analogsignal + from pynapple.io.interface_neo import to_neo_analogsignal tsdframe = nap.TsdFrame( t=np.arange(100) / 1000.0, d=np.random.randn(100, 3) @@ -255,7 +255,7 @@ def test_tsdframe_to_analog(self): def test_ts_to_spiketrain(self): """Test Ts to SpikeTrain conversion.""" - from pynapple.io.neo import to_neo_spiketrain + from pynapple.io.interface_neo import to_neo_spiketrain ts = nap.Ts(t=np.sort(np.random.uniform(0, 10, 50))) spiketrain = to_neo_spiketrain(ts, t_stop=10.0) @@ -265,7 +265,7 @@ def test_ts_to_spiketrain(self): def test_intervalset_to_epoch(self): """Test IntervalSet to Epoch conversion.""" - from pynapple.io.neo import to_neo_epoch + from pynapple.io.interface_neo import to_neo_epoch iset = nap.IntervalSet(start=[0, 5, 10], end=[2, 7, 12]) epoch = to_neo_epoch(iset) @@ -281,7 +281,7 @@ def test_intervalset_to_epoch(self): def test_ts_to_event(self): """Test Ts to Event conversion.""" - from pynapple.io.neo import to_neo_event + from pynapple.io.interface_neo import to_neo_event ts = nap.Ts(t=np.array([1.0, 2.5, 5.0, 7.5])) event = to_neo_event(ts) @@ -303,7 +303,7 @@ class TestNeoSignalInterface: def test_analog_interface_init(self): """Test initialization with AnalogSignal.""" - from pynapple.io.neo import NeoSignalInterface + from pynapple.io.interface_neo import NeoSignalInterface block = create_mock_block_with_segments(n_segments=1, n_analog=1, n_spiketrains=0) signal = block.segments[0].analogsignals[0] @@ -317,7 +317,7 @@ def test_analog_interface_init(self): def test_spiketrain_interface_init(self): """Test initialization with SpikeTrain.""" - from pynapple.io.neo import NeoSignalInterface + from pynapple.io.interface_neo import NeoSignalInterface block = create_mock_block_with_segments(n_segments=1, n_analog=0, n_spiketrains=1) spiketrain = block.segments[0].spiketrains[0] @@ -331,7 +331,7 @@ def test_spiketrain_interface_init(self): def test_tsgroup_interface_init(self): """Test initialization with multiple SpikeTrains.""" - from pynapple.io.neo import NeoSignalInterface + from pynapple.io.interface_neo import NeoSignalInterface block = create_mock_block_with_segments(n_segments=1, n_analog=0, n_spiketrains=3) spiketrains = block.segments[0].spiketrains @@ -345,7 +345,7 @@ def test_tsgroup_interface_init(self): def test_interface_load(self): """Test loading data through interface.""" - from pynapple.io.neo import NeoSignalInterface + from pynapple.io.interface_neo import NeoSignalInterface block = create_mock_block_with_segments(n_segments=1, n_analog=1, n_spiketrains=0) signal = block.segments[0].analogsignals[0] @@ -358,7 +358,7 @@ def test_interface_load(self): def test_interface_get_time_range(self): """Test getting data for a time range.""" - from pynapple.io.neo import NeoSignalInterface + from pynapple.io.interface_neo import NeoSignalInterface block = create_mock_block_with_segments(n_segments=1, n_analog=1, n_spiketrains=0) signal = block.segments[0].analogsignals[0] @@ -371,7 +371,7 @@ def test_interface_get_time_range(self): def test_interface_restrict(self): """Test restricting data to epochs.""" - from pynapple.io.neo import NeoSignalInterface + from pynapple.io.interface_neo import NeoSignalInterface block = create_mock_block_with_segments(n_segments=1, n_analog=1, n_spiketrains=0) signal = block.segments[0].analogsignals[0] @@ -394,7 +394,7 @@ class TestLegacyInterface: def test_legacy_deprecation_warning(self): """Test that legacy interface raises deprecation warning.""" - from pynapple.io.neo import NEOExperimentInterface + from pynapple.io.interface_neo import NEOExperimentInterface # Create a simple mock reader class MockReader: @@ -415,7 +415,7 @@ class TestHelperFunctions: def test_rescale_to_seconds(self): """Test rescaling quantities to seconds.""" - from pynapple.io.neo import _rescale_to_seconds + from pynapple.io.interface_neo import _rescale_to_seconds # Test milliseconds value_ms = 1000 * pq.ms @@ -427,7 +427,7 @@ def test_rescale_to_seconds(self): def test_get_signal_type(self): """Test signal type detection.""" - from pynapple.io.neo import _get_signal_type + from pynapple.io.interface_neo import _get_signal_type # 1D signal -> Tsd signal_1d = neo.AnalogSignal( @@ -443,7 +443,7 @@ def test_get_signal_type(self): def test_extract_annotations(self): """Test annotation extraction.""" - from pynapple.io.neo import _extract_annotations + from pynapple.io.interface_neo import _extract_annotations signal = neo.AnalogSignal( np.random.randn(100, 1), @@ -474,7 +474,7 @@ class TestRoundTrip: def test_tsd_roundtrip(self): """Test Tsd round-trip conversion.""" - from pynapple.io.neo import to_neo_analogsignal, _make_tsd_from_analog + from pynapple.io.interface_neo import to_neo_analogsignal, _make_tsd_from_analog original = nap.Tsd(t=np.arange(100) / 1000.0, d=np.random.randn(100)) @@ -489,7 +489,7 @@ def test_tsd_roundtrip(self): def test_intervalset_roundtrip(self): """Test IntervalSet round-trip conversion.""" - from pynapple.io.neo import to_neo_epoch, _make_intervalset_from_epoch + from pynapple.io.interface_neo import to_neo_epoch, _make_intervalset_from_epoch original = nap.IntervalSet(start=[1.0, 5.0, 10.0], end=[2.0, 7.0, 12.0]) @@ -505,7 +505,7 @@ def test_intervalset_roundtrip(self): def test_ts_roundtrip(self): """Test Ts round-trip conversion via SpikeTrain.""" - from pynapple.io.neo import to_neo_spiketrain, _make_ts_from_spiketrain + from pynapple.io.interface_neo import to_neo_spiketrain, _make_ts_from_spiketrain spike_times = np.sort(np.random.uniform(0, 10, 50)) original = nap.Ts(t=spike_times) @@ -543,7 +543,7 @@ def test_block_with_all_data_types(self): def test_multi_segment_time_support(self): """Test that time support correctly spans multiple segments.""" - from pynapple.io.neo import NeoSignalInterface + from pynapple.io.interface_neo import NeoSignalInterface block = create_mock_block_with_segments(n_segments=3, n_analog=1, n_spiketrains=0) From 31ee33aded886301900a3bb17c9b783d881f8eeb Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 29 Jan 2026 14:23:53 -0500 Subject: [PATCH 5/7] Update neo --- doc/user_guide/02_input_output.md | 13 +++++++++---- pynapple/io/__init__.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/doc/user_guide/02_input_output.md b/doc/user_guide/02_input_output.md index 00aeabb73..448c56e6f 100644 --- a/doc/user_guide/02_input_output.md +++ b/doc/user_guide/02_input_output.md @@ -42,6 +42,8 @@ import pynapple as nap import os import requests, math import tqdm +import matplotlib.pyplot as plt +import numpy as np nwb_path = 'A2929-200711.nwb' @@ -139,7 +141,7 @@ Raw LFP data can be loaded with pynapple through the NEO library. Internally, pynapple uses the NEO raw IO classes to read the data and convert them to one of the pynapple time series object. This is done through the class [`nap.NeoReader`](pynapple.io.neo.NeoReader). -See here the [list of supported formats](https://neo.readthedocs.io/en/stable/rawiolist.html)). +See here the [list of supported formats](https://neo.readthedocs.io/en/stable/rawiolist.html). Here is a minimal working example : @@ -152,7 +154,10 @@ urllib.request.urlretrieve(distantfile, "File_plexon_3.plx") ``` ```{code-cell} ipython3 -data = nap.NeoReader("File_plexon_3.plx") +:tags: [hide-output] +data = nap.NeoReader("File_plexon_3.plx"); +``` +```{code-cell} ipython3 print(data) ``` @@ -162,11 +167,11 @@ spikes = data['TsGroup'] ``` ```{code-cell} ipython3 -fig = pyplot.figure() +fig = plt.figure() ax1 = fig.add_subplot(2, 1, 1) ax2 = fig.add_subplot(2, 1, 2) ax1.plot(lfp.t, lfp.d[:]) -colors = pyplot.cm.jet(np.linspace(0, 1, len(spikes))) +colors = plt.cm.jet(np.linspace(0, 1, len(spikes))) ax2.eventplot([spikes[i].t for i in spikes.keys()], colors=colors) plt.show() ``` diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index dce6e46ca..3f884bd9b 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -1,6 +1,6 @@ from .folder import Folder from .interface_npz import NPZFile from .interface_nwb import NWBFile -from .interface_neo import NeoReader +from .interface_neo import NeoReader, NeoSignalInterface from .misc import append_NWB_LFP, load_eeg, load_file, load_folder, load_session From 9c730ba36a094e6c192c0ed9ce98ea8551d7a984 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 29 Jan 2026 16:07:20 -0500 Subject: [PATCH 6/7] adding better docstrings --- doc/api.rst | 4 +- doc/user_guide/02_input_output.md | 7 +- pynapple/io/interface_neo.py | 142 +++++++++++++++++++++--------- 3 files changed, 109 insertions(+), 44 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 47e98db4a..6ee2deed3 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -79,8 +79,8 @@ Input-Ouput :nosignatures: :recursive: - NeoReader - NeoSignalInterface + NeoReader + NeoSignalInterface .. currentmodule:: pynapple.io diff --git a/doc/user_guide/02_input_output.md b/doc/user_guide/02_input_output.md index 448c56e6f..d151bcc58 100644 --- a/doc/user_guide/02_input_output.md +++ b/doc/user_guide/02_input_output.md @@ -13,7 +13,7 @@ kernelspec: # Input-output & lazy-loading -Pynapple provides loaders for [NWB format](https://pynwb.readthedocs.io/en/stable/index.html#). +Pynapple provides multiple ways to load data. The two main formats are NWB and NPZ. In addition, raw LFP data can be loaded through the NEO library. Each pynapple objects can be saved as a [`npz`](https://numpy.org/devdocs/reference/generated/numpy.savez.html) with a special structure and loaded as a `npz`. @@ -155,7 +155,7 @@ urllib.request.urlretrieve(distantfile, "File_plexon_3.plx") ```{code-cell} ipython3 :tags: [hide-output] -data = nap.NeoReader("File_plexon_3.plx"); +data = nap.NeoReader("File_plexon_3.plx", format="PlexonIO"); ``` ```{code-cell} ipython3 print(data) @@ -171,8 +171,11 @@ fig = plt.figure() ax1 = fig.add_subplot(2, 1, 1) ax2 = fig.add_subplot(2, 1, 2) ax1.plot(lfp.t, lfp.d[:]) +ax1.set_xlabel("Time (s)") colors = plt.cm.jet(np.linspace(0, 1, len(spikes))) ax2.eventplot([spikes[i].t for i in spikes.keys()], colors=colors) +ax2.set_xlabel("Time (s)") +plt.tight_layout() plt.show() ``` diff --git a/pynapple/io/interface_neo.py b/pynapple/io/interface_neo.py index 0f04de260..000541f03 100644 --- a/pynapple/io/interface_neo.py +++ b/pynapple/io/interface_neo.py @@ -1,33 +1,8 @@ -""" -Pynapple interface for Neo (neural electrophysiology objects). +"""Pynapple interface to Neo for reading electrophysiology files.""" -Neo is a Python package for working with electrophysiology data in Python, -supporting many file formats through a unified API. - -The interface behaves like a dictionary. - -For more information on Neo, see: https://neo.readthedocs.io/ - -Neo to Pynapple Object Conversion ---------------------------------- -The following Neo objects are converted to their pynapple equivalents: - -- 'neo.AnalogSignal' -> 'Tsd', `TsdFrame`, or `TsdTensor` (depending on shape) [lazy-loaded] -- neo.IrregularlySampledSignal -> Tsd, TsdFrame, or TsdTensor (depending on shape) [lazy-loaded] -- neo.SpikeTrain -> Ts -- neo.SpikeTrain (list) -> TsGroup -- neo.SpikeTrainList -> TsGroup -- neo.Epoch -> IntervalSet -- neo.Event -> Ts - -Note: All data types support lazy loading. Data is only loaded when accessed -via __getitem__ (e.g., data["TsGroup"]). -""" - -import warnings from collections import UserDict from pathlib import Path -from typing import Union, Optional, Dict, Any, List +from typing import Union, Optional, Dict, Any, List, Type import numpy as np @@ -396,7 +371,7 @@ def _make_tsgroup_from_spiketrains( # Skip metadata that can't be converted to array pass - return nap.TsGroup(ts_dict, time_support=time_support, **meta_arrays) + return nap.TsGroup(ts_dict, time_support=time_support, metadata=meta_arrays) def _make_tsgroup_from_spiketrains_multiseg( @@ -457,7 +432,7 @@ def _make_tsgroup_from_spiketrains_multiseg( # Skip metadata that can't be converted to array pass - return nap.TsGroup(ts_dict, time_support=time_support, **meta_arrays) + return nap.TsGroup(ts_dict, time_support=time_support, metadata=meta_arrays) def _make_tsd_from_interface(interface) -> Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor]: """Convert a NeoSignalInterface to a pynapple Tsd/TsdFrame/TsdTensor. @@ -795,37 +770,91 @@ def __getitem__(self, item): class NeoReader(UserDict): - """Class for reading Neo-compatible files. + """Read Neo-compatible electrophysiology files into pynapple objects. + + `Neo `_ is a Python package for working with + electrophysiology data, supporting many file formats through a unified API. This class provides a dictionary-like interface to Neo files, with lazy-loading support. It automatically detects the appropriate IO based on the file extension. + Neo to Pynapple Object Conversion + --------------------------------- + The following Neo objects are converted to their pynapple equivalents: + + .. list-table:: + :header-rows: 1 + :widths: 40 40 20 + + * - Neo Object + - Pynapple Object + - Notes + * - `AnalogSignal `_ + - :py:class:`~pynapple.Tsd`, :py:class:`~pynapple.TsdFrame`, or :py:class:`~pynapple.TsdTensor` + - Depends on shape; lazy-loaded + * - `IrregularlySampledSignal `_ + - :py:class:`~pynapple.Tsd`, :py:class:`~pynapple.TsdFrame`, or :py:class:`~pynapple.TsdTensor` + - Depends on shape; lazy-loaded + * - `SpikeTrain `_ + - :py:class:`~pynapple.Ts` + - Single unit + * - `SpikeTrain `_ (list) + - :py:class:`~pynapple.TsGroup` + - Multiple units + * - `Epoch `_ + - :py:class:`~pynapple.IntervalSet` + - + * - `Event `_ + - :py:class:`~pynapple.Ts` + - + + Note: All data types support lazy loading. Data is only loaded when accessed + via ``__getitem__`` (e.g., ``data["TsGroup"]``). + Parameters ---------- file : str or Path Path to the file to load lazy : bool, default True Whether to use lazy loading + format : str, type, or None, default None + Specify the Neo IO format to use. Can be: + + - ``None``: Automatically detect the format using ``neo.io.get_io`` + - ``str``: Name of the IO class (e.g., ``"PlexonIO"``, ``"Plexon"``, ``"plexon"``) + - ``type``: A class from ``neo.io.iolist`` (e.g., ``neo.io.PlexonIO``) + + When a string is provided, it is matched case-insensitively against IO class names. + The "IO" suffix is optional. Examples -------- >>> import pynapple as nap - >>> data = nap.io.NeoReader("my_file.plx") + >>> data = nap.NeoReader("my_file.plx") >>> print(data) my_file - +---------------------+----------+ - | Key | Type | - +=====================+==========+ - | TsGroup | TsGroup | - | Tsd 0: LFP | Tsd | - +---------------------+----------+ + ┍━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━┑ + │ Key │ Type │ + ┝━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━┥ + │ TsGroup │ TsGroup │ + │ Tsd 0: LFP │ Tsd │ + ┕━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━┙ >>> spikes = data["TsGroup"] >>> lfp = data["Tsd 0: LFP"] + + To explicitly specify the file format: + + >>> data = nap.NeoReader("my_file.plx", format="PlexonIO") """ - def __init__(self, file: Union[str, Path], lazy: bool = True): + def __init__( + self, + file: Union[str, Path], + lazy: bool = True, + format: Union[str, Type, None] = None, + ): _check_neo_installed() self.path = Path(file) @@ -835,8 +864,41 @@ def __init__(self, file: Union[str, Path], lazy: bool = True): self.name = self.path.stem self._lazy = lazy - # Get appropriate IO - self._reader = neo.io.get_io(str(self.path)) + # Get appropriate IO based on format argument + if format is None: + # Auto-detect format + self._reader = neo.io.get_io(str(self.path)) + elif isinstance(format, str): + # Find the IO class by name (case-insensitive, with or without "IO" suffix) + io_class = None + format_lower = format.lower() + for io in neo.io.iolist: + io_name = io.__name__.lower() + io_name_no_suffix = io_name.replace("io", "") + if io_name == format_lower or io_name_no_suffix == format_lower: + io_class = io + break + if io_class is None: + available = [io.__name__ for io in neo.io.iolist] + raise ValueError( + f"Format '{format}' not found in neo.io.iolist. " + f"Available formats: {available}" + ) + self._reader = io_class(str(self.path)) + elif isinstance(format, type): + # Verify the class is in neo.io.iolist + if format not in neo.io.iolist: + available = [io.__name__ for io in neo.io.iolist] + raise ValueError( + f"Class {format.__name__} is not in neo.io.iolist. " + f"Available formats: {available}" + ) + self._reader = format(str(self.path)) + else: + raise TypeError( + f"format must be None, a string, or a class from neo.io.iolist, " + f"not {type(format).__name__}" + ) # Read blocks self._blocks = self._reader.read(lazy=lazy) From 88b3327a594be29a38954365f39c1823fd12e18c Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 2 Feb 2026 15:45:57 -0500 Subject: [PATCH 7/7] update --- pynapple/io/interface_neo.py | 355 ----------------------------------- 1 file changed, 355 deletions(-) diff --git a/pynapple/io/interface_neo.py b/pynapple/io/interface_neo.py index 000541f03..a1df15f68 100644 --- a/pynapple/io/interface_neo.py +++ b/pynapple/io/interface_neo.py @@ -1137,358 +1137,3 @@ def close(self): """Close the underlying Neo reader if it supports closing.""" if hasattr(self._reader, "close"): self._reader.close() - - - - - - - - - -# -# -# # ============================================================================= -# # Legacy Interface (for backward compatibility) -# # ============================================================================= -# -# -# class NEOSignalInterface(NeoSignalInterface): -# """Legacy alias for NeoSignalInterface.""" -# pass -# -# -# class NEOExperimentInterface: -# """Legacy interface for Neo experiments. -# -# .. deprecated:: -# Use :class:`NeoReader` instead. -# """ -# -# def __init__(self, reader, lazy=False): -# warnings.warn( -# "NEOExperimentInterface is deprecated. Use NeoReader instead.", -# DeprecationWarning, -# stacklevel=2, -# ) -# self._reader = reader -# self._lazy = lazy -# self.experiment = self._collect_time_series_info() -# -# def _collect_time_series_info(self): -# blocks = self._reader.read(lazy=self._lazy) -# -# experiments = {} -# for i, block in enumerate(blocks): -# name = f"block {i}" -# if block.name: -# name += ": " + block.name -# experiments[name] = {} -# -# starts, ends = np.empty(len(block.segments)), np.empty(len(block.segments)) -# for trial_num, segment in enumerate(block.segments): -# starts[trial_num] = segment.t_start.rescale("s").magnitude -# ends[trial_num] = segment.t_stop.rescale("s").magnitude -# -# iset = nap.IntervalSet(starts, ends) -# -# for trial_num, segment in enumerate(block.segments): -# # Analog signals -# for signal_num, signal in enumerate(segment.analogsignals): -# if signal.name: -# signame = f" {signal_num}: " + signal.name -# else: -# signame = f" {signal_num}" -# signal_interface = NeoSignalInterface( -# signal, block, iset, sig_num=signal_num -# ) -# signame = signal_interface.nap_type.__name__ + signame -# experiments[name][signame] = signal_interface -# -# # Spike trains -# if len(segment.spiketrains) == 1: -# signal = segment.spiketrains[0] -# signal_interface = NeoSignalInterface( -# signal, block, iset, sig_num=0 -# ) -# signame = f"Ts" + ": " + signal.name if signal.name else "Ts" -# experiments[name][signame] = signal_interface -# else: -# signame = f"TsGroup" -# experiments[name][signame] = NeoSignalInterface( -# segment.spiketrains, block, iset -# ) -# -# return experiments -# -# def __getitem__(self, item): -# if isinstance(item, str): -# return self.experiment[item] -# else: -# res = self.experiment -# for it in item: -# res = res[it] -# return res -# -# def keys(self): -# return [(k, k2) for k in self.experiment.keys() for k2 in self.experiment[k]] -# -# -# def load_file(path: Union[str, Path], lazy: bool = True) -> NeoReader: -# """Load a neural recording file using Neo. -# -# This function automatically detects the file format and uses the -# appropriate Neo IO to load the data. -# -# Parameters -# ---------- -# path : str or Path -# Path to the recording file -# lazy : bool, default True -# Whether to use lazy loading (recommended for large files) -# -# Returns -# ------- -# NeoReader -# Interface to the loaded data -# -# Examples -# -------- -# >>> import pynapple as nap -# >>> data = nap.io.neo.load_file("recording.plx") -# >>> print(data) -# recording -# +---------------------+----------+ -# | Key | Type | -# +=====================+==========+ -# | TsGroup | TsGroup | -# | Tsd 0: LFP | Tsd | -# +---------------------+----------+ -# -# >>> spikes = data["TsGroup"] -# -# See Also -# -------- -# NeoReader : Class for Neo file interface -# -# Notes -# ----- -# Supported formats depend on your Neo installation. Common formats include: -# - Plexon (.plx, .pl2) -# - Blackrock (.nev, .ns*) -# - Spike2 (.smr) -# - Neuralynx (.ncs, .nse, .ntt) -# - OpenEphys -# - Intan (.rhd, .rhs) -# - And many more (see Neo documentation) -# """ -# return NeoReader(path, lazy=lazy) -# -# -# # Legacy alias -# def load_experiment(path: Union[str, Path], lazy: bool = True) -> NEOExperimentInterface: -# """Load a neural recording experiment. -# -# .. deprecated:: -# Use :func:`load_file` instead. -# -# Parameters -# ---------- -# path : str or Path -# Path to the recording file -# lazy : bool, default True -# Whether to lazy load the data -# -# Returns -# ------- -# NEOExperimentInterface -# """ -# import pathlib -# -# path = pathlib.Path(path) -# reader = neo.io.get_io(path) -# -# return NEOExperimentInterface(reader, lazy=lazy) -# -# -# -# # ============================================================================= -# # Conversion functions: Pynapple -> Neo -# # ============================================================================= -# -# -# def to_neo_analogsignal( -# tsd: Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor], -# units: str = "dimensionless", -# **kwargs, -# ) -> "neo.AnalogSignal": -# """Convert a pynapple Tsd/TsdFrame/TsdTensor to a Neo AnalogSignal. -# -# Parameters -# ---------- -# tsd : Tsd, TsdFrame, or TsdTensor -# Pynapple time series object -# units : str, default "dimensionless" -# Units for the signal (e.g., "mV", "uV") -# **kwargs -# Additional arguments passed to neo.AnalogSignal -# -# Returns -# ------- -# neo.AnalogSignal -# Neo analog signal object -# """ -# _check_neo_installed() -# import quantities as pq -# -# times = tsd.times() -# data = tsd.values -# -# # Ensure 2D for AnalogSignal -# if data.ndim == 1: -# data = data.reshape(-1, 1) -# -# # Calculate sampling rate from timestamps -# if len(times) > 1: -# dt = np.median(np.diff(times)) -# sampling_rate = 1.0 / dt -# else: -# sampling_rate = 1.0 # Default if only one sample -# -# signal = neo.AnalogSignal( -# data, -# units=units, -# sampling_rate=sampling_rate * pq.Hz, -# t_start=times[0] * pq.s, -# **kwargs, -# ) -# -# return signal -# -# -# def to_neo_spiketrain( -# ts: nap.Ts, -# t_stop: Optional[float] = None, -# units: str = "s", -# **kwargs, -# ) -> "neo.SpikeTrain": -# """Convert a pynapple Ts to a Neo SpikeTrain. -# -# Parameters -# ---------- -# ts : Ts -# Pynapple Ts object -# t_stop : float, optional -# Stop time for the spike train. If None, uses the end of time_support -# units : str, default "s" -# Time units -# **kwargs -# Additional arguments passed to neo.SpikeTrain -# -# Returns -# ------- -# neo.SpikeTrain -# Neo spike train object -# """ -# _check_neo_installed() -# import quantities as pq -# -# times = ts.times() -# -# if t_stop is None: -# t_stop = float(ts.time_support.end[-1]) -# -# t_start = float(ts.time_support.start[0]) if len(times) == 0 else min(times[0], float(ts.time_support.start[0])) -# -# spiketrain = neo.SpikeTrain( -# times, -# units=units, -# t_start=t_start * pq.s, -# t_stop=t_stop * pq.s, -# **kwargs, -# ) -# -# return spiketrain -# -# -# def to_neo_epoch( -# iset: nap.IntervalSet, -# labels: Optional[np.ndarray] = None, -# **kwargs, -# ) -> "neo.Epoch": -# """Convert a pynapple IntervalSet to a Neo Epoch. -# -# Parameters -# ---------- -# iset : IntervalSet -# Pynapple IntervalSet -# labels : array-like, optional -# Labels for each epoch. If None, uses integers. -# **kwargs -# Additional arguments passed to neo.Epoch -# -# Returns -# ------- -# neo.Epoch -# Neo epoch object -# """ -# _check_neo_installed() -# import quantities as pq -# -# starts = iset.start -# ends = iset.end -# durations = ends - starts -# -# if labels is None: -# # Check if there's a 'label' column in metadata -# if hasattr(iset, "label"): -# labels = iset.label -# else: -# labels = np.arange(len(starts)).astype(str) -# -# epoch = neo.Epoch( -# times=starts * pq.s, -# durations=durations * pq.s, -# labels=labels, -# **kwargs, -# ) -# -# return epoch -# -# -# def to_neo_event( -# ts: nap.Ts, -# labels: Optional[np.ndarray] = None, -# **kwargs, -# ) -> "neo.Event": -# """Convert a pynapple Ts to a Neo Event. -# -# Parameters -# ---------- -# ts : Ts -# Pynapple Ts object -# labels : array-like, optional -# Labels for each event. If None, uses integers. -# **kwargs -# Additional arguments passed to neo.Event -# -# Returns -# ------- -# neo.Event -# Neo event object -# """ -# _check_neo_installed() -# import quantities as pq -# -# times = ts.times() -# -# if labels is None: -# labels = np.arange(len(times)).astype(str) -# -# event = neo.Event( -# times=times * pq.s, -# labels=labels, -# **kwargs, -# ) -# -# return event