Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
57e610c
GEM tabulations
pbrubeck Aug 14, 2025
e078883
Remove sympy from point_evaluation
pbrubeck Aug 20, 2025
2f7fb1c
Reuse code for tabulation at known and unknown points
pbrubeck Aug 20, 2025
2c54e72
GEM barycentric interpolation
pbrubeck Aug 21, 2025
26113d8
Fix cellwise constant case
pbrubeck Aug 21, 2025
7243809
add tests
pbrubeck Aug 21, 2025
fdefbbd
Fix transpose
pbrubeck Aug 21, 2025
9770f83
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Aug 21, 2025
6868df5
Fixup
pbrubeck Aug 22, 2025
4d528b0
Node.__neg__
pbrubeck Aug 26, 2025
572af82
Do not use barycentric interpolation at unknown points
pbrubeck Aug 26, 2025
d82ba64
comments
pbrubeck Aug 27, 2025
d743cb8
Grab Variables from ps.expression
pbrubeck Aug 27, 2025
accabdb
Fix up
pbrubeck Aug 27, 2025
3225043
Evaluate FlexiblyIndexed
pbrubeck Aug 27, 2025
8ab7002
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Aug 28, 2025
4d00815
replace indices symbolically
pbrubeck Aug 28, 2025
6cfe146
docstring
pbrubeck Aug 28, 2025
259787d
Restore gem/interpreter.py
pbrubeck Aug 28, 2025
1044442
Implement Node.__pow__
pbrubeck Aug 28, 2025
5b91496
style
pbrubeck Aug 28, 2025
394e8e0
Fix cellwise constant for non-simplices
pbrubeck Aug 28, 2025
fd0ae42
Implement Literal.__bool__
pbrubeck Aug 29, 2025
08d410f
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Oct 15, 2025
36575b5
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Oct 16, 2025
e7ccbb0
Update FIAT/barycentric_interpolation.py
pbrubeck Oct 18, 2025
b290e56
Update FIAT/expansions.py
pbrubeck Oct 18, 2025
8868ea1
import partial
pbrubeck Oct 18, 2025
a76df13
Update gem/gem.py
pbrubeck Oct 18, 2025
2b5eb2a
simplify point_evaluation
pbrubeck Oct 23, 2025
a935cf2
rounding
pbrubeck Oct 23, 2025
26d1308
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Oct 23, 2025
72eb7a1
Merge branch 'main' into pbrubeck/gem-tabulation
pbrubeck Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions FIAT/barycentric_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,23 @@ def barycentric_interpolation(nodes, wts, dmat, pts, order=0):
https://doi.org/10.1137/S0036144502417715 Eq. (4.2) & (9.4)
"""
if pts.dtype == object:
from sympy import simplify
sp_simplify = numpy.vectorize(simplify)
# Do not use barycentric interpolation at unknown points
phi = numpy.add.outer(-nodes, pts.flatten())
phis = [wi * numpy.prod(phi[:i], axis=0) * numpy.prod(phi[i+1:], axis=0) for i, wi in enumerate(wts)]
phi = numpy.asarray(phis)
else:
sp_simplify = lambda x: x
phi = numpy.add.outer(-nodes, pts.flatten())
with numpy.errstate(divide='ignore', invalid='ignore'):
numpy.reciprocal(phi, out=phi)
numpy.multiply(phi, wts[:, None], out=phi)
numpy.multiply(1.0 / numpy.sum(phi, axis=0), phi, out=phi)
phi[phi != phi] = 1.0
phi = phi.reshape(-1, *pts.shape[:-1])
# Use the second barycentric interpolation formula
phi = numpy.add.outer(-nodes, pts.flatten())
with numpy.errstate(divide='ignore', invalid='ignore'):
numpy.divide(wts[:, None], phi, out=phi)
numpy.divide(phi, numpy.sum(phi, axis=0, keepdims=True), out=phi)
# Replace nan with one
phi[phi != phi] = 1.0

phi = sp_simplify(phi)
phi = phi.reshape(-1, *pts.shape[:-1])
results = {(0,): phi}
for r in range(1, order+1):
phi = sp_simplify(numpy.dot(dmat, phi))
phi = numpy.dot(dmat, phi)
results[(r,)] = phi
return results

Expand Down
10 changes: 7 additions & 3 deletions FIAT/expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def pad_jacobian(A, embedded_dim):
def jacobi_factors(x, y, z, dx, dy, dz):
fb = 0.5 * (y + z)
fa = x + (fb + 1.0)
fc = fb ** 2
fc = fb * fb
dfa = dfb = dfc = None
if dx is not None:
dfb = 0.5 * (dy + dz)
Expand Down Expand Up @@ -723,11 +723,15 @@ def compute_partition_of_unity(ref_el, pt, unique=True, tol=1E-12):
:kwarg tol: the absolute tolerance.
:returns: a list of (weighted) characteristic functions for each subcell.
"""
from sympy import Piecewise
import gem
sd = ref_el.get_spatial_dimension()
top = ref_el.get_topology()
# assert singleton point
pt = pt.reshape((sd,))
if isinstance(pt[0], gem.Node):
import gem as backend
else:
import sympy as backend

