forked from pcyin/NL2code
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtype_simulation.py
More file actions
398 lines (290 loc) · 9.66 KB
/
type_simulation.py
File metadata and controls
398 lines (290 loc) · 9.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
from pytokens import PYKEYWORDS
import ast
import re
## TYPES #######################################################################
# types
TYPES = {}
# list and tuple map to the same type
TYPES["seq"] = "SEQ"
TYPES["dict"] = "DICT"
TYPES["class"] = "CLASS"
TYPES["func"] = "FUNC"
TYPES["gen"] = "GEN" # can be either string or num
TYPES["any"] = "ANY" # can't determine
TYPES["num"] = "NUM"
TYPES["str"] = "STR"
TYPES["funcdef"] = "FUNCDEF"
TYPES["classdef"] = "CLASSDEF"
## AST NodeVisitors ############################################################
class CallVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_Call(self, node):
try:
if isinstance(node.func, ast.Name):
identifier = node.func.id
if self.d.get(identifier) is None and identifier not in PYKEYWORDS:
self.d[identifier] = TYPES["func"]
if isinstance(node.func, ast.Attribute):
# the deepest attribute is a function
if isinstance(node.func, ast.Name):
identifier1 = node.func.id
if self.d.get(identifier1) is None and identifier1 not in PYKEYWORDS:
self.d[identifier1] = TYPES["class"]
identifier2 = node.func.attr
if self.d.get(identifier2) is None and identifier2 not in PYKEYWORDS:
self.d[identifier2] = TYPES["func"]
except:
pass
class AttributeVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_Attribute(self, node):
try:
if isinstance(node.value, ast.Attribute):
# nice ; call recursively
self.visit_Attribute(node.value)
if isinstance(node.value, ast.Name):
identifier1 = node.value.id
if self.d.get(identifier1) is None and identifier1 not in PYKEYWORDS:
self.d[identifier1] = TYPES["class"]
identifier2 = node.attr
if self.d.get(identifier2) is None and identifier2 not in PYKEYWORDS:
# if this guy's parent is an attribute; then this is a class
if isinstance(node.parent, ast.Attribute):
self.d[identifier2] = TYPES["class"]
else:
# likely a variable
self.d[identifier2] = TYPES["any"] # can easily be a num or str !
except:
pass
class SubscriptVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_Subscript(self, node):
ids = []
has_int_key = False
has_str_key = False
try:
if (node.value.id is not None and node.value.id not in PYKEYWORDS):
ids.append(node.value.id)
if (isinstance(node.slice.value, ast.Num)):
has_int_key = True
if (isinstance(node.slice.value, ast.Str)):
has_str_key = True
except:
pass
if len(ids) == 1:
identifier = ids[0]
# if identifier already filled; move on
if self.d.get(identifier) is None:
# can make confident predictions
if has_int_key:
# it is likely to be a list
self.d[identifier] = TYPES["seq"]
if has_str_key:
# it is likely to be a dict
self.d[identifier] = TYPES["dict"]
class FunctionDefVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_FunctionDef(self, node):
try:
name = node.name
self.d[name] = TYPES["funcdef"]
except:
pass
class ClassDefVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_ClassDef(self, node):
try:
name = node.name
self.d[name] = TYPES["classdef"]
except:
pass
class KWVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_keyword(self, node):
try:
if node.arg not in PYKEYWORDS:
argname = node.arg
self.d[argname] = TYPES["any"] # really could be any thing
except:
pass
class StrVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_Str(self, node):
try:
if node.s not in PYKEYWORDS:
s = node.s
self.d[s] = TYPES["str"]
except:
pass
class NumVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_Num(self, node):
try:
n = node.n
self.d[n] = TYPES["num"]
except:
pass
class AliasVisitor(ast.NodeVisitor):
# constructor
def __init__(self, d_ = {}):
self.d = d_
def visit_alias(self, node):
try:
if node.name not in PYKEYWORDS:
name = node.name
self.d[name] = TYPES["str"]
except:
pass
## Helpers and exports #########################################################
def get_unique_ref_types(ref_types):
# create a reference-type map; for example if there are 2 referenceable,
# "a" and "b" whose types are "ANY", then map "a" and "b" to "ANY_1" and
# "ANY_2"
ref_type_unique = {}
# count number of times each type appears
type_f = {}
for ref, t in ref_types.iteritems():
if type_f.get(t) is None:
# first occurance
type_f[t] = 0
else:
type_f[t] = type_f[t] + 1
c = type_f[t]
# including this we have seen this type c #times
ref_type_unique[ref] = t + "_" + str(c)
return ref_type_unique
def get_referencables(code):
# parse code
root = ast.parse(code)
refable = []
# walk the ast
for node in ast.walk(root):
# can defenitely reference a identifier from query
try:
ident = node.id
if ident not in PYKEYWORDS:
refable.append(ident)
except:
pass
# can reference strings
try:
s = node.s
if s not in PYKEYWORDS:
refable.append(s)
except:
pass
# can reference numbers
try:
n = node.n
refable.append(n)
except:
pass
# work with aliases
try:
if isinstance(node, ast.alias):
s = node.name
if s not in PYKEYWORDS:
refable.append(s)
s = node.asname
if isinstance(s, str) and s not in PYKEYWORDS:
refable.append(s)
except:
pass
return set(refable)
def unique_ref_types(ref_types):
# type counter
type_counter = {}
for k, v in ref_types.iteritems():
if type_counter.get(v) is None:
type_counter[v] = 0
else:
type_counter[v] = type_counter[v] + 1
ref_types[k] = v + "_" + str(type_counter[v])
return ref_types
def guess_types(code, query):
# get all referencables
referencables = get_referencables(code)
root = None
# the code must be parsable
try:
root = ast.parse(code)
except:
assert (False and "code non parsable")
ref_types = {}
# add parent information to all nodes in ast tree
for node in ast.walk(root):
for child in ast.iter_child_nodes(node):
# add weak reference
child.parent = node
# figure out tuples/lists/dict to the best we can
subscript_v = SubscriptVisitor(ref_types)
subscript_v.visit(root)
ref_types = subscript_v.d
# figure out functions
call_v = CallVisitor(ref_types)
call_v.visit(root)
ref_types = call_v.d
# figure out class objects, class members
attribute_v = AttributeVisitor(ref_types)
attribute_v.visit(root)
ref_types = attribute_v.d
# figure out strings
string_v = StrVisitor(ref_types)
string_v.visit(root)
ref_types = string_v.d
# figure out numbers
num_v = NumVisitor(ref_types)
num_v.visit(root)
ref_types = num_v.d
# figure out all aliases
alias_v = AliasVisitor(ref_types)
alias_v.visit(root)
ref_types = alias_v.d
# figure out all function definitions
fdef_v = FunctionDefVisitor(ref_types)
fdef_v.visit(root)
ref_types = fdef_v.d
# figure out all class definitions
cdef_v = ClassDefVisitor(ref_types)
cdef_v.visit(root)
ref_types = cdef_v.d
# figure out all keywords
kw_v = KWVisitor(ref_types)
kw_v.visit(root)
ref_types = kw_v.d
# figure out if there is any string in query !! regex :P
matches = []
# double quotes
matches = matches + re.findall(r'\"(.+?)\"', query)
# single quotes
matches = matches + re.findall(r'\'(.+?)\'', query)
# back ticks :P `
matches = matches + re.findall(r'\`(.+?)\`', query)
# these matches are strings
for m in matches:
if m not in ref_types.keys():
ref_types[m] = TYPES["str"]
# all referencable that we were not able to figure out is any type
for refable in referencables:
if refable not in ref_types.keys():
ref_types[refable] = TYPES["any"]
# make types unique !!; dont have DICT for 2 variables; instead have DICT_0
# and DICT_1
ref_types = unique_ref_types(ref_types)
return ref_types