-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathtrytheano.py
More file actions
70 lines (60 loc) · 1.78 KB
/
trytheano.py
File metadata and controls
70 lines (60 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# (C) William W. Cohen and Carnegie Mellon University, 2016
import theano
import theano.tensor as T
import theano.sparse as S
import theano.sparse.basic as B
from . import matrixdb
import numpy
def debugVar(v,depth=0,maxdepth=10):
if depth>maxdepth:
print('...')
else:
print('| '*(depth+1), end=' ')
print('var: name',v.name,'type',type(v),'def',theano.pp(v))
for a in v.get_parents():
debugApply(a,depth=depth+1,maxdepth=maxdepth)
def debugApply(a,depth=0,maxdepth=10):
if depth>maxdepth:
print('...')
else:
print('| '*(depth+1), end=' ')
print('apply: ',a,'op',type(a.op),'output types',list(map(type,a.outputs)))
for v in a.inputs:
debugVar(v,depth=depth+1,maxdepth=maxdepth)
if __name__=="__main__":
db = matrixdb.MatrixDB.loadFile("test/fam.cfacts")
va = db.onehot('william')
vb = db.onehot('sarah')
print('a',va)
print('b',vb)
print('shape',va.shape)
print('f1: s = x*((x+x)+x)')
tx = S.csr_matrix('x')
r1 = B.sp_sum(tx+tx+tx,sparse_grad=True)
s = tx*r1
s.name = 's'
f1 = theano.function(inputs=[tx],outputs=[s])
w = f1(va)
print(w[0])
debugVar(s)
#print db.rowAsSymbolDict(w[0])
#
# print 'f2(w=a,c=b)'
# tw = S.csr_matrix('w') #weighter
# tc = S.csr_matrix('c') #constant
# r2 = B.sp_sum(tw*1.7,sparse_grad=True)
# s2 = tc*r2
# f2 = theano.function(inputs=[tw,tc],outputs=[s2])
# w = f2(va,vb)
# print w[0]
#
print('f3(w=a), b constant')
tw3 = S.csr_matrix('w') #weighter
#y = sparse.CSR(data, indices, indptr, shape)
# tc3 = S.CSR(vb.data, vb.indices, vb.indptr, vb.shape)
# r3 = B.sp_sum(tw3*1.7,sparse_grad=True)
# s3 = tc3*r3
# f3 = theano.function(inputs=[tw3],outputs=[s3])
# w = f3(va)
# print w[0]
# debugVar(tw3,maxdepth=5)