From dc3222bbd57f9a0681ff130dfb8911a4fb4a411f Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 2 Apr 2026 15:02:12 -0400 Subject: [PATCH 1/9] removed shuffle --- scallops/features/agg.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index ebece64..fbaffde 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -46,8 +46,6 @@ def agg_features( xdata = xr.DataArray(data=data.X, dims=("obs", "var"), coords=coords, name="") if group_by_multi: xdata = xdata.set_xindex(by, PandasMultiIndex) - if agg_func == "median" and isinstance(xdata.data, da.Array): - xdata = xdata.groupby("obs").shuffle_to_chunks() grouped = xdata.groupby("obs") xp = get_namespace(xdata.data) From 3df10f0b2fad3d86e7c8efebc93aa74dd95e89cb Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 2 Apr 2026 15:32:59 -0400 Subject: [PATCH 2/9] handle agg multi --- scallops/features/agg.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index fbaffde..f4b00aa 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -8,7 +8,6 @@ import xarray as xr from array_api_compat import get_namespace from dask.array.numpy_compat import NUMPY_GE_200 -from pandas import MultiIndex from statsmodels.stats.weightstats import DescrStatsW from xarray.core.indexes import PandasMultiIndex @@ -102,19 +101,34 @@ def weighted_agg(x): groups.append(group) counts.append(count) + counts = [] + groups = [] + for group in grouped.groups: + val = grouped.groups[group] + if isinstance(val, slice): + count = ( + val.stop - val.start + if val.step is None + else len(val.indices(X.shape[0])) + ) + else: + count = len(val) + groups.append(group) + counts.append(count) + obs = result.coords["obs"].to_dataframe() group_counts = pd.DataFrame( data={"count": counts}, - index=pd.MultiIndex.from_tuples(groups, names=obs.index.names) - if isinstance(obs.index, MultiIndex) - else pd.Index(groups), - ) - obs = ( - obs.drop("obs", errors="ignore", axis=1) - .join(group_counts, rsuffix="_1") - .reset_index() + index=groups, ) - if not group_by_multi and "obs" in obs.columns: + obs = obs.join(group_counts, rsuffix="_1").reset_index(drop=True) + if group_by_multi: + new_obs = pd.DataFrame(obs["obs"].tolist(), columns=by) + for c in obs.columns: + if c != "obs" and c not in new_obs.columns: + new_obs[c] = obs[c] + obs = new_obs + else: obs = obs.rename({"obs": by}, axis=1) obs = obs.set_index(pd.RangeIndex(len(obs)).astype(str)) return anndata.AnnData( From 40ca8f18f68b72b9eb233c9c30d27b098827c3fe Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 2 Apr 2026 15:45:58 -0400 Subject: [PATCH 3/9] handle agg multi --- scallops/features/agg.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index f4b00aa..72040fe 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -9,7 +9,6 @@ from array_api_compat import get_namespace from dask.array.numpy_compat import NUMPY_GE_200 from statsmodels.stats.weightstats import DescrStatsW -from xarray.core.indexes import PandasMultiIndex def _weighted_median(x, weights): @@ -43,8 +42,6 @@ def agg_features( if weights_col is not None: coords[weights_col] = ("obs", data.obs[weights_col]) xdata = xr.DataArray(data=data.X, dims=("obs", "var"), coords=coords, name="") - if group_by_multi: - xdata = xdata.set_xindex(by, PandasMultiIndex) grouped = xdata.groupby("obs") xp = get_namespace(xdata.data) From bc5b4b7641b7d175e9fec8e97aed8818ef5be045 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 2 Apr 2026 15:47:26 -0400 Subject: [PATCH 4/9] handle agg multi --- scallops/features/agg.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index 72040fe..a364757 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -36,9 +36,7 @@ def agg_features( if not group_by_multi: coords = {"obs": data.obs[by]} else: - coords = {"obs": np.arange(data.shape[0])} - for col in by: - coords[col] = ("obs", data.obs[col]) + coords = {"obs": data.obs[by].apply(tuple, axis=1)} if weights_col is not None: coords[weights_col] = ("obs", data.obs[weights_col]) xdata = xr.DataArray(data=data.X, dims=("obs", "var"), coords=coords, name="") From 60d69b3a69ab00952c182ebc958b3acf1739b918 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 2 Apr 2026 15:58:44 -0400 Subject: [PATCH 5/9] convert by to list --- scallops/features/agg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index a364757..d3cfaa6 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -36,6 +36,7 @@ def agg_features( if not group_by_multi: coords = {"obs": data.obs[by]} else: + by = list(by) coords = {"obs": data.obs[by].apply(tuple, axis=1)} if weights_col is not None: coords[weights_col] = ("obs", data.obs[weights_col]) From 4007871ea3fd93cedab13091d1bf9ae4b57ec873 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 2 Apr 2026 16:00:47 -0400 Subject: [PATCH 6/9] check for list of length 1 --- scallops/features/agg.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index d3cfaa6..e577074 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -33,10 +33,15 @@ def agg_features( assert agg_func in ("mean", "median") group_by_multi = not isinstance(by, str) and isinstance(by, Sequence) + if not group_by_multi: + by = list(by) + if len(by) == 1: + by = by[0] + group_by_multi = False + if not group_by_multi: coords = {"obs": data.obs[by]} else: - by = list(by) coords = {"obs": data.obs[by].apply(tuple, axis=1)} if weights_col is not None: coords[weights_col] = ("obs", data.obs[weights_col]) From 4c3d38345036c510e21ffe43e23c75d0751ffc59 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 2 Apr 2026 16:11:21 -0400 Subject: [PATCH 7/9] count column --- scallops/features/agg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index e577074..50dc145 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -126,7 +126,7 @@ def weighted_agg(x): if group_by_multi: new_obs = pd.DataFrame(obs["obs"].tolist(), columns=by) for c in obs.columns: - if c != "obs" and c not in new_obs.columns: + if c.startswith("count") and c not in new_obs.columns: new_obs[c] = obs[c] obs = new_obs else: From 66fa3a081690207455d76150fb60aa84e70c4d41 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 2 Apr 2026 21:41:39 -0400 Subject: [PATCH 8/9] fix check for multi --- scallops/features/agg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index 50dc145..c7bf4ef 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -33,7 +33,7 @@ def agg_features( assert agg_func in ("mean", "median") group_by_multi = not isinstance(by, str) and isinstance(by, Sequence) - if not group_by_multi: + if group_by_multi: by = list(by) if len(by) == 1: by = by[0] From 487e8aa5dbe7dce0c600e3a9f679a7c7b4064f75 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 3 Apr 2026 12:13:55 -0400 Subject: [PATCH 9/9] removed duplicate code --- scallops/features/agg.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/scallops/features/agg.py b/scallops/features/agg.py index c7bf4ef..3a69f02 100644 --- a/scallops/features/agg.py +++ b/scallops/features/agg.py @@ -102,26 +102,11 @@ def weighted_agg(x): groups.append(group) counts.append(count) - counts = [] - groups = [] - for group in grouped.groups: - val = grouped.groups[group] - if isinstance(val, slice): - count = ( - val.stop - val.start - if val.step is None - else len(val.indices(X.shape[0])) - ) - else: - count = len(val) - groups.append(group) - counts.append(count) - - obs = result.coords["obs"].to_dataframe() group_counts = pd.DataFrame( data={"count": counts}, index=groups, ) + obs = result.coords["obs"].to_dataframe() obs = obs.join(group_counts, rsuffix="_1").reset_index(drop=True) if group_by_multi: new_obs = pd.DataFrame(obs["obs"].tolist(), columns=by)