diff --git a/doc/source/users/figures/geos_partial.py b/doc/source/users/figures/geos_partial.py index c3c6ad9cf..3d4892e83 100644 --- a/doc/source/users/figures/geos_partial.py +++ b/doc/source/users/figures/geos_partial.py @@ -2,8 +2,9 @@ import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt -mpl_version = tuple(map(int, mpl.__version__.split("."))) -axkwds = {"axisbg" if mpl_version < (2,) else "facecolor": "k"} +from packaging.version import Version +mpl_version = Version(mpl.__version__) +axkwds = {"axisbg" if mpl_version < Version("2") else "facecolor": "k"} fig = plt.figure() # global geostationary map centered on lon_0 diff --git a/doc/source/users/figures/nsper_partial.py b/doc/source/users/figures/nsper_partial.py index 0e9762169..dc26de972 100644 --- a/doc/source/users/figures/nsper_partial.py +++ b/doc/source/users/figures/nsper_partial.py @@ -2,8 +2,9 @@ import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt -mpl_version = tuple(map(int, mpl.__version__.split("."))) -axkwds = {"axisbg" if mpl_version < (2,) else "facecolor": "k"} +from packaging.version import Version +mpl_version = Version(mpl.__version__) +axkwds = {"axisbg" if mpl_version < Version("2") else "facecolor": "k"} fig = plt.figure() # global ortho map centered on lon_0,lat_0 diff --git a/doc/source/users/figures/ortho_partial.py b/doc/source/users/figures/ortho_partial.py index c41656311..556bf3cd3 100644 --- a/doc/source/users/figures/ortho_partial.py +++ b/doc/source/users/figures/ortho_partial.py @@ -2,8 +2,9 @@ import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt -mpl_version = tuple(map(int, mpl.__version__.split("."))) -axkwds = {"axisbg" if mpl_version < (2,) else "facecolor": "k"} +from packaging.version import Version +mpl_version = Version(mpl.__version__) +axkwds = {"axisbg" if mpl_version < Version("2") else "facecolor": "k"} fig = plt.figure() # global ortho map centered on lon_0,lat_0 diff --git a/src/mpl_toolkits/basemap/__init__.py b/src/mpl_toolkits/basemap/__init__.py index 46458b95b..a9c78e987 100644 --- a/src/mpl_toolkits/basemap/__init__.py +++ b/src/mpl_toolkits/basemap/__init__.py @@ -43,6 +43,8 @@ from matplotlib.transforms import Bbox from mpl_toolkits.axes_grid1 import make_axes_locatable +from packaging.version import Version + import pyproj import _geoslib from . proj import Proj @@ -1663,8 +1665,8 @@ def drawmapboundary(self,color='k',linewidth=1.0,fill_color=None,\ # if no fill_color given, use axes background color. # if fill_color is string 'none', really don't fill. if fill_color is None: - mpl_version = tuple(map(int, mpl.__version__.split(".")[:2])) - if mpl_version >= (2, 0): + mpl_version = Version(mpl.__version__) + if mpl_version >= Version("2.0"): fill_color = ax.get_facecolor() else: fill_color = ax.get_axis_bgcolor() @@ -1762,8 +1764,8 @@ def fillcontinents(self,color='0.8',lake_color=None,ax=None,zorder=None,alpha=No # get current axes instance (if none specified). ax = ax or self._check_ax() # get axis background color. - mpl_version = tuple(map(int, mpl.__version__.split(".")[:2])) - if mpl_version >= (2, 0): + mpl_version = Version(mpl.__version__) + if mpl_version >= Version("2.0"): axisbgc = ax.get_facecolor() else: axisbgc = ax.get_axis_bgcolor() diff --git a/test/mpl_toolkits/basemap/test_Basemap.py b/test/mpl_toolkits/basemap/test_Basemap.py index f18b2c29e..9e29f2897 100644 --- a/test/mpl_toolkits/basemap/test_Basemap.py +++ b/test/mpl_toolkits/basemap/test_Basemap.py @@ -16,13 +16,15 @@ from mpl_toolkits.basemap import Basemap from mpl_toolkits.basemap import shiftgrid +from packaging.version import Version + try: import PIL except ImportError: PIL = None -mpl_version = tuple(map(int, mpl.__version__.split(".")[:2])) +mpl_version = Version(mpl.__version__) class TestMplToolkitsBasemapBasemap(unittest.TestCase): @@ -199,7 +201,7 @@ def _test_basemap_data_warpimage(self, method, axs=None, axslen0=10): img = getattr(bmap, method)(ax=axs, scale=0.1) self.assertIsInstance(img, AxesImage) - flag = int(mpl_version < (3, 5)) + flag = int(mpl_version < Version("3.5")) axs_children = axs_obj.get_children() self.assertEqual(len(axs_children), axslen0 + 3) self.assertIsInstance(axs_children[1 - flag], Polygon)