Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit 0878f69

Browse files
committed
Add a prediction visualization script.
Work in-progress.
1 parent 80b6776 commit 0878f69

File tree

3 files changed

+199
-1
lines changed

3 files changed

+199
-1
lines changed

core/data/data_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges, include_strings=
139139
'problem_id': [1],
140140
'submission_id': [1],
141141
})
142-
142+
143143
return shapes
144144

145145

File renamed without changes.

scripts/visualize_predictions.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""Visualize model predictions."""
2+
3+
from absl import app
4+
from absl import flags
5+
6+
from flax.training import checkpoints
7+
from flax.training import common_utils
8+
import jax
9+
import jax.numpy as jnp
10+
from ml_collections.config_flags import config_flags
11+
import tensorflow_datasets as tfds
12+
13+
from core.data import codenet
14+
from core.data import codenet_paths
15+
from core.data import error_kinds
16+
from core.data import info as info_lib
17+
from core.data import process
18+
from core.lib import trainer
19+
20+
DEFAULT_DATASET_PATH = codenet_paths.DEFAULT_DATASET_PATH
21+
DEFAULT_CONFIG_PATH = codenet_paths.DEFAULT_CONFIG_PATH
22+
23+
24+
flags.DEFINE_string('dataset_path', DEFAULT_DATASET_PATH, 'Dataset path.')
25+
config_flags.DEFINE_config_file(
26+
name='config', default=DEFAULT_CONFIG_PATH, help_string='Config file.'
27+
)
28+
FLAGS = flags.FLAGS
29+
30+
31+
def get_raise_contribution_at_step(instruction_pointer, raise_decisions, raise_index):
32+
# instruction_pointer.shape: num_nodes
33+
# raise_decisions.shape: num_nodes, 2
34+
# raise_index.shape: scalar.
35+
p_raise = raise_decisions[:, 0]
36+
raise_contribution = p_raise * instruction_pointer
37+
# raise_contribution.shape: num_nodes
38+
raise_contribution = raise_contribution.at[raise_index].set(0)
39+
return raise_contribution
40+
get_raise_contribution_at_steps = jax.vmap(get_raise_contribution_at_step, in_axes=(0, 0, None))
41+
42+
43+
def get_raise_contribution(instruction_pointer, raise_decisions, raise_index, step_limit):
44+
# instruction_pointer.shape: steps, num_nodes
45+
# raise_decisions.shape: steps, num_nodes, 2
46+
# raise_index.shape: scalar.
47+
# step_limit.shape: scalar.
48+
raise_contributions = get_raise_contribution_at_steps(
49+
instruction_pointer, raise_decisions, raise_index)
50+
# raise_contributions.shape: steps, num_nodes
51+
mask = jnp.arange(instruction_pointer.shape[0]) < step_limit
52+
# mask.shape: steps
53+
raise_contributions = jnp.where(mask[:, None], raise_contributions, 0)
54+
raise_contribution = jnp.sum(raise_contributions, axis=0)
55+
# raise_contribution.shape: num_nodes
56+
return raise_contribution
57+
get_raise_contribution_batch = jax.vmap(get_raise_contribution)
58+
59+
60+
def print_spans(raw):
61+
span_starts = raw.node_span_starts
62+
span_ends = raw.node_span_ends
63+
for i, (span_start, span_end) in enumerate(zip(span_starts, span_ends)):
64+
print(f'Span {i}: {raw.source[span_start:span_end]}')
65+
66+
67+
def set_config(config):
68+
"""This function is hard-coded to load a particular checkpoint.
69+
70+
It also sets the model part of the config to match the config of that checkpoint.
71+
Everything related to parameter construction must match.
72+
"""
73+
config.multidevice=False
74+
config.batch_size=2
75+
config.raise_in_ipagnn=True
76+
config.restore_checkpoint_dir=(
77+
'/mnt/runtime-error-problems-experiments/experiments/2021-09-24-pretrain-004-copy/6/'
78+
'I1466,o=sgd,bs=32,lr=0.3,gc=2,hs=256,span=max,'
79+
'tdr=0,tadr=0,pe=False,T=default/checkpoints/'
80+
)
81+
config.optimizer = 'sgd'
82+
config.hidden_size = 256
83+
config.span_encoding_method = 'max'
84+
config.permissive_node_embeddings = False
85+
config.transformer_emb_dim = 512
86+
config.transformer_num_heads = 8
87+
config.transformer_num_layers = 6
88+
config.transformer_qkv_dim = 512
89+
config.transformer_mlp_dim = 2048
90+
91+
config.restore_checkpoint_dir=(
92+
'/mnt/runtime-error-problems-experiments/experiments/2021-09-27-finetune-001-copy/8/'
93+
'E055,o=sgd,bs=32,lr=0.1,gc=2,hs=256,span=mean,'
94+
'tdr=0.1,tadr=0.1,pe=False,T=default/checkpoints'
95+
)
96+
config.span_encoding_method = 'mean'
97+
return config
98+
99+
100+
def main(argv):
101+
del argv # Unused.
102+
103+
dataset_path = FLAGS.dataset_path
104+
config = FLAGS.config
105+
config = set_config(config)
106+
107+
jnp.set_printoptions(threshold=config.printoptions_threshold)
108+
info = info_lib.get_dataset_info(dataset_path)
109+
t = trainer.Trainer(config=config, info=info)
110+
111+
split = 'valid'
112+
dataset = t.load_dataset(
113+
dataset_path=dataset_path, split=split, include_strings=True)
114+
115+
# Initialize / Load the model state.
116+
rng = jax.random.PRNGKey(0)
117+
rng, init_rng = jax.random.split(rng)
118+
model = t.make_model(deterministic=False)
119+
state = t.create_train_state(init_rng, model)
120+
if config.restore_checkpoint_dir:
121+
state = checkpoints.restore_checkpoint(config.restore_checkpoint_dir, state)
122+
123+
train_step = t.make_train_step()
124+
for batch in tfds.as_numpy(dataset):
125+
assert not config.multidevice
126+
# We do not allow multidevice in this script.
127+
# if config.multidevice:
128+
# batch = common_utils.shard(batch)
129+
problem_ids = batch.pop('problem_id')
130+
submission_ids = batch.pop('submission_id')
131+
state, aux = train_step(state, batch)
132+
133+
instruction_pointer = aux['instruction_pointer_orig']
134+
# instruction_pointer.shape: steps, batch_size, num_nodes
135+
instruction_pointer = jnp.transpose(instruction_pointer, [1, 0, 2])
136+
# instruction_pointer.shape: batch_size, steps, num_nodes
137+
exit_index = batch['exit_index']
138+
raise_index = exit_index + 1
139+
raise_decisions = aux['raise_decisions']
140+
# raise_decisions.shape: steps, batch_size, num_nodes, 2
141+
raise_decisions = jnp.transpose(raise_decisions, [1, 0, 2, 3])
142+
# raise_decisions.shape: batch_size, steps, num_nodes, 2
143+
contributions = get_raise_contribution_batch(instruction_pointer, raise_decisions, raise_index, batch['step_limit'])
144+
# contributions.shape: batch_size, num_nodes
145+
146+
for index, (problem_id, submission_id, contribution) \
147+
in enumerate(zip(problem_ids, submission_ids, contributions)):
148+
problem_id = problem_id[0].decode('utf-8')
149+
submission_id = submission_id[0].decode('utf-8')
150+
python_path = codenet.get_python_path(problem_id, submission_id)
151+
r_index = int(raise_index[index])
152+
num_nodes = int(raise_index[index]) + 1
153+
target = int(batch['target'][index])
154+
target_error = error_kinds.to_error(target)
155+
prediction = int(jnp.argmax(aux['logits'][index]))
156+
prediction_error = error_kinds.to_error(prediction)
157+
step_limit = batch['step_limit'][index]
158+
159+
total_contribution = jnp.sum(contribution)
160+
actual_value = instruction_pointer[index, -1, r_index]
161+
max_contributor = int(jnp.argmax(contribution))
162+
max_contribution = contribution[max_contributor]
163+
164+
# Not all submissions are in the copy of the dataset in gs://project-codenet-data.
165+
# So we only visualize those that are in the copy.
166+
if os.path.exists(python_path):
167+
found = True
168+
with open(python_path, 'r') as f:
169+
source = f.read()
170+
error_lineno = codenet.get_error_lineno(problem_id, submission_id)
171+
raw = process.make_rawruntimeerrorproblem(
172+
source, target,
173+
target_lineno=error_lineno, problem_id=problem_id, submission_id=submission_id)
174+
175+
# Visualize the data.
176+
print('---')
177+
print(f'Problem: {problem_id} {submission_id} ({split})')
178+
print(f'Batch index: {index}')
179+
print(f'Target: {target} ({target_error})')
180+
print(f'Prediction: {prediction} ({prediction_error})')
181+
print()
182+
print(source.strip() + '\n')
183+
print_spans(raw)
184+
print(contribution[:num_nodes])
185+
print(f'Main contributor: Node {max_contributor} ({max_contribution})')
186+
print(f'Total contribution: {total_contribution} (Actual: {actual_value})')
187+
188+
if error_lineno:
189+
nodes_at_error = process.get_nodes_at_lineno(raw, error_lineno)
190+
print(f'Error lineno: {error_lineno} (nodes {nodes_at_error})')
191+
print(source.split('\n')[error_lineno - 1]) # -1 for line index.
192+
193+
# Wait for the user to press enter, then continue visualizing.
194+
input()
195+
196+
197+
if __name__ == '__main__':
198+
app.run(main)

0 commit comments

Comments
 (0)