Skip to content

Commit 9d8ba4d

Browse files
committed
Date converter takes Altair DateTime and single values
1 parent 454025e commit 9d8ba4d

File tree

4 files changed

+113
-50
lines changed

4 files changed

+113
-50
lines changed

mplaltair/_axis.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,14 @@ def _set_limits(channel, scale):
4545
pass # use default
4646

4747
elif channel['dtype'] == 'temporal':
48-
if 'domain' in scale and scale['type'] != 'time':
49-
"""Work is currently being done in Altair to modify what date types are allowed for domain specification.
50-
Right now, Altair can only date Altair DateTime objects for the domain.
51-
At this point, mpl-altair's date converter cannot convert Altair DateTime objects.
52-
"""
48+
if 'domain' in scale:
5349
try:
5450
domain = _convert_to_mpl_date(scale['domain'])
5551
except NotImplementedError:
5652
raise NotImplementedError
5753
lims[_axis_kwargs[channel['axis']].get('min')] = domain[0]
5854
lims[_axis_kwargs[channel['axis']].get('max')] = domain[1]
59-
elif 'type' in scale:
55+
elif 'type' in scale and scale['type'] != 'time':
6056
lims = _set_scale_type(channel, scale)
6157
else:
6258
pass # use default
@@ -72,8 +68,8 @@ def _set_limits(channel, scale):
7268

7369
def _set_scale_type(channel, scale):
7470
"""If the scale is non-linear, change the scale and return appropriate axis limits.
75-
Note: 'linear' and 'time' scale types are not included here because quantitative defaults to 'linear'
76-
and temporal defaults to 'time'.
71+
Note: The 'linear' and 'time' scale types are not included here because quantitative defaults to 'linear'
72+
and temporal defaults to 'time'. The 'utc' and 'sequential' are currently not supported.
7773
"""
7874
lims = {}
7975
if scale['type'] == 'log':

mplaltair/_data.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from ._exceptions import ValidationError
22
import matplotlib.dates as mdates
3+
import matplotlib.cbook as cbook
4+
from datetime import datetime
35
import numpy as np
46

57
def _locate_channel_dtype(chart, channel):
@@ -110,24 +112,74 @@ def _locate_channel_axis(chart, channel):
110112
return {}
111113

112114
def _convert_to_mpl_date(data):
113-
"""Converts datetime, datetime64, strings, and Altair DateTime objects to matplotlib dates"""
114-
115-
# TODO: parse both single values and sequences/iterables
116-
new_data = []
117-
for i in data:
118-
if isinstance(i, str): # string format for dates
119-
new_data.append(mdates.datestr2num(i))
120-
elif isinstance(i, np.datetime64): # sequence of datetimes, datetime64s
121-
new_data.append(mdates.date2num(i))
122-
elif isinstance(i, dict): # Altair DateTime
123-
"""Allowed formats (for domain):
124-
YYYY,
125-
YYYY-MM(-01), YYYY-MM-DD, YYYY(-01)-DD,
126-
^ plus hh, hh:mm, hh:mm:ss, hh(:00):ss, (0):mm:ss
127-
Could turn dict into iso datetime string and then use dateutil.parser.isoparse() or datestr2num()
128-
"""
129-
raise NotImplementedError
115+
"""Converts datetime, datetime64, strings, and Altair DateTime objects to Matplotlib dates.
116+
117+
Parameters
118+
----------
119+
data
120+
The data to be converted to a Matplotlib date.
121+
122+
Returns
123+
-------
124+
new_data : list
125+
A list containing the converted date(s).
126+
"""
127+
128+
if cbook.iterable(data) and not isinstance(data, str) and not isinstance(data, dict):
129+
if len(data) == 0:
130+
return []
131+
else:
132+
return [_convert_to_mpl_date(i) for i in data]
133+
else:
134+
if isinstance(data, str): # string format for dates
135+
data = mdates.datestr2num(data)
136+
elif isinstance(data, np.datetime64): # sequence of datetimes, datetime64s
137+
data = mdates.date2num(data)
138+
elif isinstance(data, dict): # Altair DateTime
139+
data = mdates.date2num(_altair_DateTime_to_datetime(data))
130140
else:
131141
raise TypeError
142+
return data
143+
132144

