forked from pcyin/NL2code
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcanon_utils.py
More file actions
92 lines (67 loc) · 1.92 KB
/
canon_utils.py
File metadata and controls
92 lines (67 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from file_utils import file_contents
import ast
## Common utilities ############################################################
def fetch_queries_codes(annot_file, code_file):
# get all queries
contents = file_contents(annot_file)
examples = contents.split("\n\n")
examples = examples[:-1]
annots = []
for ex in examples:
parts = ex.split("\n")
parts = filter(lambda x: x.strip() != "", parts)
assert (len(parts) == 2)
assert (parts[0].startswith("example"))
eid = int(parts[0].split(" ")[1])
query = parts[1]
annots.append(query)
# get all code
contents = file_contents(code_file)
examples = contents.split("\n\n")
examples = examples[:-1]
codes = []
for ex in examples:
parts = ex.split("\n")
parts = filter(lambda x: x.strip() != "", parts)
assert (len(parts) >= 2)
assert (parts[0].startswith("example"))
eid = int(parts[0].split(" ")[1])
code = "\n".join(parts[1:])
codes.append(code)
return annots, codes
def parsable_code(code):
try:
# try to parse
ast.parse(code)
return code
except:
pass
try:
# try to parse with pass statement
code = code + "\n" + "\tpass"
ast.parse(code)
return code
except:
pass
# not parsable
return ""
def reproducable_code(code):
try:
ast.parse(codegen.to_source(ast.fix_missing_locations(ast.parse(code.strip()))))
return code
except:
pass
try:
# try to reproduce with pass statement
code = code + "\n" + "\tpass"
ast.parse(codegen.to_source(ast.fix_missing_locations(ast.parse(code.strip()))))
return code
except:
pass
# not reproducable
return ""
def give_me_5(seq, desc = ""):
print "\n\n"
print desc + " - \n\n"
for s in seq[:5]:
print s