Skip to content
This repository was archived by the owner on Dec 2, 2025. It is now read-only.

Commit 1d3582c

Browse files
authored
Remove assumption that geometry uses real scalars (#110)
* remove real from reference_cell * default f64 * replace more real with TGeo scalar type * remove another real * remove real from bindings Also cargo fmt on other files * remove more real * clippy * fix dtypes * geo_dtype
1 parent 2c3a3bc commit 1d3582c

File tree

14 files changed

+831
-493
lines changed

14 files changed

+831
-493
lines changed

examples/element_family.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ndelement::types::{Continuity, ReferenceCellType};
55
fn main() {
66
// Create the degree 2 Lagrange element family. A family is a set of finite elements with the
77
// same family type, degree, and continuity across a set of cells
8-
let family = LagrangeElementFamily::<f64>::new(2, Continuity::Standard);
8+
let family = LagrangeElementFamily::<f64, f64>::new(2, Continuity::Standard);
99

1010
// Get the element in the family on a triangle
1111
let element = family.element(ReferenceCellType::Triangle);

examples/lagrange_element.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use rlst::{DynArray, rlst_dynamic_array};
77

88
fn main() {
99
// Create a P2 element on a triangle
10-
let element = lagrange::create::<f64>(ReferenceCellType::Triangle, 2, Continuity::Standard);
10+
let element =
11+
lagrange::create::<f64, f64>(ReferenceCellType::Triangle, 2, Continuity::Standard);
1112

1213
println!("This element has {} basis functions.", element.dim());
1314

examples/test_high_degree.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ fn main() {
1010
paste! {
1111
for d in 1..[<$max_degree>] {
1212
println!("Constructing Lagrange(degree={d}, cell={:?})", ReferenceCellType::[<$cell>]);
13-
let _e = lagrange::create::<f64>(ReferenceCellType::[<$cell>], d, Continuity::Standard);
13+
let _e = lagrange::create::<f64, f64>(ReferenceCellType::[<$cell>], d, Continuity::Standard);
1414
}
1515
}
1616
};
@@ -21,7 +21,7 @@ fn main() {
2121
paste! {
2222
for d in 1..[<$max_degree>] {
2323
println!("Constructing RaviartThomas(degree={d}, cell={:?})", ReferenceCellType::[<$cell>]);
24-
let _e = raviart_thomas::create::<f64>(ReferenceCellType::[<$cell>], d, Continuity::Standard);
24+
let _e = raviart_thomas::create::<f64, f64>(ReferenceCellType::[<$cell>], d, Continuity::Standard);
2525
}
2626
}
2727
};
@@ -32,7 +32,7 @@ fn main() {
3232
paste! {
3333
for d in 1..[<$max_degree>] {
3434
println!("Constructing Nedelec(degree={d}, cell={:?})", ReferenceCellType::[<$cell>]);
35-
let _e = nedelec::create::<f64>(ReferenceCellType::[<$cell>], d, Continuity::Standard);
35+
let _e = nedelec::create::<f64, f64>(ReferenceCellType::[<$cell>], d, Continuity::Standard);
3636
}
3737
}
3838
};

python/ndelement/ciarlet.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def dtype(self) -> typing.Type[np.floating]:
5252
"""Data type."""
5353
return _dtypes[_lib.ciarlet_element_dtype(self._rs_element)]
5454

55+
@property
56+
def geo_dtype(self) -> typing.Type[np.floating]:
57+
"""Data type."""
58+
return _dtypes[_lib.ciarlet_element_geo_dtype(self._rs_element)]
59+
5560
@property
5661
def value_size(self) -> int:
5762
"""Value size of the element."""
@@ -123,7 +128,7 @@ def interpolation_points(self) -> list[list[npt.NDArray]]:
123128
points_d = []
124129
for i in range(n):
125130
shape = (_lib.ciarlet_element_interpolation_npoints(self._rs_element, d, i), tdim)
126-
points_di = np.empty(shape, dtype=self.dtype(0).real.dtype)
131+
points_di = np.empty(shape, dtype=self.geo_dtype)
127132
_lib.ciarlet_element_interpolation_points(
128133
self._rs_element, d, i, _ffi.cast("void*", points_di.ctypes.data)
129134
)
@@ -152,20 +157,31 @@ def interpolation_weights(self) -> list[list[npt.NDArray]]:
152157

153158
def tabulate(self, points: npt.NDArray[np.floating], nderivs: int) -> npt.NDArray:
154159
"""Tabulate the basis functions at a set of points."""
155-
if points.dtype != self.dtype(0).real.dtype:
156-
raise TypeError("points has incorrect type")
157160
shape = np.empty(4, dtype=np.uintp)
158161
_lib.ciarlet_element_tabulate_array_shape(
159162
self._rs_element, nderivs, points.shape[0], _ffi.cast("uintptr_t*", shape.ctypes.data)
160163
)
161164
data = np.empty(shape[::-1], dtype=self.dtype)
162-
_lib.ciarlet_element_tabulate(
163-
self._rs_element,
164-
_ffi.cast("void*", points.ctypes.data),
165-
points.shape[0],
166-
nderivs,
167-
_ffi.cast("void*", data.ctypes.data),
168-
)
165+
166+
if points.dtype == np.float64:
167+
_lib.ciarlet_element_tabulate_f64(
168+
self._rs_element,
169+
_ffi.cast("double*", points.ctypes.data),
170+
points.shape[0],
171+
nderivs,
172+
_ffi.cast("void*", data.ctypes.data),
173+
)
174+
elif points.dtype == np.float32:
175+
_lib.ciarlet_element_tabulate_f32(
176+
self._rs_element,
177+
_ffi.cast("float*", points.ctypes.data),
178+
points.shape[0],
179+
nderivs,
180+
_ffi.cast("void*", data.ctypes.data),
181+
)
182+
else:
183+
raise TypeError(f"Unsupported dtype: {points.dtype}")
184+
169185
return data
170186

171187
def physical_value_size(self, gdim: int) -> int:
@@ -189,13 +205,12 @@ def push_forward(
189205
inverse_jacobians: npt.NDArray[np.floating],
190206
) -> npt.NDArray[np.floating]:
191207
"""Push values forward to a physical cell."""
192-
if reference_values.dtype != self.dtype(0):
193-
raise TypeError("reference_values has incorrect type")
194-
if jacobians.dtype != self.dtype(0).real.dtype:
208+
geo_dtype = reference_values.dtype
209+
if jacobians.dtype != geo_dtype:
195210
raise TypeError("jacobians has incorrect type")
196-
if jacobian_determinants.dtype != self.dtype(0).real.dtype:
211+
if jacobian_determinants.dtype != geo_dtype:
197212
raise TypeError("jacobian_determinants has incorrect type")
198-
if inverse_jacobians.dtype != self.dtype(0).real.dtype:
213+
if inverse_jacobians.dtype != geo_dtype:
199214
raise TypeError("inverse_jacobians has incorrect type")
200215

201216
gdim = jacobians.shape[1]
@@ -205,18 +220,29 @@ def push_forward(
205220

206221
shape = (pvs,) + reference_values.shape[1:]
207222
data = np.empty(shape, dtype=self.dtype)
208-
_lib.ciarlet_element_push_forward(
223+
224+
if reference_values.dtype == np.float64:
225+
push_function = _lib.ciarlet_element_push_forward_f64
226+
geo_type = "double*"
227+
elif reference_values.dtype == np.float32:
228+
push_function = _lib.ciarlet_element_push_forward_f32
229+
geo_type = "float*"
230+
else:
231+
raise TypeError(f"Unsupported dtype: {reference_values.dtype}")
232+
233+
push_function(
209234
self._rs_element,
210235
npts,
211236
nfuncs,
212237
gdim,
213-
_ffi.cast("void*", reference_values.ctypes.data),
238+
_ffi.cast(geo_type, reference_values.ctypes.data),
214239
nderivs,
215-
_ffi.cast("void*", jacobians.ctypes.data),
216-
_ffi.cast("void*", jacobian_determinants.ctypes.data),
217-
_ffi.cast("void*", inverse_jacobians.ctypes.data),
240+
_ffi.cast(geo_type, jacobians.ctypes.data),
241+
_ffi.cast(geo_type, jacobian_determinants.ctypes.data),
242+
_ffi.cast(geo_type, inverse_jacobians.ctypes.data),
218243
_ffi.cast("void*", data.ctypes.data),
219244
)
245+
220246
return data
221247

222248
def pull_back(
@@ -228,13 +254,12 @@ def pull_back(
228254
inverse_jacobians: npt.NDArray[np.floating],
229255
) -> npt.NDArray[np.floating]:
230256
"""Push values back from a physical cell."""
231-
if physical_values.dtype != self.dtype(0):
232-
raise TypeError("physical_values has incorrect type")
233-
if jacobians.dtype != self.dtype(0).real.dtype:
257+
geo_dtype = physical_values.dtype
258+
if jacobians.dtype != geo_dtype:
234259
raise TypeError("jacobians has incorrect type")
235-
if jacobian_determinants.dtype != self.dtype(0).real.dtype:
260+
if jacobian_determinants.dtype != geo_dtype:
236261
raise TypeError("jacobian_determinants has incorrect type")
237-
if inverse_jacobians.dtype != self.dtype(0).real.dtype:
262+
if inverse_jacobians.dtype != geo_dtype:
238263
raise TypeError("inverse_jacobians has incorrect type")
239264

240265
gdim = jacobians.shape[1]
@@ -244,18 +269,29 @@ def pull_back(
244269

245270
shape = (vs,) + physical_values.shape[1:]
246271
data = np.empty(shape, dtype=self.dtype)
247-
_lib.ciarlet_element_pull_back(
272+
273+
if physical_values.dtype == np.float64:
274+
pull_function = _lib.ciarlet_element_pull_back_f64
275+
geo_type = "double*"
276+
elif physical_values.dtype == np.float32:
277+
pull_function = _lib.ciarlet_element_pull_back_f32
278+
geo_type = "float*"
279+
else:
280+
raise TypeError(f"Unsupported dtype: {physical_values.dtype}")
281+
282+
pull_function(
248283
self._rs_element,
249284
npts,
250285
nfuncs,
251286
gdim,
252-
_ffi.cast("void*", physical_values.ctypes.data),
287+
_ffi.cast(geo_type, physical_values.ctypes.data),
253288
nderivs,
254-
_ffi.cast("void*", jacobians.ctypes.data),
255-
_ffi.cast("void*", jacobian_determinants.ctypes.data),
256-
_ffi.cast("void*", inverse_jacobians.ctypes.data),
289+
_ffi.cast(geo_type, jacobians.ctypes.data),
290+
_ffi.cast(geo_type, jacobian_determinants.ctypes.data),
291+
_ffi.cast(geo_type, inverse_jacobians.ctypes.data),
257292
_ffi.cast("void*", data.ctypes.data),
258293
)
294+
259295
return data
260296

261297
def apply_dof_permutations(self, data: npt.NDArray, orientation: np.int32):

python/test/test_element.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,6 @@ def test_value_size(cell, degree):
2727
assert element.value_size == 1
2828

2929

30-
@pytest.mark.parametrize(
31-
"dt0,dt1",
32-
[
33-
(np.float32, np.float64),
34-
(np.float64, np.float32),
35-
(np.complex64, np.float64),
36-
(np.complex128, np.float32),
37-
]
38-
+ [(dt0, dt1) for dt0 in dtypes for dt1 in [np.complex64, np.complex128]],
39-
)
40-
def test_incompatible_types(dt0, dt1):
41-
family = create_family(Family.Lagrange, 2, dtype=dt0)
42-
element = family.element(ReferenceCellType.Triangle)
43-
points = np.array([[0.0, 0.0], [0.2, 0.1], [0.8, 0.05]], dtype=dt1)
44-
45-
with pytest.raises(TypeError):
46-
element.tabulate(points, 1)
47-
48-
4930
@pytest.mark.parametrize("dtype", dtypes)
5031
def test_lagrange_2_triangle_tabulate(dtype):
5132
family = create_family(Family.Lagrange, 2, dtype=dtype)

0 commit comments

Comments
 (0)