133-
return new_data
145+
def _altair_DateTime_to_datetime(dt):
146+
"""Convert dictionary representation of an Altair DateTime to datetime object.
147+
Parameters
148+
----------
149+
dt : dict
150+
The dictionary representation of the Altair DateTime object to be converted.
151+
152+
Returns
153+
-------
154+
A datetime object
155+
"""
156+
MONTHS = {'Jan': 1, 'January': 1, 'Feb': 2, 'February': 2, 'Mar': 3, 'March': 3, 'Apr': 4, 'April': 4,
157+
'May': 5, 'May': 5, 'Jun': 6, 'June': 6, 'Jul': 7, 'July': 7, 'Aug': 8, 'August': 8,
158+
'Sep': 9, 'Sept': 9, 'September': 9, 'Oct': 10, 'October': 10, 'Nov': 11, 'November': 11,
159+
'Dec': 12, 'December': 12}
160+
161+
alt_to_datetime_kw_mapping = {'date': 'day', 'hours': 'hour', 'milliseconds': 'microsecond', 'minutes': 'minute',
162+
'month': 'month', 'seconds': 'second', 'year': 'year'}
163+
164+
datetime_kwargs = {'year': 0, 'month': 1, 'day': 1, 'hour': 0, 'minute': 0, 'second': 0, 'microsecond': 0}
165+
166+
if 'day' in dt or 'quarter' in dt:
167+
raise NotImplementedError
168+
if 'year' not in dt:
169+
raise KeyError('A year must be provided.')
170+
if 'month' not in dt:
171+
dt['month'] = 1 # Default to January
172+
else:
173+
if isinstance(dt['month'], str): # convert from str to number form for months
174+
dt['month'] = MONTHS[dt['month']]
175+
if 'date' not in dt:
176+
dt['date'] = 1 # Default to the first of the month
177+
if 'milliseconds' in dt:
178+
dt['milliseconds'] = dt['milliseconds']*1000 # convert to microseconds
179+
if 'utc' in dt:
180+
raise NotImplementedError("mpl-altair currently doesn't support timezones.")
181+
182+
for k, v in dt.items():
183+
datetime_kwargs[alt_to_datetime_kw_mapping[k]] = v
184+
185+
return datetime(**datetime_kwargs)

mplaltair/tests/test_axis_temporal.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,33 +60,23 @@ def test_axis(x, y):
6060

6161
@pytest.mark.parametrize('domain', [
6262
(['2014-12-25', '2015-03-01']),
63-
pytest.param([alt.DateTime(year=2014, month="Dec", date=25), alt.DateTime(year=2015, month="March", date=1)], marks=pytest.mark.xfail)
63+
([alt.DateTime(year=2014, month="Dec", date=25), alt.DateTime(year=2015, month="March", date=1)])
6464
])
6565
def test_axis_temporal_domain(domain):
66-
chart = alt.Chart(df).mark_point().encode(alt.X('a'), alt.Y('days'))
66+
chart = alt.Chart(df).mark_point().encode(alt.X('a'), alt.Y('days', scale=alt.Scale(domain=domain)))
6767
mapping = convert(chart)
6868
fig, ax = plt.subplots()
6969
ax.scatter(**mapping)
70+
convert_axis(ax, chart)
7071

71-
for channel in chart.to_dict()['encoding']:
72-
if channel in ['x', 'y']:
73-
chart_info = {'ax': ax, 'axis': channel,
74-
'data': _locate_channel_data(chart, channel),
75-
'dtype': _locate_channel_dtype(chart, channel)}
76-
if chart_info['dtype'] == 'temporal':
77-
chart_info['data'] = _convert_to_mpl_date(chart_info['data'])
78-
79-
scale_info = _locate_channel_scale(chart, channel)
80-
if channel == 'y':
81-
scale_info['domain'] = domain
82-
axis_info = _locate_channel_axis(chart, channel)
83-
84-
_set_limits(chart_info, scale_info)
85-
_set_tick_locator(chart_info, axis_info)
86-
_set_tick_formatter(chart_info, axis_info)
8772
yvmin, yvmax = ax.yaxis.get_view_interval()
88-
assert yvmin == _convert_to_mpl_date(domain)[0]
89-
assert yvmax == _convert_to_mpl_date(domain)[1]
73+
try:
74+
expected_domain = [domain[0].to_dict(), domain[1].to_dict()]
75+
except:
76+
expected_domain = domain
77+
assert yvmin == _convert_to_mpl_date(expected_domain)[0]
78+
assert yvmax == _convert_to_mpl_date(expected_domain)[1]
79+
plt.show()
9080

