11import altair as alt
22import pandas as pd
3- import mplaltair ._convert as convert
3+ import mplaltair ._data as _data
44import pytest
55
6- df_quantitative = pd .DataFrame ({
7- "a" : [1 , 2 , 3 , 4 , 5 ], "b" : [1.1 , 2.2 , 3.3 , 4.4 , 5.5 ], "c" : [1 , 2.2 , 3 , 4.4 , 5 ]
8- })
9-
10- df_temporal = pd .DataFrame ({
6+ df = pd .DataFrame ({
7+ "a" : [1 , 2 , 3 , 4 , 5 ], "b" : [1.1 , 2.2 , 3.3 , 4.4 , 5.5 ], "c" : [1 , 2.2 , 3 , 4.4 , 5 ],
8+ "nom" : ['a' , 'b' , 'c' , 'd' , 'e' ], "ord" : [1 , 2 , 3 , 4 , 5 ],
119 "years" : pd .date_range ('01/01/2015' , periods = 5 , freq = 'Y' ), "months" : pd .date_range ('1/1/2015' , periods = 5 , freq = 'M' ),
1210 "days" : pd .date_range ('1/1/2015' , periods = 5 , freq = 'D' ), "hrs" : pd .date_range ('1/1/2015' , periods = 5 , freq = 'H' ),
1311 "combination" : pd .to_datetime (['1/1/2015' , '1/1/2015 10:00:00' , '1/2/2015 00:00' , '1/4/2016 10:00' , '5/1/2016' ]),
1412 "quantitative" : [1.1 , 2.1 , 3.1 , 4.1 , 5.1 ]
1513})
1614
1715
18- @pytest .mark .parametrize ("column" , ['a' , 'b' , 'c' ])
19- def test_data_field_quantitative (column ):
20- chart = alt .Chart (df_quantitative ).mark_point ().encode (alt .X (field = column , type = 'quantitative' ))
16+ # _locate_channel_data() tests
17+
18+ @pytest .mark .parametrize ("column, dtype" , [
19+ ('a' , 'quantitative' ), ('b' , 'quantitative' ), ('c' , 'quantitative' ), ('combination' , 'temporal' )
20+ ])
21+ def test_data_field_quantitative (column , dtype ):
22+ chart = alt .Chart (df ).mark_point ().encode (alt .X (field = column , type = dtype ))
2123 for channel in chart .to_dict ()['encoding' ]:
22- data = convert ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
23- assert list (data ) == list (df_quantitative [column ].values )
24+ data = _data ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
25+ assert list (data ) == list (df [column ].values )
2426
2527
26- @pytest .mark .parametrize ("column" , ['a' , 'b' , 'c' ])
28+ @pytest .mark .parametrize ("column" , ['a' , 'b' , 'c' , 'combination' ])
2729def test_data_shorthand_quantitative (column ):
28- chart = alt .Chart (df_quantitative ).mark_point ().encode (alt .X (column ))
30+ chart = alt .Chart (df ).mark_point ().encode (alt .X (column ))
2931 for channel in chart .to_dict ()['encoding' ]:
30- data = convert ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
31- assert list (data ) == list (df_quantitative [column ].values )
32+ data = _data ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
33+ assert list (data ) == list (df [column ].values )
3234
3335
3436def test_data_value_quantitative ():
35- chart = alt .Chart (df_quantitative ).mark_point ().encode (opacity = alt .value (0.5 ))
37+ chart = alt .Chart (df ).mark_point ().encode (opacity = alt .value (0.5 ))
3638 for channel in chart .to_dict ()['encoding' ]:
37- data = convert ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
39+ data = _data ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
3840 assert data == 0.5
3941
4042
4143@pytest .mark .parametrize ("column" , ['a' , 'b' , 'c' ])
4244@pytest .mark .xfail (raises = NotImplementedError )
4345def test_data_aggregate_quantitative (column ):
44- chart = alt .Chart (df_quantitative ).mark_point ().encode (alt .X (field = column , type = 'quantitative' , aggregate = 'average' ))
46+ chart = alt .Chart (df ).mark_point ().encode (alt .X (field = column , type = 'quantitative' , aggregate = 'average' ))
4547 for channel in chart .to_dict ()['encoding' ]:
46- data = convert ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
48+ data = _data ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
4749
4850
49- def test_data_field_temporal ():
50- chart = alt .Chart (df_temporal ).mark_point ().encode (alt .X (field = 'combination' , type = 'temporal' ))
51+ @pytest .mark .xfail (raises = NotImplementedError )
52+ def test_data_timeUnit_shorthand_temporal ():
53+ chart = alt .Chart (df ).mark_point ().encode (alt .X ('month(combination):T' ))
5154 for channel in chart .to_dict ()['encoding' ]:
52- data = convert ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
53- assert list (data ) == list (df_temporal ['combination' ].values )
55+ data = _data ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
5456
5557
56- def test_data_shorthand_temporal ():
57- chart = alt .Chart (df_temporal ).mark_point ().encode (alt .X ('combination' ))
58+ @pytest .mark .xfail (raises = NotImplementedError )
59+ def test_data_timeUnit_field_temporal ():
60+ chart = alt .Chart (df ).mark_point ().encode (alt .X (field = 'combination' , type = 'temporal' , timeUnit = 'month' ))
5861 for channel in chart .to_dict ()['encoding' ]:
59- data = convert ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
60- assert list (data ) == list (df_temporal ['combination' ].values )
62+ data = _data ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
6163
6264
63- @pytest .mark .xfail (raises = NotImplementedError )
64- def test_data_timeUnit_shorthand_temporal ():
65- chart = alt .Chart (df_temporal ).mark_point ().encode (alt .X ('month(combination):T' ))
65+ # _locate_channel_dtype() tests
66+
67+ @pytest .mark .parametrize ('column, expected' , [
68+ ('a:Q' , 'quantitative' ), ('nom:N' , 'nominal' ), ('ord:O' , 'ordinal' ), ('combination:T' , 'temporal' )
69+ ])
70+ def test_data_dtype (column , expected ):
71+ chart = alt .Chart (df ).mark_point ().encode (alt .X (column ))
6672 for channel in chart .to_dict ()['encoding' ]:
67- data = convert ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
73+ dtype = _data ._locate_channel_dtype (chart .to_dict ()['encoding' ][channel ], chart .data )
74+ assert dtype == expected
6875
6976
7077@pytest .mark .xfail (raises = NotImplementedError )
71- def test_data_timeUnit_field_temporal ():
72- chart = alt .Chart (df_temporal ).mark_point ().encode (alt .X ( field = 'combination' , type = 'temporal' , timeUnit = 'month' ))
78+ def test_data_dtype_fail ():
79+ chart = alt .Chart (df ).mark_point ().encode (opacity = alt .value ( .5 ))
7380 for channel in chart .to_dict ()['encoding' ]:
74- data = convert ._locate_channel_data (chart .to_dict ()['encoding' ][channel ], chart .data )
81+ dtype = _data ._locate_channel_dtype (chart .to_dict ()['encoding' ][channel ], chart .data )
82+ assert dtype == 'quantitative'
0 commit comments