@@ -17,10 +17,10 @@ class BaseSorting(BaseExtractor):
1717 Abstract class representing several segment several units and relative spiketrains.
1818 """
1919
20- def __init__ (self , sampling_frequency : float , unit_ids : List ):
20+ def __init__ (self , sampling_frequency : float , unit_ids : list ):
2121 BaseExtractor .__init__ (self , unit_ids )
2222 self ._sampling_frequency = float (sampling_frequency )
23- self ._sorting_segments : List [BaseSortingSegment ] = []
23+ self ._sorting_segments : list [BaseSortingSegment ] = []
2424 # this weak link is to handle times from a recording object
2525 self ._recording = None
2626 self ._sorting_info = None
@@ -212,7 +212,7 @@ def set_sorting_info(self, recording_dict, params_dict, log_dict):
212212 sorting_info = dict (recording = recording_dict , params = params_dict , log = log_dict )
213213 self .annotate (__sorting_info__ = sorting_info )
214214
215- def has_recording (self ):
215+ def has_recording (self ) -> bool :
216216 return self ._recording is not None
217217
218218 def has_time_vector (self , segment_index = None ) -> bool :
@@ -302,14 +302,6 @@ def get_unit_property(self, unit_id, key):
302302 v = values [self .id_to_index (unit_id )]
303303 return v
304304
305- def get_total_num_spikes (self ):
306- warnings .warn (
307- "Sorting.get_total_num_spikes() is deprecated and will be removed in spikeinterface 0.102, use sorting.count_num_spikes_per_unit()" ,
308- DeprecationWarning ,
309- stacklevel = 2 ,
310- )
311- return self .count_num_spikes_per_unit (outputs = "dict" )
312-
313305 def count_num_spikes_per_unit (self , outputs = "dict" ):
314306 """
315307 For each unit : get number of spikes across segments.
@@ -451,12 +443,34 @@ def remove_empty_units(self):
451443 non_empty_units = self .get_non_empty_unit_ids ()
452444 return self .select_units (non_empty_units )
453445
454- def get_non_empty_unit_ids (self ):
446+ def get_non_empty_unit_ids (self ) -> np .ndarray :
447+ """
448+ Return the unit IDs that have at least one spike across all segments.
449+
450+ This method computes the number of spikes for each unit using
451+ `count_num_spikes_per_unit` and filters out units with zero spikes.
452+
453+ Returns
454+ -------
455+ np.ndarray
456+ Array of unit IDs (same dtype as self.unit_ids) for which at least one spike exists.
457+ """
455458 num_spikes_per_unit = self .count_num_spikes_per_unit ()
456459
457460 return np .array ([unit_id for unit_id in self .unit_ids if num_spikes_per_unit [unit_id ] != 0 ])
458461
459- def get_empty_unit_ids (self ):
462+ def get_empty_unit_ids (self ) -> np .ndarray :
463+ """
464+ Return the unit IDs that have zero spikes across all segments.
465+
466+ This method returns the complement of `get_non_empty_unit_ids` with respect
467+ to all unit IDs in the sorting.
468+
469+ Returns
470+ -------
471+ np.ndarray
472+ Array of unit IDs (same dtype as self.unit_ids) for which no spikes exist.
473+ """
460474 unit_ids = self .unit_ids
461475 empty_units = unit_ids [~ np .isin (unit_ids , self .get_non_empty_unit_ids ())]
462476 return empty_units
@@ -506,44 +520,6 @@ def time_to_sample_index(self, time, segment_index=0):
506520
507521 return sample_index
508522
509- def get_all_spike_trains (self , outputs = "unit_id" ):
510- """
511- Return all spike trains concatenated.
512- This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead
513- """
514-
515- warnings .warn (
516- "Sorting.get_all_spike_trains() will be deprecated. Sorting.to_spike_vector() instead" ,
517- DeprecationWarning ,
518- stacklevel = 2 ,
519- )
520-
521- assert outputs in ("unit_id" , "unit_index" )
522- spikes = []
523- for segment_index in range (self .get_num_segments ()):
524- spike_times = []
525- spike_labels = []
526- for i , unit_id in enumerate (self .unit_ids ):
527- st = self .get_unit_spike_train (unit_id = unit_id , segment_index = segment_index )
528- spike_times .append (st )
529- if outputs == "unit_id" :
530- spike_labels .append (np .array ([unit_id ] * st .size ))
531- elif outputs == "unit_index" :
532- spike_labels .append (np .zeros (st .size , dtype = "int64" ) + i )
533-
534- if len (spike_times ) > 0 :
535- spike_times = np .concatenate (spike_times )
536- spike_labels = np .concatenate (spike_labels )
537- order = np .argsort (spike_times )
538- spike_times = spike_times [order ]
539- spike_labels = spike_labels [order ]
540- else :
541- spike_times = np .array ([], dtype = np .int64 )
542- spike_labels = np .array ([], dtype = np .int64 )
543-
544- spikes .append ((spike_times , spike_labels ))
545- return spikes
546-
547523 def precompute_spike_trains (self , from_spike_vector = None ):
548524 """
549525 Pre-computes and caches all spike trains for this sorting
0 commit comments