diff --git a/doc/changes/devel/13228.newfeature.rst b/doc/changes/devel/13228.newfeature.rst new file mode 100644 index 00000000000..a242762d2f6 --- /dev/null +++ b/doc/changes/devel/13228.newfeature.rst @@ -0,0 +1 @@ +Add an ``extras`` attribute to :class:`mne.Annotations` for storing arbitrary metadata, by `Pierre Guetschel`_. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index d20931b7a51..9aee59d9a3f 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -234,6 +234,7 @@ .. _Peter Molfese: https://github.com/pmolfese .. _Phillip Alday: https://palday.bitbucket.io .. _Pierre Ablin: https://pierreablin.com +.. _Pierre Guetschel: https://github.com/PierreGtch .. _Pierre-Antoine Bannier: https://github.com/PABannier .. _Ping-Keng Jao: https://github.com/nafraw .. _Proloy Das: https://github.com/proloyd diff --git a/mne/annotations.py b/mne/annotations.py index 629ee7b20cb..171b9a510fb 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -5,7 +5,7 @@ import json import re import warnings -from collections import Counter, OrderedDict +from collections import Counter, OrderedDict, UserDict, UserList from collections.abc import Iterable from copy import deepcopy from datetime import datetime, timedelta, timezone @@ -58,7 +58,105 @@ _datetime = datetime -def _check_o_d_s_c(onset, duration, description, ch_names): +class _AnnotationsExtrasDict(UserDict): + """A dictionary for storing extra fields of annotations. + + The keys of the dictionary are strings, and the values can be + strings, integers, floats, or None. + """ + + def __setitem__(self, key: str, value: str | int | float | None) -> None: + _validate_type(key, str, "key") + if key in ("onset", "duration", "description", "ch_names"): + raise ValueError(f"Key '{key}' is reserved and cannot be used in extras.") + _validate_type( + value, + (str, int, float, None), + "value", + ) + super().__setitem__(key, value) + + +class _AnnotationsExtrasList(UserList): + """A list of dictionaries for storing extra fields of annotations. + + Each dictionary in the list corresponds to an annotation and contains + extra fields. + The keys of the dictionaries are strings, and the values can be + strings, integers, floats, or None. + """ + + def __repr__(self): + return repr(self.data) + + @staticmethod + def _validate_value( + value: dict | _AnnotationsExtrasDict | None, + ) -> _AnnotationsExtrasDict: + _validate_type( + value, + (dict, _AnnotationsExtrasDict, None), + "extras dict value", + "dict or None", + ) + return ( + value + if isinstance(value, _AnnotationsExtrasDict) + else _AnnotationsExtrasDict(value or {}) + ) + + def __init__(self, initlist=None): + if not (isinstance(initlist, _AnnotationsExtrasList) or initlist is None): + initlist = [self._validate_value(v) for v in initlist] + super().__init__(initlist) + + def __setitem__( # type: ignore[override] + self, + key: int | slice, + value, + ) -> None: + _validate_type(key, (int, slice), "key", "int or slice") + if isinstance(key, int): + iterable = False + value = [value] + else: + _validate_type(value, Iterable, "value", "Iterable when key is a slice") + iterable = True + + new_values = [self._validate_value(v) for v in value] + if not iterable: + new_values = new_values[0] + super().__setitem__(key, new_values) + + def __iadd__(self, other): + if not isinstance(other, _AnnotationsExtrasList): + other = _AnnotationsExtrasList(other) + super().__iadd__(other) + + def append(self, item): + super().append(self._validate_value(item)) + + def insert(self, i, item): + super().insert(i, self._validate_value(item)) + + def extend(self, other): + if not isinstance(other, _AnnotationsExtrasList): + other = _AnnotationsExtrasList(other) + super().extend(other) + + +def _validate_extras(extras, length: int): + _validate_type(extras, (None, list, _AnnotationsExtrasList), "extras") + if extras is not None and len(extras) != length: + raise ValueError( + f"extras must be None or a list of length {length}, got {len(extras)}." + ) + if isinstance(extras, _AnnotationsExtrasList): + return extras + return _AnnotationsExtrasList(extras or [None] * length) + + +def _check_o_d_s_c_e(onset, duration, description, ch_names, extras): onset = np.atleast_1d(np.array(onset, dtype=float)) if onset.ndim != 1: raise ValueError( @@ -100,7 +198,9 @@ def _check_o_d_s_c(onset, duration, description, ch_names): f"equal in sizes, got {len(onset)}, {len(duration)}, " f"{len(description)}, and {len(ch_names)}." ) - return onset, duration, description, ch_names + + extras = _validate_extras(extras, len(onset)) + return onset, duration, description, ch_names, extras def _ndarray_ch_names(ch_names): @@ -146,6 +246,11 @@ class Annotations: %(ch_names_annot)s .. versionadded:: 0.23 + extras : list[dict[str, int | float | str | None] | None] | None + Optional list of dicts containing extra fields for each annotation. + The number of items must match the number of annotations. + + .. versionadded:: 1.10 See Also -------- @@ -274,10 +379,19 @@ class Annotations: :meth:`Raw.save() ` notes for details. """ # noqa: E501 - def __init__(self, onset, duration, description, orig_time=None, ch_names=None): + def __init__( + self, + onset, + duration, + description, + orig_time=None, + ch_names=None, + *, + extras=None, + ): self._orig_time = _handle_meas_date(orig_time) - self.onset, self.duration, self.description, self.ch_names = _check_o_d_s_c( - onset, duration, description, ch_names + self.onset, self.duration, self.description, self.ch_names, self._extras = ( + _check_o_d_s_c_e(onset, duration, description, ch_names, extras) ) self._sort() # ensure we're sorted @@ -286,6 +400,25 @@ def orig_time(self): """The time base of the Annotations.""" return self._orig_time + @property + def extras(self): + """The extras of the Annotations. + + The ``extras`` attribute is a list of dictionaries. + It can easily be converted to a pandas DataFrame using: + ``pd.DataFrame(extras)``. + """ + return self._extras + + @extras.setter + def extras(self, extras): + self._extras = _validate_extras(extras, len(self.onset)) + + @property + def _extras_columns(self) -> set[str]: + """The set containing all the keys in all extras dicts.""" + return {k for d in self.extras for k in d} + def __eq__(self, other): """Compare to another Annotations instance.""" if not isinstance(other, Annotations): @@ -339,7 +472,11 @@ def __iadd__(self, other): f"{self.orig_time} != {other.orig_time})" ) return self.append( - other.onset, other.duration, other.description, other.ch_names + other.onset, + other.duration, + other.description, + other.ch_names, + extras=other.extras, ) def __iter__(self): @@ -350,7 +487,7 @@ def __iter__(self): for idx in range(len(self.onset)): yield self.__getitem__(idx, with_ch_names=with_ch_names) - def __getitem__(self, key, *, with_ch_names=None): + def __getitem__(self, key, *, with_ch_names=None, with_extras=True): """Propagate indexing and slicing to the underlying numpy structure.""" if isinstance(key, int_like): out_keys = ("onset", "duration", "description", "orig_time") @@ -363,6 +500,9 @@ def __getitem__(self, key, *, with_ch_names=None): if with_ch_names or (with_ch_names is None and self._any_ch_names()): out_keys += ("ch_names",) out_vals += (self.ch_names[key],) + if with_extras: + out_keys += ("extras",) + out_vals += (self.extras[key],) return OrderedDict(zip(out_keys, out_vals)) else: key = list(key) if isinstance(key, tuple) else key @@ -372,10 +512,11 @@ def __getitem__(self, key, *, with_ch_names=None): description=self.description[key], orig_time=self.orig_time, ch_names=self.ch_names[key], + extras=[self.extras[i] for i in np.arange(len(self.extras))[key]], ) @fill_doc - def append(self, onset, duration, description, ch_names=None): + def append(self, onset, duration, description, ch_names=None, *, extras=None): """Add an annotated segment. Operates inplace. Parameters @@ -391,6 +532,11 @@ def append(self, onset, duration, description, ch_names=None): %(ch_names_annot)s .. versionadded:: 0.23 + extras : list[dict[str, int | float | str | None] | None] | None + Optional list of dicts containing extras fields for each annotation. + The number of items must match the number of annotations. + + .. versionadded:: 1.10 Returns ------- @@ -403,13 +549,14 @@ def append(self, onset, duration, description, ch_names=None): to not only ``list.append``, but also `list.extend `__. """ # noqa: E501 - onset, duration, description, ch_names = _check_o_d_s_c( - onset, duration, description, ch_names + onset, duration, description, ch_names, extras = _check_o_d_s_c_e( + onset, duration, description, ch_names, extras ) self.onset = np.append(self.onset, onset) self.duration = np.append(self.duration, duration) self.description = np.append(self.description, description) self.ch_names = np.append(self.ch_names, ch_names) + self.extras.extend(extras) self._sort() return self @@ -436,6 +583,12 @@ def delete(self, idx): self.duration = np.delete(self.duration, idx) self.description = np.delete(self.description, idx) self.ch_names = np.delete(self.ch_names, idx) + if isinstance(idx, int_like): + del self.extras[idx] + elif len(idx) > 0: + # convert slice-like idx to ints, and delete list items in reverse order + for i in np.sort(np.arange(len(self.extras))[idx])[::-1]: + del self.extras[i] @fill_doc def to_data_frame(self, time_format="datetime"): @@ -466,6 +619,8 @@ def to_data_frame(self, time_format="datetime"): if self._any_ch_names(): df.update(ch_names=self.ch_names) df = pd.DataFrame(df) + extras_df = pd.DataFrame(self.extras) + df = pd.concat([df, extras_df], axis=1) return df def count(self): @@ -567,6 +722,7 @@ def _sort(self): self.duration = self.duration[order] self.description = self.description[order] self.ch_names = self.ch_names[order] + self.extras = [self.extras[i] for i in order] @verbose def crop( @@ -619,10 +775,10 @@ def crop( ) logger.debug(f"Cropping annotations {absolute_tmin} - {absolute_tmax}") - onsets, durations, descriptions, ch_names = [], [], [], [] + onsets, durations, descriptions, ch_names, extras = [], [], [], [], [] out_of_bounds, clip_left_elem, clip_right_elem = [], [], [] - for idx, (onset, duration, description, ch) in enumerate( - zip(self.onset, self.duration, self.description, self.ch_names) + for idx, (onset, duration, description, ch, extra) in enumerate( + zip(self.onset, self.duration, self.description, self.ch_names, self.extras) ): # if duration is NaN behave like a zero if np.isnan(duration): @@ -660,12 +816,14 @@ def crop( ) descriptions.append(description) ch_names.append(ch) + extras.append(extra) logger.debug(f"Cropping complete (kept {len(onsets)})") self.onset = np.array(onsets, float) self.duration = np.array(durations, float) assert (self.duration >= 0).all() self.description = np.array(descriptions, dtype=str) self.ch_names = _ndarray_ch_names(ch_names) + self.extras = extras if emit_warning: omitted = np.array(out_of_bounds).sum() @@ -822,9 +980,17 @@ def set_annotations(self, annotations, on_missing="raise", *, verbose=None): self._annotations = new_annotations return self - def get_annotations_per_epoch(self): + def get_annotations_per_epoch(self, *, with_extras=False): """Get a list of annotations that occur during each epoch. + Parameters + ---------- + with_extras : bool + Whether to include the annotations extra fields in the output, + as an additional last element of the tuple. Default is False. + + .. versionadded:: 1.10 + Returns ------- epoch_annots : list @@ -893,11 +1059,13 @@ def get_annotations_per_epoch(self): this_annot["duration"], this_annot["description"], ) + if with_extras: + annot += (this_annot["extras"],) # ...then add it to the correct sublist of `epoch_annot_list` epoch_annot_list[epo_ix].append(annot) return epoch_annot_list - def add_annotations_to_metadata(self, overwrite=False): + def add_annotations_to_metadata(self, overwrite=False, *, with_extras=True): """Add raw annotations into the Epochs metadata data frame. Adds three columns to the ``metadata`` consisting of a list @@ -914,6 +1082,11 @@ def add_annotations_to_metadata(self, overwrite=False): overwrite : bool Whether to overwrite existing columns in metadata or not. Default is False. + with_extras : bool + Whether to include the annotations extra fields in the output, + as an additional last element of the tuple. Default is True. + + .. versionadded:: 1.10 Returns ------- @@ -955,8 +1128,9 @@ def add_annotations_to_metadata(self, overwrite=False): # get the Epoch annotations, then convert to separate lists for # onsets, durations, and descriptions - epoch_annot_list = self.get_annotations_per_epoch() + epoch_annot_list = self.get_annotations_per_epoch(with_extras=with_extras) onset, duration, description = [], [], [] + extras = {k: [] for k in self.annotations._extras_columns} for epoch_annot in epoch_annot_list: for ix, annot_prop in enumerate((onset, duration, description)): entry = [annot[ix] for annot in epoch_annot] @@ -966,12 +1140,17 @@ def add_annotations_to_metadata(self, overwrite=False): entry = np.round(entry, decimals=12).tolist() annot_prop.append(entry) + for k in extras.keys(): + entry = [annot[3].get(k, None) for annot in epoch_annot] + extras[k].append(entry) # Create a new Annotations column that is instantiated as an empty # list per Epoch. metadata["annot_onset"] = pd.Series(onset) metadata["annot_duration"] = pd.Series(duration) metadata["annot_description"] = pd.Series(description) + for k, v in extras.items(): + metadata[f"annot_{k}"] = pd.Series(v) # reset the metadata self.metadata = metadata @@ -1100,6 +1279,12 @@ def _write_annotations(fid, annotations): write_string( fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, json.dumps(tuple(annotations.ch_names)) ) + if any(d is not None for d in annotations.extras): + write_string( + fid, + FIFF.FIFF_FREE_LIST, + json.dumps([extra.data for extra in annotations.extras]), + ) end_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) @@ -1110,6 +1295,18 @@ def _write_annotations_csv(fname, annot): _safe_name_list(ch, "write", name=f'annot["ch_names"][{ci}') for ci, ch in enumerate(annot["ch_names"]) ] + extras_columns = set(annot.columns) - { + "onset", + "duration", + "description", + "ch_names", + } + for col in extras_columns: + if len(dtypes := annot[col].apply(type).unique()) > 1: + warn( + f"Extra field '{col}' contains heterogeneous dtypes ({dtypes}). " + "Loading these CSV annotations may not return the original dtypes." + ) annot.to_csv(fname, index=False) @@ -1119,8 +1316,10 @@ def _write_annotations_txt(fname, annot): # for backward compat, we do not write tzinfo (assumed UTC) content += f"# orig_time : {annot.orig_time.replace(tzinfo=None)}\n" content += "# onset, duration, description" + n_cols = 3 data = [annot.onset, annot.duration, annot.description] if annot._any_ch_names(): + n_cols += 1 content += ", ch_names" data.append( [ @@ -1128,11 +1327,22 @@ def _write_annotations_txt(fname, annot): for ci, ch in enumerate(annot.ch_names) ] ) + if len(extras_columns := annot._extras_columns) > 0: + n_cols += len(extras_columns) + for column in extras_columns: + content += f", {column}" + values = [extra.get(column, None) for extra in annot.extras] + if len(dtypes := set(type(v) for v in values)) > 1: + warn( + f"Extra field '{column}' contains heterogeneous dtypes ({dtypes}). " + "Loading these TXT annotations may not return the original dtypes." + ) + data.append([val if val is not None else "" for val in values]) content += "\n" data = np.array(data, dtype=str).T assert data.ndim == 2 assert data.shape[0] == len(annot.onset) - assert data.shape[1] in (3, 4) + assert data.shape[1] == n_cols with open(fname, "wb") as fid: fid.write(content.encode()) np.savetxt(fid, data, delimiter=",", fmt="%s") @@ -1242,6 +1452,13 @@ def read_annotations( def _read_annotations_csv(fname): """Read annotations from csv. + The dtypes of the extra fields will automatically be inferred + by pandas. If some fields have heterogeneous types on the + different rows, this automatic inference may return unexpected + types. + If you need to save heterogeneous extra dtypes, we recommend + saving to FIF. + Parameters ---------- fname : path-like @@ -1275,7 +1492,13 @@ def _read_annotations_csv(fname): _safe_name_list(val, "read", "annotation channel name") for val in df["ch_names"].values ] - return Annotations(onset, duration, description, orig_time, ch_names) + extra_columns = list( + df.columns.difference(["onset", "duration", "description", "ch_names"]) + ) + extras = None + if len(extra_columns) > 0: + extras = df[extra_columns].to_dict(orient="records") + return Annotations(onset, duration, description, orig_time, ch_names, extras=extras) def _read_brainstorm_annotations(fname, orig_time=None): @@ -1328,28 +1551,89 @@ def _read_annotations_txt_parse_header(fname): def is_orig_time(x): return x.startswith("# orig_time :") + def is_columns(x): + return x.startswith("# onset, duration, description") + with open(fname) as fid: header = list(takewhile(lambda x: x.startswith("#"), fid)) orig_values = [h[13:].strip() for h in header if is_orig_time(h)] orig_values = [_handle_meas_date(orig) for orig in orig_values if _is_iso8601(orig)] - return None if not orig_values else orig_values[0] + columns = [[c.strip() for c in h[2:].split(",")] for h in header if is_columns(h)] + + return ( + None if not orig_values else orig_values[0], + (None if not columns else columns[0]), + len(header), + ) def _read_annotations_txt(fname): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") out = np.loadtxt(fname, delimiter=",", dtype=np.bytes_, unpack=True) - ch_names = None + orig_time, columns, n_rows_header = _read_annotations_txt_parse_header(fname) + ch_names = extras = None if len(out) == 0: onset, duration, desc = [], [], [] else: - _check_option("text header", len(out), (3, 4)) - if len(out) == 3: - onset, duration, desc = out - else: - onset, duration, desc, ch_names = out + if columns is None: + # No column names were present in the header + # We assume the first three columns are onset, duration, description + # And eventually a fourth column with ch_names + _check_option("text header", len(out), (3, 4)) + columns = ["onset", "duration", "description"] + ( + ["ch_names"] if len(out) == 4 else [] + ) + col_map = {col: i for i, col in enumerate(columns)} + if len(col_map) != len(columns): + raise ValueError( + "Duplicate column names found in header. Please check the file format." + ) + if missing := {"onset", "duration", "description"} - set(col_map.keys()): + raise ValueError( + f"Column(s) {missing} not found in header. " + "Please check the file format." + ) + _check_option("text header len", len(out), (len(columns),)) + onset = out[col_map["onset"]] + duration = out[col_map["duration"]] + desc = out[col_map["description"]] + if "ch_names" in col_map: + ch_names = out[col_map["ch_names"]] + extra_columns = set(col_map.keys()) - { + "onset", + "duration", + "description", + "ch_names", + } + if extra_columns: + pd = _check_pandas_installed(strict=False) + if pd: + df = pd.read_csv( + fname, + delimiter=",", + names=columns, + usecols=extra_columns, + skiprows=n_rows_header, + header=None, + keep_default_na=False, + ) + extras = df.to_dict(orient="records") + else: + warn( + "Extra fields found in the header but pandas is not installed. " + "Therefore the dtypes of the extra fields can not automatically " + "be inferred so they will be loaded as strings." + ) + extras = [ + { + col_name: out[col_map[col_name]][i].decode("UTF-8") + for col_name in extra_columns + } + for i in range(len(onset)) + ] onset = [float(o.decode()) for o in np.atleast_1d(onset)] duration = [float(d.decode()) for d in np.atleast_1d(duration)] @@ -1360,14 +1644,13 @@ def _read_annotations_txt(fname): for ci, ch in enumerate(ch_names) ] - orig_time = _read_annotations_txt_parse_header(fname) - annotations = Annotations( onset=onset, duration=duration, description=desc, orig_time=orig_time, ch_names=ch_names, + extras=extras, ) return annotations @@ -1380,7 +1663,7 @@ def _read_annotations_fif(fid, tree): annotations = None else: annot_data = annot_data[0] - orig_time = ch_names = None + orig_time = ch_names = extras = None onset, duration, description = list(), list(), list() for ent in annot_data["directory"]: kind = ent.kind @@ -1402,8 +1685,14 @@ def _read_annotations_fif(fid, tree): orig_time = tuple(orig_time) # new way elif kind == FIFF.FIFF_MNE_EPOCHS_DROP_LOG: ch_names = tuple(tuple(x) for x in json.loads(tag.data)) + elif kind == FIFF.FIFF_FREE_LIST: + extras = json.loads(tag.data) assert len(onset) == len(duration) == len(description) - annotations = Annotations(onset, duration, description, orig_time, ch_names) + if extras is not None: + assert len(extras) == len(onset) + annotations = Annotations( + onset, duration, description, orig_time, ch_names, extras=extras + ) return annotations diff --git a/mne/epochs.py b/mne/epochs.py index 96f247875d9..c042715e6ae 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -2465,11 +2465,13 @@ def equalize_event_counts( # 2b. for non-tag ids, just pass them directly # 3. do this for every input event_ids = [ - [ - k for k in ids if all(tag in k.split("/") for tag in id_) - ] # ids matching all tags - if all(id__ not in ids for id__ in id_) - else id_ # straight pass for non-tag inputs + ( + [ + k for k in ids if all(tag in k.split("/") for tag in id_) + ] # ids matching all tags + if all(id__ not in ids for id__ in id_) + else id_ + ) # straight pass for non-tag inputs for id_ in event_ids ] for ii, id_ in enumerate(event_ids): @@ -3575,6 +3577,18 @@ def __init__( raw, events, event_id, annotations, on_missing ) + # add the annotations.extras to the metadata + if not all(len(d) == 0 for d in annotations.extras): + pd = _check_pandas_installed(strict=True) + extras_df = pd.DataFrame(annotations.extras) + if metadata is None: + metadata = extras_df + else: + extras_df.set_index(metadata.index, inplace=True) + metadata = pd.concat( + [metadata, extras_df], axis=1, ignore_index=False + ) + # call BaseEpochs constructor super().__init__( info, diff --git a/mne/io/egi/tests/test_egi.py b/mne/io/egi/tests/test_egi.py index 923c5ce925a..8e3275a733e 100644 --- a/mne/io/egi/tests/test_egi.py +++ b/mne/io/egi/tests/test_egi.py @@ -212,7 +212,7 @@ def test_io_egi_mff(events_as_annotations): if events_as_annotations: # Grab the first annotation. Should be the first "DIN1" event. assert len(raw.annotations) - onset, dur, desc, _ = raw.annotations[0].values() + onset, dur, desc, _, _ = raw.annotations[0].values() assert_allclose(onset, 2.438) assert np.isclose(dur, 0) assert desc == "DIN1" diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 4d0db170e2a..7a9a0faea43 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -29,6 +29,8 @@ read_annotations, ) from mne.annotations import ( + _AnnotationsExtrasDict, + _AnnotationsExtrasList, _handle_meas_date, _read_annotations_txt_parse_header, _sync_onset, @@ -630,10 +632,24 @@ def test_annotation_epoching(): assert_equal([0, 2, 4], epochs.selection) -def test_annotation_concat(): +@pytest.mark.parametrize("with_extras", [True, False]) +def test_annotation_concat(with_extras): """Test if two Annotations objects can be concatenated.""" + extras = None + if with_extras: + extras = [ + {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None}, + None, + None, + ] a = Annotations([1, 2, 3], [5, 5, 8], ["a", "b", "c"], ch_names=[["1"], ["2"], []]) - b = Annotations([11, 12, 13], [1, 2, 2], ["x", "y", "z"], ch_names=[[], ["3"], []]) + b = Annotations( + [11, 12, 13], + [1, 2, 2], + ["x", "y", "z"], + ch_names=[[], ["3"], []], + extras=extras, + ) # test + operator (does not modify a or b) c = a + b @@ -656,6 +672,10 @@ def test_annotation_concat(): assert_equal(len(a), 6) assert_equal(len(b), 3) + if with_extras: + all_extras = [extra or {} for extra in [None] * 3 + extras] + assert all(c.extras[i] == all_extras[i] for i in range(len(all_extras))) + # test += operator (modifies a in place) b._orig_time = _handle_meas_date(1038942070.7201) with pytest.raises(ValueError, match="orig_time should be the same"): @@ -949,7 +969,7 @@ def _constant_id(*args, **kwargs): # Test for IO with .csv files -def _assert_annotations_equal(a, b, tol=0): +def _assert_annotations_equal(a, b, tol=0, comp_extras_as_str=False): __tracebackhide__ = True assert_allclose(a.onset, b.onset, rtol=0, atol=tol, err_msg="onset") assert_allclose(a.duration, b.duration, rtol=0, atol=tol, err_msg="duration") @@ -958,14 +978,24 @@ def _assert_annotations_equal(a, b, tol=0): a_orig_time = a.orig_time b_orig_time = b.orig_time assert a_orig_time == b_orig_time, "orig_time" + extras_columns = a._extras_columns.union(b._extras_columns) + for col in extras_columns: + for i, extra in enumerate(a.extras): + exa = extra.get(col, None) + exb = b.extras[i].get(col, None) + if comp_extras_as_str: + exa = str(exa) if exa is not None else "" + exb = str(exb) if exb is not None else "" + assert exa == exb, f"extras[{i}][{col}]" _ORIG_TIME = datetime.fromtimestamp(1038942071.7201, timezone.utc) -@pytest.fixture(scope="function", params=("ch_names", "fmt")) -def dummy_annotation_file(tmp_path_factory, ch_names, fmt): +@pytest.fixture(scope="function", params=("ch_names", "fmt", "with_extras")) +def dummy_annotation_file(tmp_path_factory, ch_names, fmt, with_extras): """Create csv file for testing.""" + extras_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} if fmt == "csv": content = ( "onset,duration,description\n" @@ -982,7 +1012,10 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt): ) else: assert fmt == "fif" - content = Annotations([0, 9], [1, 2.425], ["AA", "BB"], orig_time=_ORIG_TIME) + extras = [extras_row0, None] if with_extras else None + content = Annotations( + [0, 9], [1, 2.425], ["AA", "BB"], orig_time=_ORIG_TIME, extras=extras + ) if ch_names: if isinstance(content, Annotations): @@ -994,6 +1027,14 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt): content[-2] += "," content[-1] += ",MEG0111:MEG2563" content = "\n".join(content) + if with_extras and fmt != "fif": + content = content.splitlines() + content[-3] += "," + ",".join(extras_row0.keys()) + content[-2] += "," + ",".join( + ["" if v is None else str(v) for v in extras_row0.values()] + ) + content[-1] += ",,,," + content = "\n".join(content) fname = tmp_path_factory.mktemp("data") / f"annotations-annot.{fmt}" if isinstance(content, str): @@ -1004,17 +1045,27 @@ def dummy_annotation_file(tmp_path_factory, ch_names, fmt): return fname +@pytest.mark.filterwarnings("ignore:.*heterogeneous dtypes.*") @pytest.mark.parametrize("ch_names", (False, True)) @pytest.mark.parametrize("fmt", [pytest.param("csv", marks=needs_pandas), "txt", "fif"]) -def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names): +@pytest.mark.parametrize("with_extras", [True, False]) +def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names, with_extras): """Test CSV, TXT, and FIF input/output (which support ch_names).""" annot = read_annotations(dummy_annotation_file) assert annot.orig_time == _ORIG_TIME kwargs = dict(orig_time=_ORIG_TIME) if ch_names: kwargs["ch_names"] = ((), ("MEG0111", "MEG2563")) + if with_extras: + kwargs["extras"] = [ + {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None}, + None, + ] _assert_annotations_equal( - annot, Annotations([0.0, 9.0], [1.0, 2.425], ["AA", "BB"], **kwargs), tol=1e-6 + annot, + Annotations([0.0, 9.0], [1.0, 2.425], ["AA", "BB"], **kwargs), + tol=1e-6, + comp_extras_as_str=fmt in ["csv", "txt"], ) # Now test writing @@ -1030,6 +1081,52 @@ def test_io_annotation(dummy_annotation_file, tmp_path, fmt, ch_names): _assert_annotations_equal(annot, annot2) +@pytest.mark.parametrize("fmt", [pytest.param("csv", marks=needs_pandas), "txt"]) +def test_write_annotation_warn_heterogeneous(tmp_path, fmt): + """Test that CSV, and TXT annotation writers warn on heterogeneous dtypes.""" + annot = Annotations( + onset=[0.0, 9.0], + duration=[1.0, 2.425], + description=["AA", "BB"], + orig_time=_ORIG_TIME, + extras=[ + {"foo1": "a", "foo2": "a"}, + {"foo1": 1, "foo2": None}, + ], + ) + fname = tmp_path / f"annotations-annot.{fmt}" + with ( + pytest.warns(RuntimeWarning, match="'foo2' contains heterogeneous dtypes"), + pytest.warns(RuntimeWarning, match="'foo1' contains heterogeneous dtypes"), + ): + annot.save(fname) + + +def test_write_annotation_warn_heterogeneous_b(tmp_path): + """Additional cases for test_write_annotation_warn_heterogeneous. + + These cases are only compatible with the TXT writer. + """ + fmt = "txt" + annot = Annotations( + onset=[0.0, 9.0], + duration=[1.0, 2.425], + description=["AA", "BB"], + orig_time=_ORIG_TIME, + extras=[ + {"foo3": 1, "foo4": 1, "foo5": 1.0}, + {"foo3": 1.0, "foo4": None, "foo5": None}, + ], + ) + fname = tmp_path / f"annotations-annot.{fmt}" + with ( + pytest.warns(RuntimeWarning, match="'foo5' contains heterogeneous dtypes"), + pytest.warns(RuntimeWarning, match="'foo4' contains heterogeneous dtypes"), + pytest.warns(RuntimeWarning, match="'foo3' contains heterogeneous dtypes"), + ): + annot.save(fname) + + def test_broken_csv(tmp_path): """Test broken .csv that does not use timestamps.""" pytest.importorskip("pandas") @@ -1123,9 +1220,10 @@ def test_read_annotation_txt_header(tmp_path): fname = tmp_path / "header.txt" with open(fname, "w") as f: f.write(content) - orig_time = _read_annotations_txt_parse_header(fname) + orig_time, _, n_rows_header = _read_annotations_txt_parse_header(fname) want = datetime.fromtimestamp(1038942071.7201, timezone.utc) assert orig_time == want + assert n_rows_header == 5 def test_read_annotation_txt_one_segment(tmp_path): @@ -1178,29 +1276,34 @@ def test_annotations_slices(): NUM_ANNOT = 5 EXPECTED_ONSETS = EXPECTED_DURATIONS = [x for x in range(NUM_ANNOT)] EXPECTED_DESCS = [x.__repr__() for x in range(NUM_ANNOT)] + EXTRAS_ROW = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} + EXPECTED_EXTRAS = [EXTRAS_ROW] * NUM_ANNOT annot = Annotations( onset=EXPECTED_ONSETS, duration=EXPECTED_DURATIONS, description=EXPECTED_DESCS, orig_time=None, + extras=EXPECTED_EXTRAS, ) # Indexing returns a copy. So this has no effect in annot annot[0]["onset"] = 42 annot[0]["duration"] = 3.14 annot[0]["description"] = "foobar" + annot[0]["extras"] = EXTRAS_ROW annot[:1].onset[0] = 42 annot[:1].duration[0] = 3.14 annot[:1].description[0] = "foobar" + annot[:1].extras[0] = EXTRAS_ROW # Slicing with single element returns a dictionary for ii in EXPECTED_ONSETS: assert annot[ii] == dict( zip( - ["onset", "duration", "description", "orig_time"], - [ii, ii, str(ii), None], + ["onset", "duration", "description", "orig_time", "extras"], + [ii, ii, str(ii), None, EXTRAS_ROW], ) ) @@ -1825,3 +1928,59 @@ def test_append_splits_boundary(tmp_path, split_size): assert len(raw.annotations) == 2 assert raw.annotations.description[0] == "BAD boundary" assert_allclose(raw.annotations.onset, [onset] * 2) + + +@pytest.mark.parametrize( + "key, value, expected_error, match", + ( + ("onset", 1, ValueError, "reserved"), + ("duration", 1, ValueError, "reserved"), + ("description", 1, ValueError, "reserved"), + ("ch_names", 1, ValueError, "reserved"), + ("valid_key", [], TypeError, "value must be an instance of"), + (1, 1, TypeError, "key must be an instance of"), + ), +) +def test_extras_dict_raises(key, value, expected_error, match): + """Test that _AnnotationsExtrasDict raises errors for invalid keys/values.""" + extras_dict = _AnnotationsExtrasDict() + with pytest.raises(expected_error, match=match): + extras_dict[key] = value + with pytest.raises(expected_error, match=match): + extras_dict.update({key: value}) + with pytest.raises(expected_error, match=match): + _AnnotationsExtrasDict({key: value}) + if isinstance(key, str): + with pytest.raises(expected_error, match=match): + _AnnotationsExtrasDict(**{key: value}) + + +@pytest.mark.parametrize( + "key, value, expected_error, match", + ( + ("onset", 1, ValueError, "reserved"), + ("duration", 1, ValueError, "reserved"), + ("description", 1, ValueError, "reserved"), + ("ch_names", 1, ValueError, "reserved"), + ("valid_key", [], TypeError, "value must be an instance of"), + (1, 1, TypeError, "key must be an instance of"), + ), +) +def test_extras_list_raises(key, value, expected_error, match): + """Test that _AnnotationsExtrasList raises errors for invalid keys/values.""" + extras = _AnnotationsExtrasList([None]) + assert all(isinstance(extra, _AnnotationsExtrasDict) for extra in extras) + with pytest.raises(expected_error, match=match): + extras[0] = {key: value} + with pytest.raises(expected_error, match=match): + extras[:1] = [{key: value}] + with pytest.raises(expected_error, match=match): + extras[0].update({key: value}) + with pytest.raises(expected_error, match=match): + _AnnotationsExtrasList([{key: value}]) + with pytest.raises(expected_error, match=match): + extras.append({key: value}) + with pytest.raises(expected_error, match=match): + extras.extend([{key: value}]) + with pytest.raises(expected_error, match=match): + extras += [{key: value}] diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 88f2d9cdc13..fc1a95c9ba3 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -4917,9 +4917,15 @@ def test_add_channels_picks(): @pytest.mark.parametrize("first_samp", [0, 10]) @pytest.mark.parametrize( - "meas_date, orig_date", [[None, None], [np.pi, None], [np.pi, timedelta(seconds=1)]] + "meas_date, orig_date, with_extras", + [ + [None, None, False], + [np.pi, None, False], + [np.pi, timedelta(seconds=1), False], + [None, None, True], + ], ) -def test_epoch_annotations(first_samp, meas_date, orig_date, tmp_path): +def test_epoch_annotations(first_samp, meas_date, orig_date, with_extras, tmp_path): """Test Epoch Annotations from RawArray with dates. Tests the following cases crossed with each other: @@ -4942,21 +4948,26 @@ def test_epoch_annotations(first_samp, meas_date, orig_date, tmp_path): if orig_date is not None: orig_date = meas_date + orig_date ant_dur = 0.1 + extras_row0 = {"foo1": 1, "foo2": 1.1, "foo3": "a", "foo4": None} + extras = [extras_row0, None, None] if with_extras else None ants = Annotations( onset=[1.1, 1.2, 2.1], duration=[ant_dur, ant_dur, ant_dur], description=["x", "y", "z"], orig_time=orig_date, + extras=extras, ) raw.set_annotations(ants) epochs = make_fixed_length_epochs(raw, duration=1, overlap=0.5) # add Annotations to Epochs metadata - epochs.add_annotations_to_metadata() + epochs.add_annotations_to_metadata(with_extras=with_extras) metadata = epochs.metadata assert "annot_onset" in metadata.columns assert "annot_duration" in metadata.columns assert "annot_description" in metadata.columns + if with_extras: + assert all(f"annot_{k}" in metadata.columns for k in extras_row0.keys()) # Test that writing and reading back these new metadata works temp_fname = tmp_path / "test-epo.fif" diff --git a/mne/utils/check.py b/mne/utils/check.py index 085c51b6996..550903b6641 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -609,11 +609,13 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" check_types = sum( ( - (type(None),) - if type_ is None - else (type_,) - if not isinstance(type_, str) - else _multi[type_] + ( + (type(None),) + if type_ is None + else (type_,) + if not isinstance(type_, str) + else _multi[type_] + ) for type_ in types ), (), @@ -622,11 +624,13 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra="" if not isinstance(item, check_types): if type_name is None: type_name = [ - "None" - if cls_ is None - else cls_.__name__ - if not isinstance(cls_, str) - else cls_ + ( + "None" + if cls_ is None + else cls_.__name__ + if not isinstance(cls_, str) + else cls_ + ) for cls_ in types ] if len(type_name) == 1: