-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample.py
More file actions
305 lines (248 loc) · 10.3 KB
/
example.py
File metadata and controls
305 lines (248 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import os
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from utils import model_configs, prepare_Q_A_FROM_ARC, download_data
K = 5
DEPTH = 1
SAMPLE_ID = "sky_color"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
device = "cuda" # the device to load the model onto
# Model Config
MODEL_ID = "llama" # OR "llama" for the 8B model
configs = model_configs.get(MODEL_ID)
model_path = configs.get("model_path")
answer_token_id = configs.get("answer_token_id")
def plot_attention_heatmap(attn_tensor, sample_id):
"""
Plots the heatmap of a (H, N) PyTorch tensor representing attention maps for 12 heads.
Args:
attn_tensor (torch.Tensor): Tensor of shape (H, N)
"""
plt.figure(figsize=(10, 6))
sns.heatmap(attn_tensor.float().cpu().numpy(), cmap='viridis', cbar=True)
plt.title('Attention Map of Answer Token')
plt.xlabel('Sequence Length')
plt.ylabel('Attention Heads ')
plt.savefig(f"images/Attention Map of Answer Token_{sample_id}", bbox_inches='tight')
plt.show()
def find_token_index(ids, token_id):
for i, id in enumerate(ids):
if id.item() == token_id:
return i
def aggregate_answer_attentions(output_ids, attentions_lst, input_length, end_think_id):
"""
For each of the answer tokens, we will extract the attentions to the thinking part. Then aggregate them.
:param output_ids: All output ids, including thinking and final
:param attentions_lst: A tuple of all attentions for generated tokens
:param input_length: The length of the model's prompt
:param end_think_id: The index of the ending thinking token </think>
:return: The aggregated (average) of attention maps
"""
end_think_index = find_token_index(output_ids, end_think_id)
result = []
for i in range(end_think_index, len(output_ids)):
think_attentions = attentions_lst[i][-1][0, :, 0,
input_length:end_think_index + input_length] # Shape: (H, end_think_index + 1)
result.append(think_attentions)
result = torch.stack(result, dim=0)
return torch.mean(result, dim=0)
def find_top_k_attended_tokens(think_token_index, k, output_ids, attentions_lst, input_length):
"""
Find the top k attended thinking tokens given a think_token_id.
:param think_token_index: An index of a thinking token
:param k:
:param output_ids:
:param attentions_lst:
:param input_length:
:return:
"""
if think_token_index == 0:
return torch.tensor([]), torch.tensor([])
previous_think_attentions = attentions_lst[think_token_index][-1][0, :, 0,
input_length:think_token_index + input_length] # (H, Pi)
aggregated = previous_think_attentions.sum(dim=0) # Shape: (Pi, )
top_k = torch.topk(aggregated, k=k)
return top_k.values, top_k.indices
def average_pool_attention(attn_weights, window_size=10, stride=10):
"""
Performs token-level segmentation and average pooling over attention weights.
Parameters:
- attn_weights: List or numpy array of attention scores (1D array of length N tokens)
- window_size: Number of tokens per segment
- stride: Step size for moving the window
Returns:
- pooled_attn: List of pooled attention values
- segments: Corresponding token segment indices
"""
pooled_attn = []
segments = []
for i in range(0, len(attn_weights) - window_size + 1, stride):
window = attn_weights[i:i + window_size]
pooled_value = np.mean(window)
pooled_attn.append(pooled_value)
segments.append((i, i + window_size))
return pooled_attn, segments
def average_pool_on_segments(attn, num_segments=3):
"""
Segments the attention values into `num_segments` and performs average pooling.
Parameters:
- attn: The list or array of pooled attention values.
- num_segments: The number of segments to divide the attention list into.
Returns:
- pooled_attn_segments: The average pooled attention for each segment.
"""
segment_size = len(attn) // num_segments
pooled_attn_segments = []
for i in range(num_segments):
start = i * segment_size
end = (i + 1) * segment_size if i != num_segments - 1 else len(attn)
segment = attn[start:end]
avg_pool = np.mean(segment)
pooled_attn_segments.append(avg_pool)
return pooled_attn_segments
if __name__ == '__main__':
# Create output directory
img_dir = "images"
os.makedirs(img_dir, exist_ok=True)
print(f"--- Starting Analysis for Model: {MODEL_ID} ---")
print(f"Device: {device}")
# --- Step 1: Load Model and Tokenizer ---
print("\n[1/5] Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto"
)
# --- Step 2: Prepare Input & Inference ---
print("\n[2/5] Running inference...")
prompt = (
"What is the color of the sky at night?\n"
"A) Blue\n"
"B) White\n"
"C) Pink\n"
"D) Black\n"
"Select the correct answer option (A, B, C, or D) ONLY and put your final answer in this format: Answer:"
" STRICTLY FOLLOW THIS FORMAT."
)
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
input_length = len(model_inputs.input_ids[0])
output_dict = model.generate(
**model_inputs,
max_new_tokens=2048,
output_attentions=True,
return_dict_in_generate=True
)
output_ids = output_dict.sequences[0, input_length:]
full_response = tokenizer.decode(output_ids)
print("-" * 40)
print("Full Model Output:\n", full_response)
print("-" * 40)
# --- Step 3: Analyze Answer and Attention ---
print("\n[3/5] analyzing attention maps...")
start_thinking_id = tokenizer.convert_tokens_to_ids("<think>")
end_thinking_id = tokenizer.convert_tokens_to_ids("</think>")
# Find the specific answer token
ans_start_index = find_token_index(output_ids, answer_token_id)
if ans_start_index is None:
print("Warning: Answer keyword token not found in output. Check prompt format.")
else:
# The final answer letter (e.g., "D") is expected 2 tokens after "Answer" (index + 2)
# Assuming structure: "Answer" -> ":" -> " D"
target_ans_index = ans_start_index + 2
if target_ans_index < len(output_ids):
ans_token_str = tokenizer.decode(output_ids[target_ans_index].item())
print(f"Detected Answer Token: '{ans_token_str}' at index {target_ans_index}")
else:
print("Warning: Answer index out of bounds.")
# Aggregate attentions (Case 2 logic)
thinking_attentions = aggregate_answer_attentions(
output_ids,
output_dict.attentions,
input_length,
end_thinking_id
)
# Sum attention across heads
aggregated_attentions = thinking_attentions.sum(dim=0)
# Plot primary heatmap
print(f"Saving main attention heatmap to {img_dir}...")
plot_attention_heatmap(thinking_attentions, sample_id=SAMPLE_ID)
# --- Step 4: Trace Reasoning Tree ---
print("\n[4/5] Tracing reasoning tree...")
# Get Top K attended tokens for the final answer
top_k_attended_tokens = torch.topk(aggregated_attentions, k=K)
print(f"\nTop {K} tokens attended by the final answer:")
for i, index in enumerate(top_k_attended_tokens.indices):
token_str = tokenizer.decode(output_ids[index].item()).strip()
print(f" {i + 1}. '{token_str}' (Index: {index})")
# Build depth tree
tree = []
top_k_answer_tokens = top_k_attended_tokens.indices
tree.append(top_k_answer_tokens)
d = 0
while d < DEPTH:
current_level = []
for thinking_id in tree[-1]:
_, indices = find_top_k_attended_tokens(
thinking_id, K, output_ids, output_dict.attentions, input_length
)
current_level += indices.tolist()
tree.append(current_level)
d += 1
print("\nReasoning Dependency Tree:")
for level_idx, level_indices in enumerate(tree):
tokens_in_level = [tokenizer.decode(output_ids[idx].item()).strip() for idx in level_indices]
print(f" Level {level_idx}: {tokens_in_level}")
# --- Step 5: Segment Analysis Visualizations ---
print("\n[5/5] Generating segment visualization plots...")
# 5a. Rolling Window Heatmap
attn_weights = aggregated_attentions.cpu().float().numpy()
window_size = 5
stride = 5
pooled_attn, segments = average_pool_attention(attn_weights, window_size, stride)
heatmap_data = np.array(pooled_attn).reshape(1, -1)
plt.figure(figsize=(10, 2))
step = 10
xtick_labels = [f"{seg[0]}-{seg[1]}" if i % step == 0 else "" for i, seg in enumerate(segments)]
sns.heatmap(
heatmap_data,
annot=False,
cmap="Blues",
xticklabels=xtick_labels,
yticklabels=["Pooled Attn"],
cbar=True
)
plt.xlabel("Token Segments")
plt.title("Attention Heatmap Over Token Segments")
plt.savefig(f"{img_dir}/attention_heatmap_{SAMPLE_ID}.png", bbox_inches='tight')
plt.show()
# 5b. Broad Segment Heatmap
pooled_attn_inputs = [pooled_attn] # List wrapper for consistency
pooled_attn_segments_per_input = [
average_pool_on_segments(input_attn, num_segments=3) for input_attn in pooled_attn_inputs
]
heatmap_data_broad = np.array(pooled_attn_segments_per_input).T
plt.figure(figsize=(10, 6))
sns.heatmap(
heatmap_data_broad,
annot=True,
cmap="Blues",
xticklabels=[f"Input {i + 1}" for i in range(len(pooled_attn_inputs))],
yticklabels=[f"Segment {i + 1}" for i in range(3)],
cbar=True
)
plt.xlabel("Inputs")
plt.ylabel("Segments")
plt.title("Average Pooled Attention Across 3 Broad Segments")
plt.savefig(f"{img_dir}/Average_Pooled_Attention_Across_{SAMPLE_ID}.png", bbox_inches='tight')
plt.show()
print("\nDone! Check the 'images' directory for outputs.")