Skip to content
103 changes: 70 additions & 33 deletions src/scifem/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,37 +55,57 @@ def prepare_interpolation_data(
array_evaluated = compiled_expr.eval(mesh, np.arange(num_cells, dtype=np.int32))
assert np.prod(Q.value_shape) == np.prod(expr.ufl_shape)

im = Q.element.basix_element.interpolation_matrix

# Get data as (num_cells*num_points,1, expr_shape, num_test_basis_functions*test_block_size)
expr_size = int(np.prod(expr.ufl_shape))
array_evaluated = array_evaluated.reshape(
num_cells * q_points.shape[0], 1, expr_size, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
)
jacobian = dolfinx.fem.Expression(ufl.Jacobian(mesh), q_points)
detJ = dolfinx.fem.Expression(ufl.JacobianDeterminant(mesh), q_points)
K = dolfinx.fem.Expression(ufl.JacobianInverse(mesh), q_points)
jacs = jacobian.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape(
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
)
detJs = detJ.eval(mesh, np.arange(num_cells, dtype=np.int32)).flatten()
Ks = K.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape(
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
)

Q_vs = Q.element.basix_element.value_size
# Check if we are dealing with a quadrature element or not.
# They do not have a complete DOLFINx API, which makes them tricky to use.
try:
basix_el = Q.element.basix_element
Q_vs = basix_el.value_size
pull_back = basix_el.pull_back
im = basix_el.interpolation_matrix
except RuntimeError:
Q_vs = 1 # If we do not have a basix element, assume value size is 1
assert isinstance(Q.ufl_element().pullback, ufl.pullback.IdentityPullback)
pull_back = lambda x: None
assert Q.element.interpolation_ident
im = None

new_array = np.zeros(
(num_cells * num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs),
dtype=np.float64,
)
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
for q in range(Q.dofmap.bs):
new_array[:, q * Q_vs : (q + 1) * Q_vs, i] = Q.element.basix_element.pull_back(
array_evaluated[:, :, q * Q_vs : (q + 1) * Q_vs, i], jacs, detJs, Ks
).reshape(num_cells * num_points, Q_vs)
new_array = new_array.reshape(
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
)

# Check if pullback is identity, then we can skip this step
if not isinstance(Q.ufl_element().pullback, ufl.pullback.IdentityPullback):
jacobian = dolfinx.fem.Expression(ufl.Jacobian(mesh), q_points)
detJ = dolfinx.fem.Expression(ufl.JacobianDeterminant(mesh), q_points)
K = dolfinx.fem.Expression(ufl.JacobianInverse(mesh), q_points)
jacs = jacobian.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape(
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
)
detJs = detJ.eval(mesh, np.arange(num_cells, dtype=np.int32)).flatten()
Ks = K.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape(
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
)

for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
for q in range(Q.dofmap.bs):
new_array[:, q * Q_vs : (q + 1) * Q_vs, i] = pull_back(
array_evaluated[:, :, q * Q_vs : (q + 1) * Q_vs, i], jacs, detJs, Ks
).reshape(num_cells * num_points, Q_vs)
new_array = new_array.reshape(
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
)
else:
new_array = array_evaluated.reshape(
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
)

interpolated_matrix = np.zeros(
(
num_cells,
Expand All @@ -94,19 +114,36 @@ def prepare_interpolation_data(
),
dtype=np.float64,
)
# Check if interpolation matrix of dual operator is identity, then we can use a vectorized
# version of this step
if Q.element.interpolation_ident:
# Smart vectorized version with identity mapping
if Q.dofmap.bs == 1:
interpolated_matrix = new_array.transpose(0, 2, 1, 3).reshape(
new_array.shape[0], new_array.shape[1] * new_array.shape[2], new_array.shape[3]
)
else:
i_scalar = new_array.transpose(0, 2, 1, 3)
interpolated_matrix = np.zeros(
(new_array.shape[0], new_array.shape[1] * new_array.shape[2], new_array.shape[3])
)
for q in range(Q.dofmap.bs):
interpolated_matrix[:, q :: Q.dofmap.bs, :] = i_scalar[:, q, :, :]

for c in range(num_cells):
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
tmp_array = np.zeros((int(num_points), Q.dofmap.bs * Q_vs), dtype=np.float64)
for p in range(num_points):
tmp_array[p] = new_array[c, p, :, i]
if Q.dofmap.bs == 1:
interpolated_matrix[c, :, i] = (im @ tmp_array.T.flatten()).flatten()
else:
for q in range(Q.dofmap.bs):
interpolated_matrix[c, q :: Q.dofmap.bs, i] = (
im @ tmp_array.T[q].flatten()
).flatten()
else:
# Tedious non-identity version
for c in range(num_cells):
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
tmp_array = np.zeros((int(num_points), Q.dofmap.bs * Q_vs), dtype=np.float64)
for p in range(num_points):
tmp_array[p] = new_array[c, p, :, i]
if Q.dofmap.bs == 1:
interpolated_matrix[c, :, i] = (im @ tmp_array.T.flatten()).flatten()
else:
for q in range(Q.dofmap.bs):
interpolated_matrix[c, q :: Q.dofmap.bs, i] = (
im @ tmp_array.T[q].flatten()
).flatten()
# Apply dof transformation to each column (using Piopla maps)
mesh.topology.create_entity_permutations()
if Q.element.needs_dof_transformations:
Expand Down
36 changes: 30 additions & 6 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import ufl
import numpy as np
import basix.ufl


@pytest.mark.skipif(
Expand All @@ -20,8 +21,10 @@
],
)
@pytest.mark.parametrize("use_petsc", [True, False])
@pytest.mark.parametrize("degree", [1, 3, 5])
def test_interpolation_matrix(use_petsc, cell_type, degree):
@pytest.mark.parametrize("degree", [2, 4])
@pytest.mark.parametrize("out_family", ["Lagrange", "DG", "Quadrature"])
@pytest.mark.parametrize("value_shape", [(), (2,), (2, 3)])
def test_interpolation_matrix(use_petsc, cell_type, degree, out_family, value_shape):
if use_petsc:
pytest.importorskip("petsc4py")

Expand All @@ -33,14 +36,27 @@ def test_interpolation_matrix(use_petsc, cell_type, degree):
else:
raise ValueError("Unsupported cell type")

V = dolfinx.fem.functionspace(mesh, ("DG", degree))
Q = dolfinx.fem.functionspace(mesh, ("Lagrange", degree))
V = dolfinx.fem.functionspace(mesh, ("DG", degree, value_shape))
if out_family == "Quadrature":
el = basix.ufl.quadrature_element(mesh.basix_cell(), degree=degree, value_shape=value_shape)
else:
el = (out_family, degree, value_shape)
Q = dolfinx.fem.functionspace(mesh, el)

def f(x):
scalar_val = x[0] ** degree + x[1] if tdim == 2 else x[0] + x[1] + x[2] ** degree
vs = int(np.prod(value_shape))
f_rep = np.tile(scalar_val, vs).reshape(vs, -1)
for i in range(vs):
f_rep[i] += np.pi * (i + 1)
return f_rep

u = dolfinx.fem.Function(V)
u.interpolate(lambda x: x[0] ** degree + x[1] if tdim == 2 else x[0] + x[1] + x[2] ** degree)
u.interpolate(f)

q = dolfinx.fem.Function(Q)
expr = ufl.TrialFunction(V)

if use_petsc:
A = scifem.interpolation.petsc_interpolation_matrix(expr, Q)
A.mult(u.x.petsc_vec, q.x.petsc_vec)
Expand All @@ -59,7 +75,15 @@ def test_interpolation_matrix(use_petsc, cell_type, degree):
q.x.scatter_forward()

q_ref = dolfinx.fem.Function(Q)
q_ref.interpolate(u)
if out_family == "Quadrature":
try:
ip = Q.element.interpolation_points()
except TypeError:
ip = Q.element.interpolation_points
u_expr = dolfinx.fem.Expression(u, ip)
q_ref.interpolate(u_expr)
else:
q_ref.interpolate(u)

np.testing.assert_allclose(q.x.array, q_ref.x.array, rtol=1e-12, atol=1e-13)

Expand Down
Loading