Skip to content

Commit 3bc1d49

Browse files
committed
Add string formatter for temporal and quantitative
1 parent b453e3a commit 3bc1d49

File tree

5 files changed

+86
-5
lines changed

5 files changed

+86
-5
lines changed

mplaltair/_axis.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _set_scale_type(channel, scale):
107107
elif scale['type'] == 'utc':
108108
raise NotImplementedError
109109
elif scale['type'] == 'sequential':
110-
raise NotImplementedError
110+
raise NotImplementedError("sequential scales used primarily for continuous colors")
111111
else:
112112
raise NotImplementedError
113113
return lims
@@ -133,19 +133,49 @@ def _set_tick_locator(channel, axis):
133133

134134

135135
def _set_tick_formatter(channel, axis):
136+
current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis}
137+
format_str = ''
138+
139+
if 'format' in axis:
140+
format_str = axis['format']
141+
136142
if channel['dtype'] == 'temporal':
137-
formatter = mdates.DateFormatter('%b %d, %Y')
143+
if not format_str:
144+
format_str = '%b %d, %Y'
145+
146+
current_axis[channel['axis']].set_major_formatter(mdates.DateFormatter(format_str))
138147

139-
current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis}
140-
current_axis[channel['axis']].set_major_formatter(formatter)
148+
try:
149+
current_axis[channel['axis']].get_major_formatter().__call__(1)
150+
except ValueError:
151+
raise ValueError("Matplotlib only supports `strftime` formatting for dates."
152+
"Currently, %L, %Q, and %s are allowed in Altair, but not allowed in Matplotlib."
153+
"Please use a :func:`strftime` compliant format string.")
141154

155+
156+
# TODO: move rotation to another function?
142157
if channel['axis'] == 'x':
143158
for label in channel['ax'].get_xticklabels():
144159
# Rotate the labels on the x-axis so they don't run into each other.
145160
label.set_rotation(30)
146161
label.set_ha('right')
162+
163+
elif channel['dtype'] == 'quantitative':
164+
if format_str:
165+
current_axis[channel['axis']].set_major_formatter(ticker.StrMethodFormatter('{x:' + format_str + '}'))
166+
167+
# Verify that the format string is valid for Matplotlib and exit nicely if not.
168+
try:
169+
current_axis[channel['axis']].get_major_formatter().__call__(1)
170+
except ValueError:
171+
raise ValueError("Matplotlib only supports format strings as used by `str.format()`."
172+
"Some format strings that work in Altair may not work in Matplotlib."
173+
"Please use a different format string.")
174+
else:
175+
# Use the default formatter for quantitative (it has similar, if not the same settings as Altair)
176+
pass
147177
else:
148-
pass # Use the auto formatter for quantitative (it has similar, if not the same settings as Altair)
178+
pass
149179

150180

151181
def convert_axis(ax, chart):
11.5 KB
Loading
27.4 KB
Loading

mplaltair/tests/test_axis.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,29 @@ def test_axis_scale_NotImplemented_quantitative(type):
180180
fig, ax = plt.subplots()
181181
ax.scatter(**mapping)
182182
convert_axis(ax, chart)
183+
184+
185+
@pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis')
186+
def test_axis_formatter_quantitative():
187+
chart = alt.Chart(df_quant).mark_point().encode(
188+
alt.X('c', axis=alt.Axis(format='-.2g')),
189+
alt.Y('b', axis=alt.Axis(format='+.3g'))
190+
)
191+
mapping = convert(chart)
192+
fig, ax = plt.subplots()
193+
ax.scatter(**mapping)
194+
convert_axis(ax, chart)
195+
fig.tight_layout()
196+
return fig
197+
198+
199+
@pytest.mark.xfail(raises=ValueError)
200+
def test_axis_formatter_quantitative_fail():
201+
chart = alt.Chart(df_quant).mark_point().encode(
202+
alt.X('c', axis=alt.Axis(format='-$.2g')),
203+
alt.Y('b', axis=alt.Axis(format='+.3r'))
204+
)
205+
mapping = convert(chart)
206+
fig, ax = plt.subplots()
207+
ax.scatter(**mapping)
208+
convert_axis(ax, chart)

mplaltair/tests/test_axis_temporal.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,28 @@ def test_axis_temporal_timezone(x):
129129
convert_axis(ax, chart)
130130
fig.tight_layout()
131131
return fig
132+
133+
@pytest.mark.mpl_image_compare(baseline_dir='baseline_images/test_axis_temporal')
134+
def test_axis_temporal_formatter():
135+
chart = alt.Chart(df).mark_point().encode(
136+
alt.X('months:T', axis=alt.Axis(format='%b %Y')),
137+
alt.Y('hrs:T', axis=alt.Axis(format='%H:%M:%S'))
138+
)
139+
mapping = convert(chart)
140+
fig, ax = plt.subplots()
141+
ax.scatter(**mapping)
142+
convert_axis(ax, chart)
143+
fig.tight_layout()
144+
return fig
145+
146+
@pytest.mark.xfail(raises=ValueError)
147+
def test_axis_formatter_temporal_fail():
148+
chart = alt.Chart(df).mark_point().encode(
149+
alt.X('months:T', axis=alt.Axis(format='%L')),
150+
alt.Y('months:T', axis=alt.Axis(format='%s'))
151+
)
152+
mapping = convert(chart)
153+
fig, ax = plt.subplots()
154+
ax.scatter(**mapping)
155+
convert_axis(ax, chart)
156+
plt.show()

0 commit comments

Comments
 (0)