55#
66# -----------------------------------------------------------------------------
77
8+ import importlib
9+ import platform
10+ import sys
811from pathlib import Path
912from typing import Dict , List , Optional , Union
1013from warnings import warn
1114
1215import numpy as np
1316
14- try :
15- import qaicrt
16- except ImportError :
17- import platform
18- import sys
1917
20- sys .path .append (f"/opt/qti-aic/dev/lib/{ platform .machine ()} " )
21- import qaicrt
22-
23- try :
24- import QAicApi_pb2 as aicapi
25- except ImportError :
26- import sys
27-
28- sys .path .append ("/opt/qti-aic/dev/python" )
29- import QAicApi_pb2 as aicapi
18+ class QAICInferenceSession :
19+ _qaicrt = None
20+ _aicapi = None
3021
31- aic_to_np_dtype_mapping = {
32- aicapi .FLOAT_TYPE : np .dtype (np .float32 ),
33- aicapi .FLOAT_16_TYPE : np .dtype (np .float16 ),
34- aicapi .INT8_Q_TYPE : np .dtype (np .int8 ),
35- aicapi .UINT8_Q_TYPE : np .dtype (np .uint8 ),
36- aicapi .INT16_Q_TYPE : np .dtype (np .int16 ),
37- aicapi .INT32_Q_TYPE : np .dtype (np .int32 ),
38- aicapi .INT32_I_TYPE : np .dtype (np .int32 ),
39- aicapi .INT64_I_TYPE : np .dtype (np .int64 ),
40- aicapi .INT8_TYPE : np .dtype (np .int8 ),
41- }
22+ @property
23+ def qaicrt (self ):
24+ if QAICInferenceSession ._qaicrt is None :
25+ try :
26+ QAICInferenceSession ._qaicrt = importlib .import_module ("qaicrt" )
27+ except ImportError :
28+ sys .path .append (f"/opt/qti-aic/dev/lib/{ platform .machine ()} " )
29+ QAICInferenceSession ._qaicrt = importlib .import_module ("qaicrt" )
30+ return QAICInferenceSession ._qaicrt
4231
32+ @property
33+ def aicapi (self ):
34+ if QAICInferenceSession ._aicapi is None :
35+ try :
36+ QAICInferenceSession ._aicapi = importlib .import_module ("QAicApi_pb2" )
37+ except ImportError :
38+ sys .path .append ("/opt/qti-aic/dev/python" )
39+ QAICInferenceSession ._aicapi = importlib .import_module ("QAicApi_pb2" )
40+ return QAICInferenceSession ._aicapi
4341
44- class QAICInferenceSession :
4542 def __init__ (
4643 self ,
4744 qpc_path : Union [Path , str ],
@@ -58,59 +55,81 @@ def __init__(
5855 :activate: bool. If false, activation will be disabled. Default=True.
5956 :enable_debug_logs: bool. If True, It will enable debug logs. Default=False.
6057 """
58+
59+ # Build the dtype map one time, not on every property access
60+ self .aic_to_np_dtype_mapping = {
61+ self .aicapi .FLOAT_TYPE : np .dtype (np .float32 ),
62+ self .aicapi .FLOAT_16_TYPE : np .dtype (np .float16 ),
63+ self .aicapi .INT8_Q_TYPE : np .dtype (np .int8 ),
64+ self .aicapi .UINT8_Q_TYPE : np .dtype (np .uint8 ),
65+ self .aicapi .INT16_Q_TYPE : np .dtype (np .int16 ),
66+ self .aicapi .INT32_Q_TYPE : np .dtype (np .int32 ),
67+ self .aicapi .INT32_I_TYPE : np .dtype (np .int32 ),
68+ self .aicapi .INT64_I_TYPE : np .dtype (np .int64 ),
69+ self .aicapi .INT8_TYPE : np .dtype (np .int8 ),
70+ }
71+
6172 # Load QPC
6273 if device_ids is not None :
63- devices = qaicrt .QIDList (device_ids )
64- self .context = qaicrt .Context (devices )
65- self .queue = qaicrt .Queue (self .context , device_ids [0 ])
74+ devices = self . qaicrt .QIDList (device_ids )
75+ self .context = self . qaicrt .Context (devices )
76+ self .queue = self . qaicrt .Queue (self .context , device_ids [0 ])
6677 else :
67- self .context = qaicrt .Context ()
68- self .queue = qaicrt .Queue (self .context , 0 ) # Async API
78+ self .context = self .qaicrt .Context ()
79+ self .queue = self .qaicrt .Queue (self .context , 0 ) # Async API
80+
6981 if enable_debug_logs :
70- if self .context .setLogLevel (qaicrt .QLogLevel .QL_DEBUG ) != qaicrt .QStatus .QS_SUCCESS :
82+ if self .context .setLogLevel (self . qaicrt .QLogLevel .QL_DEBUG ) != self . qaicrt .QStatus .QS_SUCCESS :
7183 raise RuntimeError ("Failed to setLogLevel" )
72- qpc = qaicrt .Qpc (str (qpc_path ))
84+
85+ qpc = self .qaicrt .Qpc (str (qpc_path ))
86+
7387 # Load IO Descriptor
74- iodesc = aicapi .IoDesc ()
88+ iodesc = self . aicapi .IoDesc ()
7589 status , iodesc_data = qpc .getIoDescriptor ()
76- if status != qaicrt .QStatus .QS_SUCCESS :
90+ if status != self . qaicrt .QStatus .QS_SUCCESS :
7791 raise RuntimeError ("Failed to getIoDescriptor" )
7892 iodesc .ParseFromString (bytes (iodesc_data ))
93+
7994 self .allowed_shapes = [
80- [(aic_to_np_dtype_mapping [x .type ].itemsize , list (x .dims )) for x in allowed_shape .shapes ]
95+ [(self . aic_to_np_dtype_mapping [x .type ].itemsize , list (x .dims )) for x in allowed_shape .shapes ]
8196 for allowed_shape in iodesc .allowed_shapes
8297 ]
8398 self .bindings = iodesc .selected_set .bindings
8499 self .binding_index_map = {binding .name : binding .index for binding in self .bindings }
100+
85101 # Create and load Program
86- prog_properties = qaicrt .QAicProgramProperties ()
102+ prog_properties = self . qaicrt .QAicProgramProperties ()
87103 prog_properties .SubmitRetryTimeoutMs = 60_000
88104 if device_ids and len (device_ids ) > 1 :
89105 prog_properties .devMapping = ":" .join (map (str , device_ids ))
90- self .program = qaicrt .Program (self .context , None , qpc , prog_properties )
91- if self .program .load () != qaicrt .QStatus .QS_SUCCESS :
106+
107+ self .program = self .qaicrt .Program (self .context , None , qpc , prog_properties )
108+ if self .program .load () != self .qaicrt .QStatus .QS_SUCCESS :
92109 raise RuntimeError ("Failed to load program" )
110+
93111 if activate :
94112 self .activate ()
113+
95114 # Create input qbuffers and buf_dims
96- self .qbuffers = [qaicrt .QBuffer (bytes (binding .size )) for binding in self .bindings ]
97- self .buf_dims = qaicrt .BufferDimensionsVecRef (
98- [(aic_to_np_dtype_mapping [binding .type ].itemsize , list (binding .dims )) for binding in self .bindings ]
115+ self .qbuffers = [self . qaicrt .QBuffer (bytes (binding .size )) for binding in self .bindings ]
116+ self .buf_dims = self . qaicrt .BufferDimensionsVecRef (
117+ [(self . aic_to_np_dtype_mapping [binding .type ].itemsize , list (binding .dims )) for binding in self .bindings ]
99118 )
100119
101120 @property
102121 def input_names (self ) -> List [str ]:
103- return [binding .name for binding in self .bindings if binding .dir == aicapi .BUFFER_IO_TYPE_INPUT ]
122+ return [binding .name for binding in self .bindings if binding .dir == self . aicapi .BUFFER_IO_TYPE_INPUT ]
104123
105124 @property
106125 def output_names (self ) -> List [str ]:
107- return [binding .name for binding in self .bindings if binding .dir == aicapi .BUFFER_IO_TYPE_OUTPUT ]
126+ return [binding .name for binding in self .bindings if binding .dir == self . aicapi .BUFFER_IO_TYPE_OUTPUT ]
108127
109128 def activate (self ):
110129 """Activate qpc"""
111130
112131 self .program .activate ()
113- self .execObj = qaicrt .ExecObj (self .context , self .program )
132+ self .execObj = self . qaicrt .ExecObj (self .context , self .program )
114133
115134 def deactivate (self ):
116135 """Deactivate qpc"""
@@ -131,7 +150,7 @@ def set_buffers(self, buffers: Dict[str, np.ndarray]):
131150 warn (f'Buffer: "{ buffer_name } " not found' )
132151 continue
133152 buffer_index = self .binding_index_map [buffer_name ]
134- self .qbuffers [buffer_index ] = qaicrt .QBuffer (buffer .tobytes ())
153+ self .qbuffers [buffer_index ] = self . qaicrt .QBuffer (buffer .tobytes ())
135154 self .buf_dims [buffer_index ] = (
136155 buffer .itemsize ,
137156 buffer .shape if len (buffer .shape ) > 0 else (1 ,),
@@ -157,21 +176,19 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
157176 Return:
158177 :Dict[str, np.ndarray]:
159178 """
160- # Set inputs
179+
161180 self .set_buffers (inputs )
162- if self .execObj .setData (self .qbuffers , self .buf_dims ) != qaicrt .QStatus .QS_SUCCESS :
181+ if self .execObj .setData (self .qbuffers , self .buf_dims ) != self . qaicrt .QStatus .QS_SUCCESS :
163182 raise MemoryError ("Failed to setData" )
164- # # Run with sync API
165- # if self.execObj.run(self.qbuffers) != qaicrt.QStatus.QS_SUCCESS:
166- # Run with async API
167- if self .queue .enqueue (self .execObj ) != qaicrt .QStatus .QS_SUCCESS :
183+
184+ if self .queue .enqueue (self .execObj ) != self .qaicrt .QStatus .QS_SUCCESS :
168185 raise MemoryError ("Failed to enqueue" )
169- if self .execObj .waitForCompletion () != qaicrt .QStatus .QS_SUCCESS :
186+
187+ if self .execObj .waitForCompletion () != self .qaicrt .QStatus .QS_SUCCESS :
170188 error_message = "Failed to run"
171- # Print additional error messages for unmatched dimension error
189+
172190 if self .allowed_shapes :
173- error_message += "\n \n "
174- error_message += '(Only if "No matching dimension found" error is present above)'
191+ error_message += "\n \n (Only if 'No matching dimension found' error is present above)"
175192 error_message += "\n Allowed shapes:"
176193 for i , allowed_shape in enumerate (self .allowed_shapes ):
177194 error_message += f"\n { i } \n "
@@ -189,18 +206,18 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
189206 continue
190207 error_message += f"{ binding .name } :\t { elemsize } \t { shape } \n "
191208 raise ValueError (error_message )
192- # Get output buffers
209+
193210 status , output_qbuffers = self .execObj .getData ()
194- if status != qaicrt .QStatus .QS_SUCCESS :
211+ if status != self . qaicrt .QStatus .QS_SUCCESS :
195212 raise MemoryError ("Failed to getData" )
196- # Build output
213+
197214 outputs = {}
198215 for output_name in self .output_names :
199216 buffer_index = self .binding_index_map [output_name ]
200217 if self .qbuffers [buffer_index ].size == 0 :
201218 continue
202219 outputs [output_name ] = np .frombuffer (
203220 bytes (output_qbuffers [buffer_index ]),
204- aic_to_np_dtype_mapping [self .bindings [buffer_index ].type ],
221+ self . aic_to_np_dtype_mapping [self .bindings [buffer_index ].type ],
205222 ).reshape (self .buf_dims [buffer_index ][1 ])
206223 return outputs
0 commit comments