Skip to content

Commit 7539c70

Browse files
committed
Add tests for _data._locate_channel_dtype
1 parent afa3868 commit 7539c70

File tree

1 file changed

+42
-34
lines changed

1 file changed

+42
-34
lines changed

mplaltair/tests/test_data.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,82 @@
11
import altair as alt
22
import pandas as pd
3-
import mplaltair._convert as convert
3+
import mplaltair._data as _data
44
import pytest
55

6-
df_quantitative = pd.DataFrame({
7-
"a": [1, 2, 3, 4, 5], "b": [1.1, 2.2, 3.3, 4.4, 5.5], "c": [1, 2.2, 3, 4.4, 5]
8-
})
9-
10-
df_temporal = pd.DataFrame({
6+
df = pd.DataFrame({
7+
"a": [1, 2, 3, 4, 5], "b": [1.1, 2.2, 3.3, 4.4, 5.5], "c": [1, 2.2, 3, 4.4, 5],
8+
"nom": ['a', 'b', 'c', 'd', 'e'], "ord": [1, 2, 3, 4, 5],
119
"years": pd.date_range('01/01/2015', periods=5, freq='Y'), "months": pd.date_range('1/1/2015', periods=5, freq='M'),
1210
"days": pd.date_range('1/1/2015', periods=5, freq='D'), "hrs": pd.date_range('1/1/2015', periods=5, freq='H'),
1311
"combination": pd.to_datetime(['1/1/2015', '1/1/2015 10:00:00', '1/2/2015 00:00', '1/4/2016 10:00', '5/1/2016']),
1412
"quantitative": [1.1, 2.1, 3.1, 4.1, 5.1]
1513
})
1614

1715

18-
@pytest.mark.parametrize("column", ['a', 'b', 'c'])
19-
def test_data_field_quantitative(column):
20-
chart = alt.Chart(df_quantitative).mark_point().encode(alt.X(field=column, type='quantitative'))
16+
# _locate_channel_data() tests
17+
18+
@pytest.mark.parametrize("column, dtype", [
19+
('a', 'quantitative'), ('b', 'quantitative'), ('c', 'quantitative'), ('combination', 'temporal')
20+
])
21+
def test_data_field_quantitative(column, dtype):
22+
chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type=dtype))
2123
for channel in chart.to_dict()['encoding']:
22-
data = convert._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
23-
assert list(data) == list(df_quantitative[column].values)
24+
data = _data._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
25+
assert list(data) == list(df[column].values)
2426

2527

26-
@pytest.mark.parametrize("column", ['a', 'b', 'c'])
28+
@pytest.mark.parametrize("column", ['a', 'b', 'c', 'combination'])
2729
def test_data_shorthand_quantitative(column):
28-
chart = alt.Chart(df_quantitative).mark_point().encode(alt.X(column))
30+
chart = alt.Chart(df).mark_point().encode(alt.X(column))
2931
for channel in chart.to_dict()['encoding']:
30-
data = convert._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
31-
assert list(data) == list(df_quantitative[column].values)
32+
data = _data._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
33+
assert list(data) == list(df[column].values)
3234

3335

3436
def test_data_value_quantitative():
35-
chart = alt.Chart(df_quantitative).mark_point().encode(opacity=alt.value(0.5))
37+
chart = alt.Chart(df).mark_point().encode(opacity=alt.value(0.5))
3638
for channel in chart.to_dict()['encoding']:
37-
data = convert._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
39+
data = _data._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
3840
assert data == 0.5
3941

4042

4143
@pytest.mark.parametrize("column", ['a', 'b', 'c'])
4244
@pytest.mark.xfail(raises=NotImplementedError)
4345
def test_data_aggregate_quantitative(column):
44-
chart = alt.Chart(df_quantitative).mark_point().encode(alt.X(field=column, type='quantitative', aggregate='average'))
46+
chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type='quantitative', aggregate='average'))
4547
for channel in chart.to_dict()['encoding']:
46-
data = convert._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
48+
data = _data._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
4749

4850

49-
def test_data_field_temporal():
50-
chart = alt.Chart(df_temporal).mark_point().encode(alt.X(field='combination', type='temporal'))
51+
@pytest.mark.xfail(raises=NotImplementedError)
52+
def test_data_timeUnit_shorthand_temporal():
53+
chart = alt.Chart(df).mark_point().encode(alt.X('month(combination):T'))
5154
for channel in chart.to_dict()['encoding']:
52-
data = convert._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
53-
assert list(data) == list(df_temporal['combination'].values)
55+
data = _data._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
5456

5557

56-
def test_data_shorthand_temporal():
57-
chart = alt.Chart(df_temporal).mark_point().encode(alt.X('combination'))
58+
@pytest.mark.xfail(raises=NotImplementedError)
59+
def test_data_timeUnit_field_temporal():
60+
chart = alt.Chart(df).mark_point().encode(alt.X(field='combination', type='temporal', timeUnit='month'))
5861
for channel in chart.to_dict()['encoding']:
59-
data = convert._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
60-
assert list(data) == list(df_temporal['combination'].values)
62+
data = _data._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
6163

6264

63-
@pytest.mark.xfail(raises=NotImplementedError)
64-
def test_data_timeUnit_shorthand_temporal():
65-
chart = alt.Chart(df_temporal).mark_point().encode(alt.X('month(combination):T'))
65+
# _locate_channel_dtype() tests
66+
67+
@pytest.mark.parametrize('column, expected', [
68+
('a:Q', 'quantitative'), ('nom:N', 'nominal'), ('ord:O', 'ordinal'), ('combination:T', 'temporal')
69+
])
70+
def test_data_dtype(column, expected):
71+
chart = alt.Chart(df).mark_point().encode(alt.X(column))
6672
for channel in chart.to_dict()['encoding']:
67-
data = convert._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
73+
dtype = _data._locate_channel_dtype(chart.to_dict()['encoding'][channel], chart.data)
74+
assert dtype == expected
6875

6976

7077
@pytest.mark.xfail(raises=NotImplementedError)
71-
def test_data_timeUnit_field_temporal():
72-
chart = alt.Chart(df_temporal).mark_point().encode(alt.X(field='combination', type='temporal', timeUnit='month'))
78+
def test_data_dtype_fail():
79+
chart = alt.Chart(df).mark_point().encode(opacity=alt.value(.5))
7380
for channel in chart.to_dict()['encoding']:
74-
data = convert._locate_channel_data(chart.to_dict()['encoding'][channel], chart.data)
81+
dtype = _data._locate_channel_dtype(chart.to_dict()['encoding'][channel], chart.data)
82+
assert dtype == 'quantitative'

0 commit comments

Comments
 (0)