diff --git a/scallops/features/agg.py b/scallops/features/agg.py index ebece64..3a69f02 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -8,9 +8,7 @@ import xarray as xr from array_api_compat import get_namespace from dask.array.numpy_compat import NUMPY_GE_200 -from pandas import MultiIndex from statsmodels.stats.weightstats import DescrStatsW -from xarray.core.indexes import PandasMultiIndex def _weighted_median(x, weights): @@ -35,19 +33,19 @@ def agg_features( assert agg_func in ("mean", "median") group_by_multi = not isinstance(by, str) and isinstance(by, Sequence) + if group_by_multi: + by = list(by) + if len(by) == 1: + by = by[0] + group_by_multi = False + if not group_by_multi: coords = {"obs": data.obs[by]} else: - coords = {"obs": np.arange(data.shape[0])} - for col in by: - coords[col] = ("obs", data.obs[col]) + coords = {"obs": data.obs[by].apply(tuple, axis=1)} if weights_col is not None: coords[weights_col] = ("obs", data.obs[weights_col]) xdata = xr.DataArray(data=data.X, dims=("obs", "var"), coords=coords, name="") - if group_by_multi: - xdata = xdata.set_xindex(by, PandasMultiIndex) - if agg_func == "median" and isinstance(xdata.data, da.Array): - xdata = xdata.groupby("obs").shuffle_to_chunks() grouped = xdata.groupby("obs") xp = get_namespace(xdata.data) @@ -104,19 +102,19 @@ def weighted_agg(x): groups.append(group) counts.append(count) - obs = result.coords["obs"].to_dataframe() group_counts = pd.DataFrame( data={"count": counts}, - index=pd.MultiIndex.from_tuples(groups, names=obs.index.names) - if isinstance(obs.index, MultiIndex) - else pd.Index(groups), + index=groups, ) - obs = ( - obs.drop("obs", errors="ignore", axis=1) - .join(group_counts, rsuffix="_1") - .reset_index() - ) - if not group_by_multi and "obs" in obs.columns: + obs = result.coords["obs"].to_dataframe() + obs = obs.join(group_counts, rsuffix="_1").reset_index(drop=True) + if group_by_multi: + new_obs = pd.DataFrame(obs["obs"].tolist(), columns=by) + for c in obs.columns: + if c.startswith("count") and c not in new_obs.columns: + new_obs[c] = obs[c] + obs = new_obs + else: obs = obs.rename({"obs": by}, axis=1) obs = obs.set_index(pd.RangeIndex(len(obs)).astype(str)) return anndata.AnnData(