-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathargs.py
More file actions
124 lines (92 loc) · 3.37 KB
/
args.py
File metadata and controls
124 lines (92 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
import itertools
from createTreeModel import _classification_ids
import os
from collections import defaultdict
path_structure = [
'dataset_id',
'n_estimators use_predicted_class',
'sample_id',
'method T_max lr optimizer',
]
def process_arg_dict(arg_dict):
arg_dict_flatten = flatten_arg_dict(arg_dict)
arg_comb = dict2comb(arg_dict_flatten)
arg_comb = filter_arg(arg_comb)
add_path(arg_comb)
arg_comb = sort(arg_comb, arg_dict['dataset_id'])
return arg_comb
def flatten_arg_dict(arg_dict):
unprocessed = [arg_dict]
arg_dict_flatten = []
while len(unprocessed):
dict_cur = unprocessed.pop()
flag = True
for key, value in dict_cur.items():
if isinstance(value, dict):
dict_cur.pop(key)
for name, component in value.items():
dict_new = dict_cur.copy()
if isinstance(component, list):
dict_new[key] = component
elif isinstance(component, dict):
dict_new[key] = name
dict_new.update(component)
else:
raise ValueError
unprocessed.append(dict_new)
flag = False
break
if flag:
arg_dict_flatten.append(dict_cur)
return arg_dict_flatten
def dict2comb(arg_dict_flatten):
if not isinstance(arg_dict_flatten, list):
assert isinstance(arg_dict_flatten, dict)
arg_dict_flatten = [arg_dict_flatten]
arg_comb = []
for dict_cur in arg_dict_flatten:
for key, value in dict_cur.items():
if isinstance(value, np.ndarray):
dict_cur[key] = value.tolist()
elif isinstance(value, range):
dict_cur[key] = list(value)
elif not isinstance(value, list):
dict_cur[key] = [value]
keys = dict_cur.keys()
values = dict_cur.values()
for instance in itertools.product(*values):
arg_comb.append(dict(zip(keys, instance)))
return arg_comb
def filter_arg(arg_comb):
arg_comb_filtered = []
for arg in arg_comb:
if arg['dataset_id'] not in _classification_ids and arg['use_predicted_class']:
continue
if 'T_max' in arg:
if (arg['T_max'], arg['lr']) not in [(10, 5), (50, 5), (100, 5), (100, 1), (10, 1)]:
continue
arg_comb_filtered.append(arg)
return arg_comb_filtered
def add_path(arg_comb):
structure = []
for item in path_structure:
structure.append(item.split())
for arg in arg_comb:
path = arg.pop('root')
for keys in structure:
path_cur = ''
for key in keys:
if key in arg.keys():
path_cur += f'{key}={arg[key]}-'
path_cur = path_cur[:-1]
path = os.path.join(path, path_cur)
arg['path_results'] = path + '.npz'
def sort(arg_comb, dataset_ids):
arg_sorted = defaultdict(list)
for arg in arg_comb:
arg_sorted[arg['dataset_id']].append(arg)
arg_comb = []
for dataset_id in dataset_ids:
arg_comb += arg_sorted[dataset_id]
return arg_comb