Skip to content

Commit 1fd464b

Browse files
samukwekusamuel.oranyeli
andauthored
[ENH] conditional_join refactor single join (#1520)
* single equi join migration * changelog * fix relative imports * fix relative imports * fix relative imports --------- Co-authored-by: samuel.oranyeli <samuel.oranyeli@grow.inc>
1 parent 82e0399 commit 1fd464b

File tree

9 files changed

+802
-2586
lines changed

9 files changed

+802
-2586
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,4 @@ tags
143143
*.profraw
144144
/scratch.py
145145
midpoint.csv
146+
examples/notebooks/cond_join.ipynb

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Changelog
22

33
## [Unreleased]
4-
- [ENH] Added `row_count` parameter for janitor.conditional_join - Issue #1269 @samukweku
4+
- [ENH] `return_ragged_arrays` deprecated; get_join_indices function now returns a dictionary - Issue #520 @samukweku
55
- [ENH] Reverse deprecation of `pivot_wider()` -- Issue #1464
66
- [ENH] Add accessor and method for pandas DataFrameGroupBy objects. - Issue #587 @samukweku
77
- [ENH] Call mutate/summarise directly on groupby objects instead. Also add `ungroup` method to expose underlying dataframe of a grouped object. - Issue #1511 @samukweku
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# helper functions for >/>=
2+
import numpy as np
3+
import pandas as pd
4+
5+
from janitor.functions._conditional_join._helpers import (
6+
_null_checks_cond_join,
7+
_sort_if_not_monotonic,
8+
)
9+
10+
11+
def _ge_gt_indices(
12+
left: pd.array,
13+
left_index: np.ndarray,
14+
right: pd.array,
15+
strict: bool,
16+
) -> tuple | None:
17+
"""
18+
Use binary search to get indices where left
19+
is greater than or equal to right.
20+
21+
If strict is True, then only indices
22+
where `left` is greater than
23+
(but not equal to) `right` are returned.
24+
"""
25+
search_indices = right.searchsorted(left, side="right")
26+
# if any of the positions in `search_indices`
27+
# is equal to 0 (less than 1), it implies that
28+
# left[position] is not greater than any value
29+
# in right
30+
booleans = search_indices > 0
31+
if not booleans.any():
32+
return None
33+
if not booleans.all():
34+
left = left[booleans]
35+
left_index = left_index[booleans]
36+
search_indices = search_indices[booleans]
37+
# the idea here is that if there are any equal values
38+
# shift downwards to the immediate next position
39+
# that is not equal
40+
if strict:
41+
booleans = left == right[search_indices - 1]
42+
# replace positions where rows are equal with
43+
# searchsorted('left');
44+
# this works fine since we will be using the value
45+
# as the right side of a slice, which is not included
46+
# in the final computed value
47+
if booleans.any():
48+
replacements = right.searchsorted(left, side="left")
49+
# now we can safely replace values
50+
# with strictly greater than positions
51+
search_indices = np.where(booleans, replacements, search_indices)
52+
# any value less than 1 should be discarded
53+
# since the lowest value for binary search
54+
# with side='right' should be 1
55+
booleans = search_indices > 0
56+
if not booleans.any():
57+
return None
58+
if not booleans.all():
59+
left_index = left_index[booleans]
60+
search_indices = search_indices[booleans]
61+
return left_index, search_indices
62+
63+
64+
def _greater_than_indices(
65+
left: pd.Series,
66+
right: pd.Series,
67+
strict: bool,
68+
keep: str,
69+
return_matching_indices: bool,
70+
) -> dict | None:
71+
"""
72+
Use binary search to get indices where left
73+
is greater than or equal to right.
74+
75+
If strict is True, then only indices
76+
where `left` is greater than
77+
(but not equal to) `right` are returned.
78+
"""
79+
# quick break, avoiding the hassle
80+
if left.max() < right.min():
81+
return {
82+
"left_index": np.array([], dtype=np.intp),
83+
"right_index": np.array([], dtype=np.intp),
84+
}
85+
outcome = _null_checks_cond_join(series=left)
86+
if outcome is None:
87+
return {
88+
"left_index": np.array([], dtype=np.intp),
89+
"right_index": np.array([], dtype=np.intp),
90+
}
91+
left, _ = outcome
92+
outcome = _null_checks_cond_join(series=right)
93+
if outcome is None:
94+
return {
95+
"left_index": np.array([], dtype=np.intp),
96+
"right_index": np.array([], dtype=np.intp),
97+
}
98+
right, any_nulls = outcome
99+
right, right_is_sorted = _sort_if_not_monotonic(series=right)
100+
outcome = _ge_gt_indices(
101+
left=left.array,
102+
right=right.array,
103+
left_index=left.index._values,
104+
strict=strict,
105+
)
106+
if outcome is None:
107+
return {
108+
"left_index": np.array([], dtype=np.intp),
109+
"right_index": np.array([], dtype=np.intp),
110+
}
111+
left_index, search_indices = outcome
112+
right_index = right.index._values
113+
if right_is_sorted & (keep == "first"):
114+
indexer = np.zeros_like(search_indices)
115+
return {"left_index": left_index, "right_index": right_index[indexer]}
116+
if right_is_sorted & (keep == "last") & any_nulls:
117+
return {
118+
"left_index": left_index,
119+
"right_index": right_index[search_indices - 1],
120+
}
121+
if right_is_sorted & (keep == "last"):
122+
return {"left_index": left_index, "right_index": search_indices - 1}
123+
if keep == "first":
124+
right = [right_index[:ind] for ind in search_indices]
125+
right = [arr.min() for arr in right]
126+
return {"left_index": left_index, "right_index": right}
127+
if keep == "last":
128+
right = [right_index[:ind] for ind in search_indices]
129+
right = [arr.max() for arr in right]
130+
return {"left_index": left_index, "right_index": right}
131+
if return_matching_indices:
132+
return dict(
133+
left_index=left_index,
134+
right_index=right_index,
135+
starts=np.repeat(0, search_indices.size),
136+
ends=search_indices,
137+
)
138+
right = [right_index[:ind] for ind in search_indices]
139+
right = np.concatenate(right)
140+
left = left_index.repeat(search_indices)
141+
return {"left_index": left, "right_index": right}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# helper functions for conditional_join.py
2+
3+
from enum import Enum
4+
from typing import Sequence
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
10+
class _JoinOperator(Enum):
11+
"""
12+
List of operators used in conditional_join.
13+
"""
14+
15+
GREATER_THAN = ">"
16+
LESS_THAN = "<"
17+
GREATER_THAN_OR_EQUAL = ">="
18+
LESS_THAN_OR_EQUAL = "<="
19+
STRICTLY_EQUAL = "=="
20+
NOT_EQUAL = "!="
21+
22+
23+
less_than_join_types = {
24+
_JoinOperator.LESS_THAN.value,
25+
_JoinOperator.LESS_THAN_OR_EQUAL.value,
26+
}
27+
greater_than_join_types = {
28+
_JoinOperator.GREATER_THAN.value,
29+
_JoinOperator.GREATER_THAN_OR_EQUAL.value,
30+
}
31+
32+
33+
def _maybe_remove_nulls_from_dataframe(
34+
df: pd.DataFrame, columns: Sequence, return_bools: bool = False
35+
):
36+
"""
37+
Remove nulls if op is not !=;
38+
"""
39+
any_nulls = df.loc[:, [*columns]].isna().any(axis=1)
40+
if any_nulls.all():
41+
return None
42+
if return_bools:
43+
any_nulls = ~any_nulls
44+
return any_nulls
45+
if any_nulls.any():
46+
df = df.loc[~any_nulls]
47+
return df
48+
49+
50+
def _null_checks_cond_join(series: pd.Series) -> tuple | None:
51+
"""
52+
Checks for nulls in the pandas series before conducting binary search.
53+
"""
54+
any_nulls = series.isna()
55+
if any_nulls.all():
56+
return None
57+
if any_nulls.any():
58+
series = series[~any_nulls]
59+
return series, any_nulls.any()
60+
61+
62+
def _sort_if_not_monotonic(series: pd.Series) -> pd.Series | None:
63+
"""
64+
Sort the pandas `series` if it is not monotonic increasing
65+
"""
66+
67+
is_sorted = series.is_monotonic_increasing
68+
if not is_sorted:
69+
series = series.sort_values(kind="stable")
70+
return series, is_sorted
71+
72+
73+
def _keep_output(keep: str, left: np.ndarray, right: np.ndarray):
74+
"""return indices for left and right index based on the value of `keep`."""
75+
if keep == "all":
76+
return left, right
77+
grouped = pd.Series(right).groupby(left, sort=False)
78+
if keep == "first":
79+
grouped = grouped.min()
80+
return grouped.index, grouped._values
81+
grouped = grouped.max()
82+
return grouped.index, grouped._values
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# helper functions for </<=
2+
import numpy as np
3+
import pandas as pd
4+
5+
from janitor.functions._conditional_join._helpers import (
6+
_null_checks_cond_join,
7+
_sort_if_not_monotonic,
8+
)
9+
10+
11+
def _le_lt_indices(
12+
left: pd.array,
13+
left_index: np.ndarray,
14+
right: pd.array,
15+
strict: bool,
16+
) -> tuple | None:
17+
"""
18+
Use binary search to get indices where left
19+
is less than or equal to right.
20+
21+
If strict is True, then only indices
22+
where `left` is less than
23+
(but not equal to) `right` are returned.
24+
25+
Returns the left index and the binary search positions for left in right.
26+
"""
27+
search_indices = right.searchsorted(left, side="left")
28+
# if any of the positions in `search_indices`
29+
# is equal to the length of `right_keys`
30+
# that means the respective position in `left`
31+
# has no values from `right` that are less than
32+
# or equal, and should therefore be discarded
33+
len_right = right.size
34+
booleans = search_indices < len_right
35+
if not booleans.any():
36+
return None
37+
if not booleans.all():
38+
left = left[booleans]
39+
left_index = left_index[booleans]
40+
search_indices = search_indices[booleans]
41+
# the idea here is that if there are any equal values
42+
# shift to the right to the immediate next position
43+
# that is not equal
44+
if strict:
45+
booleans = left == right[search_indices]
46+
# replace positions where rows are equal
47+
# with positions from searchsorted('right')
48+
# positions from searchsorted('right') will never
49+
# be equal and will be the furthermost in terms of position
50+
# example : right -> [2, 2, 2, 3], and we need
51+
# positions where values are not equal for 2;
52+
# the furthermost will be 3, and searchsorted('right')
53+
# will return position 3.
54+
if booleans.any():
55+
replacements = right.searchsorted(left, side="right")
56+
# now we can safely replace values
57+
# with strictly less than positions
58+
search_indices = np.where(booleans, replacements, search_indices)
59+
# check again if any of the values
60+
# have become equal to length of right
61+
# and get rid of them
62+
booleans = search_indices < len_right
63+
if not booleans.any():
64+
return None
65+
if not booleans.all():
66+
left_index = left_index[booleans]
67+
search_indices = search_indices[booleans]
68+
return left_index, search_indices
69+
70+
71+
def _less_than_indices(
72+
left: pd.Series,
73+
right: pd.Series,
74+
strict: bool,
75+
keep: str,
76+
return_matching_indices: bool,
77+
) -> dict | None:
78+
"""
79+
Use binary search to get indices where left
80+
is less than or equal to right.
81+
82+
If strict is True, then only indices
83+
where `left` is less than
84+
(but not equal to) `right` are returned.
85+
"""
86+
# no point going through all the hassle
87+
if left.min() > right.max():
88+
return {
89+
"left_index": np.array([], dtype=np.intp),
90+
"right_index": np.array([], dtype=np.intp),
91+
}
92+
outcome = _null_checks_cond_join(series=left)
93+
if not outcome:
94+
return {
95+
"left_index": np.array([], dtype=np.intp),
96+
"right_index": np.array([], dtype=np.intp),
97+
}
98+
left, _ = outcome
99+
outcome = _null_checks_cond_join(series=right)
100+
if not outcome:
101+
return {
102+
"left_index": np.array([], dtype=np.intp),
103+
"right_index": np.array([], dtype=np.intp),
104+
}
105+
right, any_nulls = outcome
106+
right, right_is_sorted = _sort_if_not_monotonic(series=right)
107+
outcome = _le_lt_indices(
108+
left=left.array,
109+
right=right.array,
110+
left_index=left.index._values,
111+
strict=strict,
112+
)
113+
if not outcome:
114+
return {
115+
"left_index": np.array([], dtype=np.intp),
116+
"right_index": np.array([], dtype=np.intp),
117+
}
118+
left_index, search_indices = outcome
119+
len_right = right.size
120+
right_index = right.index._values
121+
if right_is_sorted & (keep == "last"):
122+
indexer = np.empty_like(search_indices)
123+
indexer[:] = len_right - 1
124+
return {"left_index": left_index, "right_index": right_index[indexer]}
125+
if right_is_sorted & (keep == "first") & any_nulls:
126+
return {
127+
"left_index": left_index,
128+
"right_index": right_index[search_indices],
129+
}
130+
if right_is_sorted & (keep == "first"):
131+
return {"left_index": left_index, "right_index": search_indices}
132+
if keep == "first":
133+
right = [right_index[ind:len_right] for ind in search_indices]
134+
right = [arr.min() for arr in right]
135+
return {"left_index": left_index, "right_index": right}
136+
if keep == "last":
137+
right = [right_index[ind:len_right] for ind in search_indices]
138+
right = [arr.max() for arr in right]
139+
return {"left_index": left_index, "right_index": right}
140+
if return_matching_indices:
141+
return dict(
142+
left_index=left_index,
143+
right_index=right_index,
144+
starts=search_indices,
145+
ends=np.repeat(len_right, search_indices.size),
146+
)
147+
right = [right_index[ind:len_right] for ind in search_indices]
148+
right = np.concatenate(right)
149+
left = left_index.repeat(len_right - search_indices)
150+
return {"left_index": left, "right_index": right}

0 commit comments

Comments
 (0)