-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathensemble.py
More file actions
214 lines (180 loc) · 7.92 KB
/
ensemble.py
File metadata and controls
214 lines (180 loc) · 7.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""
Ensemble output of four models: qwen.jsonl, gemma3.jsonl, robertwwm.jsonl, llama.jsonl
Strategy:
1. Merge Quadruplets from all models
2. For the same (Aspect, Category, Opinion) combination, calculate the weighted average VA value based on model weights
3. Use a weighted voting mechanism: retain only if the sum of weights >= threshold
"""
import json
from collections import defaultdict
import os
# Input file paths and corresponding weights (based on CF1 score)
model_configs = [
# qwen model
{"file": "test/qwen32B-3e.jsonl", "weight": 0.5852, "name": "qwen32-3e"},
{"file": "test/qwen32B_best_loss.jsonl", "weight": 0.5795, "name": "qwen32-loss"},
{"file": "test/qwen32B_best_cF1.jsonl", "weight": 0.5723, "name": "qwen32-cF1"},
{"file": "test/qwen14B-3e.jsonl", "weight": 0.5848, "name": "qwen14-3e"},
{"file": "test/qwen14B_best_loss.jsonl", "weight": 0.5748, "name": "qwen14-loss"},
# roberta model
{"file": "test/robertwwm.jsonl", "weight": 0.5673, "name": "roberta"},
# gemma3 model
{"file": "test/gemma-3e.jsonl", "weight": 0.5501, "name": "gemma-3e"},
{"file": "test/gemma-4e.jsonl", "weight": 0.5379, "name": "gemma-4e"},
{"file": "test/gemma-5e.jsonl", "weight": 0.5375, "name": "gemma-5e"},
# llama model
{"file": "test/llama-3e.jsonl", "weight": 0.5413, "name": "llama-3e"},
{"file": "test/llama_best_loss.jsonl", "weight": 0.5313, "name": "llama-loss"},
# Other closed-source models
{"file": "test/gpt.jsonl", "weight": 0.2745, "name": "gpt"},
{"file": "test/gemini.jsonl", "weight": 0.3516, "name": "gemini"},
]
# Normalize weights
total_weight = sum(cfg["weight"] for cfg in model_configs)
for cfg in model_configs:
cfg["norm_weight"] = cfg["weight"] / total_weight
# Voting threshold: Lower the threshold to improve Recall
VOTE_THRESHOLD = 0.3421
# Output file
output_file = "output/ensemble_withClose.jsonl"
def load_jsonl(filepath):
"""Load JSONL file"""
data = {}
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
item = json.loads(line.strip())
data[item['ID']] = item
return data
def parse_va(va_str):
"""Parse VA string into (valence, arousal) tuple"""
try:
parts = va_str.split('#')
return float(parts[0]), float(parts[1])
except:
return None, None
def format_va(valence, arousal):
"""Format VA value"""
return f"{valence:.2f}#{arousal:.2f}"
def get_quad_key(quad):
"""Get quadruplet key (for comparison)"""
aspect = quad.get('Aspect', '').strip().lower()
category = quad.get('Category', '').strip().upper()
opinion = quad.get('Opinion', '').strip().lower()
return (aspect, category, opinion)
def ensemble_quadruplets_weighted(quad_lists_with_weights, vote_threshold=0.5):
"""
Merge quadruplets from multiple models (weighted version)
quad_lists_with_weights: [(quads, weight, name), ...]
vote_threshold: Retain only if the sum of weights reaches this threshold
"""
# Collect all quadruplets along with their VA values and weights
quad_votes = defaultdict(list) # key -> [(va_valence, va_arousal, weight), ...]
quad_original = {} # key -> original quadruplet (preserve original case)
for quads, weight, name in quad_lists_with_weights:
if quads is None:
continue
for quad in quads:
key = get_quad_key(quad)
va = quad.get('VA', '')
v, a = parse_va(va)
if v is not None and a is not None:
quad_votes[key].append((v, a, weight))
# Preserve original format
if key not in quad_original:
quad_original[key] = {
'Aspect': quad.get('Aspect', ''),
'Category': quad.get('Category', ''),
'Opinion': quad.get('Opinion', ''),
}
# Generate final quadruplets based on weighted voting results
result = []
for key, va_list in quad_votes.items():
# Calculate sum of weights
total_weight = sum(w for v, a, w in va_list)
if total_weight >= vote_threshold:
# Weighted average VA value
weighted_valence = sum(v * w for v, a, w in va_list) / total_weight
weighted_arousal = sum(a * w for v, a, w in va_list) / total_weight
quad = quad_original[key].copy()
quad['VA'] = format_va(weighted_valence, weighted_arousal)
result.append(quad)
return result
def main():
# Load prediction results from all models
print("Loading model prediction results...")
print("\nModel weight settings (based on CF1 score):")
all_data = []
for cfg in model_configs:
filepath = cfg["file"]
if os.path.exists(filepath):
data = load_jsonl(filepath)
all_data.append((data, cfg["norm_weight"], cfg["name"]))
print(f" ✅ {cfg['name']:10s}: CF1={cfg['weight']:.4f}, Normalized Weight={cfg['norm_weight']:.4f}, {len(data)} items")
else:
print(f" ❌ {filepath}: File does not exist")
all_data.append(({}, 0, cfg["name"]))
print(f"\nVoting Threshold: {VOTE_THRESHOLD} (Sum of weights must be >= {VOTE_THRESHOLD})")
# Get all IDs
all_ids = set()
for data, weight, name in all_data:
all_ids.update(data.keys())
all_ids = sorted(all_ids)
print(f"Total {len(all_ids)} samples")
# Ensemble
print("\nStarting weighted Ensemble...")
results = []
stats = {'total': 0, 'with_quads': 0, 'empty': 0}
for sample_id in all_ids:
# Collect predictions for this sample from each model (with weights)
quad_lists_with_weights = []
text = None
for data, weight, name in all_data:
if sample_id in data:
quads = data[sample_id].get('Quadruplet', [])
quad_lists_with_weights.append((quads, weight, name))
if text is None and 'Text' in data[sample_id]:
text = data[sample_id]['Text']
else:
quad_lists_with_weights.append(([], weight, name))
# Weighted Ensemble
ensemble_quads = ensemble_quadruplets_weighted(quad_lists_with_weights, vote_threshold=VOTE_THRESHOLD)
result = {
'ID': sample_id,
'Quadruplet': ensemble_quads
}
if text:
result['Text'] = text
results.append(result)
stats['total'] += 1
if len(ensemble_quads) > 0:
stats['with_quads'] += 1
else:
stats['empty'] += 1
# Save results
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
for item in results:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Rename and zip
import shutil
import zipfile
# Define new filename
new_filename = "pred_zho_restaurant.jsonl"
new_filepath = os.path.join(os.path.dirname(output_file), new_filename)
# Copy/Rename file
shutil.copy(output_file, new_filepath)
# Create zip file
zip_filename = output_file.replace('.jsonl', '.zip')
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zf:
zf.write(new_filepath, arcname=new_filename)
# Delete temporary renamed file
os.remove(new_filepath)
print(f"\n✅ Weighted Ensemble completed!")
print(f" Output file: {output_file}")
print(f" Renamed file: {new_filepath}")
print(f" Zip file: {zip_filename}")
print(f" Total samples: {stats['total']}")
print(f" With Quadruplet: {stats['with_quads']}")
print(f" Empty Quadruplet: {stats['empty']}")
if __name__ == "__main__":
main()