# The distance to the nearest cell is equal to the distance to the parent cell
best = ref_el.get_parent().distance_to_point_l1(pt, rescale=True)
Expand All @@ -739,7 +743,7 @@ def compute_partition_of_unity(ref_el, pt, unique=True, tol=1E-12):
for cell in sorted(top[sd]):
# Bin points based on l1 distance
pt_near_cell = ref_el.distance_to_point_l1(pt, entity=(sd, cell), rescale=True) < tol
masks.append(Piecewise(*otherwise, (1.0, pt_near_cell), (0.0, True)))
masks.append(backend.Piecewise(*otherwise, (1.0, pt_near_cell), (0.0, True)))
if unique:
otherwise.append((0.0, pt_near_cell))
# If the point is on a facet, divide the characteristic function by the facet multiplicity
Expand Down
6 changes: 3 additions & 3 deletions finat/argyris.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _edge_transform(V, vorder, eorder, fiat_cell, coordinate_mapping, avg=False)
V[s, v1id] = P1 * Bnt
V[s, v0id] = P0 * Bnt
if k > 0:
V[s, s + eorder] = -1 * Bnt
V[s, s + eorder] = -Bnt


class Argyris(PhysicallyMappedElement, ScalarFiatElement):
Expand Down Expand Up @@ -139,7 +139,7 @@ def basis_transformation(self, coordinate_mapping):

# vertex points
V[s, v1id] = 15/8 * Bnt
V[s, v0id] = -1 * V[s, v1id]
V[s, v0id] = -V[s, v1id]

# vertex derivatives
for i in range(sd):
Expand All @@ -150,7 +150,7 @@ def basis_transformation(self, coordinate_mapping):
tau = [Jt[0]*Jt[0], 2*Jt[0]*Jt[1], Jt[1]*Jt[1]]
for i in range(len(tau)):
V[s, v1id+3+i] = 1/32 * Bnt * tau[i]
V[s, v0id+3+i] = -1 * V[s, v1id+3+i]
V[s, v0id+3+i] = -V[s, v1id+3+i]

# Patch up conditioning
h = coordinate_mapping.cell_size()
Expand Down
4 changes: 2 additions & 2 deletions finat/bell.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def basis_transformation(self, coordinate_mapping):

# vertex points
V[s, v1id] = 1/21 * Bnt
V[s, v0id] = -1 * V[s, v1id]
V[s, v0id] = -V[s, v1id]

# vertex derivatives
for i in range(sd):
Expand All @@ -55,7 +55,7 @@ def basis_transformation(self, coordinate_mapping):
tau = [Jt[0]*Jt[0], 2*Jt[0]*Jt[1], Jt[1]*Jt[1]]
for i in range(len(tau)):
V[s, v1id+3+i] = 1/252 * Bnt * tau[i]
V[s, v0id+3+i] = -1 * V[s, v1id+3+i]
V[s, v0id+3+i] = -V[s, v1id+3+i]

# Patch up conditioning
h = coordinate_mapping.cell_size()
Expand Down
145 changes: 56 additions & 89 deletions finat/fiat_elements.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import FIAT
import gem
import numpy as np
import sympy as sp
from gem.utils import cached_property

from finat.finiteelementbase import FiniteElementBase
from finat.point_set import PointSet
from finat.sympy2gem import sympy2gem
from finat.point_set import PointSet, PointSingleton


