@@ -46,21 +46,23 @@ void overdrive_core_loop_cpu(
4646 Tensor last_in,
4747 const Tensor last_out,
4848 Tensor output_waveform) {
49- if (waveform.dtype () == aoti_torch_dtype_float64 ()) {
49+ int32_t dtype;
50+ aoti_torch_get_dtype (waveform.get (), &dtype);
51+ if (dtype == aoti_torch_dtype_float64 ()) {
5052 overdrive_cpu_kernel<double >(
5153 Accessor<2 , double >(wave_acc),
5254 Accessor<2 , double >(temp_acc),
5355 Accessor<1 , double >(last_in_acc),
5456 Accessor<1 , double >(last_out_acc),
5557 Accessor<2 , double >(out_acc));
56- } else if (waveform. dtype () == aoti_torch_dtype_float32 ()) {
58+ } else if (dtype == aoti_torch_dtype_float32 ()) {
5759 overdrive_cpu_kernel<float >(
5860 Accessor<2 , float >(wave_acc),
5961 Accessor<2 , float >(temp_acc),
6062 Accessor<1 , float >(last_in_acc),
6163 Accessor<1 , float >(last_out_acc),
6264 Accessor<2 , float >(out_acc));
63- } else if (waveform. dtype () == aoti_torch_dtype_float16 ()) {
65+ } else if (dtype == aoti_torch_dtype_float16 ()) {
6466 overdrive_cpu_kernel<c10::Half>(
6567 Accessor<2 , c10::Half>(wave_acc),
6668 Accessor<2 , c10::Half>(temp_acc),
0 commit comments