Skip to content

Commit a00a23b

Browse files
authored
Merge pull request #711 from agnesnatasya/squeezenet
Implement Squeezenet using Squeezenet1.1
2 parents e4082c6 + 1f88218 commit a00a23b

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

examples/onnx/squeezenet.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under th
18+
19+
import os
20+
import numpy as np
21+
from PIL import Image
22+
23+
from singa import device
24+
from singa import tensor
25+
from singa import autograd
26+
from singa import sonnx
27+
import onnx
28+
from utils import download_model, update_batch_size, check_exist_or_download
29+
30+
import logging
31+
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
32+
33+
34+
def preprocess(img):
35+
img = img.resize((224, 224))
36+
img = img.crop((0, 0, 224, 224))
37+
img = np.array(img).astype(np.float32) / 255.
38+
img = np.rollaxis(img, 2, 0)
39+
for channel, mean, std in zip(range(3), [0.485, 0.456, 0.406],
40+
[0.229, 0.224, 0.225]):
41+
img[channel, :, :] -= mean
42+
img[channel, :, :] /= std
43+
img = np.expand_dims(img, axis=0)
44+
return img
45+
46+
47+
def get_image_label():
48+
# download label
49+
label_url = 'https://s3.amazonaws.com/onnx-model-zoo/synset.txt'
50+
with open(check_exist_or_download(label_url), 'r') as f:
51+
labels = [l.rstrip() for l in f]
52+
53+
# download image
54+
image_url = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
55+
img = Image.open(check_exist_or_download(image_url))
56+
return img, labels
57+
58+
59+
class Infer:
60+
61+
def __init__(self, sg_ir):
62+
self.sg_ir = sg_ir
63+
for idx, tens in sg_ir.tensor_map.items():
64+
# allow the tensors to be updated
65+
tens.requires_grad = True
66+
tens.stores_grad = True
67+
sg_ir.tensor_map[idx] = tens
68+
69+
def forward(self, x):
70+
return sg_ir.run([x])[0]
71+
72+
73+
if __name__ == "__main__":
74+
75+
url = 'https://github.com/onnx/models/raw/master/vision/classification/squeezenet/model/squeezenet1.1-7.tar.gz'
76+
download_dir = '/tmp/'
77+
model_path = os.path.join(download_dir, 'squeezenet1.1',
78+
'squeezenet1.1.onnx')
79+
80+
logging.info("onnx load model...")
81+
download_model(url)
82+
onnx_model = onnx.load(model_path)
83+
84+
# set batch size
85+
onnx_model = update_batch_size(onnx_model, 1)
86+
87+
# prepare the model
88+
logging.info("prepare model...")
89+
dev = device.create_cuda_gpu()
90+
sg_ir = sonnx.prepare(onnx_model, device=dev)
91+
autograd.training = False
92+
model = Infer(sg_ir)
93+
94+
# verify the test
95+
# from utils import load_dataset
96+
# inputs, ref_outputs = load_dataset(
97+
# os.path.join('/tmp', 'squeezenet1.1', 'test_data_set_0'))
98+
# x_batch = tensor.Tensor(device=dev, data=inputs[0])
99+
# outputs = model.forward(x_batch)
100+
# for ref_o, o in zip(ref_outputs, outputs):
101+
# np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
102+
103+
# inference
104+
logging.info("preprocessing...")
105+
img, labels = get_image_label()
106+
img = preprocess(img)
107+
108+
logging.info("model running...")
109+
x_batch = tensor.Tensor(device=dev, data=img)
110+
y = model.forward(x_batch)
111+
112+
logging.info("postprocessing...")
113+
y = tensor.softmax(y)
114+
scores = tensor.to_numpy(y)
115+
scores = np.squeeze(scores)
116+
a = np.argsort(scores)[::-1]
117+
for i in a[0:5]:
118+
logging.info('class=%s ; probability=%f' % (labels[i], scores[i]))

0 commit comments

Comments
 (0)