9181
@pytest.mark.parametrize('x,tickCount', [
9282
('years', 1), ('years', 3), ('years', 5), ('years', 10),
@@ -131,8 +121,7 @@ def test_axis_temporal_values():
131121
assert list(ax.yaxis.get_major_locator().tick_values(1, 1)) == list(_convert_to_mpl_date(['1/12/2015', '3/1/2015', '4/18/2015', '5/3/2015']))
132122

133123

134-
@pytest.mark.xfail(raises=NotImplementedError)
135-
@pytest.mark.parametrize('type', ['time', 'utc'])
124+
@pytest.mark.parametrize('type', ['time', pytest.param('utc', marks=pytest.mark.xfail(raises=NotImplementedError))])
136125
def test_axis_scale_NotImplemented_temporal(type):
137126
chart = alt.Chart(df).mark_point().encode(
138127
alt.X('years:T', scale=alt.Scale(type=type)),

mplaltair/tests/test_data.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,34 @@ def test_data_dtype_fail():
100100
'e': pd.to_datetime(['1/4/2016 10:00', '5/1/2016 10:10', '3/3/2016'])
101101
})
102102

103-
def test_str():
103+
def test_convert_to_mpl_str():
104104
assert list(_data._convert_to_mpl_date(df_nonstandard['c'].values)) == list(mdates.datestr2num(df_nonstandard['c']))
105105

106-
def test_datetime64():
106+
def test_convert_to_mpl_datetime64():
107107
assert list(_data._convert_to_mpl_date(df_nonstandard['e'].values)) == list(mdates.date2num(df_nonstandard['e']))
108+
109+
def test_convert_to_mpl_altair_datetime():
110+
dates = [alt.DateTime(year=2015, date=7).to_dict(), alt.DateTime(year=2015, month="March", date=20).to_dict()]
111+
assert list(_data._convert_to_mpl_date(dates)) == list(mdates.datestr2num(['2015-01-07', '2015-03-20']))
112+
113+
114+
@pytest.mark.parametrize('date,expected', [
115+
(df_nonstandard['c'].values[0], mdates.datestr2num(df_nonstandard['c'].values[0])),
116+
(df_nonstandard['e'].values[0], mdates.date2num(df_nonstandard['e'].values[0])),
117+
(alt.DateTime(year=2015, month="March", date=7).to_dict(), mdates.datestr2num('2015-03-07'))
118+
])
119+
def test_convert_to_mpl_single_vals(date, expected):
120+
assert _data._convert_to_mpl_date(date) == expected
121+
122+
@pytest.mark.parametrize('date,expected', [
123+
(alt.DateTime(year=2015, month="March", date=7).to_dict(), '2015-03-07'),
124+
(alt.DateTime(year=2015, date=7).to_dict(), '2015-01-07'),
125+
(alt.DateTime(year=2015, month=3).to_dict(), '2015-03-01'),
126+
(alt.DateTime(year=2015, date=7, milliseconds=1).to_dict(), '2015-01-07 00:00:00.001'),
127+
pytest.param(alt.DateTime(day="Mon").to_dict(), '2015-01-07', marks=pytest.mark.xfail(raises=NotImplementedError)),
128+
pytest.param(alt.DateTime(year=2015, date=20, quarter=1).to_dict(), '2015-01-20', marks=pytest.mark.xfail(raises=NotImplementedError)),
129+
pytest.param(alt.DateTime(year=2015, date=20, utc=True).to_dict(), '2015-01-07', marks=pytest.mark.xfail(raises=NotImplementedError)),
130+
pytest.param(alt.DateTime(date=20).to_dict(), '2015-01-07', marks=pytest.mark.xfail(raises=KeyError)),
131+
])
132+
def test_altair_datetime(date, expected):
133+
assert mdates.date2num(_data._altair_DateTime_to_datetime(date)) == mdates.datestr2num(expected)

0 commit comments

Comments
 (0)