Skip to content

Commit 95238df

Browse files
committed
SYCL_KHR_GROUP_INTERFACE prototype
1 parent 8db38cf commit 95238df

File tree

3 files changed

+413
-0
lines changed

3 files changed

+413
-0
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#include "sycl/sycl.hpp" // IWYU pragma: keep
2+
3+
#define SYCL_KHR_GROUP_INTERFACE 1
4+
5+
namespace simsycl::sycl::khr {
6+
7+
template<typename ParentGroup>
8+
class member_item {
9+
public:
10+
using id_type = typename ParentGroup::id_type;
11+
using linear_id_type = typename ParentGroup::linear_id_type;
12+
using range_type = typename ParentGroup::range_type;
13+
// using extents_type = /* extents of all 1s with ParentGroup's index type */; // C++23
14+
using size_type = typename ParentGroup::size_type;
15+
static constexpr int dimensions = ParentGroup::dimensions;
16+
static constexpr memory_scope fence_scope = memory_scope::work_item;
17+
18+
/* -- common by-value interface members -- */
19+
20+
id_type id() const noexcept { return m_parent_group.get_local_id(); }
21+
linear_id_type linear_id() const noexcept { return m_parent_group.get_local_linear_id(); }
22+
23+
range_type range() const noexcept { return m_parent_group.get_local_range(); }
24+
25+
// constexpr extents_type extents() const noexcept; // C++23
26+
// constexpr extents_type::index_type extent(extents_type::rank_type r) const noexcept; // C++23
27+
28+
// static constexpr extents_type::rank_type rank() noexcept; // C++23
29+
// static constexpr extents_type::rank_type rank_dynamic() noexcept; // C++23
30+
// static constexpr size_t static_extent(rank_type r) noexcept; // C++23
31+
32+
constexpr size_type size() const noexcept { return 1; }
33+
34+
private:
35+
ParentGroup m_parent_group;
36+
member_item(ParentGroup g) noexcept : m_parent_group(g) {}
37+
38+
linear_id_type get_local_linear_id() const noexcept { return m_parent_group.get_local_linear_id(); }
39+
40+
template<typename Group>
41+
friend member_item<Group> get_member_item(Group g) noexcept;
42+
template<typename Group>
43+
friend bool leader_of(Group g) noexcept;
44+
};
45+
46+
template<int Dimensions = 1>
47+
class work_group {
48+
public:
49+
using id_type = sycl::id<Dimensions>;
50+
using linear_id_type = size_t;
51+
using range_type = sycl::range<Dimensions>;
52+
// using extents_type = std::dextents<size_t, Dimensions>; // C++23
53+
using size_type = size_t;
54+
static constexpr int dimensions = Dimensions;
55+
static constexpr memory_scope fence_scope = memory_scope::work_group;
56+
57+
work_group(group<Dimensions> g) noexcept : m_group(g) {}
58+
59+
operator group<Dimensions>() const noexcept { return m_group; }
60+
61+
/* -- common by-value interface members -- */
62+
63+
id_type id() const noexcept { return m_group.get_group_id(); }
64+
linear_id_type linear_id() const noexcept { return m_group.get_group_linear_id(); }
65+
66+
range_type range() const noexcept { return m_group.get_group_range(); }
67+
68+
// extents_type extents() const noexcept; // C++23
69+
// extents_type::index_type extent(extents_type::rank_type r) const noexcept; // C++23
70+
71+
// static constexpr extents_type::rank_type rank() noexcept; // C++23
72+
// static constexpr extents_type::rank_type rank_dynamic() noexcept; // C++23
73+
// static constexpr size_t static_extent(rank_type r) noexcept; // C++23
74+
75+
size_type size() const noexcept { return m_group.get_local_range().size(); }
76+
77+
private:
78+
group<Dimensions> m_group;
79+
80+
id_type get_local_id() const noexcept { return m_group.get_local_id(); }
81+
linear_id_type get_local_linear_id() const noexcept { return m_group.get_local_linear_id(); }
82+
range_type get_local_range() const noexcept { return m_group.get_local_range(); }
83+
friend class member_item<work_group>;
84+
template<typename Group>
85+
friend bool leader_of(Group g) noexcept;
86+
};
87+
88+
class sub_group {
89+
public:
90+
using id_type = sycl::id<1>;
91+
using linear_id_type = uint32_t;
92+
using range_type = sycl::range<1>;
93+
// using extents_type = std::dextents<uint32_t, 1>; // C++23
94+
using size_type = uint32_t;
95+
static constexpr int dimensions = 1;
96+
static constexpr memory_scope fence_scope = memory_scope::sub_group;
97+
98+
sub_group(sycl::sub_group sg) noexcept : m_sub_group(sg) {}
99+
100+
operator sycl::sub_group() const noexcept { return m_sub_group; }
101+
102+
/* -- common by-value interface members -- */
103+
104+
id_type id() const noexcept { return m_sub_group.get_group_id(); }
105+
linear_id_type linear_id() const noexcept { return m_sub_group.get_group_linear_id(); }
106+
107+
range_type range() const noexcept { return m_sub_group.get_group_range(); }
108+
109+
// extents_type extents() const noexcept; // C++23
110+
// extents_type::index_type extent(extents_type::rank_type r) const noexcept; // C++23
111+
112+
// static constexpr extents_type::rank_type rank() noexcept; // C++23
113+
// static constexpr extents_type::rank_type rank_dynamic() noexcept; // C++23
114+
// static constexpr size_t static_extent(rank_type r) noexcept; // C++23
115+
116+
size_type size() const noexcept { return m_sub_group.get_local_range().size(); }
117+
size_type max_size() const noexcept { return m_sub_group.get_max_local_range().size(); }
118+
119+
private:
120+
sycl::sub_group m_sub_group;
121+
122+
id_type get_local_id() const noexcept { return m_sub_group.get_local_id(); }
123+
linear_id_type get_local_linear_id() const noexcept { return m_sub_group.get_local_linear_id(); }
124+
range_type get_local_range() const noexcept { return m_sub_group.get_local_range(); }
125+
friend class member_item<sub_group>;
126+
template<typename Group>
127+
friend bool leader_of(Group g) noexcept;
128+
};
129+
130+
template<typename Group>
131+
member_item<Group> get_member_item(Group g) noexcept {
132+
return member_item<Group>(g);
133+
}
134+
135+
template<typename Group>
136+
bool leader_of(Group g) noexcept {
137+
return g.get_local_linear_id() == 0;
138+
}
139+
140+
} // namespace simsycl::sycl::khr

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_executable(tests
2222
simulation_tests.cc
2323
alloc_tests.cc
2424
vec_tests.cc
25+
extensions/khr_group_interface_tests.cc
2526
)
2627

2728
add_sycl_to_target(TARGET tests SIMSYCL_ALL_WARNINGS)

0 commit comments

Comments
 (0)