|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 |
|
17 | | -import numbers |
18 | | - |
19 | | -import numpy as np |
20 | | - |
21 | 17 | import dpctl |
22 | | -import dpctl.memory as dpm |
23 | 18 | import dpctl.tensor as dpt |
24 | 19 | import dpctl.tensor._tensor_impl as ti |
25 | 20 | from dpctl.tensor._manipulation_functions import _broadcast_shape_impl |
26 | | -from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer |
27 | 21 | from dpctl.utils import ExecutionPlacementError, SequentialOrderManager |
28 | 22 |
|
29 | 23 | from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK |
| 24 | +from ._scalar_utils import ( |
| 25 | + _get_dtype, |
| 26 | + _get_queue_usm_type, |
| 27 | + _get_shape, |
| 28 | + _validate_dtype, |
| 29 | +) |
30 | 30 | from ._type_utils import ( |
31 | | - WeakBooleanType, |
32 | | - WeakComplexType, |
33 | | - WeakFloatingType, |
34 | | - WeakIntegralType, |
35 | 31 | _acceptance_fn_default_binary, |
36 | 32 | _acceptance_fn_default_unary, |
37 | 33 | _all_data_types, |
38 | 34 | _find_buf_dtype, |
39 | 35 | _find_buf_dtype2, |
40 | 36 | _find_buf_dtype_in_place_op, |
41 | 37 | _resolve_weak_types, |
42 | | - _to_device_supported_dtype, |
43 | 38 | ) |
44 | 39 |
|
45 | 40 |
|
@@ -289,78 +284,6 @@ def __call__(self, x, /, *, out=None, order="K"): |
289 | 284 | return out |
290 | 285 |
|
291 | 286 |
|
292 | | -def _get_queue_usm_type(o): |
293 | | - """Return SYCL device where object `o` allocated memory, or None.""" |
294 | | - if isinstance(o, dpt.usm_ndarray): |
295 | | - return o.sycl_queue, o.usm_type |
296 | | - elif hasattr(o, "__sycl_usm_array_interface__"): |
297 | | - try: |
298 | | - m = dpm.as_usm_memory(o) |
299 | | - return m.sycl_queue, m.get_usm_type() |
300 | | - except Exception: |
301 | | - return None, None |
302 | | - return None, None |
303 | | - |
304 | | - |
305 | | -def _get_dtype(o, dev): |
306 | | - if isinstance(o, dpt.usm_ndarray): |
307 | | - return o.dtype |
308 | | - if hasattr(o, "__sycl_usm_array_interface__"): |
309 | | - return dpt.asarray(o).dtype |
310 | | - if _is_buffer(o): |
311 | | - host_dt = np.array(o).dtype |
312 | | - dev_dt = _to_device_supported_dtype(host_dt, dev) |
313 | | - return dev_dt |
314 | | - if hasattr(o, "dtype"): |
315 | | - dev_dt = _to_device_supported_dtype(o.dtype, dev) |
316 | | - return dev_dt |
317 | | - if isinstance(o, bool): |
318 | | - return WeakBooleanType(o) |
319 | | - if isinstance(o, int): |
320 | | - return WeakIntegralType(o) |
321 | | - if isinstance(o, float): |
322 | | - return WeakFloatingType(o) |
323 | | - if isinstance(o, complex): |
324 | | - return WeakComplexType(o) |
325 | | - return np.object_ |
326 | | - |
327 | | - |
328 | | -def _validate_dtype(dt) -> bool: |
329 | | - return isinstance( |
330 | | - dt, |
331 | | - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), |
332 | | - ) or ( |
333 | | - isinstance(dt, dpt.dtype) |
334 | | - and dt |
335 | | - in [ |
336 | | - dpt.bool, |
337 | | - dpt.int8, |
338 | | - dpt.uint8, |
339 | | - dpt.int16, |
340 | | - dpt.uint16, |
341 | | - dpt.int32, |
342 | | - dpt.uint32, |
343 | | - dpt.int64, |
344 | | - dpt.uint64, |
345 | | - dpt.float16, |
346 | | - dpt.float32, |
347 | | - dpt.float64, |
348 | | - dpt.complex64, |
349 | | - dpt.complex128, |
350 | | - ] |
351 | | - ) |
352 | | - |
353 | | - |
354 | | -def _get_shape(o): |
355 | | - if isinstance(o, dpt.usm_ndarray): |
356 | | - return o.shape |
357 | | - if _is_buffer(o): |
358 | | - return memoryview(o).shape |
359 | | - if isinstance(o, numbers.Number): |
360 | | - return tuple() |
361 | | - return getattr(o, "shape", tuple()) |
362 | | - |
363 | | - |
364 | 287 | class BinaryElementwiseFunc: |
365 | 288 | """ |
366 | 289 | Class that implements binary element-wise functions. |
|
0 commit comments