-
Notifications
You must be signed in to change notification settings - Fork 47
Expand file tree
/
Copy pathfigure_doa_experiment_plot.py
More file actions
145 lines (118 loc) · 4.97 KB
/
figure_doa_experiment_plot.py
File metadata and controls
145 lines (118 loc) · 4.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import division
import sys
import numpy as np
import getopt
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tools import polar_distance, polar_error
from experiment import arrays
if __name__ == "__main__":
# parse arguments
argv = sys.argv[1:]
# This is the output from `figure_doa_experiment.py`
data_file = 'data/20160909-203344_doa_experiment.npz'
try:
opts, args = getopt.getopt(argv, "hf:", ["file=",])
except getopt.GetoptError:
print('test_doa_recorded.py -f <data_file>')
sys.exit(2)
for opt, arg in opts:
if opt == '-h':
print('test_doa_recorded.py -a <algo> -f <file> -b <n_bands>')
sys.exit()
elif opt in ("-f", "--file"):
data_file = arg
# Get the speakers and microphones grounndtruth locations
exp_folder = './recordings/20160908/'
sys.path.append(exp_folder)
from edm_to_positions import twitters
# Get the microphone array locations
array_str = 'pyramic'
twitters.center(array_str)
R_flat_I = range(8, 16) + range(24, 32) + range(40, 48)
mic_array = arrays['pyramic_tetrahedron'][:, R_flat_I].copy()
mic_array += twitters[[array_str]]
# set the reference point to center of pyramic array
v = {array_str: np.mean(mic_array, axis=1)}
twitters.correct(v)
data = np.load(data_file)
# build some container arrays
algo_names = data['algo_names'].tolist()
# Now loop and process the results
columns = ['sources','SNR','Algorithm','Error']
table = []
close_sources = []
for pt in data['out']:
SNR = pt[0]
speakers = [s.replace("'","") for s in pt[1]]
K = len(speakers)
# Get groundtruth for speaker
phi_gt = np.array([twitters.doa(array_str, s) for s in speakers])[:,0]
for alg in pt[2].keys():
phi_recon = pt[2][alg]
recon_err, sort_idx = polar_distance(phi_gt, phi_recon)
table.append([K, SNR, alg, np.degrees(recon_err)])
# we single out the reconstruction of the two closely located sources
if '7' in speakers and '16' in speakers:
# by construction '7' is always first and '16' second
success = 0
for p1,p2 in zip(phi_gt[:2], phi_recon[sort_idx[:2,1]]):
if polar_error(p1,p2) < polar_error(phi_gt[0], phi_gt[1]) / 2:
success += 1
close_sources.append([alg,
success,
phi_recon[sort_idx[0,1]],
phi_recon[sort_idx[1,1]],
])
# Create pandas data frame
df = pd.DataFrame(table, columns=columns)
# Compute statistics for the reconstructed angles
df_close_sources = pd.DataFrame(close_sources, columns=['Algorithm','Success','7','16'])
mu = {'7':{},'16':{}}
std = {'7':{},'16':{}}
for alg in ['FRI','MUSIC','SRP']:
phi_r = df_close_sources[['Algorithm','7','16']][df_close_sources['Algorithm'] == alg]
for spkr in ['7','16']:
mu[spkr][alg] = np.angle(np.mean(np.exp(1j*phi_r[spkr])))
std[spkr][alg] = np.mean([polar_error(p, mu[spkr][alg]) for p in phi_r[spkr]])
for spkr in ['7','16']:
for alg in ['FRI','MUSIC','SRP']:
print spkr,alg,'mu=',np.degrees(mu[spkr][alg]),'std=',np.degrees(std[spkr][alg])
'''
for spkr in ['7','16']:
for alg in ['FRI','MUSIC','SRP']:
print np.degrees(mu[spkr][alg]),np.degrees(std[spkr][alg]),
print ''
'''
# Create the super plot comparing all algorithms
algo_plot = ['FRI','MUSIC','SRP', 'CSSM', 'TOPS', 'WAVES']
sns.set(style='whitegrid',context='paper', font_scale=1.2,
rc={
'figure.figsize':(3.5,3.15),
'lines.linewidth':1.5,
'font.family': 'sans-serif',
'font.sans-serif': [u'Helvetica'],
'text.usetex': False,
})
#pal = sns.cubehelix_palette(6, start=0.5, rot=-0.75, dark=0.3, light=.8, reverse=True)
#pal = sns.cubehelix_palette(6, start=0.5, rot=-0.5,dark=0.3, light=.85, reverse=True, hue=1.)
pal = sns.cubehelix_palette(6, start=0.5, rot=-0.5,dark=0.3, light=.75, reverse=True, hue=1.)
plt.figure(figsize=(4.7,3.15), dpi=90)
sns.boxplot(x="sources", y="Error", hue="Algorithm",
hue_order=algo_plot, data=df,
palette=pal,
fliersize=0.)
leg = plt.legend(loc='upper left',title='Algorithm',
bbox_to_anchor=[-0.02,1.03],
frameon=False, framealpha=0.1)
leg.get_frame().set_linewidth(0.0)
#palette="PRGn")
sns.despine(offset=10, trim=True, left=True)
plt.xlabel("Number of sources")
plt.ylabel("Error $[^\circ]$")
plt.yticks(np.arange(0,80))
plt.ylim([0.0, 3.3])
plt.tight_layout(pad=0.1)
plt.savefig('figures/experiment_error_box.pdf')
plt.savefig('figures/experiment_error_box.png')