diff --git a/src/scifem/interpolation.py b/src/scifem/interpolation.py index ed32a33..effb244 100644 --- a/src/scifem/interpolation.py +++ b/src/scifem/interpolation.py @@ -55,37 +55,57 @@ def prepare_interpolation_data( array_evaluated = compiled_expr.eval(mesh, np.arange(num_cells, dtype=np.int32)) assert np.prod(Q.value_shape) == np.prod(expr.ufl_shape) - im = Q.element.basix_element.interpolation_matrix - # Get data as (num_cells*num_points,1, expr_shape, num_test_basis_functions*test_block_size) expr_size = int(np.prod(expr.ufl_shape)) array_evaluated = array_evaluated.reshape( num_cells * q_points.shape[0], 1, expr_size, V.dofmap.bs * V.dofmap.dof_layout.num_dofs ) - jacobian = dolfinx.fem.Expression(ufl.Jacobian(mesh), q_points) - detJ = dolfinx.fem.Expression(ufl.JacobianDeterminant(mesh), q_points) - K = dolfinx.fem.Expression(ufl.JacobianInverse(mesh), q_points) - jacs = jacobian.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape( - num_cells * num_points, mesh.geometry.dim, mesh.topology.dim - ) - detJs = detJ.eval(mesh, np.arange(num_cells, dtype=np.int32)).flatten() - Ks = K.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape( - num_cells * num_points, mesh.geometry.dim, mesh.topology.dim - ) - Q_vs = Q.element.basix_element.value_size + # Check if we are dealing with a quadrature element or not. + # They do not have a complete DOLFINx API, which makes them tricky to use. + try: + basix_el = Q.element.basix_element + Q_vs = basix_el.value_size + pull_back = basix_el.pull_back + im = basix_el.interpolation_matrix + except RuntimeError: + Q_vs = 1 # If we do not have a basix element, assume value size is 1 + assert isinstance(Q.ufl_element().pullback, ufl.pullback.IdentityPullback) + pull_back = lambda x: None + assert Q.element.interpolation_ident + im = None + new_array = np.zeros( (num_cells * num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs), dtype=np.float64, ) - for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs): - for q in range(Q.dofmap.bs): - new_array[:, q * Q_vs : (q + 1) * Q_vs, i] = Q.element.basix_element.pull_back( - array_evaluated[:, :, q * Q_vs : (q + 1) * Q_vs, i], jacs, detJs, Ks - ).reshape(num_cells * num_points, Q_vs) - new_array = new_array.reshape( - num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs - ) + + # Check if pullback is identity, then we can skip this step + if not isinstance(Q.ufl_element().pullback, ufl.pullback.IdentityPullback): + jacobian = dolfinx.fem.Expression(ufl.Jacobian(mesh), q_points) + detJ = dolfinx.fem.Expression(ufl.JacobianDeterminant(mesh), q_points) + K = dolfinx.fem.Expression(ufl.JacobianInverse(mesh), q_points) + jacs = jacobian.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape( + num_cells * num_points, mesh.geometry.dim, mesh.topology.dim + ) + detJs = detJ.eval(mesh, np.arange(num_cells, dtype=np.int32)).flatten() + Ks = K.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape( + num_cells * num_points, mesh.geometry.dim, mesh.topology.dim + ) + + for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs): + for q in range(Q.dofmap.bs): + new_array[:, q * Q_vs : (q + 1) * Q_vs, i] = pull_back( + array_evaluated[:, :, q * Q_vs : (q + 1) * Q_vs, i], jacs, detJs, Ks + ).reshape(num_cells * num_points, Q_vs) + new_array = new_array.reshape( + num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs + ) + else: + new_array = array_evaluated.reshape( + num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs + ) + interpolated_matrix = np.zeros( ( num_cells, @@ -94,19 +114,36 @@ def prepare_interpolation_data( ), dtype=np.float64, ) + # Check if interpolation matrix of dual operator is identity, then we can use a vectorized + # version of this step + if Q.element.interpolation_ident: + # Smart vectorized version with identity mapping + if Q.dofmap.bs == 1: + interpolated_matrix = new_array.transpose(0, 2, 1, 3).reshape( + new_array.shape[0], new_array.shape[1] * new_array.shape[2], new_array.shape[3] + ) + else: + i_scalar = new_array.transpose(0, 2, 1, 3) + interpolated_matrix = np.zeros( + (new_array.shape[0], new_array.shape[1] * new_array.shape[2], new_array.shape[3]) + ) + for q in range(Q.dofmap.bs): + interpolated_matrix[:, q :: Q.dofmap.bs, :] = i_scalar[:, q, :, :] - for c in range(num_cells): - for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs): - tmp_array = np.zeros((int(num_points), Q.dofmap.bs * Q_vs), dtype=np.float64) - for p in range(num_points): - tmp_array[p] = new_array[c, p, :, i] - if Q.dofmap.bs == 1: - interpolated_matrix[c, :, i] = (im @ tmp_array.T.flatten()).flatten() - else: - for q in range(Q.dofmap.bs): - interpolated_matrix[c, q :: Q.dofmap.bs, i] = ( - im @ tmp_array.T[q].flatten() - ).flatten() + else: + # Tedious non-identity version + for c in range(num_cells): + for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs): + tmp_array = np.zeros((int(num_points), Q.dofmap.bs * Q_vs), dtype=np.float64) + for p in range(num_points): + tmp_array[p] = new_array[c, p, :, i] + if Q.dofmap.bs == 1: + interpolated_matrix[c, :, i] = (im @ tmp_array.T.flatten()).flatten() + else: + for q in range(Q.dofmap.bs): + interpolated_matrix[c, q :: Q.dofmap.bs, i] = ( + im @ tmp_array.T[q].flatten() + ).flatten() # Apply dof transformation to each column (using Piopla maps) mesh.topology.create_entity_permutations() if Q.element.needs_dof_transformations: diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index f4b2cb5..3d48732 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -5,6 +5,7 @@ import pytest import ufl import numpy as np +import basix.ufl @pytest.mark.skipif( @@ -20,8 +21,10 @@ ], ) @pytest.mark.parametrize("use_petsc", [True, False]) -@pytest.mark.parametrize("degree", [1, 3, 5]) -def test_interpolation_matrix(use_petsc, cell_type, degree): +@pytest.mark.parametrize("degree", [2, 4]) +@pytest.mark.parametrize("out_family", ["Lagrange", "DG", "Quadrature"]) +@pytest.mark.parametrize("value_shape", [(), (2,), (2, 3)]) +def test_interpolation_matrix(use_petsc, cell_type, degree, out_family, value_shape): if use_petsc: pytest.importorskip("petsc4py") @@ -33,14 +36,27 @@ def test_interpolation_matrix(use_petsc, cell_type, degree): else: raise ValueError("Unsupported cell type") - V = dolfinx.fem.functionspace(mesh, ("DG", degree)) - Q = dolfinx.fem.functionspace(mesh, ("Lagrange", degree)) + V = dolfinx.fem.functionspace(mesh, ("DG", degree, value_shape)) + if out_family == "Quadrature": + el = basix.ufl.quadrature_element(mesh.basix_cell(), degree=degree, value_shape=value_shape) + else: + el = (out_family, degree, value_shape) + Q = dolfinx.fem.functionspace(mesh, el) + + def f(x): + scalar_val = x[0] ** degree + x[1] if tdim == 2 else x[0] + x[1] + x[2] ** degree + vs = int(np.prod(value_shape)) + f_rep = np.tile(scalar_val, vs).reshape(vs, -1) + for i in range(vs): + f_rep[i] += np.pi * (i + 1) + return f_rep u = dolfinx.fem.Function(V) - u.interpolate(lambda x: x[0] ** degree + x[1] if tdim == 2 else x[0] + x[1] + x[2] ** degree) + u.interpolate(f) q = dolfinx.fem.Function(Q) expr = ufl.TrialFunction(V) + if use_petsc: A = scifem.interpolation.petsc_interpolation_matrix(expr, Q) A.mult(u.x.petsc_vec, q.x.petsc_vec) @@ -59,7 +75,15 @@ def test_interpolation_matrix(use_petsc, cell_type, degree): q.x.scatter_forward() q_ref = dolfinx.fem.Function(Q) - q_ref.interpolate(u) + if out_family == "Quadrature": + try: + ip = Q.element.interpolation_points() + except TypeError: + ip = Q.element.interpolation_points + u_expr = dolfinx.fem.Expression(u, ip) + q_ref.interpolate(u_expr) + else: + q_ref.interpolate(u) np.testing.assert_allclose(q.x.array, q_ref.x.array, rtol=1e-12, atol=1e-13)