From 57e610c7768bae2f96e3e5f80af693ddfd0aa43c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 14 Aug 2025 16:33:43 +0100 Subject: [PATCH 01/27] GEM tabulations --- FIAT/expansions.py | 2 +- finat/fiat_elements.py | 95 ++++++++++++++++++++++++++++-------------- gem/gem.py | 2 + 3 files changed, 66 insertions(+), 33 deletions(-) diff --git a/FIAT/expansions.py b/FIAT/expansions.py index b7d846f26..1b1e0245e 100644 --- a/FIAT/expansions.py +++ b/FIAT/expansions.py @@ -53,7 +53,7 @@ def pad_jacobian(A, embedded_dim): def jacobi_factors(x, y, z, dx, dy, dz): fb = 0.5 * (y + z) fa = x + (fb + 1.0) - fc = fb ** 2 + fc = fb * fb dfa = dfb = dfc = None if dx is not None: dfb = 0.5 * (dy + dz) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index a5d61b3b5..cf4fc0da6 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -178,7 +178,7 @@ def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=Non esd = self.cell.construct_subelement(entity_dim).get_spatial_dimension() assert isinstance(refcoords, gem.Node) and refcoords.shape == (esd,) - return point_evaluation(self._element, order, refcoords, (entity_dim, entity_i)) + return point_evaluation(self, order, refcoords, (entity_dim, entity_i)) @cached_property def _dual_basis(self): @@ -272,46 +272,77 @@ def mapping(self): return result -def point_evaluation(fiat_element, order, refcoords, entity): - # Coordinates on the reference entity (SymPy) - esd, = refcoords.shape - Xi = sp.symbols('X Y Z')[:esd] +def point_evaluation(finat_element, order, ps, entity): + fiat_element = finat_element._element + # Coordinates on the reference entity (GEM) + Xi = tuple(gem.Indexed(ps, i) for i in np.ndindex(ps.shape)) + if finat_element.complex.is_macrocell(): + # Coordinates on the reference entity (SymPy) + points = [sp.symbols('X Y Z')[:len(Xi)]] + # Convert SymPy expression to GEM + mapper = gem.node.Memoizer(sympy2gem) + mapper.bindings = dict(zip(points[0], Xi)) + mapper = np.vectorize(mapper) + else: + points = [Xi] + mapper = None + + fiat_result = fiat_element.tabulate(order, points, entity) + + value_shape = finat_element.value_shape + value_size = np.prod(value_shape, dtype=int) space_dimension = fiat_element.space_dimension() - value_size = np.prod(fiat_element.value_shape(), dtype=int) - fiat_result = fiat_element.tabulate(order, [Xi], entity) + + if finat_element.space_dimension() == space_dimension: + beta = finat_element.get_indices() + index_shape = tuple(index.extent for index in beta) + else: + index_shape = (space_dimension,) + beta = tuple(gem.Index(extent=i) for i in index_shape) + assert len(beta) == len(finat_element.get_indices()) + + zeta = finat_element.get_value_indices() + basis_indices = beta + zeta + result = {} for alpha, fiat_table in fiat_result.items(): if isinstance(fiat_table, Exception): result[alpha] = gem.Failure((space_dimension,) + fiat_element.value_shape(), fiat_table) continue - # Convert SymPy expression to GEM - mapper = gem.node.Memoizer(sympy2gem) - mapper.bindings = {s: gem.Indexed(refcoords, (i,)) - for i, s in enumerate(Xi)} - gem_table = np.vectorize(mapper)(fiat_table) - - table_roll = gem_table.reshape(space_dimension, value_size).transpose() - - exprs = [] - for table in table_roll: - exprs.append(gem.ListTensor(table.reshape(space_dimension))) - if fiat_element.value_shape(): - beta = (gem.Index(extent=space_dimension),) - zeta = tuple(gem.Index(extent=d) - for d in fiat_element.value_shape()) - result[alpha] = gem.ComponentTensor( - gem.Indexed( - gem.ListTensor(np.array( - [gem.Indexed(expr, beta) for expr in exprs] - ).reshape(fiat_element.value_shape())), - zeta), - beta + zeta - ) + derivative = sum(alpha) + fiat_table = fiat_table.reshape(space_dimension, value_size, -1) + if mapper is not None: + fiat_table = mapper(fiat_table) + + point_indices = () + if fiat_table.dtype == object: + assert len(points) == 1 + elif derivative == finat_element.degree and not finat_element.complex.is_macrocell(): + # Make sure numerics satisfies theory + fiat_table = fiat_table[..., 0] + elif derivative > finat_element.degree: + # Make sure numerics satisfies theory + assert np.allclose(fiat_table, 0.0) + fiat_table = np.zeros(fiat_table.shape[:-1]) else: - expr, = exprs - result[alpha] = expr + point_indices = ps.indices + + point_shape = tuple(index.extent for index in point_indices) + table_shape = index_shape + value_shape + point_shape + table_indices = basis_indices + point_indices + + fiat_table = fiat_table.reshape(table_shape) + if fiat_table.dtype == object: + gem_table = gem.ListTensor(fiat_table) + else: + gem_table = gem.Literal(fiat_table) + + expr = gem.Indexed(gem_table, table_indices) + expr = gem.ComponentTensor(expr, basis_indices) + result[alpha] = expr + return result diff --git a/gem/gem.py b/gem/gem.py index 4bae48e89..823f91830 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -1275,6 +1275,8 @@ def as_gem(expr): return expr elif isinstance(expr, Number): return Literal(expr) + elif isinstance(expr, numpy.ndarray): + return ListTensor(expr) if expr.dtype == object else Literal(expr) else: raise ValueError("Do not know how to convert %r to GEM" % expr) From e078883e94ac59421d94bc9ff23118d25bc29404 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 20 Aug 2025 19:06:54 +0100 Subject: [PATCH 02/27] Remove sympy from point_evaluation --- FIAT/expansions.py | 11 ++++++++--- finat/fiat_elements.py | 15 +-------------- finat/sympy2gem.py | 14 +------------- gem/gem.py | 43 +++++++++++++++++++++++++++++++++++++++++- 4 files changed, 52 insertions(+), 31 deletions(-) diff --git a/FIAT/expansions.py b/FIAT/expansions.py index 1b1e0245e..d4a4a34c4 100644 --- a/FIAT/expansions.py +++ b/FIAT/expansions.py @@ -723,11 +723,16 @@ def compute_partition_of_unity(ref_el, pt, unique=True, tol=1E-12): :kwarg tol: the absolute tolerance. :returns: a list of (weighted) characteristic functions for each subcell. """ - from sympy import Piecewise + import gem sd = ref_el.get_spatial_dimension() top = ref_el.get_topology() # assert singleton point - pt = pt.reshape((sd,)) + pt = numpy.reshape(pt, (sd,)) + + if isinstance(pt[0], gem.Node): + import gem as backend + else: + import sympy as backend # The distance to the nearest cell is equal to the distance to the parent cell best = ref_el.get_parent().distance_to_point_l1(pt, rescale=True) @@ -739,7 +744,7 @@ def compute_partition_of_unity(ref_el, pt, unique=True, tol=1E-12): for cell in sorted(top[sd]): # Bin points based on l1 distance pt_near_cell = ref_el.distance_to_point_l1(pt, entity=(sd, cell), rescale=True) < tol - masks.append(Piecewise(*otherwise, (1.0, pt_near_cell), (0.0, True))) + masks.append(backend.Piecewise(*otherwise, (1.0, pt_near_cell), (0.0, True))) if unique: otherwise.append((0.0, pt_near_cell)) # If the point is on a facet, divide the characteristic function by the facet multiplicity diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index cf4fc0da6..bf2f16728 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -1,12 +1,10 @@ import FIAT import gem import numpy as np -import sympy as sp from gem.utils import cached_property from finat.finiteelementbase import FiniteElementBase from finat.point_set import PointSet -from finat.sympy2gem import sympy2gem try: from firedrake_citations import Citations @@ -277,16 +275,7 @@ def point_evaluation(finat_element, order, ps, entity): # Coordinates on the reference entity (GEM) Xi = tuple(gem.Indexed(ps, i) for i in np.ndindex(ps.shape)) - if finat_element.complex.is_macrocell(): - # Coordinates on the reference entity (SymPy) - points = [sp.symbols('X Y Z')[:len(Xi)]] - # Convert SymPy expression to GEM - mapper = gem.node.Memoizer(sympy2gem) - mapper.bindings = dict(zip(points[0], Xi)) - mapper = np.vectorize(mapper) - else: - points = [Xi] - mapper = None + points = [Xi] fiat_result = fiat_element.tabulate(order, points, entity) @@ -313,8 +302,6 @@ def point_evaluation(finat_element, order, ps, entity): derivative = sum(alpha) fiat_table = fiat_table.reshape(space_dimension, value_size, -1) - if mapper is not None: - fiat_table = mapper(fiat_table) point_indices = () if fiat_table.dtype == object: diff --git a/finat/sympy2gem.py b/finat/sympy2gem.py index 9d613a307..440585cc0 100644 --- a/finat/sympy2gem.py +++ b/finat/sympy2gem.py @@ -1,6 +1,5 @@ from functools import singledispatch, reduce -import numpy import sympy try: import symengine @@ -130,18 +129,7 @@ def sympy2gem_le(node, self): @sympy2gem.register(sympy.Piecewise) @sympy2gem.register(symengine.Piecewise) def sympy2gem_conditional(node, self): - expr = None - pieces = [] - for v, c in node.args: - if isinstance(c, (bool, numpy.bool, sympy.logic.boolalg.BooleanTrue)) and c: - expr = self(v) - break - pieces.append((v, c)) - if expr is None: - expr = gem.Literal(float("nan")) - for v, c in reversed(pieces): - expr = gem.Conditional(self(c), self(v), expr) - return expr + return gem.Piecewise(*[(self(v), self(c)) for v, c in node.args]) @sympy2gem.register(sympy.ITE) diff --git a/gem/gem.py b/gem/gem.py index 823f91830..e73978d7b 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -36,7 +36,7 @@ 'IndexSum', 'ListTensor', 'Concatenate', 'Delta', 'OrientationVariableIndex', 'index_sum', 'partial_indexed', 'reshape', 'view', 'indices', 'as_gem', 'FlexiblyIndexed', - 'Inverse', 'Solve', 'extract_type', 'uint_type'] + 'Inverse', 'Solve', 'extract_type', 'uint_type', 'Piecewise'] uint_type = numpy.dtype(numpy.uintc) @@ -124,6 +124,21 @@ def __matmul__(self, other): def __rmatmul__(self, other): return as_gem(other).__matmul__(self) + def __abs__(self): + return componentwise(lambda x: MathFunction("abs", x), self) + + def __lt__(self, other): + return componentwise(lambda x, y: Comparison("<", x, y), self, other) + + def __gt__(self, other): + return componentwise(lambda x, y: Comparison(">", x, y), self, other) + + def __le__(self, other): + return componentwise(lambda x, y: Comparison("<=", x, y), self, other) + + def __ge__(self, other): + return componentwise(lambda x, y: Comparison(">=", x, y), self, other) + @property def T(self): i = indices(len(self.shape)) @@ -1275,6 +1290,8 @@ def as_gem(expr): return expr elif isinstance(expr, Number): return Literal(expr) + elif isinstance(expr, (bool, numpy.bool)): + return Literal(bool(expr)) elif isinstance(expr, numpy.ndarray): return ListTensor(expr) if expr.dtype == object else Literal(expr) else: @@ -1311,3 +1328,27 @@ def as_gem_uint(expr): def extract_type(expressions, klass): """Collects objects of type klass in expressions.""" return tuple(node for node in traversal(expressions) if isinstance(node, klass)) + + +def Piecewise(*args): + """Represents a piecewise function. + + Each argument is a 2-tuple defining an expression and condition. + + Returns + ------- + Node + A GEM nested Conditional. + """ + expr = None + pieces = [] + for v, c in args: + if isinstance(c, (bool, numpy.bool)) and c: + expr = as_gem(v) + break + pieces.append((as_gem(v), as_gem(c))) + if expr is None: + expr = Literal(float("nan")) + for v, c in reversed(pieces): + expr = Conditional(c, v, expr) + return expr From 2f7fb1cdeb459b8cae9b0f0e2323301265a423fa Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 20 Aug 2025 19:26:13 +0100 Subject: [PATCH 03/27] Reuse code for tabulation at known and unknown points --- finat/fiat_elements.py | 157 ++++++++++++------------------------- finat/physically_mapped.py | 4 - gem/gem.py | 8 +- 3 files changed, 55 insertions(+), 114 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index bf2f16728..3a5bec03e 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -4,7 +4,7 @@ from gem.utils import cached_property from finat.finiteelementbase import FiniteElementBase -from finat.point_set import PointSet +from finat.point_set import PointSet, PointSingleton try: from firedrake_citations import Citations @@ -96,10 +96,12 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): :param ps: the point set. :param entity: the cell entity on which to tabulate. ''' - space_dimension = self._element.space_dimension() - value_size = np.prod(self._element.value_shape(), dtype=int) - fiat_result = self._element.tabulate(order, ps.points, entity) - result = {} + fiat_element = self._element + fiat_result = fiat_element.tabulate(order, ps.points, entity) + + value_shape = self.value_shape + value_size = np.prod(value_shape, dtype=int) + space_dimension = fiat_element.space_dimension() # In almost all cases, we have # self.space_dimension() == self._element.space_dimension() # But for Bell, FIAT reports 21 basis functions, @@ -107,51 +109,49 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # basis functions, and the additional 3 are for # dealing with transformations between physical # and reference space). - index_shape = (self._element.space_dimension(),) + if self.space_dimension() == space_dimension: + beta = self.get_indices() + index_shape = tuple(index.extent for index in beta) + else: + index_shape = (space_dimension,) + beta = tuple(gem.Index(extent=i) for i in index_shape) + assert len(beta) == len(self.get_indices()) + + zeta = self.get_value_indices() + basis_indices = beta + zeta + + result = {} for alpha, fiat_table in fiat_result.items(): if isinstance(fiat_table, Exception): - result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table) + shape = ps.points.shape[:-1] + index_shape + value_shape + result[alpha] = gem.Failure(shape, fiat_table) continue derivative = sum(alpha) - shp = (space_dimension, value_size, *ps.points.shape[:-1]) - table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1) - - exprs = [] - for table in table_roll: - if derivative == self.degree and not self.complex.is_macrocell(): - # Make sure numerics satisfies theory - exprs.append(gem.Literal(table[0])) - elif derivative > self.degree: - # Make sure numerics satisfies theory - assert np.allclose(table, 0.0) - exprs.append(gem.Literal(np.zeros(self.index_shape))) - else: - point_indices = ps.indices - point_shape = tuple(index.extent for index in point_indices) - - exprs.append(gem.partial_indexed( - gem.Literal(table.reshape(point_shape + index_shape)), - point_indices - )) - if self.value_shape: - # As above, this extent may be different from that - # advertised by the finat element. - beta = tuple(gem.Index(extent=i) for i in index_shape) - assert len(beta) == len(self.get_indices()) - - zeta = self.get_value_indices() - result[alpha] = gem.ComponentTensor( - gem.Indexed( - gem.ListTensor(np.array( - [gem.Indexed(expr, beta) for expr in exprs] - ).reshape(self.value_shape)), - zeta), - beta + zeta - ) + fiat_table = fiat_table.reshape(space_dimension, value_size, -1) + + point_indices = () + if derivative == self.degree and not self.complex.is_macrocell(): + # Make sure numerics satisfies theory + fiat_table = fiat_table[..., 0] + elif derivative > self.degree: + # Make sure numerics satisfies theory + assert np.allclose(fiat_table, 0.0) + fiat_table = np.zeros(fiat_table.shape[:-1]) else: - expr, = exprs - result[alpha] = expr + point_indices = ps.indices + + point_shape = tuple(index.extent for index in point_indices) + table_shape = index_shape + value_shape + point_shape + table_indices = basis_indices + point_indices + + fiat_table = fiat_table.reshape(table_shape) + gem_table = gem.as_gem(fiat_table) + + expr = gem.Indexed(gem_table, table_indices) + expr = gem.ComponentTensor(expr, basis_indices) + result[alpha] = expr + return result def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None): @@ -176,7 +176,11 @@ def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=Non esd = self.cell.construct_subelement(entity_dim).get_spatial_dimension() assert isinstance(refcoords, gem.Node) and refcoords.shape == (esd,) - return point_evaluation(self, order, refcoords, (entity_dim, entity_i)) + # Coordinates on the reference entity (GEM) + Xi = tuple(gem.Indexed(refcoords, i) for i in np.ndindex(refcoords.shape)) + ps = PointSingleton(Xi) + return self.basis_evaluation(order, ps, entity=entity, + coordinate_mapping=coordinate_mapping) @cached_property def _dual_basis(self): @@ -270,69 +274,6 @@ def mapping(self): return result -def point_evaluation(finat_element, order, ps, entity): - fiat_element = finat_element._element - - # Coordinates on the reference entity (GEM) - Xi = tuple(gem.Indexed(ps, i) for i in np.ndindex(ps.shape)) - points = [Xi] - - fiat_result = fiat_element.tabulate(order, points, entity) - - value_shape = finat_element.value_shape - value_size = np.prod(value_shape, dtype=int) - space_dimension = fiat_element.space_dimension() - - if finat_element.space_dimension() == space_dimension: - beta = finat_element.get_indices() - index_shape = tuple(index.extent for index in beta) - else: - index_shape = (space_dimension,) - beta = tuple(gem.Index(extent=i) for i in index_shape) - assert len(beta) == len(finat_element.get_indices()) - - zeta = finat_element.get_value_indices() - basis_indices = beta + zeta - - result = {} - for alpha, fiat_table in fiat_result.items(): - if isinstance(fiat_table, Exception): - result[alpha] = gem.Failure((space_dimension,) + fiat_element.value_shape(), fiat_table) - continue - - derivative = sum(alpha) - fiat_table = fiat_table.reshape(space_dimension, value_size, -1) - - point_indices = () - if fiat_table.dtype == object: - assert len(points) == 1 - elif derivative == finat_element.degree and not finat_element.complex.is_macrocell(): - # Make sure numerics satisfies theory - fiat_table = fiat_table[..., 0] - elif derivative > finat_element.degree: - # Make sure numerics satisfies theory - assert np.allclose(fiat_table, 0.0) - fiat_table = np.zeros(fiat_table.shape[:-1]) - else: - point_indices = ps.indices - - point_shape = tuple(index.extent for index in point_indices) - table_shape = index_shape + value_shape + point_shape - table_indices = basis_indices + point_indices - - fiat_table = fiat_table.reshape(table_shape) - if fiat_table.dtype == object: - gem_table = gem.ListTensor(fiat_table) - else: - gem_table = gem.Literal(fiat_table) - - expr = gem.Indexed(gem_table, table_indices) - expr = gem.ComponentTensor(expr, basis_indices) - result[alpha] = expr - - return result - - class Regge(FiatElement): # symmetric matrix valued def __init__(self, cell, degree, variant=None): super().__init__(FIAT.Regge(cell, degree, variant=variant)) diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index 54cce6d4e..dff76cb3b 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -327,10 +327,6 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): result = super().basis_evaluation(order, ps, entity=entity) return self.map_tabulation(result, coordinate_mapping) - def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None): - result = super().point_evaluation(order, refcoords, entity=entity) - return self.map_tabulation(result, coordinate_mapping) - class DirectlyDefinedElement(NeedsCoordinateMappingElement): """Base class for directly defined elements such as direct diff --git a/gem/gem.py b/gem/gem.py index e73978d7b..efa46cdd9 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -1333,12 +1333,16 @@ def extract_type(expressions, klass): def Piecewise(*args): """Represents a piecewise function. - Each argument is a 2-tuple defining an expression and condition. + Parameters + ---------- + *args + Each argument is a 2-tuple defining an expression and condition. Returns ------- Node - A GEM nested Conditional. + A nested Conditional. + """ expr = None pieces = [] From 2c54e72483013e6b7a6f0fa53825afc66a628a73 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 21 Aug 2025 08:22:18 +0100 Subject: [PATCH 04/27] GEM barycentric interpolation --- FIAT/barycentric_interpolation.py | 15 ++++++++++++--- finat/fiat_elements.py | 3 ++- finat/physically_mapped.py | 8 +++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/FIAT/barycentric_interpolation.py b/FIAT/barycentric_interpolation.py index 45b0287e3..0f44f8327 100644 --- a/FIAT/barycentric_interpolation.py +++ b/FIAT/barycentric_interpolation.py @@ -6,6 +6,7 @@ # # Written by Pablo D. Brubeck (brubeck@protonmail.com), 2021 +import gem import numpy from FIAT import reference_element, expansions, polynomial_set @@ -24,19 +25,27 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0): via the second barycentric interpolation formula. See Berrut and Trefethen (2004) https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4) """ - if pts.dtype == object: + if pts.dtype == object and not isinstance(pts.flat[0], gem.Node): from sympy import simplify sp_simplify = numpy.vectorize(simplify) else: sp_simplify = lambda x: x + phi = numpy.add.outer(-nodes, pts.flatten()) with numpy.errstate(divide='ignore', invalid='ignore'): numpy.reciprocal(phi, out=phi) numpy.multiply(phi, wts[:, None], out=phi) numpy.multiply(1.0 / numpy.sum(phi, axis=0), phi, out=phi) - phi[phi != phi] = 1.0 - phi = phi.reshape(-1, *pts.shape[:-1]) + # Replace 0/0 with 1.0 + if isinstance(pts.flat[0], gem.Node): + one = gem.Literal(1.0) + for i in numpy.ndindex(phi.shape): + phi[i] = gem.Conditional(gem.Comparison("!=", phi[i], phi[i]), one, phi[i]) + else: + phi[phi != phi] = 1.0 + + phi = phi.reshape(-1, *pts.shape[:-1]) phi = sp_simplify(phi) results = {(0,): phi} for r in range(1, order+1): diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 3a5bec03e..0dbb29180 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -136,7 +136,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): fiat_table = fiat_table[..., 0] elif derivative > self.degree: # Make sure numerics satisfies theory - assert np.allclose(fiat_table, 0.0) + if fiat_table.dtype != object: + assert np.allclose(fiat_table, 0.0) fiat_table = np.zeros(fiat_table.shape[:-1]) else: point_indices = ps.indices diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index dff76cb3b..7ea409f0d 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -266,6 +266,9 @@ class MappedTabulation(Mapping): on the requested derivatives.""" def __init__(self, M, ref_tabulation): + M = gem.optimise.aggressive_unroll(M) + M, = gem.optimise.constant_fold_zero((M,)) + self.M = M self.ref_tabulation = ref_tabulation # we expect M to be sparse with O(1) nonzeros per row @@ -283,9 +286,8 @@ def matvec(self, table): exprs = [gem.ComponentTensor(gem.Sum(*(self.M.array[i, j] * phi[j] for j in js)), ii) for i, js in enumerate(self.csr)] - val = gem.ListTensor(exprs) - # val = self.M @ table - return gem.optimise.aggressive_unroll(val) + # return gem.optimise.aggressive_unroll(self.M @ table) + return gem.ListTensor(exprs) def __getitem__(self, alpha): try: From 26113d8fa62fde2c4873a88f17d14728a88c6c60 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 21 Aug 2025 13:13:53 +0100 Subject: [PATCH 05/27] Fix cellwise constant case --- finat/fiat_elements.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 0dbb29180..cff69ea45 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -133,6 +133,14 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): point_indices = () if derivative == self.degree and not self.complex.is_macrocell(): # Make sure numerics satisfies theory + if fiat_table.dtype == object: + bindings = {X: np.zeros(X.shape) + for pt in ps.points + for X in gem.extract_type(pt, gem.Variable)} + gem_table = gem.as_gem(fiat_table) + val, = gem.interpreter.evaluate((gem_table,), bindings=bindings) + fiat_table = val.arr.T + fiat_table = fiat_table[..., 0] elif derivative > self.degree: # Make sure numerics satisfies theory From 724380923f54bf8080811df53c8d72b1c84b2a08 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 21 Aug 2025 14:54:38 +0100 Subject: [PATCH 06/27] add tests --- FIAT/barycentric_interpolation.py | 9 +++++---- test/FIAT/unit/test_macro.py | 33 +++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/FIAT/barycentric_interpolation.py b/FIAT/barycentric_interpolation.py index 0f44f8327..835c9db42 100644 --- a/FIAT/barycentric_interpolation.py +++ b/FIAT/barycentric_interpolation.py @@ -38,10 +38,11 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0): numpy.multiply(1.0 / numpy.sum(phi, axis=0), phi, out=phi) # Replace 0/0 with 1.0 - if isinstance(pts.flat[0], gem.Node): - one = gem.Literal(1.0) - for i in numpy.ndindex(phi.shape): - phi[i] = gem.Conditional(gem.Comparison("!=", phi[i], phi[i]), one, phi[i]) + if pts.dtype == object: + if isinstance(pts.flat[0], gem.Node): + one = gem.Literal(1.0) + for i in numpy.ndindex(phi.shape): + phi[i] = gem.Conditional(gem.Comparison("!=", phi[i], phi[i]), one, phi[i]) else: phi[phi != phi] = 1.0 diff --git a/test/FIAT/unit/test_macro.py b/test/FIAT/unit/test_macro.py index abfec9504..97c53b918 100644 --- a/test/FIAT/unit/test_macro.py +++ b/test/FIAT/unit/test_macro.py @@ -428,10 +428,11 @@ def test_macro_sympy(cell, element): variant = "spectral,alfeld" K = IsoSplit(cell) ebig = element(K, 3, variant=variant) - pts = get_lagrange_points(ebig.dual_basis()) + pts = numpy.asarray(get_lagrange_points(ebig.dual_basis())) + + X = tuple(sympy.Symbol(f"X[{i}]") for i in range(pts.shape[1])) dim = cell.get_spatial_dimension() - X = tuple(sympy.Symbol(f"X[{i}]") for i in range(dim)) degrees = range(1, 3) if element is Lagrange else range(3) for degree in degrees: fe = element(cell, degree, variant=variant) @@ -443,6 +444,34 @@ def test_macro_sympy(cell, element): assert numpy.allclose(results, tab_numpy) +@pytest.mark.parametrize("element", (DiscontinuousLagrange, Lagrange)) +def test_macro_gem(cell, element): + import gem + from gem.interpreter import evaluate + + variant = "spectral,alfeld" + K = IsoSplit(cell) + ebig = element(K, 3, variant=variant) + pts = numpy.asarray(get_lagrange_points(ebig.dual_basis())) + + coords = gem.Variable('X', pts.shape) + bindings = {coords: pts} + index = gem.Index() + point = gem.partial_indexed(coords, (index,)) + X = tuple(gem.Indexed(point, i) for i in numpy.ndindex(point.shape)) + + dim = cell.get_spatial_dimension() + degrees = range(1, 3) if element is Lagrange else range(3) + for degree in degrees: + fe = element(cell, degree, variant=variant) + tab_gem = fe.tabulate(0, X)[(0,) * dim] + + results, = evaluate((gem.as_gem(tab_gem),), bindings=bindings) + results = results.arr.T + tab_numpy = fe.tabulate(0, pts)[(0,) * dim] + assert numpy.allclose(results, tab_numpy) + + @pytest.mark.parametrize("element,degree", [ (Lagrange, 1), (Nedelec, 1), (RaviartThomas, 1), (DiscontinuousLagrange, 0), (Regge, 0), (HellanHerrmannJohnson, 0), From fdefbbd73579f033d105cf25c94f9a35d15d0583 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 21 Aug 2025 16:25:06 +0100 Subject: [PATCH 07/27] Fix transpose --- finat/fiat_elements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index cff69ea45..328e295cf 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -128,7 +128,6 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): continue derivative = sum(alpha) - fiat_table = fiat_table.reshape(space_dimension, value_size, -1) point_indices = () if derivative == self.degree and not self.complex.is_macrocell(): @@ -139,14 +138,15 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): for X in gem.extract_type(pt, gem.Variable)} gem_table = gem.as_gem(fiat_table) val, = gem.interpreter.evaluate((gem_table,), bindings=bindings) - fiat_table = val.arr.T + fiat_table = val.arr.transpose((*range(1, val.arr.ndim), 0)) + fiat_table = fiat_table.reshape(space_dimension, value_size, -1) fiat_table = fiat_table[..., 0] elif derivative > self.degree: # Make sure numerics satisfies theory if fiat_table.dtype != object: assert np.allclose(fiat_table, 0.0) - fiat_table = np.zeros(fiat_table.shape[:-1]) + fiat_table = np.zeros((space_dimension, value_size)) else: point_indices = ps.indices From 6868df51a4910aadb90a761818a856e752706b1c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 22 Aug 2025 15:33:56 +0100 Subject: [PATCH 08/27] Fixup --- FIAT/barycentric_interpolation.py | 25 +++++++++++-------------- finat/fiat_elements.py | 21 +++++++-------------- gem/gem.py | 12 +++++++----- 3 files changed, 25 insertions(+), 33 deletions(-) diff --git a/FIAT/barycentric_interpolation.py b/FIAT/barycentric_interpolation.py index 835c9db42..6b3636773 100644 --- a/FIAT/barycentric_interpolation.py +++ b/FIAT/barycentric_interpolation.py @@ -6,7 +6,6 @@ # # Written by Pablo D. Brubeck (brubeck@protonmail.com), 2021 -import gem import numpy from FIAT import reference_element, expansions, polynomial_set @@ -25,28 +24,26 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0): via the second barycentric interpolation formula. See Berrut and Trefethen (2004) https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4) """ - if pts.dtype == object and not isinstance(pts.flat[0], gem.Node): - from sympy import simplify - sp_simplify = numpy.vectorize(simplify) - else: - sp_simplify = lambda x: x - + sp_simplify = lambda x: x phi = numpy.add.outer(-nodes, pts.flatten()) with numpy.errstate(divide='ignore', invalid='ignore'): - numpy.reciprocal(phi, out=phi) - numpy.multiply(phi, wts[:, None], out=phi) - numpy.multiply(1.0 / numpy.sum(phi, axis=0), phi, out=phi) + numpy.divide(wts[:, None], phi, out=phi) + numpy.divide(phi, numpy.sum(phi, axis=0, keepdims=True), out=phi) + phi = phi.reshape(-1, *pts.shape[:-1]) # Replace 0/0 with 1.0 if pts.dtype == object: - if isinstance(pts.flat[0], gem.Node): + import gem + if any(isinstance(Xi, gem.Node) for Xi in pts.flat): one = gem.Literal(1.0) - for i in numpy.ndindex(phi.shape): - phi[i] = gem.Conditional(gem.Comparison("!=", phi[i], phi[i]), one, phi[i]) + for i, u in numpy.ndenumerate(phi): + phi[i] = gem.Conditional(gem.Comparison("!=", u, u), one, u) + else: + from sympy import simplify + sp_simplify = numpy.vectorize(simplify) else: phi[phi != phi] = 1.0 - phi = phi.reshape(-1, *pts.shape[:-1]) phi = sp_simplify(phi) results = {(0,): phi} for r in range(1, order+1): diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 328e295cf..902664205 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -98,10 +98,6 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): ''' fiat_element = self._element fiat_result = fiat_element.tabulate(order, ps.points, entity) - - value_shape = self.value_shape - value_size = np.prod(value_shape, dtype=int) - space_dimension = fiat_element.space_dimension() # In almost all cases, we have # self.space_dimension() == self._element.space_dimension() # But for Bell, FIAT reports 21 basis functions, @@ -109,6 +105,8 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # basis functions, and the additional 3 are for # dealing with transformations between physical # and reference space). + value_shape = self.value_shape + space_dimension = fiat_element.space_dimension() if self.space_dimension() == space_dimension: beta = self.get_indices() index_shape = tuple(index.extent for index in beta) @@ -128,7 +126,6 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): continue derivative = sum(alpha) - point_indices = () if derivative == self.degree and not self.complex.is_macrocell(): # Make sure numerics satisfies theory @@ -140,27 +137,23 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): val, = gem.interpreter.evaluate((gem_table,), bindings=bindings) fiat_table = val.arr.transpose((*range(1, val.arr.ndim), 0)) - fiat_table = fiat_table.reshape(space_dimension, value_size, -1) + fiat_table = fiat_table.reshape(*index_shape, *value_shape, -1) fiat_table = fiat_table[..., 0] elif derivative > self.degree: # Make sure numerics satisfies theory if fiat_table.dtype != object: assert np.allclose(fiat_table, 0.0) - fiat_table = np.zeros((space_dimension, value_size)) + fiat_table = np.zeros(index_shape + value_shape) else: point_indices = ps.indices - point_shape = tuple(index.extent for index in point_indices) - table_shape = index_shape + value_shape + point_shape - table_indices = basis_indices + point_indices + point_shape = tuple(i.extent for i in point_indices) + fiat_table = fiat_table.reshape(index_shape + value_shape + point_shape) - fiat_table = fiat_table.reshape(table_shape) gem_table = gem.as_gem(fiat_table) - - expr = gem.Indexed(gem_table, table_indices) + expr = gem.Indexed(gem_table, basis_indices + point_indices) expr = gem.ComponentTensor(expr, basis_indices) result[alpha] = expr - return result def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None): diff --git a/gem/gem.py b/gem/gem.py index efa46cdd9..46a85c4e0 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -287,7 +287,6 @@ class Literal(Constant): __back__ = ('dtype',) def __new__(cls, array, dtype=None): - array = asarray(array) return super(Literal, cls).__new__(cls) def __init__(self, array, dtype=None): @@ -568,8 +567,8 @@ def __init__(self, a, b): self.children = a, b -class Conditional(Node): - __slots__ = ('children', 'shape') +class Conditional(Scalar): + __slots__ = ('children',) def __new__(cls, condition, then, else_): assert not condition.shape @@ -582,7 +581,6 @@ def __new__(cls, condition, then, else_): self = super(Conditional, cls).__new__(cls) self.children = condition, then, else_ - self.shape = then.shape self.dtype = Node.inherit_dtype_from_children((then, else_)) return self @@ -1293,7 +1291,11 @@ def as_gem(expr): elif isinstance(expr, (bool, numpy.bool)): return Literal(bool(expr)) elif isinstance(expr, numpy.ndarray): - return ListTensor(expr) if expr.dtype == object else Literal(expr) + if expr.dtype == object: + expr = numpy.vectorize(as_gem)(expr) + return ListTensor(expr) + else: + return Literal(expr) else: raise ValueError("Do not know how to convert %r to GEM" % expr) From 4d528b0565e9d9430990d93e56537e97512b741f Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 26 Aug 2025 10:19:39 +0100 Subject: [PATCH 09/27] Node.__neg__ --- FIAT/barycentric_interpolation.py | 2 +- finat/argyris.py | 6 +++--- finat/bell.py | 4 ++-- finat/hct.py | 2 +- finat/piola_mapped.py | 6 +++--- gem/gem.py | 8 +++++--- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/FIAT/barycentric_interpolation.py b/FIAT/barycentric_interpolation.py index 6b3636773..ba958b6ea 100644 --- a/FIAT/barycentric_interpolation.py +++ b/FIAT/barycentric_interpolation.py @@ -35,7 +35,7 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0): if pts.dtype == object: import gem if any(isinstance(Xi, gem.Node) for Xi in pts.flat): - one = gem.Literal(1.0) + one = gem.as_gem(1.0) for i, u in numpy.ndenumerate(phi): phi[i] = gem.Conditional(gem.Comparison("!=", u, u), one, u) else: diff --git a/finat/argyris.py b/finat/argyris.py index 562ae41b2..e582d0177 100644 --- a/finat/argyris.py +++ b/finat/argyris.py @@ -95,7 +95,7 @@ def _edge_transform(V, vorder, eorder, fiat_cell, coordinate_mapping, avg=False) V[s, v1id] = P1 * Bnt V[s, v0id] = P0 * Bnt if k > 0: - V[s, s + eorder] = -1 * Bnt + V[s, s + eorder] = -Bnt class Argyris(PhysicallyMappedElement, ScalarFiatElement): @@ -139,7 +139,7 @@ def basis_transformation(self, coordinate_mapping): # vertex points V[s, v1id] = 15/8 * Bnt - V[s, v0id] = -1 * V[s, v1id] + V[s, v0id] = -V[s, v1id] # vertex derivatives for i in range(sd): @@ -150,7 +150,7 @@ def basis_transformation(self, coordinate_mapping): tau = [Jt[0]*Jt[0], 2*Jt[0]*Jt[1], Jt[1]*Jt[1]] for i in range(len(tau)): V[s, v1id+3+i] = 1/32 * Bnt * tau[i] - V[s, v0id+3+i] = -1 * V[s, v1id+3+i] + V[s, v0id+3+i] = -V[s, v1id+3+i] # Patch up conditioning h = coordinate_mapping.cell_size() diff --git a/finat/bell.py b/finat/bell.py index 98d4aecc6..69422c44b 100644 --- a/finat/bell.py +++ b/finat/bell.py @@ -44,7 +44,7 @@ def basis_transformation(self, coordinate_mapping): # vertex points V[s, v1id] = 1/21 * Bnt - V[s, v0id] = -1 * V[s, v1id] + V[s, v0id] = -V[s, v1id] # vertex derivatives for i in range(sd): @@ -55,7 +55,7 @@ def basis_transformation(self, coordinate_mapping): tau = [Jt[0]*Jt[0], 2*Jt[0]*Jt[1], Jt[1]*Jt[1]] for i in range(len(tau)): V[s, v1id+3+i] = 1/252 * Bnt * tau[i] - V[s, v0id+3+i] = -1 * V[s, v1id+3+i] + V[s, v0id+3+i] = -V[s, v1id+3+i] # Patch up conditioning h = coordinate_mapping.cell_size() diff --git a/finat/hct.py b/finat/hct.py index efcb131b6..722d36dd8 100644 --- a/finat/hct.py +++ b/finat/hct.py @@ -72,7 +72,7 @@ def basis_transformation(self, coordinate_mapping): # vertex points V[s, v0id] = 1/5 * Bnt - V[s, v1id] = -1 * V[s, v0id] + V[s, v1id] = -V[s, v0id] # vertex derivatives for i in range(sd): diff --git a/finat/piola_mapped.py b/finat/piola_mapped.py index c9557224f..b948ce31f 100644 --- a/finat/piola_mapped.py +++ b/finat/piola_mapped.py @@ -61,7 +61,7 @@ def normal_tangential_edge_transform(fiat_cell, J, detJ, f): alpha = Jn @ Jt beta = Jt @ Jt # Compute the last row of inv([[1, 0], [alpha/detJ, beta/detJ]]) - row = (-1 * alpha / beta, detJ / beta) + row = (-alpha / beta, detJ / beta) return row @@ -84,8 +84,8 @@ def normal_tangential_face_transform(fiat_cell, J, detJ, f): det1 = A[2, 0] * A[0, 1] - A[2, 1] * A[0, 0] det2 = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0] scale = detJ / det0 - rows = ((-1 * det1 / det0, -1 * scale * A[2, 1], scale * A[2, 0]), - (-1 * det2 / det0, scale * A[1, 1], -1 * scale * A[1, 0])) + rows = ((-det1 / det0, -scale * A[2, 1], scale * A[2, 0]), + (-det2 / det0, scale * A[1, 1], -scale * A[1, 0])) return rows diff --git a/gem/gem.py b/gem/gem.py index 46a85c4e0..55cd4f014 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -88,6 +88,9 @@ def __getitem__(self, indices): indices = (indices, ) return Indexed(self, indices) + def __neg__(self): + return componentwise(Product, minus, self) + def __add__(self, other): return componentwise(Sum, self, as_gem(other)) @@ -95,9 +98,7 @@ def __radd__(self, other): return as_gem(other).__add__(self) def __sub__(self, other): - return componentwise( - Sum, self, - componentwise(Product, Literal(-1), as_gem(other))) + return componentwise(Sum, self, -as_gem(other)) def __rsub__(self, other): return as_gem(other).__sub__(self) @@ -1231,6 +1232,7 @@ def view(expression, *slices): # Static one object for quicker constant folding one = Literal(1) +minus = Literal(-1) # Syntax sugar From 572af821b85086099270e5fd992f8c95ed68297d Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 26 Aug 2025 16:39:42 +0100 Subject: [PATCH 10/27] Do not use barycentric interpolation at unknown points --- FIAT/barycentric_interpolation.py | 32 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/FIAT/barycentric_interpolation.py b/FIAT/barycentric_interpolation.py index ba958b6ea..c2807a74f 100644 --- a/FIAT/barycentric_interpolation.py +++ b/FIAT/barycentric_interpolation.py @@ -24,30 +24,26 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0): via the second barycentric interpolation formula. See Berrut and Trefethen (2004) https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4) """ - sp_simplify = lambda x: x - phi = numpy.add.outer(-nodes, pts.flatten()) - with numpy.errstate(divide='ignore', invalid='ignore'): - numpy.divide(wts[:, None], phi, out=phi) - numpy.divide(phi, numpy.sum(phi, axis=0, keepdims=True), out=phi) - phi = phi.reshape(-1, *pts.shape[:-1]) - - # Replace 0/0 with 1.0 if pts.dtype == object: - import gem - if any(isinstance(Xi, gem.Node) for Xi in pts.flat): - one = gem.as_gem(1.0) - for i, u in numpy.ndenumerate(phi): - phi[i] = gem.Conditional(gem.Comparison("!=", u, u), one, u) - else: - from sympy import simplify - sp_simplify = numpy.vectorize(simplify) + # Do not use barycentric interpolation at unknown points + phi = numpy.add.outer(-nodes, pts.flatten()) + n = len(nodes) + phis = [wts[i] * numpy.prod(phi[[*range(0, i), *range(i+1, n)]], axis=0) + for i in range(n)] + phi = numpy.asarray(phis) else: + # Use the second barycentric interpolation formula + phi = numpy.add.outer(-nodes, pts.flatten()) + with numpy.errstate(divide='ignore', invalid='ignore'): + numpy.divide(wts[:, None], phi, out=phi) + numpy.divide(phi, numpy.sum(phi, axis=0, keepdims=True), out=phi) + # Replace nan with one phi[phi != phi] = 1.0 - phi = sp_simplify(phi) + phi = phi.reshape(-1, *pts.shape[:-1]) results = {(0,): phi} for r in range(1, order+1): - phi = sp_simplify(numpy.dot(dmat, phi)) + phi = numpy.dot(dmat, phi) results[(r,)] = phi return results From d82ba64dcbc03194eccea9aeb8d5ceb94500b361 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 27 Aug 2025 10:40:09 +0100 Subject: [PATCH 11/27] comments --- finat/fiat_elements.py | 9 ++++++--- finat/point_set.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 902664205..b4794cb17 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -128,19 +128,22 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): derivative = sum(alpha) point_indices = () if derivative == self.degree and not self.complex.is_macrocell(): - # Make sure numerics satisfies theory + # Ensure a cellwise constant tabulation if fiat_table.dtype == object: + # Eliminate Variables by forcing numerical evaluation bindings = {X: np.zeros(X.shape) for pt in ps.points for X in gem.extract_type(pt, gem.Variable)} gem_table = gem.as_gem(fiat_table) + ndim = len(gem_table.free_indices) val, = gem.interpreter.evaluate((gem_table,), bindings=bindings) - fiat_table = val.arr.transpose((*range(1, val.arr.ndim), 0)) + fiat_table = val.arr.transpose((*range(ndim, val.arr.ndim), *range(ndim))) fiat_table = fiat_table.reshape(*index_shape, *value_shape, -1) + assert np.allclose(fiat_table, fiat_table[..., 0, None]) fiat_table = fiat_table[..., 0] elif derivative > self.degree: - # Make sure numerics satisfies theory + # Ensure a zero tabulation if fiat_table.dtype != object: assert np.allclose(fiat_table, 0.0) fiat_table = np.zeros(index_shape + value_shape) diff --git a/finat/point_set.py b/finat/point_set.py index 9e91faae4..6781576c1 100644 --- a/finat/point_set.py +++ b/finat/point_set.py @@ -76,7 +76,7 @@ def points(self): @cached_property def expression(self): - return gem.Literal(self.point) + return gem.as_gem(self.point) class UnknownPointsArray(): From d743cb897656aa66528c54724cf99dd1e16941e5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 27 Aug 2025 11:45:58 +0100 Subject: [PATCH 12/27] Grab Variables from ps.expression --- finat/fiat_elements.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index b4794cb17..5a0a2b50f 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -132,8 +132,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): if fiat_table.dtype == object: # Eliminate Variables by forcing numerical evaluation bindings = {X: np.zeros(X.shape) - for pt in ps.points - for X in gem.extract_type(pt, gem.Variable)} + for X in gem.extract_type(ps.expression, gem.Variable)} gem_table = gem.as_gem(fiat_table) ndim = len(gem_table.free_indices) val, = gem.interpreter.evaluate((gem_table,), bindings=bindings) From accabdb2a13a50fa5ed3337f8f1370749f106696 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 27 Aug 2025 11:50:27 +0100 Subject: [PATCH 13/27] Fix up --- finat/fiat_elements.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 5a0a2b50f..0277e77f7 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -121,8 +121,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): result = {} for alpha, fiat_table in fiat_result.items(): if isinstance(fiat_table, Exception): - shape = ps.points.shape[:-1] + index_shape + value_shape - result[alpha] = gem.Failure(shape, fiat_table) + result[alpha] = gem.Failure(index_shape + value_shape, fiat_table) continue derivative = sum(alpha) From 32250438847afcb1ea6ca322853ab64378b03292 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 27 Aug 2025 15:41:20 +0100 Subject: [PATCH 14/27] Evaluate FlexiblyIndexed --- finat/fiat_elements.py | 2 +- gem/interpreter.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 0277e77f7..805845b64 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -130,7 +130,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): # Ensure a cellwise constant tabulation if fiat_table.dtype == object: # Eliminate Variables by forcing numerical evaluation - bindings = {X: np.zeros(X.shape) + bindings = {X: np.random.random_sample(X.shape) for X in gem.extract_type(ps.expression, gem.Variable)} gem_table = gem.as_gem(fiat_table) ndim = len(gem_table.free_indices) diff --git a/gem/interpreter.py b/gem/interpreter.py index 13eeb44a2..a8ac48b6e 100644 --- a/gem/interpreter.py +++ b/gem/interpreter.py @@ -4,6 +4,7 @@ import numpy import operator from collections import OrderedDict +from numbers import Integral from functools import singledispatch import itertools @@ -286,6 +287,29 @@ def _evaluate_indexed(e, self): return Result(val[idx], val.fids + fids) +@_evaluate.register(gem.FlexiblyIndexed) +def _evaluate_flexiblyindexed(e, self): + val = self(e.children[0]) + + assert all(isinstance(index, gem.Index) and isinstance(stride, Integral) + for (offset, idxs) in e.dim2idxs for (index, stride) in idxs) + + fids = e.index_ordering() + shape = tuple(i.extent for i in fids) + strides = (1, *numpy.cumprod(val.arr.shape, dtype=int)) + dim2idxs = e.dim2idxs + + def index_mapping(idx): + indices = iter(idx) + cur = sum(s * sum((j*stride for (index, stride), j in zip(idxs, indices)), offset) + for (offset, idxs), s in zip(dim2idxs, strides)) + return cur + + vidx = list(map(index_mapping, numpy.ndindex(shape))) + sub = val.arr.flatten()[vidx].reshape(shape) + return Result(sub, val.fids + fids) + + @_evaluate.register(gem.ComponentTensor) def _evaluate_componenttensor(e, self): """Component tensors map free indices to shape.""" From 4d008154cfab9bc4d16fff3dbc9e43d2dc6564ad Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 28 Aug 2025 16:03:05 +0100 Subject: [PATCH 15/27] replace indices symbolically --- finat/fiat_elements.py | 23 ++++++++---------- gem/interpreter.py | 55 ++++++++++++++++++++++-------------------- gem/optimise.py | 4 +-- 3 files changed, 41 insertions(+), 41 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 805845b64..b32e81add 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -124,22 +124,17 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): result[alpha] = gem.Failure(index_shape + value_shape, fiat_table) continue - derivative = sum(alpha) point_indices = () + replace_indices = () + derivative = sum(alpha) if derivative == self.degree and not self.complex.is_macrocell(): # Ensure a cellwise constant tabulation if fiat_table.dtype == object: - # Eliminate Variables by forcing numerical evaluation - bindings = {X: np.random.random_sample(X.shape) - for X in gem.extract_type(ps.expression, gem.Variable)} - gem_table = gem.as_gem(fiat_table) - ndim = len(gem_table.free_indices) - val, = gem.interpreter.evaluate((gem_table,), bindings=bindings) - fiat_table = val.arr.transpose((*range(ndim, val.arr.ndim), *range(ndim))) - - fiat_table = fiat_table.reshape(*index_shape, *value_shape, -1) - assert np.allclose(fiat_table, fiat_table[..., 0, None]) - fiat_table = fiat_table[..., 0] + replace_indices = tuple((i, 0) for i in ps.expression.free_indices) + else: + fiat_table = fiat_table.reshape(*index_shape, *value_shape, -1) + assert np.allclose(fiat_table, fiat_table[..., 0, None]) + fiat_table = fiat_table[..., 0] elif derivative > self.degree: # Ensure a zero tabulation if fiat_table.dtype != object: @@ -150,10 +145,12 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): point_shape = tuple(i.extent for i in point_indices) fiat_table = fiat_table.reshape(index_shape + value_shape + point_shape) - gem_table = gem.as_gem(fiat_table) expr = gem.Indexed(gem_table, basis_indices + point_indices) expr = gem.ComponentTensor(expr, basis_indices) + if replace_indices: + expr, = gem.optimise.remove_componenttensors((expr,), subst=replace_indices) + result[alpha] = expr return result diff --git a/gem/interpreter.py b/gem/interpreter.py index a8ac48b6e..0bc4081a9 100644 --- a/gem/interpreter.py +++ b/gem/interpreter.py @@ -4,7 +4,6 @@ import numpy import operator from collections import OrderedDict -from numbers import Integral from functools import singledispatch import itertools @@ -287,29 +286,6 @@ def _evaluate_indexed(e, self): return Result(val[idx], val.fids + fids) -@_evaluate.register(gem.FlexiblyIndexed) -def _evaluate_flexiblyindexed(e, self): - val = self(e.children[0]) - - assert all(isinstance(index, gem.Index) and isinstance(stride, Integral) - for (offset, idxs) in e.dim2idxs for (index, stride) in idxs) - - fids = e.index_ordering() - shape = tuple(i.extent for i in fids) - strides = (1, *numpy.cumprod(val.arr.shape, dtype=int)) - dim2idxs = e.dim2idxs - - def index_mapping(idx): - indices = iter(idx) - cur = sum(s * sum((j*stride for (index, stride), j in zip(idxs, indices)), offset) - for (offset, idxs), s in zip(dim2idxs, strides)) - return cur - - vidx = list(map(index_mapping, numpy.ndindex(shape))) - sub = val.arr.flatten()[vidx].reshape(shape) - return Result(sub, val.fids + fids) - - @_evaluate.register(gem.ComponentTensor) def _evaluate_componenttensor(e, self): """Component tensors map free indices to shape.""" @@ -324,12 +300,39 @@ def _evaluate_componenttensor(e, self): # Now the bound free indices for i in e.multiindex: axes.append(val.fids.index(i)) - # Now the existing shape - axes.extend(range(len(val.fshape), len(val.tshape))) return Result(numpy.transpose(val.arr, axes=axes), tuple(fids)) +@_evaluate.register(gem.FlexiblyIndexed) +def _evaluate_flexiblyindexed(e, self): + """FlexiblyIndexed first slices and then reshapes.""" + val = self(e.children[0]) + assert len(val.fids) == 0 + + idx = [] + axes = [] + for offset, idxs in e.dim2idxs: + if isinstance(offset, gem.Node): + offset = self(offset) + if len(idxs) == 0: + idx.append(offset) + continue + + indices, strides = zip(*idxs) + strides = tuple(self(stride) if isinstance(stride, gem.Node) else stride for stride in strides) + assert all(isinstance(i, gem.Index) for i in indices) + last = sum(((i.extent-1) * stride for i, stride in zip(indices, strides)), offset) + idx.append(slice(offset, last + 1)) + ndim = len(axes) + axes.extend(sorted(range(ndim, ndim + len(strides)), key=lambda i: strides[i], reverse=True)) + + fids = e.index_ordering() + shape = tuple(i.extent for i in fids) + arr = val[idx].reshape(numpy.asarray(shape)[axes]).transpose(numpy.argsort(axes)) + return Result(arr, fids) + + @_evaluate.register(gem.IndexSum) def _evaluate_indexsum(e, self): """Index sums reduce over the given axis.""" diff --git a/gem/optimise.py b/gem/optimise.py index 289ccff7d..f75a92d27 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -161,10 +161,10 @@ def filtered_replace_indices(node, self, subst): return replace_indices(node, self, filtered_subst) -def remove_componenttensors(expressions): +def remove_componenttensors(expressions, subst=()): """Removes all ComponentTensors in multi-root expression DAG.""" mapper = MemoizerArg(filtered_replace_indices) - return [mapper(expression, ()) for expression in expressions] + return [mapper(expression, subst) for expression in expressions] @singledispatch From 6cfe14631561e091bfee1ae0f82ce5aa2e7c6195 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 28 Aug 2025 17:30:18 +0100 Subject: [PATCH 16/27] docstring --- gem/interpreter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gem/interpreter.py b/gem/interpreter.py index 0bc4081a9..b0bc8cbbe 100644 --- a/gem/interpreter.py +++ b/gem/interpreter.py @@ -306,7 +306,7 @@ def _evaluate_componenttensor(e, self): @_evaluate.register(gem.FlexiblyIndexed) def _evaluate_flexiblyindexed(e, self): - """FlexiblyIndexed first slices and then reshapes.""" + """Flexibly indexed first slices and then reshapes.""" val = self(e.children[0]) assert len(val.fids) == 0 From 259787d4eef86f43e0ace0c5eeac8a77337a621e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 28 Aug 2025 18:54:27 +0100 Subject: [PATCH 17/27] Restore gem/interpreter.py --- gem/interpreter.py | 31 ++----------------------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/gem/interpreter.py b/gem/interpreter.py index b0bc8cbbe..13eeb44a2 100644 --- a/gem/interpreter.py +++ b/gem/interpreter.py @@ -300,39 +300,12 @@ def _evaluate_componenttensor(e, self): # Now the bound free indices for i in e.multiindex: axes.append(val.fids.index(i)) + # Now the existing shape + axes.extend(range(len(val.fshape), len(val.tshape))) return Result(numpy.transpose(val.arr, axes=axes), tuple(fids)) -@_evaluate.register(gem.FlexiblyIndexed) -def _evaluate_flexiblyindexed(e, self): - """Flexibly indexed first slices and then reshapes.""" - val = self(e.children[0]) - assert len(val.fids) == 0 - - idx = [] - axes = [] - for offset, idxs in e.dim2idxs: - if isinstance(offset, gem.Node): - offset = self(offset) - if len(idxs) == 0: - idx.append(offset) - continue - - indices, strides = zip(*idxs) - strides = tuple(self(stride) if isinstance(stride, gem.Node) else stride for stride in strides) - assert all(isinstance(i, gem.Index) for i in indices) - last = sum(((i.extent-1) * stride for i, stride in zip(indices, strides)), offset) - idx.append(slice(offset, last + 1)) - ndim = len(axes) - axes.extend(sorted(range(ndim, ndim + len(strides)), key=lambda i: strides[i], reverse=True)) - - fids = e.index_ordering() - shape = tuple(i.extent for i in fids) - arr = val[idx].reshape(numpy.asarray(shape)[axes]).transpose(numpy.argsort(axes)) - return Result(arr, fids) - - @_evaluate.register(gem.IndexSum) def _evaluate_indexsum(e, self): """Index sums reduce over the given axis.""" From 1044442ef970b9f7308d26371ad669ae83807e0e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 28 Aug 2025 21:55:37 +0100 Subject: [PATCH 18/27] Implement Node.__pow__ --- gem/gem.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index b0894c5e6..62108f80f 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -128,17 +128,20 @@ def __rmatmul__(self, other): def __abs__(self): return componentwise(lambda x: MathFunction("abs", x), self) + def __pow__(self, other): + return componentwise(lambda x, y: Power(x, y), self, as_gem(other)) + def __lt__(self, other): - return componentwise(lambda x, y: Comparison("<", x, y), self, other) + return componentwise(lambda x, y: Comparison("<", x, y), self, as_gem(other)) def __gt__(self, other): - return componentwise(lambda x, y: Comparison(">", x, y), self, other) + return componentwise(lambda x, y: Comparison(">", x, y), self, as_gem(other)) def __le__(self, other): - return componentwise(lambda x, y: Comparison("<=", x, y), self, other) + return componentwise(lambda x, y: Comparison("<=", x, y), self, as_gem(other)) def __ge__(self, other): - return componentwise(lambda x, y: Comparison(">=", x, y), self, other) + return componentwise(lambda x, y: Comparison(">=", x, y), self, as_gem(other)) @property def T(self): From 5b914960a6442704104386922f90fa8e5e391afb Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 28 Aug 2025 22:04:39 +0100 Subject: [PATCH 19/27] style --- FIAT/expansions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/FIAT/expansions.py b/FIAT/expansions.py index d4a4a34c4..dc50055a1 100644 --- a/FIAT/expansions.py +++ b/FIAT/expansions.py @@ -727,8 +727,7 @@ def compute_partition_of_unity(ref_el, pt, unique=True, tol=1E-12): sd = ref_el.get_spatial_dimension() top = ref_el.get_topology() # assert singleton point - pt = numpy.reshape(pt, (sd,)) - + pt = pt.reshape((sd,)) if isinstance(pt[0], gem.Node): import gem as backend else: From 394e8e0289829063883234a96e722a0b6f6089fa Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 28 Aug 2025 23:09:11 +0100 Subject: [PATCH 20/27] Fix cellwise constant for non-simplices --- finat/fiat_elements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index b32e81add..bcd1b0798 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -127,7 +127,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): point_indices = () replace_indices = () derivative = sum(alpha) - if derivative == self.degree and not self.complex.is_macrocell(): + if derivative == self.degree and self.complex.is_simplex(): # Ensure a cellwise constant tabulation if fiat_table.dtype == object: replace_indices = tuple((i, 0) for i in ps.expression.free_indices) From fd0ae42b13e98f6e701609f66a1560979ad91b5e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 29 Aug 2025 10:56:24 +0100 Subject: [PATCH 21/27] Implement Literal.__bool__ --- gem/gem.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gem/gem.py b/gem/gem.py index 62108f80f..a80b9603d 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -325,6 +325,9 @@ def value(self): def shape(self): return self.array.shape + def __bool__(self): + return bool(self.value) + class Variable(Terminal): """Symbolic variable tensor""" @@ -1346,7 +1349,7 @@ def Piecewise(*args): expr = None pieces = [] for v, c in args: - if isinstance(c, (bool, numpy.bool)) and c: + if isinstance(c, (bool, numpy.bool, Literal)) and c: expr = as_gem(v) break pieces.append((as_gem(v), as_gem(c))) From e7ccbb0b22d336060534cc1a4d08be1c4226b1dc Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 18 Oct 2025 10:24:51 +0100 Subject: [PATCH 22/27] Update FIAT/barycentric_interpolation.py --- FIAT/barycentric_interpolation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/FIAT/barycentric_interpolation.py b/FIAT/barycentric_interpolation.py index c2807a74f..879f6ffab 100644 --- a/FIAT/barycentric_interpolation.py +++ b/FIAT/barycentric_interpolation.py @@ -27,9 +27,7 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0): if pts.dtype == object: # Do not use barycentric interpolation at unknown points phi = numpy.add.outer(-nodes, pts.flatten()) - n = len(nodes) - phis = [wts[i] * numpy.prod(phi[[*range(0, i), *range(i+1, n)]], axis=0) - for i in range(n)] + phis = [wi * numpy.prod(phi[:i], axis=0) * numpy.prod(phi[i+1:], axis=0) for i, wi in enumerate(wts)] phi = numpy.asarray(phis) else: # Use the second barycentric interpolation formula From b290e56682a81cefcba2c62123d9f6859ae9ed34 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 18 Oct 2025 10:25:39 +0100 Subject: [PATCH 23/27] Update FIAT/expansions.py --- FIAT/expansions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FIAT/expansions.py b/FIAT/expansions.py index dc50055a1..09582d6ea 100644 --- a/FIAT/expansions.py +++ b/FIAT/expansions.py @@ -53,7 +53,7 @@ def pad_jacobian(A, embedded_dim): def jacobi_factors(x, y, z, dx, dy, dz): fb = 0.5 * (y + z) fa = x + (fb + 1.0) - fc = fb * fb + fc = fb ** 2 dfa = dfb = dfc = None if dx is not None: dfb = 0.5 * (dy + dz) From 8868ea1fce3c01e6cdac3fe7d806b8655d3246b7 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 18 Oct 2025 10:28:37 +0100 Subject: [PATCH 24/27] import partial --- gem/gem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gem/gem.py b/gem/gem.py index a80b9603d..c1ad6ce37 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -16,7 +16,7 @@ from abc import ABCMeta from itertools import chain, repeat -from functools import reduce +from functools import partial, reduce from operator import attrgetter from numbers import Integral, Number From a76df133fa19458b9a5c3a9cb58dad24c8e5a0ce Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 18 Oct 2025 10:29:01 +0100 Subject: [PATCH 25/27] Update gem/gem.py --- gem/gem.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index c1ad6ce37..77f6e1dca 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -126,22 +126,22 @@ def __rmatmul__(self, other): return as_gem(other).__matmul__(self) def __abs__(self): - return componentwise(lambda x: MathFunction("abs", x), self) + return componentwise(partial(MathFunction, "abs"), self) def __pow__(self, other): - return componentwise(lambda x, y: Power(x, y), self, as_gem(other)) + return componentwise(Power, self, as_gem(other)) def __lt__(self, other): - return componentwise(lambda x, y: Comparison("<", x, y), self, as_gem(other)) + return componentwise(partial(Comparison, "<"), self, as_gem(other)) def __gt__(self, other): - return componentwise(lambda x, y: Comparison(">", x, y), self, as_gem(other)) + return componentwise(partial(Comparison, ">"), self, as_gem(other)) def __le__(self, other): - return componentwise(lambda x, y: Comparison("<=", x, y), self, as_gem(other)) + return componentwise(partial(Comparison, "<="), self, as_gem(other)) def __ge__(self, other): - return componentwise(lambda x, y: Comparison(">=", x, y), self, as_gem(other)) + return componentwise(partial(Comparison, ">="), self, as_gem(other)) @property def T(self): From 2b5eb2a921be75eb50fb9e0a0a09748f759af55a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 23 Oct 2025 12:00:19 +0100 Subject: [PATCH 26/27] simplify point_evaluation --- finat/fiat_elements.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index cd6952711..cb15ef00a 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -119,7 +119,6 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): expr = gem.ComponentTensor(expr, basis_indices) if replace_indices: expr, = gem.optimise.remove_componenttensors((expr,), subst=replace_indices) - result[alpha] = expr return result @@ -148,8 +147,15 @@ def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=Non # Coordinates on the reference entity (GEM) Xi = tuple(gem.Indexed(refcoords, i) for i in np.ndindex(refcoords.shape)) ps = PointSingleton(Xi) - return self.basis_evaluation(order, ps, entity=entity, - coordinate_mapping=coordinate_mapping) + result = self.basis_evaluation(order, ps, entity=entity, coordinate_mapping=coordinate_mapping) + + # Apply symbolic simplification + vals = result.values() + vals = gem.optimise.constant_fold_zero(vals) + vals = map(gem.optimise.aggressive_unroll, vals) + vals = gem.optimise.remove_componenttensors(vals) + result = dict(zip(result.keys(), vals)) + return result @cached_property def _dual_basis(self): From a935cf23088ef230c7dfb4be6e5e1a28edbf733a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 23 Oct 2025 14:52:21 +0100 Subject: [PATCH 27/27] rounding --- FIAT/expansions.py | 2 +- finat/fiat_elements.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/FIAT/expansions.py b/FIAT/expansions.py index 09582d6ea..dc50055a1 100644 --- a/FIAT/expansions.py +++ b/FIAT/expansions.py @@ -53,7 +53,7 @@ def pad_jacobian(A, embedded_dim): def jacobi_factors(x, y, z, dx, dy, dz): fb = 0.5 * (y + z) fa = x + (fb + 1.0) - fc = fb ** 2 + fc = fb * fb dfa = dfb = dfc = None if dx is not None: dfb = 0.5 * (dy + dz) diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index cb15ef00a..e78ff91ad 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -147,10 +147,12 @@ def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=Non # Coordinates on the reference entity (GEM) Xi = tuple(gem.Indexed(refcoords, i) for i in np.ndindex(refcoords.shape)) ps = PointSingleton(Xi) - result = self.basis_evaluation(order, ps, entity=entity, coordinate_mapping=coordinate_mapping) + result = self.basis_evaluation(order, ps, entity=entity, + coordinate_mapping=coordinate_mapping) # Apply symbolic simplification vals = result.values() + vals = map(gem.optimise.ffc_rounding, vals, [1E-15]*len(vals)) vals = gem.optimise.constant_fold_zero(vals) vals = map(gem.optimise.aggressive_unroll, vals) vals = gem.optimise.remove_componenttensors(vals)