class FiatElement(FiniteElementBase):
Expand Down Expand Up @@ -67,62 +65,61 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
:param ps: the point set.
:param entity: the cell entity on which to tabulate.
'''
space_dimension = self._element.space_dimension()
value_size = np.prod(self._element.value_shape(), dtype=int)
fiat_result = self._element.tabulate(order, ps.points, entity)
result = {}
fiat_element = self._element
fiat_result = fiat_element.tabulate(order, ps.points, entity)
# In almost all cases, we have
# self.space_dimension() == self._element.space_dimension()
# But for Bell, FIAT reports 21 basis functions,
# but FInAT only 18 (because there are actually 18
# basis functions, and the additional 3 are for
# dealing with transformations between physical
# and reference space).
index_shape = (self._element.space_dimension(),)
value_shape = self.value_shape
space_dimension = fiat_element.space_dimension()
if self.space_dimension() == space_dimension:
beta = self.get_indices()
index_shape = tuple(index.extent for index in beta)
else:
index_shape = (space_dimension,)
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
basis_indices = beta + zeta

result = {}
for alpha, fiat_table in fiat_result.items():
if isinstance(fiat_table, Exception):
result[alpha] = gem.Failure(self.index_shape + self.value_shape, fiat_table)
result[alpha] = gem.Failure(index_shape + value_shape, fiat_table)
continue

point_indices = ()
replace_indices = ()
derivative = sum(alpha)
shp = (space_dimension, value_size, *ps.points.shape[:-1])
table_roll = np.moveaxis(fiat_table.reshape(shp), 0, -1)

exprs = []
for table in table_roll:
if derivative == self.degree and not self.complex.is_macrocell():
# Make sure numerics satisfies theory
exprs.append(gem.Literal(table[0]))
elif derivative > self.degree:
# Make sure numerics satisfies theory
assert np.allclose(table, 0.0)
exprs.append(gem.Literal(np.zeros(self.index_shape)))
if derivative == self.degree and self.complex.is_simplex():
# Ensure a cellwise constant tabulation
if fiat_table.dtype == object:
replace_indices = tuple((i, 0) for i in ps.expression.free_indices)
else:
point_indices = ps.indices
point_shape = tuple(index.extent for index in point_indices)

exprs.append(gem.partial_indexed(
gem.Literal(table.reshape(point_shape + index_shape)),
point_indices
))
if self.value_shape:
# As above, this extent may be different from that
# advertised by the finat element.
beta = tuple(gem.Index(extent=i) for i in index_shape)
assert len(beta) == len(self.get_indices())

zeta = self.get_value_indices()
result[alpha] = gem.ComponentTensor(
gem.Indexed(
gem.ListTensor(np.array(
[gem.Indexed(expr, beta) for expr in exprs]
).reshape(self.value_shape)),
zeta),
beta + zeta
)
fiat_table = fiat_table.reshape(*index_shape, *value_shape, -1)
assert np.allclose(fiat_table, fiat_table[..., 0, None])
fiat_table = fiat_table[..., 0]
elif derivative > self.degree:
# Ensure a zero tabulation
if fiat_table.dtype != object:
assert np.allclose(fiat_table, 0.0)
fiat_table = np.zeros(index_shape + value_shape)
else:
expr, = exprs
result[alpha] = expr
point_indices = ps.indices

point_shape = tuple(i.extent for i in point_indices)
fiat_table = fiat_table.reshape(index_shape + value_shape + point_shape)
gem_table = gem.as_gem(fiat_table)
expr = gem.Indexed(gem_table, basis_indices + point_indices)
expr = gem.ComponentTensor(expr, basis_indices)
if replace_indices:
expr, = gem.optimise.remove_componenttensors((expr,), subst=replace_indices)
result[alpha] = expr
return result

def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None):
Expand All @@ -147,7 +144,20 @@ def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=Non
esd = self.cell.construct_subelement(entity_dim).get_spatial_dimension()
assert isinstance(refcoords, gem.Node) and refcoords.shape == (esd,)

return point_evaluation(self._element, order, refcoords, (entity_dim, entity_i))
# Coordinates on the reference entity (GEM)
Xi = tuple(gem.Indexed(refcoords, i) for i in np.ndindex(refcoords.shape))
ps = PointSingleton(Xi)
result = self.basis_evaluation(order, ps, entity=entity,
coordinate_mapping=coordinate_mapping)

# Apply symbolic simplification
vals = result.values()
vals = map(gem.optimise.ffc_rounding, vals, [1E-15]*len(vals))
vals = gem.optimise.constant_fold_zero(vals)
vals = map(gem.optimise.aggressive_unroll, vals)
vals = gem.optimise.remove_componenttensors(vals)
result = dict(zip(result.keys(), vals))
return result

@cached_property
def _dual_basis(self):
Expand Down Expand Up @@ -255,49 +265,6 @@ def mapping(self):
return result


def point_evaluation(fiat_element, order, refcoords, entity):
# Coordinates on the reference entity (SymPy)
esd, = refcoords.shape
Xi = sp.symbols('X Y Z')[:esd]

space_dimension = fiat_element.space_dimension()
value_size = np.prod(fiat_element.value_shape(), dtype=int)
fiat_result = fiat_element.tabulate(order, [Xi], entity)
result = {}
for alpha, fiat_table in fiat_result.items():
if isinstance(fiat_table, Exception):
result[alpha] = gem.Failure((space_dimension,) + fiat_element.value_shape(), fiat_table)
continue

# Convert SymPy expression to GEM
mapper = gem.node.Memoizer(sympy2gem)
mapper.bindings = {s: gem.Indexed(refcoords, (i,))
for i, s in enumerate(Xi)}
gem_table = np.vectorize(mapper)(fiat_table)

table_roll = gem_table.reshape(space_dimension, value_size).transpose()

exprs = []
for table in table_roll:
exprs.append(gem.ListTensor(table.reshape(space_dimension)))
if fiat_element.value_shape():
beta = (gem.Index(extent=space_dimension),)
zeta = tuple(gem.Index(extent=d)
for d in fiat_element.value_shape())
result[alpha] = gem.ComponentTensor(
gem.Indexed(
gem.ListTensor(np.array(
[gem.Indexed(expr, beta) for expr in exprs]
).reshape(fiat_element.value_shape())),
zeta),
beta + zeta
)
else:
expr, = exprs
result[alpha] = expr
return result


class Regge(FiatElement): # symmetric matrix valued
def __init__(self, cell, degree, **kwargs):
super().__init__(FIAT.Regge(cell, degree, **kwargs))
Expand Down
2 changes: 1 addition & 1 deletion finat/hct.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def basis_transformation(self, coordinate_mapping):

# vertex points
V[s, v0id] = 1/5 * Bnt
V[s, v1id] = -1 * V[s, v0id]
V[s, v1id] = -V[s, v0id]

# vertex derivatives
for i in range(sd):
Expand Down
12 changes: 5 additions & 7 deletions finat/physically_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class MappedTabulation(Mapping):
on the requested derivatives."""

def __init__(self, M, ref_tabulation):
M = gem.optimise.aggressive_unroll(M)
M, = gem.optimise.constant_fold_zero((M,))

self.M = M
self.ref_tabulation = ref_tabulation
# we expect M to be sparse with O(1) nonzeros per row
Expand All @@ -35,9 +38,8 @@ def matvec(self, table):
exprs = [gem.ComponentTensor(gem.Sum(*(self.M.array[i, j] * phi[j] for j in js)), ii)
for i, js in enumerate(self.csr)]

val = gem.ListTensor(exprs)
# val = self.M @ table
return gem.optimise.aggressive_unroll(val)
# return gem.optimise.aggressive_unroll(self.M @ table)
return gem.ListTensor(exprs)

def __getitem__(self, alpha):
try:
Expand Down Expand Up @@ -78,10 +80,6 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
result = super().basis_evaluation(order, ps, entity=entity)
return self.map_tabulation(result, coordinate_mapping)

def point_evaluation(self, order, refcoords, entity=None, coordinate_mapping=None):
result = super().point_evaluation(order, refcoords, entity=entity)
return self.map_tabulation(result, coordinate_mapping)


class DirectlyDefinedElement(NeedsCoordinateMappingElement):
"""Base class for directly defined elements such as direct
Expand Down
6 changes: 3 additions & 3 deletions finat/piola_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def normal_tangential_edge_transform(fiat_cell, J, detJ, f):
alpha = Jn @ Jt
beta = Jt @ Jt
# Compute the last row of inv([[1, 0], [alpha/detJ, beta/detJ]])
row = (-1 * alpha / beta, detJ / beta)
row = (-alpha / beta, detJ / beta)
return row


Expand All @@ -84,8 +84,8 @@ def normal_tangential_face_transform(fiat_cell, J, detJ, f):
det1 = A[2, 0] * A[0, 1] - A[2, 1] * A[0, 0]
det2 = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0]
scale = detJ / det0
rows = ((-1 * det1 / det0, -1 * scale * A[2, 1], scale * A[2, 0]),
(-1 * det2 / det0, scale * A[1, 1], -1 * scale * A[1, 0]))
rows = ((-det1 / det0, -scale * A[2, 1], scale * A[2, 0]),
(-det2 / det0, scale * A[1, 1], -scale * A[1, 0]))
return rows


Expand Down
2 changes: 1 addition & 1 deletion finat/point_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def points(self):

@cached_property
def expression(self):
return gem.Literal(self.point)
return gem.as_gem(self.point)


class UnknownPointsArray():
Expand Down
Loading