From 06cfaeccc10a0aebba9045f5defb573c342ef3a8 Mon Sep 17 00:00:00 2001 From: SanjayUG Date: Tue, 18 Mar 2025 17:39:27 +0530 Subject: [PATCH] feat: Add Polars DataFrame support --- README.md | 33 ++++++++ pyopenms_viz/__init__.py | 30 ++++++- pyopenms_viz/_bokeh/core.py | 9 +- pyopenms_viz/_core.py | 118 +++++++++++++------------- pyopenms_viz/_dataframe.py | 141 +++++++++++++++++++++++++++++++ pyopenms_viz/_matplotlib/core.py | 8 +- pyopenms_viz/_plotly/core.py | 3 +- requirements.txt | 1 + 8 files changed, 273 insertions(+), 70 deletions(-) create mode 100644 pyopenms_viz/_dataframe.py diff --git a/README.md b/README.md index 96f0524f..a28141f5 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,9 @@ pyOpenMS-Viz is a Python library that provides a simple interface for extending - Versatile column selection for easy adaptation to different data formats - Consistent API across different plotting backends for easy switching between static and interactive plots - Suitable for use in scripts, Jupyter notebooks, and web applications +- Now supports both pandas and polars DataFrames! +- Interactive plots with zoom, pan, and hover capabilities +- Customizable plot styling and annotations ## Suported Plots | **Plot Type** | **Required Dimensions** | **pyopenms_viz Name** | **Matplotlib** | **Bokeh** | **Plotly** | @@ -57,3 +60,33 @@ Documentation can be found [here](https://pyopenms-viz.readthedocs.io/en/latest/ - Pfeuffer, J., Bielow, C., Wein, S. et al. OpenMS 3 enables reproducible analysis of large-scale mass spectrometry data. Nat Methods 21, 365–367 (2024). [https://doi.org/10.1038/s41592-024-02197-7](https://doi.org/10.1038/s41592-024-02197-7) - Röst HL, Schmitt U, Aebersold R, Malmström L. pyOpenMS: a Python-based interface to the OpenMS mass-spectrometry algorithm library. Proteomics. 2014 Jan;14(1):74-7. [https://doi.org/10.1002/pmic.201300246](https://doi.org/10.1002/pmic.201300246). PMID: [24420968](https://pubmed.ncbi.nlm.nih.gov/24420968/). + +## Quick Start + +```python +import pandas as pd +import polars as pl +from pyopenms_viz import plot + +# Using pandas DataFrame +df = pd.DataFrame({ + 'mz': [100, 200, 300], + 'intensity': [1000, 2000, 3000] +}) +plot(df, x='mz', y='intensity', kind='spectrum') + +# Using polars DataFrame +df_pl = pl.DataFrame({ + 'mz': [100, 200, 300], + 'intensity': [1000, 2000, 3000] +}) +plot(df_pl, x='mz', y='intensity', kind='spectrum') +``` + +## Contributing + +Contributions are welcome! Please feel free to submit a Pull Request. + +## License + +This project is licensed under the MIT License - see the LICENSE file for details. diff --git a/pyopenms_viz/__init__.py b/pyopenms_viz/__init__.py index e822cc9f..27f0a210 100644 --- a/pyopenms_viz/__init__.py +++ b/pyopenms_viz/__init__.py @@ -2,9 +2,13 @@ init """ -from pandas.plotting._core import PlotAccessor +from __future__ import annotations + +from typing import Any, Union +import pandas as pd +import polars as pl from pandas.core.frame import DataFrame -from typing import Any +from pandas.plotting._core import PlotAccessor from pandas.core.dtypes.generic import ABCDataFrame import importlib import types @@ -197,4 +201,26 @@ def _get_plot_backend(backend: str | None = None): return module +def plot(data: Union[pd.DataFrame, pl.DataFrame], *args, **kwargs): + """ + Make plots of MassSpec data using pandas or polars DataFrames. + + Parameters + ---------- + data : pandas.DataFrame or polars.DataFrame + The data to be plotted. + *args : tuple + Variable length argument list. + **kwargs : dict + Arbitrary keyword arguments. + + Returns + ------- + figure + The plot figure object. + """ + plot_obj = PlotAccessor(data) + return plot_obj(*args, **kwargs) + + __all__ = ["PlotAccessor"] diff --git a/pyopenms_viz/_bokeh/core.py b/pyopenms_viz/_bokeh/core.py index 0450d71f..b7060e92 100644 --- a/pyopenms_viz/_bokeh/core.py +++ b/pyopenms_viz/_bokeh/core.py @@ -2,7 +2,7 @@ from abc import ABC -from typing import Tuple, Iterator +from typing import Tuple, Iterator, Any, Dict, List, Optional, Union from dataclasses import dataclass from bokeh.plotting import figure @@ -21,6 +21,7 @@ from pandas.core.frame import DataFrame from numpy import nan +import numpy as np # pyopenms_viz imports from .._core import ( @@ -35,6 +36,7 @@ SpectrumPlot, APPEND_PLOT_DOC, ) +from .._dataframe import DataFrameWrapper from .._misc import ColorGenerator, MarkerShapeGenerator, is_latex_formatted from ..constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON @@ -294,7 +296,7 @@ def plot(self): Plot a line plot """ if self.by is None: - source = ColumnDataSource(self.data) + source = ColumnDataSource(self.data.to_pandas()) line = self.fig.line( x=self.x, y=self.y, @@ -303,10 +305,9 @@ def plot(self): line_width=self.line_width, ) else: - legend_items = [] for group, df in self.data.groupby(self.by, sort=False): - source = ColumnDataSource(df) + source = ColumnDataSource(df.to_pandas()) line = self.fig.line( x=self.x, y=self.y, diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index 62d51c02..e9f56756 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -31,9 +31,12 @@ sturges_rule, freedman_diaconis_rule, mz_tolerance_binning, + MarkerShapeGenerator, + is_latex_formatted, ) -from .constants import IS_SPHINX_BUILD, IS_NOTEBOOK +from .constants import IS_SPHINX_BUILD, IS_NOTEBOOK, PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON import warnings +from ._dataframe import DataFrameType, DataFrameWrapper, wrap_dataframe _common_kinds = ("line", "vline", "scatter") @@ -126,8 +129,8 @@ def canvas(self): def canvas(self, value): self._config.canvas = value - def __init__(self, data: DataFrame, config: BasePlotConfig = None, **kwargs): - self.data = data.copy() + def __init__(self, data: Union[DataFrame, DataFrameType], config: BasePlotConfig = None, **kwargs): + self.data = wrap_dataframe(data) if config is None: self._config = self._configClass(**kwargs) else: @@ -145,7 +148,7 @@ def __init__(self, data: DataFrame, config: BasePlotConfig = None, **kwargs): if self.by is not None: # Ensure by column data is string self.by = self._verify_column(self.by, "by") - self.data[self.by] = self.data[self.by].astype(str) + self.data.set_column(self.by, self.data.get_column(self.by).astype(str)) # only value that needs to be dynamically set def _copy_config_attributes(self): @@ -178,7 +181,7 @@ def holds_integer(column) -> bool: raise ValueError(f"For `{self._kind}` plot, `{name}` must be set") # if integer is supplied get the corresponding column associated with that index - if is_integer(colname) and not holds_integer(self.data.columns): + if is_integer(colname) and not holds_integer(self.data.get_column(colname)): if colname >= len(self.data.columns): raise ValueError( f"Column index `{colname}` out of range, `{name}` could not be set" @@ -229,11 +232,10 @@ def _check_and_aggregate_duplicates(self): col for col in self.known_columns if col != self.y ] - if self.data[known_columns_without_int].duplicated().any(): + if self.data.get_column(known_columns_without_int).duplicated().any(): if self.aggregate_duplicates: self.data = ( - self.data[self.known_columns] - .groupby(known_columns_without_int) + self.data.groupby(known_columns_without_int) .sum() .reset_index() ) @@ -573,7 +575,7 @@ def __init__(self, data, config: ChromatogramConfig = None, **kwargs) -> None: # Convert to relative intensity if required if self.relative_intensity: - self.data[self.y] = self.data[self.y] / self.data[self.y].max() * 100 + self.data.set_column(self.y, self.data.get_column(self.y) / self.data.get_column(self.y).max() * 100) self.plot() @@ -593,7 +595,7 @@ def plot(self): linePlot = self.get_line_renderer(data=self.data, config=self._config) self.canvas = linePlot.generate(tooltips, custom_hover_data) - self._modify_y_range((0, self.data[self.y].max()), (0, 0.1)) + self._modify_y_range((0, self.data.get_column(self.y).max()), (0, 0.1)) if self._interactive: self.manual_boundary_renderer = self._add_bounding_vertical_drawer() @@ -620,9 +622,8 @@ def compute_apex_intensity(self, annotation_data): Compute the apex intensity of the peak group based on the peak boundaries """ for idx, feature in annotation_data.iterrows(): - annotation_data.loc[idx, "apexIntensity"] = self.data.loc[ - self.data[self.x].between(feature["leftWidth"], feature["rightWidth"]), - self.y, + annotation_data.loc[idx, "apexIntensity"] = self.data.get_column(self.y).loc[ + self.data.get_column(self.x).between(feature["leftWidth"], feature["rightWidth"]) ].max() @@ -646,7 +647,7 @@ def load_config(self, **kwargs): def plot(self): fig = super().plot() - self._modify_y_range((0, self.data[self.y].max()), (0, 0.1)) + self._modify_y_range((0, self.data.get_column(self.y).max()), (0, 0.1)) class SpectrumPlot(BaseMSPlot, ABC): @@ -701,11 +702,10 @@ def _check_and_aggregate_duplicates(self): super()._check_and_aggregate_duplicates() if self.reference_spectrum is not None: - if self.reference_spectrum[self.known_columns].duplicated().any(): + if self.reference_spectrum.get_column(self.known_columns).duplicated().any(): if self.aggregate_duplicates: self.reference_spectrum = ( - self.reference_spectrum[self.known_columns] - .groupby(self.known_columns) + self.reference_spectrum.groupby(self.known_columns) .sum() .reset_index() ) @@ -800,7 +800,7 @@ def plot(self): if self.mirror_spectrum and self.reference_spectrum is not None: ## create a mirror spectrum # Set intensity to negative values - reference_spectrum[self.y] = reference_spectrum[self.y] * -1 + reference_spectrum.set_column(self.y, reference_spectrum.get_column(self.y) * -1) color_mirror = self._get_colors(reference_spectrum, kind="peak") reference_spectrum = self.convert_for_line_plots( @@ -828,19 +828,19 @@ def plot(self): self.plot_x_axis_line(self.canvas, line_width=2) # Adjust x axis padding (Plotly cuts outermost peaks) - min_values = [spectrum[self.x].min()] - max_values = [spectrum[self.x].max()] + min_values = [spectrum.get_column(self.x).min()] + max_values = [spectrum.get_column(self.x).max()] if reference_spectrum is not None: - min_values.append(reference_spectrum[self.x].min()) - max_values.append(reference_spectrum[self.x].max()) + min_values.append(reference_spectrum.get_column(self.x).min()) + max_values.append(reference_spectrum.get_column(self.x).max()) self._modify_x_range((min(min_values), max(max_values)), padding=(0.20, 0.20)) # Adjust y axis padding (annotations should stay inside plot) - max_value = spectrum[self.y].max() + max_value = spectrum.get_column(self.y).max() min_value = 0 min_padding = 0 max_padding = 0.15 if reference_spectrum is not None and self.mirror_spectrum: - min_value = reference_spectrum[self.y].min() + min_value = reference_spectrum.get_column(self.y).min() min_padding = -0.2 max_padding = 0.4 @@ -869,9 +869,9 @@ def assign_bin(value): return nan # For values that don't fall into any bin # Apply the binning - df[self.x] = df[self.x].apply(assign_bin) + df.set_column(self.x, df.get_column(self.x).apply(assign_bin)) else: # use computed number of bins, bins evenly spaced - bins = np.histogram_bin_edges(df[self.x], self._computed_num_bins) + bins = np.histogram_bin_edges(df.get_column(self.x), self._computed_num_bins) def assign_bin(value): for low_idx in range(len(bins) - 1): @@ -880,7 +880,7 @@ def assign_bin(value): return nan # For values that don't fall into any bin # Apply the binning - df[self.x] = df[self.x].apply(assign_bin) + df.set_column(self.x, df.get_column(self.x).apply(assign_bin)) # TODO I am not sure why "cut" method seems to be failing with plotly so created a workaround for now # error is that object is not JSON serializable because of Interval type @@ -916,7 +916,7 @@ def convert_to_numeric(value): else: return value - df[self.x] = df[self.x].apply(convert_to_numeric).astype(float) + df.set_column(self.x, df.get_column(self.x).apply(convert_to_numeric).astype(float)) df = df.fillna(0) return df @@ -935,7 +935,7 @@ def _prepare_data(self, df, label_suffix=""): # Convert to relative intensity if required if self.relative_intensity or self.mirror_spectrum: - df[self.y] = df[self.y] / df[self.y].max() * 100 + df.set_column(self.y, df.get_column(self.y) / df.get_column(self.y).max() * 100) # Bin peaks if required if self.bin_peaks == True or (self.bin_peaks == "auto"): @@ -953,31 +953,31 @@ def _get_colors( self.annotation_color is not None and self.annotation_color in data.columns ): - return ColorGenerator(data[self.annotation_color]) + return ColorGenerator(data.get_column(self.annotation_color)) # Ion annotation colors elif ( self.ion_annotation is not None and self.ion_annotation in data.columns ): # Generate colors based on ion annotations return ColorGenerator( - self._get_ion_color_annotation(data[self.ion_annotation]) + self._get_ion_color_annotation(data.get_column(self.ion_annotation)) ) # Grouped by colors (from default color map) elif self.by is not None: # Get unique values to determine number of distinct colors - uniques = data[self.by].unique() + uniques = data.get_column(self.by).unique().tolist() color_gen = ColorGenerator() # Generate a list of colors equal to the number of unique values colors = [next(color_gen) for _ in range(len(uniques))] # Create a mapping of unique values to their corresponding colors color_map = {uniques[i]: colors[i] for i in range(len(colors))} # Apply the color mapping to the specified column in the data and turn it into a ColorGenerator - return ColorGenerator(data[self.by].apply(lambda x: color_map[x])) + return ColorGenerator(data.get_column(self.by).apply(lambda x: color_map[x])) # Fallback ColorGenerator with one color return ColorGenerator(n=1) else: # Peaks if self.by: - uniques = data[self.by].unique().tolist() + uniques = data.get_column(self.by).unique().tolist() # Custom colors with top priority if self.peak_color is not None: return ColorGenerator(uniques) @@ -990,7 +990,7 @@ def _get_colors( def _get_annotations(self, data: DataFrame, x: str, y: str): """Create annotations for each peak. Return lists of texts, x and y locations and colors.""" - data["color"] = ["black" for _ in range(len(data))] + data.set_column("color", ["black" for _ in range(len(data))]) ann_texts = [] top_n = self.annotate_top_n_peaks @@ -1000,25 +1000,25 @@ def _get_annotations(self, data: DataFrame, x: str, y: str): top_n = 0 # sort values for top intensity peaks on top (ascending for reference spectra with negative values) data = data.sort_values( - y, ascending=True if data[y].min() < 0 else False + y, ascending=True if data.get_column(y).min() < 0 else False ).reset_index() for i, row in data.iterrows(): texts = [] if i < top_n: if self.annotate_mz: - texts.append(str(round(row[x], 4))) + texts.append(str(round(row.get_column(x), 4))) if self.ion_annotation and self.ion_annotation in data.columns: - texts.append(str(row[self.ion_annotation])) + texts.append(str(row.get_column(self.ion_annotation))) if ( self.sequence_annotation and self.sequence_annotation in data.columns ): - texts.append(str(row[self.sequence_annotation])) + texts.append(str(row.get_column(self.sequence_annotation))) if self.custom_annotation and self.custom_annotation in data.columns: - texts.append(str(row[self.custom_annotation])) + texts.append(str(row.get_column(self.custom_annotation))) ann_texts.append("\n".join(texts)) - return ann_texts, data[x].tolist(), data[y].tolist(), data["color"].tolist() + return ann_texts, data.get_column(x).tolist(), data.get_column(y).tolist(), data.get_column("color").tolist() def _get_ion_color_annotation(self, ion_annotations: str) -> str: """Retrieve the color associated with a specific ion annotation from a predefined colormap.""" @@ -1063,12 +1063,12 @@ def to_line(self, x, y): def convert_for_line_plots(self, data: DataFrame, x: str, y: str) -> DataFrame: if self.by is None: - x_data, y_data = self.to_line(data[x], data[y]) + x_data, y_data = self.to_line(data.get_column(x), data.get_column(y)) return DataFrame({x: x_data, y: y_data}) else: dfs = [] for name, df in data.groupby(self.by, sort=False): - x_data, y_data = self.to_line(df[x], df[y]) + x_data, y_data = self.to_line(df.get_column(x), df.get_column(y)) dfs.append(DataFrame({x: x_data, y: y_data, self.by: name})) return concat(dfs) @@ -1127,15 +1127,15 @@ def __init__(self, data, **kwargs) -> None: def prepare_data(self): # Convert intensity values to relative intensity if required if self.relative_intensity and self.z is not None: - self.data[self.z] = self.data[self.z] / max(self.data[self.z]) * 100 + self.data.set_column(self.z, self.data.get_column(self.z) / max(self.data.get_column(self.z)) * 100) # Bin peaks if required if self.bin_peaks == True or ( self.data.shape[0] > self.num_x_bins * self.num_y_bins and self.bin_peaks == "auto" ): - self.data[self.x] = cut(self.data[self.x], bins=self.num_x_bins) - self.data[self.y] = cut(self.data[self.y], bins=self.num_y_bins) + self.data.set_column(self.x, cut(self.data.get_column(self.x), bins=self.num_x_bins)) + self.data.set_column(self.y, cut(self.data.get_column(self.y), bins=self.num_y_bins)) if self.z is not None: if self.by is not None: # Group by x, y and by columns and calculate the mean intensity within each bin @@ -1151,17 +1151,13 @@ def prepare_data(self): .agg({self.z: "mean"}) .reset_index() ) - self.data[self.x] = ( - self.data[self.x].apply(lambda interval: interval.mid).astype(float) - ) - self.data[self.y] = ( - self.data[self.y].apply(lambda interval: interval.mid).astype(float) - ) + self.data.set_column(self.x, self.data.get_column(self.x).apply(lambda interval: interval.mid).astype(float)) + self.data.set_column(self.y, self.data.get_column(self.y).apply(lambda interval: interval.mid).astype(float)) self.data = self.data.fillna(0) # Log intensity scale if self.z_log_scale: - self.data[self.z] = log1p(self.data[self.z]) + self.data.set_column(self.z, log1p(self.data.get_column(self.z))) # Sort values by intensity in ascending order to plot highest intensity peaks last if self.z is not None: @@ -1312,21 +1308,21 @@ def center_of_gravity(x, m): t = feature[self.annotation_names] c = feature[self.annotation_colors] selected_data = self.data[ - (self.data[x] > x0) - & (self.data[x] < x1) - & (self.data[y] > y0) - & (self.data[y] < y1) + (self.data.get_column(x) > x0) + & (self.data.get_column(x) < x1) + & (self.data.get_column(y) > y0) + & (self.data.get_column(y) < y1) ] if len(selected_data) == 0: annotations_3d.append( - (np.mean((x0, x1)), np.mean((y0, y1)), np.mean(self.data[z]), t, c) + (np.mean((x0, x1)), np.mean((y0, y1)), np.mean(self.data.get_column(z)), t, c) ) else: annotations_3d.append( ( - center_of_gravity(selected_data[x], selected_data[z]), - center_of_gravity(selected_data[y], selected_data[z]), - np.max(selected_data[z]) * 1.05, + center_of_gravity(selected_data.get_column(x), selected_data.get_column(z)), + center_of_gravity(selected_data.get_column(y), selected_data.get_column(z)), + np.max(selected_data.get_column(z)) * 1.05, t, c, ) diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py new file mode 100644 index 00000000..a50a467a --- /dev/null +++ b/pyopenms_viz/_dataframe.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from typing import Any, Union, List, Optional +import pandas as pd +import polars as pl +import numpy as np + +DataFrameType = Union[pd.DataFrame, pl.DataFrame] + +class DataFrameWrapper: + """ + A wrapper class that provides a unified interface for both pandas and polars DataFrames. + This allows pyopenms_viz to work with either type without modifying the existing API. + """ + + def __init__(self, data: DataFrameType): + self._data = data + self._is_polars = isinstance(data, pl.DataFrame) + + @property + def data(self) -> DataFrameType: + """Get the underlying DataFrame.""" + return self._data + + def to_pandas(self) -> pd.DataFrame: + """Convert the DataFrame to pandas if needed.""" + if self._is_polars: + return self._data.to_pandas() + return self._data + + def to_polars(self) -> pl.DataFrame: + """Convert the DataFrame to polars if needed.""" + if not self._is_polars: + return pl.from_pandas(self._data) + return self._data + + def copy(self) -> 'DataFrameWrapper': + """Create a copy of the DataFrame.""" + if self._is_polars: + return DataFrameWrapper(self._data.clone()) + return DataFrameWrapper(self._data.copy()) + + def get_column(self, col: str) -> np.ndarray: + """Get a column as numpy array.""" + if self._is_polars: + return self._data[col].to_numpy() + return self._data[col].to_numpy() + + def set_column(self, col: str, value: Any) -> None: + """Set a column value.""" + if self._is_polars: + self._data = self._data.with_columns(pl.Series(col, value)) + else: + self._data[col] = value + + def groupby(self, by: str) -> 'GroupByWrapper': + """Group the DataFrame by a column.""" + if self._is_polars: + return GroupByWrapper(self._data.groupby(by), is_polars=True) + return GroupByWrapper(self._data.groupby(by), is_polars=False) + + def fillna(self, value: Any) -> 'DataFrameWrapper': + """Fill NA/null values.""" + if self._is_polars: + return DataFrameWrapper(self._data.fill_null(value)) + return DataFrameWrapper(self._data.fillna(value)) + + def between(self, col: str, left: float, right: float) -> 'DataFrameWrapper': + """Select rows where column values are between left and right.""" + if self._is_polars: + mask = (self._data[col] >= left) & (self._data[col] <= right) + return DataFrameWrapper(self._data.filter(mask)) + return DataFrameWrapper(self._data[self._data[col].between(left, right)]) + + def max(self, col: str) -> float: + """Get maximum value of a column.""" + if self._is_polars: + return self._data[col].max() + return self._data[col].max() + + def min(self, col: str) -> float: + """Get minimum value of a column.""" + if self._is_polars: + return self._data[col].min() + return self._data[col].min() + + def iterrows(self): + """Iterate over DataFrame rows.""" + if self._is_polars: + for row in self._data.iter_rows(named=True): + yield row + else: + for idx, row in self._data.iterrows(): + yield row + + @property + def columns(self) -> List[str]: + """Get column names.""" + if self._is_polars: + return self._data.columns + return list(self._data.columns) + + def __getitem__(self, key: str) -> np.ndarray: + """Get a column by name.""" + return self.get_column(key) + + +class GroupByWrapper: + """Wrapper for grouped DataFrame operations.""" + + def __init__(self, grouped, is_polars: bool): + self._grouped = grouped + self._is_polars = is_polars + + def __iter__(self): + """Iterate over groups.""" + if self._is_polars: + for name, group in self._grouped.groups(): + yield name, DataFrameWrapper(group) + else: + for name, group in self._grouped: + yield name, DataFrameWrapper(group) + + def agg(self, func: dict) -> DataFrameWrapper: + """Aggregate using the specified functions.""" + if self._is_polars: + agg_exprs = [] + for col, agg_func in func.items(): + if isinstance(agg_func, str): + agg_exprs.append(pl.col(col).agg(agg_func)) + else: + agg_exprs.append(pl.col(col).agg(lambda x: agg_func(x.to_numpy()))) + result = self._grouped.agg(agg_exprs) + return DataFrameWrapper(result) + else: + return DataFrameWrapper(self._grouped.agg(func)) + + +def wrap_dataframe(data: DataFrameType) -> DataFrameWrapper: + """Create a DataFrameWrapper instance from a pandas or polars DataFrame.""" + return DataFrameWrapper(data) \ No newline at end of file diff --git a/pyopenms_viz/_matplotlib/core.py b/pyopenms_viz/_matplotlib/core.py index 11e7c737..703243de 100644 --- a/pyopenms_viz/_matplotlib/core.py +++ b/pyopenms_viz/_matplotlib/core.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC -from typing import Tuple +from typing import Tuple, Any, Dict, List, Optional, Union import re from numpy import nan import matplotlib.pyplot as plt @@ -9,6 +9,9 @@ from matplotlib.patches import Rectangle from matplotlib.axes import Axes from matplotlib.figure import Figure +import numpy as np +from pandas import DataFrame +from mpl_toolkits.mplot3d import Axes3D from .._config import LegendConfig @@ -25,6 +28,7 @@ PeakMapPlot, APPEND_PLOT_DOC, ) +from .._dataframe import DataFrameWrapper class MATPLOTLIBPlot(BasePlot, ABC): @@ -307,7 +311,7 @@ def plot(self): ) else: - for group, df in self.data.groupby(self.by, sort=True): + for group, df in self.data.groupby(self.by): (line,) = self.ax.plot( df[self.x], df[self.y], color=self.current_color, **kwargs ) diff --git a/pyopenms_viz/_plotly/core.py b/pyopenms_viz/_plotly/core.py index b106a044..84d027d2 100644 --- a/pyopenms_viz/_plotly/core.py +++ b/pyopenms_viz/_plotly/core.py @@ -2,7 +2,7 @@ from abc import ABC -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Any, Dict, Optional import plotly.graph_objects as go from plotly.graph_objs import Figure @@ -28,6 +28,7 @@ from .._config import bokeh_line_dash_mapper from .._misc import ColorGenerator, MarkerShapeGenerator, is_latex_formatted from ..constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON +from .._dataframe import DataFrameWrapper class PLOTLYPlot(BasePlot, ABC): diff --git a/requirements.txt b/requirements.txt index e99ee052..f5bc5871 100644 --- a/requirements.txt +++ b/requirements.txt @@ -217,6 +217,7 @@ webencodings==0.5.1 # tinycss2 xyzservices==2024.9.0 # via bokeh +polars>=0.20.7 # pyopenms-viz git+https://github.com/OpenMS/pyopenms_viz.git