-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdetect.py
More file actions
101 lines (82 loc) · 3.34 KB
/
detect.py
File metadata and controls
101 lines (82 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Based on https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/raspberry_pi/README.md
import re
import cv2
from tflite_runtime.interpreter import Interpreter
import numpy as np
CAMERA_WIDTH = 1292
CAMERA_HEIGHT = 964
def load_labels(path="labels.txt"):
"""Loads the labels file. Supports files with or without index numbers."""
with open(path, "r", encoding="utf-8") as f:
lines = f.readlines()
labels = {}
for row_number, content in enumerate(lines):
pair = re.split(r"[:\s]+", content.strip(), maxsplit=1)
if len(pair) == 2 and pair[0].strip().isdigit():
labels[int(pair[0])] = pair[1].strip()
else:
labels[row_number] = pair[0].strip()
return labels
def set_input_tensor(interpreter, image):
"""Sets the input tensor."""
tensor_index = interpreter.get_input_details()[0]["index"]
input_tensor = interpreter.tensor(tensor_index)()[0]
input_tensor[:, :] = np.expand_dims((image - 255) / 255, axis=0)
def get_output_tensor(interpreter, index):
"""Returns the output tensor at the given index."""
output_details = interpreter.get_output_details()[index]
tensor = np.squeeze(interpreter.get_tensor(output_details["index"]))
return tensor
def detect_objects(interpreter, image, threshold):
"""Returns a list of detection results, each a dictionary of object info."""
set_input_tensor(interpreter, image)
interpreter.invoke()
# Get all output details
boxes = get_output_tensor(interpreter, 0)
classes = get_output_tensor(interpreter, 1)
scores = get_output_tensor(interpreter, 2)
count = int(get_output_tensor(interpreter, 3))
results = []
for i in range(count):
if scores[i] >= threshold:
result = {
"bounding_box": boxes[i],
"class_id": classes[i],
"score": scores[i],
}
results.append(result)
return results
def main():
labels = load_labels()
interpreter = Interpreter("detect.tflite")
interpreter.allocate_tensors()
_, input_height, input_width, _ = interpreter.get_input_details()[0]["shape"]
cap = cv2.VideoCapture("http://bl23i-di-serv-02.diamond.ac.uk:8080/ECAM6.mjpg.mjpg")
while cap.isOpened():
ret, frame = cap.read()
img = cv2.resize(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), (640, 640))
res = detect_objects(interpreter, img, 0.6)
print(res)
for result in res:
ymin, xmin, ymax, xmax = result["bounding_box"]
xmin = int(max(1, xmin * CAMERA_WIDTH))
xmax = int(min(CAMERA_WIDTH, xmax * CAMERA_WIDTH))
ymin = int(max(1, ymin * CAMERA_HEIGHT))
ymax = int(min(CAMERA_HEIGHT, ymax * CAMERA_HEIGHT))
cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (0, 255, 0), 3)
cv2.putText(
frame,
labels[int(result["class_id"])],
(xmin, min(ymax, CAMERA_HEIGHT - 20)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
2,
cv2.LINE_AA,
)
cv2.imshow("PinDet", frame)
if cv2.waitKey(10) & 0xFF == ord("q"):
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()