Skip to content

Commit 368c7c6

Browse files
committed
MNT: Improve Grouper
- add Grouper.get_siblings(..., include_self=...) parameter - `Axes._shared_axes` is a class variable. Therefore, we can make `get_shared_x/y_axes()` class methods. This is in preparations of matplotlib#30159 (comment).
1 parent a00d606 commit 368c7c6

File tree

7 files changed

+30
-16
lines changed

7 files changed

+30
-16
lines changed

lib/matplotlib/axes/_base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4756,13 +4756,15 @@ def twiny(self, axes_class=None, **kwargs):
47564756
ax2.yaxis.units = self.yaxis.units
47574757
return ax2
47584758

4759-
def get_shared_x_axes(self):
4759+
@classmethod
4760+
def get_shared_x_axes(cls):
47604761
"""Return an immutable view on the shared x-axes Grouper."""
4761-
return cbook.GrouperView(self._shared_axes["x"])
4762+
return cbook.GrouperView(cls._shared_axes["x"])
47624763

4763-
def get_shared_y_axes(self):
4764+
@classmethod
4765+
def get_shared_y_axes(cls):
47644766
"""Return an immutable view on the shared y-axes Grouper."""
4765-
return cbook.GrouperView(self._shared_axes["y"])
4767+
return cbook.GrouperView(cls._shared_axes["y"])
47664768

47674769
def label_outer(self, remove_inner_ticks=False):
47684770
"""

lib/matplotlib/axes/_base.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,10 @@ class _AxesBase(martist.Artist):
388388
) -> Bbox | None: ...
389389
def twinx(self, axes_class: Axes | None = ..., **kwargs) -> Axes: ...
390390
def twiny(self, axes_class: Axes | None = ..., **kwargs) -> Axes: ...
391-
def get_shared_x_axes(self) -> cbook.GrouperView: ...
392-
def get_shared_y_axes(self) -> cbook.GrouperView: ...
391+
@classmethod
392+
def get_shared_x_axes(cls) -> cbook.GrouperView: ...
393+
@classmethod
394+
def get_shared_y_axes(cls) -> cbook.GrouperView: ...
393395
def label_outer(self, remove_inner_ticks: bool = ...) -> None: ...
394396

395397
# The methods underneath this line are added via the `_axis_method_wrapper` class

lib/matplotlib/axis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def isDefault_minfmt(self, value):
703703
self.minor._formatter_is_default = value
704704

705705
def _get_shared_axes(self):
706-
"""Return Grouper of shared Axes for current axis."""
706+
"""Return a list of shared Axes for current axis."""
707707
return self.axes._shared_axes[
708708
self._get_axis_name()].get_siblings(self.axes)
709709

lib/matplotlib/cbook.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -886,10 +886,17 @@ def __iter__(self):
886886
for group in unique_groups.values():
887887
yield sorted(group, key=self._ordering.__getitem__)
888888

889-
def get_siblings(self, a):
890-
"""Return all of the items joined with *a*, including itself."""
889+
def get_siblings(self, a, *, include_self=True):
890+
"""
891+
Return all the items joined with *a*.
892+
893+
*a* is included in the list if *include_self* is True.
894+
"""
891895
siblings = self._mapping.get(a, [a])
892-
return sorted(siblings, key=self._ordering.get)
896+
result = sorted(siblings, key=self._ordering.get)
897+
if not include_self:
898+
result.remove(a)
899+
return result
893900

894901

895902
class GrouperView:
@@ -905,11 +912,13 @@ def joined(self, a, b):
905912
"""
906913
return self._grouper.joined(a, b)
907914

908-
def get_siblings(self, a):
915+
def get_siblings(self, a, *, include_self=True):
909916
"""
910-
Return all of the items joined with *a*, including itself.
917+
Return all the items joined with *a*.
918+
919+
*a* is included in the list if *include_self* is True.
911920
"""
912-
return self._grouper.get_siblings(a)
921+
return self._grouper.get_siblings(a, include_self=include_self)
913922

914923

915924
def simple_linear_interpolation(a, steps):

lib/matplotlib/cbook.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,14 @@ class Grouper(Generic[_T]):
107107
def joined(self, a: _T, b: _T) -> bool: ...
108108
def remove(self, a: _T) -> None: ...
109109
def __iter__(self) -> Iterator[list[_T]]: ...
110-
def get_siblings(self, a: _T) -> list[_T]: ...
110+
def get_siblings(self, a: _T, *, include_self: bool = True) -> list[_T]: ...
111111

112112
class GrouperView(Generic[_T]):
113113
def __init__(self, grouper: Grouper[_T]) -> None: ...
114114
def __contains__(self, item: _T) -> bool: ...
115115
def __iter__(self) -> Iterator[list[_T]]: ...
116116
def joined(self, a: _T, b: _T) -> bool: ...
117-
def get_siblings(self, a: _T) -> list[_T]: ...
117+
def get_siblings(self, a: _T, *, include_self: bool = True) -> list[_T]: ...
118118

119119
def simple_linear_interpolation(a: ArrayLike, steps: int) -> np.ndarray: ...
120120
def delete_masked_points(*args): ...

lib/matplotlib/figure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ def _remove_axes(self, ax, owners):
947947

948948
for name in ax._axis_names: # Break link between any shared Axes
949949
grouper = ax._shared_axes[name]
950-
siblings = [other for other in grouper.get_siblings(ax) if other is not ax]
950+
siblings = grouper.get_siblings(ax, include_self=False)
951951
if not siblings: # Axes was not shared along this axis; we're done.
952952
continue
953953
grouper.remove(ax)

lib/matplotlib/tests/test_cbook.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ class Dummy:
622622
g.join(*objs)
623623
assert set(list(g)[0]) == set(objs)
624624
assert set(g.get_siblings(a)) == set(objs)
625+
assert a not in g.get_siblings(a, include_self=False)
625626

626627
for other in objs[1:]:
627628
assert g.joined(a, other)

0 commit comments

Comments
 (0)