66// Must be called with at least `max(stride_b * num_rows, krylov_dim *
77// num_cols)` threads in total.
88template <size_type block_size, typename ValueType>
9- void initialize_kernel (
10- size_type num_rows, size_type num_cols, size_type krylov_dim,
11- const ValueType *__restrict__ b, size_type stride_b,
12- ValueType *__restrict__ residual, size_type stride_residual,
13- ValueType *__restrict__ givens_sin, size_type stride_sin,
14- ValueType *__restrict__ givens_cos, size_type stride_cos,
15- stopping_status *__restrict__ stop_status, sycl::nd_item<3 > item_ct1)
9+ void initialize_kernel (size_type num_rows, size_type num_cols,
10+ size_type krylov_dim, const ValueType* __restrict__ b,
11+ size_type stride_b, ValueType* __restrict__ residual,
12+ size_type stride_residual,
13+ ValueType* __restrict__ givens_sin, size_type stride_sin,
14+ ValueType* __restrict__ givens_cos, size_type stride_cos,
15+ stopping_status* __restrict__ stop_status,
16+ sycl::nd_item<3 > item_ct1)
1617{
1718 const auto global_id = thread::get_thread_id_flat (item_ct1);
1819
@@ -39,15 +40,15 @@ void initialize_kernel(
3940
4041template <size_type block_size, typename ValueType>
4142void initialize_kernel (dim3 grid, dim3 block, size_type dynamic_shared_memory,
42- sycl::queue * queue, size_type num_rows,
43- size_type num_cols, size_type krylov_dim,
44- const ValueType * b, size_type stride_b,
45- ValueType * residual, size_type stride_residual,
46- ValueType * givens_sin, size_type stride_sin,
47- ValueType * givens_cos, size_type stride_cos,
48- stopping_status * stop_status)
43+ sycl::queue* queue, size_type num_rows,
44+ size_type num_cols, size_type krylov_dim,
45+ const ValueType* b, size_type stride_b,
46+ ValueType* residual, size_type stride_residual,
47+ ValueType* givens_sin, size_type stride_sin,
48+ ValueType* givens_cos, size_type stride_cos,
49+ stopping_status* stop_status)
4950{
50- queue->submit ([&](sycl::handler & cgh) {
51+ queue->submit ([&](sycl::handler& cgh) {
5152 cgh.parallel_for (
5253 sycl_nd_range (grid, block), [=](sycl::nd_item<3 > item_ct1) {
5354 initialize_kernel<block_size>(
@@ -61,12 +62,12 @@ void initialize_kernel(dim3 grid, dim3 block, size_type dynamic_shared_memory,
6162
6263template <typename ValueType>
6364void calculate_sin_and_cos_kernel (size_type col_idx, size_type num_cols,
64- size_type iter, const ValueType & this_hess,
65- const ValueType & next_hess,
66- ValueType * givens_sin, size_type stride_sin,
67- ValueType * givens_cos, size_type stride_cos,
68- ValueType & register_sin,
69- ValueType & register_cos)
65+ size_type iter, const ValueType& this_hess,
66+ const ValueType& next_hess,
67+ ValueType* givens_sin, size_type stride_sin,
68+ ValueType* givens_cos, size_type stride_cos,
69+ ValueType& register_sin,
70+ ValueType& register_cos)
7071{
7172 if (is_zero (this_hess)) {
7273 register_cos = zero<ValueType>();
@@ -89,10 +90,10 @@ void calculate_sin_and_cos_kernel(size_type col_idx, size_type num_cols,
8990template <typename ValueType>
9091void calculate_residual_norm_kernel (size_type col_idx, size_type num_cols,
9192 size_type iter,
92- const ValueType & register_sin,
93- const ValueType & register_cos,
94- remove_complex<ValueType> * residual_norm,
95- ValueType * residual_norm_collection,
93+ const ValueType& register_sin,
94+ const ValueType& register_cos,
95+ remove_complex<ValueType>* residual_norm,
96+ ValueType* residual_norm_collection,
9697 size_type stride_residual_norm_collection)
9798{
9899 const auto this_rnc =
@@ -112,13 +113,13 @@ void calculate_residual_norm_kernel(size_type col_idx, size_type num_cols,
112113template <size_type block_size, typename ValueType>
113114void givens_rotation_kernel (
114115 size_type num_rows, size_type num_cols, size_type iter,
115- ValueType * __restrict__ hessenberg_iter, size_type stride_hessenberg,
116- ValueType * __restrict__ givens_sin, size_type stride_sin,
117- ValueType * __restrict__ givens_cos, size_type stride_cos,
118- remove_complex<ValueType> * __restrict__ residual_norm,
119- ValueType * __restrict__ residual_norm_collection,
116+ ValueType* __restrict__ hessenberg_iter, size_type stride_hessenberg,
117+ ValueType* __restrict__ givens_sin, size_type stride_sin,
118+ ValueType* __restrict__ givens_cos, size_type stride_cos,
119+ remove_complex<ValueType>* __restrict__ residual_norm,
120+ ValueType* __restrict__ residual_norm_collection,
120121 size_type stride_residual_norm_collection,
121- const stopping_status * __restrict__ stop_status, sycl::nd_item<3 > item_ct1)
122+ const stopping_status* __restrict__ stop_status, sycl::nd_item<3 > item_ct1)
122123{
123124 const auto col_idx = thread::get_thread_id_flat (item_ct1);
124125
@@ -167,18 +168,18 @@ void givens_rotation_kernel(
167168
168169template <size_type block_size, typename ValueType>
169170void givens_rotation_kernel (dim3 grid, dim3 block,
170- size_type dynamic_shared_memory, sycl::queue * queue,
171+ size_type dynamic_shared_memory, sycl::queue* queue,
171172 size_type num_rows, size_type num_cols,
172- size_type iter, ValueType * hessenberg_iter,
173- size_type stride_hessenberg, ValueType * givens_sin,
174- size_type stride_sin, ValueType * givens_cos,
173+ size_type iter, ValueType* hessenberg_iter,
174+ size_type stride_hessenberg, ValueType* givens_sin,
175+ size_type stride_sin, ValueType* givens_cos,
175176 size_type stride_cos,
176- remove_complex<ValueType> * residual_norm,
177- ValueType * residual_norm_collection,
177+ remove_complex<ValueType>* residual_norm,
178+ ValueType* residual_norm_collection,
178179 size_type stride_residual_norm_collection,
179- const stopping_status * stop_status)
180+ const stopping_status* stop_status)
180181{
181- queue->submit ([&](sycl::handler & cgh) {
182+ queue->submit ([&](sycl::handler& cgh) {
182183 cgh.parallel_for (
183184 sycl_nd_range (grid, block), [=](sycl::nd_item<3 > item_ct1) {
184185 givens_rotation_kernel<block_size>(
@@ -195,11 +196,11 @@ void givens_rotation_kernel(dim3 grid, dim3 block,
195196template <size_type block_size, typename ValueType>
196197void solve_upper_triangular_kernel (
197198 size_type num_cols, size_type num_rhs,
198- const ValueType * __restrict__ residual_norm_collection,
199+ const ValueType* __restrict__ residual_norm_collection,
199200 size_type stride_residual_norm_collection,
200- const ValueType * __restrict__ hessenberg, size_type stride_hessenberg,
201- ValueType * __restrict__ y, size_type stride_y,
202- const size_type * __restrict__ final_iter_nums, sycl::nd_item<3 > item_ct1)
201+ const ValueType* __restrict__ hessenberg, size_type stride_hessenberg,
202+ ValueType* __restrict__ y, size_type stride_y,
203+ const size_type* __restrict__ final_iter_nums, sycl::nd_item<3 > item_ct1)
203204{
204205 const auto col_idx = thread::get_thread_id_flat (item_ct1);
205206
@@ -225,14 +226,14 @@ void solve_upper_triangular_kernel(
225226
226227template <size_type block_size, typename ValueType>
227228void solve_upper_triangular_kernel (
228- dim3 grid, dim3 block, size_type dynamic_shared_memory, sycl::queue * queue,
229+ dim3 grid, dim3 block, size_type dynamic_shared_memory, sycl::queue* queue,
229230 size_type num_cols, size_type num_rhs,
230- const ValueType * residual_norm_collection,
231- size_type stride_residual_norm_collection, const ValueType * hessenberg,
232- size_type stride_hessenberg, ValueType * y, size_type stride_y,
233- const size_type * final_iter_nums)
231+ const ValueType* residual_norm_collection,
232+ size_type stride_residual_norm_collection, const ValueType* hessenberg,
233+ size_type stride_hessenberg, ValueType* y, size_type stride_y,
234+ const size_type* final_iter_nums)
234235{
235- queue->submit ([&](sycl::handler & cgh) {
236+ queue->submit ([&](sycl::handler& cgh) {
236237 cgh.parallel_for (
237238 sycl_nd_range (grid, block), [=](sycl::nd_item<3 > item_ct1) {
238239 solve_upper_triangular_kernel<block_size>(
0 commit comments