Skip to content

Commit da0533a

Browse files
authored
Merge pull request #10 from kdorr/encodings
Merge convert-numeric into encodings
2 parents 8acd496 + fbe3b61 commit da0533a

File tree

7 files changed

+513
-17
lines changed

7 files changed

+513
-17
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
branch = true
33
source =
44
mplaltair
5+
omit = *tests*
56

67
[report]
78
exclude_lines =

mplaltair/__init__.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,23 @@
11
import matplotlib
22
import altair
33

4+
from ._convert import _convert
45

5-
# TODO rename this?
6-
def convert(encoding, *, figure=None):
6+
def convert(chart):
77
"""Convert an altair encoding to a Matplotlib figure
88
99
1010
Parameters
1111
----------
12-
encoding
13-
The Altair encoding of the plot.
14-
15-
figure : matplotib.figure.Figure, optional
16-
# TODO: generalize this to 'thing that supports gridspec slicing?
12+
chart
13+
The Altair chart object generated by Altair
1714
1815
Returns
1916
-------
20-
figure : matplotlib.figure.Figure
21-
The Figure with all artists in it (ready to be saved or shown)
22-
2317
mapping : dict
2418
Mapping from parts of the encoding to the Matplotlib artists. This is
2519
for later customization.
2620
2721
2822
"""
29-
if figure is None:
30-
from matplotlib import pyplot as plt
31-
figure = plt.figure()
32-
33-
mapping = {}
34-
35-
return figure, mapping
23+
return _convert(chart)

mplaltair/_convert.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,129 @@
1+
import matplotlib.dates as mdates
2+
from ._data import _locate_channel_data, _locate_channel_dtype
13

4+
def _allowed_ranged_marks(enc_channel, mark):
5+
"""TODO: DOCS
6+
"""
7+
return mark in ['area', 'bar', 'rect', 'rule'] if enc_channel in ['x2', 'y2'] else True
8+
9+
def _process_x(dtype, data):
10+
"""Returns the MPL encoding equivalent for Altair x channel
11+
"""
12+
return ('x', data)
13+
14+
15+
def _process_y(dtype, data):
16+
"""Returns the MPL encoding equivalent for Altair y channel
17+
"""
18+
return ('y', data)
19+
20+
21+
def _process_x2(dtype, data):
22+
"""Returns the MPL encoding equivalent for Altair x2 channel
23+
"""
24+
raise NotImplementedError
25+
26+
27+
def _process_y2(dtype, data):
28+
"""Returns the MPL encoding equivalent for Altair y2 channel
29+
"""
30+
raise NotImplementedError
31+
32+
33+
def _process_color(dtype, data):
34+
"""Returns the MPL encoding equivalent for Altair color channel
35+
"""
36+
if dtype == 'quantitative':
37+
return ('c', data)
38+
elif dtype == 'nominal':
39+
raise NotImplementedError
40+
elif dtype == 'ordinal':
41+
return ('c', data)
42+
else: # temporal
43+
return ('c', data)
44+
45+
46+
def _process_fill(dtype, data):
47+
"""Returns the MPL encoding equivalent for Altair fill channel
48+
"""
49+
return _process_color(dtype, data)
50+
51+
52+
def _process_shape(dtype, data):
53+
"""Returns the MPL encoding equivalent for Altair shape channel
54+
"""
55+
raise NotImplementedError
56+
57+
58+
def _process_opacity(dtype, data):
59+
"""Returns the MPL encoding equivalent for Altair opacity channel
60+
"""
61+
raise NotImplementedError
62+
63+
64+
def _process_size(dtype, data):
65+
"""Returns the MPL encoding equivalent for Altair size channel
66+
"""
67+
if dtype == 'quantitative':
68+
return ('s', data)
69+
elif dtype == 'nominal':
70+
raise NotImplementedError
71+
elif dtype == 'ordinal':
72+
return ('s', data)
73+
elif dtype == 'temporal':
74+
raise NotImplementedError
75+
76+
77+
def _process_stroke(dtype, data):
78+
"""Returns the MPL encoding equivalent for Altair stroke channel
79+
"""
80+
raise NotImplementedError
81+
82+
_mappings = {
83+
'x': _process_x,
84+
'y': _process_y,
85+
'x2': _process_x2,
86+
'y2': _process_y2,
87+
'color': _process_color,
88+
'fill': _process_fill,
89+
'shape': _process_shape,
90+
'opacity': _process_opacity,
91+
'size': _process_size,
92+
'stroke': _process_stroke,
93+
}
94+
95+
def _convert(chart):
96+
"""Convert an altair encoding to a Matplotlib figure
97+
98+
99+
Parameters
100+
----------
101+
chart
102+
The Altair chart.
103+
104+
Returns
105+
-------
106+
mapping : dict
107+
Mapping from parts of the encoding to the Matplotlib artists. This is
108+
for later customization.
109+
"""
110+
mapping = {}
111+
112+
if not chart.to_dict().get('encoding'):
113+
raise ValueError("Encoding not provided with the chart specification")
114+
115+
for enc_channel, enc_spec in chart.to_dict()['encoding'].items():
116+
if not _allowed_ranged_marks(enc_channel, chart.to_dict()['mark']):
117+
raise ValueError("Ranged encoding channels like x2, y2 not allowed for Mark: {}".format(chart['mark']))
118+
119+
for channel in chart.to_dict()['encoding']:
120+
data = _locate_channel_data(chart, channel)
121+
dtype = _locate_channel_dtype(chart, channel)
122+
if dtype == 'temporal':
123+
try:
124+
data = mdates.date2num(data) # Convert dates to Matplotlib dates
125+
except AttributeError:
126+
raise
127+
mapping[_mappings[channel](dtype, data)[0]] = _mappings[channel](dtype, data)[1]
128+
129+
return mapping

mplaltair/_data.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from ._exceptions import ValidationError
2+
3+
def _locate_channel_dtype(chart, channel):
4+
"""Locates dtype used for each channel
5+
Parameters
6+
----------
7+
chart
8+
The Altair chart
9+
channel
10+
The Altair channel being examined
11+
12+
Returns
13+
-------
14+
A string representing the data type from the Altair chart ('quantitative', 'ordinal', 'numeric', 'temporal')
15+
"""
16+
17+
channel_val = chart.to_dict()['encoding'][channel]
18+
if channel_val.get('type'):
19+
return channel_val.get('type')
20+
else:
21+
# TODO: find some way to deal with 'value' so that, opacity, for instance, can be plotted with a value defined
22+
if channel_val.get('value'):
23+
raise NotImplementedError
24+
raise NotImplementedError
25+
26+
27+
def _locate_channel_data(chart, channel):
28+
"""Locates data used for each channel
29+
30+
Parameters
31+
----------
32+
chart
33+
The Altair chart
34+
channel
35+
The Altair channel being examined
36+
37+
Returns
38+
-------
39+
A numpy ndarray containing the data used for the channel
40+
41+
Raises
42+
------
43+
ValidationError
44+
Raised when the specification does not contain any data attribute
45+
46+
"""
47+
48+
channel_val = chart.to_dict()['encoding'][channel]
49+
if channel_val.get('value'):
50+
return channel_val.get('value')
51+
elif channel_val.get('aggregate'):
52+
return _aggregate_channel()
53+
elif channel_val.get('timeUnit'):
54+
return _handle_timeUnit()
55+
else: # field is required if the above are not present.
56+
return chart.data[channel_val.get('field')].values
57+
58+
59+
def _aggregate_channel():
60+
raise NotImplementedError
61+
62+
63+
def _handle_timeUnit():
64+
raise NotImplementedError

mplaltair/_exceptions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class ValidationError(Exception):
2+
pass

0 commit comments

Comments
 (0)