Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 254 additions & 11 deletions src/scifem/xdmf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import defaultdict
import logging
import typing
import xml.etree.ElementTree as ET
Expand All @@ -14,6 +15,8 @@
import numpy as np
import numpy.typing as npt
import dolfinx
import basix
import ufl

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -299,9 +302,10 @@ def create_function_space_data(V: dolfinx.fem.FunctionSpace) -> FunctionSpaceDat
bs = V.dofmap.index_map_bs
index_map = V.dofmap.index_map
comm = V.mesh.comm
gdim = V.mesh.geometry.dim

return FunctionSpaceData(
points=points,
points=points[:, :gdim],
bs=bs,
num_dofs_global=index_map.size_global,
num_dofs_local=index_map.size_local,
Expand Down Expand Up @@ -374,11 +378,122 @@ def create_pointcloud(
write_hdf5_h5py(functions=functions, h5name=h5name, data=data)


def read_time_values_from_xdmf(xdmfname: os.PathLike) -> dict[float, int]:
"""Read time values from an XDMF file.

Args:
xdmfname: The name of the XDMF file.

Returns:
A dictionary with time values as keys and step numbers as values.

"""
tree = ET.parse(xdmfname)
current_name = ""
all_times_values: dict[str, dict[float, int]] = defaultdict(dict)
index = 0
for elem in tree.iter():
if elem.tag == "Attribute":
new_current_name = elem.get("Name", "")
if new_current_name != current_name:
current_name = new_current_name
index = 0
if elem.tag == "Time":
time_value = float(elem.get("Value", 0.0))
if current_name == "":
raise ValueError(f"No name found for time value {time_value}")

all_times_values[current_name][time_value] = index
index += 1

# Check if all time values are the same
first_time_values = list(all_times_values.values())[0]
for name, time_values in all_times_values.items():
if time_values != first_time_values:
raise ValueError(f"Time values for {name} are not the same")

return first_time_values


# Taken from adios4dolfinx
adios_to_numpy_dtype = {
"float": np.float32,
"double": np.float64,
"float complex": np.complex64,
"double complex": np.complex128,
"uint32_t": np.uint32,
}


def create_point_mesh(
comm: MPI.Intracomm,
points: npt.NDArray[np.floating],
cells: npt.NDArray[np.floating] | None = None,
) -> dolfinx.mesh.Mesh:
"""
Create a mesh consisting of points only.

Note:
No nodes are shared between processes.

Args:
comm: MPI communicator to create the mesh on.
points: Points local to the process in the mesh.
"""
# Create mesh topology
if cells is None:
cells = np.arange(points.shape[0], dtype=np.int32).reshape(-1, 1)
topology = dolfinx.cpp.mesh.Topology(MPI.COMM_WORLD, dolfinx.mesh.CellType.point)
num_nodes_local = cells.shape[0]
imap = dolfinx.common.IndexMap(MPI.COMM_WORLD, num_nodes_local)
local_range = imap.local_range[0]
igi = np.arange(num_nodes_local, dtype=np.int64) + local_range
topology.set_index_map(0, imap)
topology.set_connectivity(dolfinx.graph.adjacencylist(cells.astype(np.int32)), 0, 0)

# Create mesh geometry
e = basix.ufl.element("Lagrange", "point", 0, shape=(points.shape[1],))
c_el = dolfinx.fem.coordinate_element(e.basix_element)
geometry = dolfinx.mesh.create_geometry(imap, cells, c_el._cpp_object, points, igi)

# Create DOLFINx mesh
if points.dtype == np.float64:
cpp_mesh = dolfinx.cpp.mesh.Mesh_float64(comm, topology, geometry._cpp_object)
elif points.dtype == np.float32:
cpp_mesh = dolfinx.cpp.mesh.Mesh_float32(comm, topology, geometry._cpp_object)
else:
raise RuntimeError(f"Unsupported dtype for mesh {points.dtype}")
# Wrap as Python object
return dolfinx.mesh.Mesh(cpp_mesh, domain=ufl.Mesh(e))


def compute_local_range(comm: MPI.Intracomm, N: np.int64):
"""
Divide a set of `N` objects into `M` partitions, where `M` is
the size of the MPI communicator `comm`.

NOTE: If N is not divisible by the number of ranks, the first `r`
processes gets an extra value

Returns the local range of values
"""
rank = comm.rank
size = comm.size
n = N // size
r = N % size
# First r processes has one extra value
if rank < r:
return [rank * (n + 1), (rank + 1) * (n + 1)]
else:
return [rank * n + r, (rank + 1) * n + r]


class BaseXDMFFile(abc.ABC):
filename: os.PathLike
filemode: typing.Literal["r", "a", "w"]
backend: typing.Literal["h5py", "adios2"]
_data: FunctionSpaceData
_time_values: dict[float, int]

@property
@abc.abstractmethod
Expand All @@ -400,12 +515,15 @@ def __post_init__(self) -> None:
raise FileNotFoundError(f"{self.h5name} does not exist")
if not self.xdmfname.exists():
raise FileNotFoundError(f"{self.xdmfname} does not exist")

# Read time values from XDMF file

elif self.filemode == "w":
# Overwrite existing files so make sure they don't exist
self.h5name.unlink(missing_ok=True)
self.xdmfname.unlink(missing_ok=True)

self._time_values: dict[float, int] = {}
else:
raise NotImplementedError(f"Filemode {self.filemode} not supported")

def _init_backend(self) -> None:
assert self.backend in [
Expand All @@ -423,11 +541,26 @@ def _init_backend(self) -> None:
if self.backend == "h5py":
self._init_h5py()

self._init_time_values()

def _init_time_values(self) -> None:
if self.filemode == "r":
self._time_values = read_time_values_from_xdmf(self.xdmfname)
elif self.filemode == "w":
self._time_values = {}
else:
raise NotImplementedError(f"Filemode {self.filemode} not supported")

def _init_h5py(self) -> None:
logger.debug("Initializing h5py")
self._outfile = h5pyfile(
h5name=self.h5name, filemode=self.filemode, comm=self._data.comm
).__enter__()
if self.filemode == "r":
assert "Step0" in self._outfile, "Step0 not found in HDF5 file"
self._step = self._outfile["Step0"]
return None

self._step = self._outfile.create_group(np.bytes_("Step0"))
points = self._step.create_dataset(
"Points",
Expand Down Expand Up @@ -458,24 +591,64 @@ def _write_h5py(self, index: int) -> None:
)
dset[self._data.local_range[0] : self._data.local_range[1], :] = array

def _read_h5py(self, index: int) -> None:
logger.debug(f"Writing h5py at time {index}")

cells = self._step["Cells"]
points = self._step["Points"]
assert cells.shape[0] == points.shape[0]
local_range = compute_local_range(self._data.comm, cells.shape[0])
cells_local = cells[local_range[0] : local_range[1]]
points_local = points[local_range[0] : local_range[1], :]
point_mesh = create_point_mesh(
comm=self._data.comm,
points=points_local,
cells=cells_local.reshape(-1, 1),
)

if self._data.bs == 1:
shape: tuple[int, ...] = ()
else:
shape = (self._data.bs,)

V = dolfinx.fem.functionspace(point_mesh, ("DG", 0, shape))
self.vs = []
for data_name in self.data_names:
v = dolfinx.fem.Function(V, name=data_name)
# Pad array to 3D if vector space with 2 components
key = f"Values_{data_name}_{index}"
if key not in self._step:
raise ValueError(f"Variable {data_name} not found in HDF5 file")

v.x.array[:] = self._step[f"Values_{data_name}_{index}"][
local_range[0] : local_range[1], : self._data.bs
].flatten()
self.vs.append(v)

def _close_h5py(self) -> None:
logger.debug("Closing HDF5 file")
self._outfile.close()

def _open_adios(self):
self._adios = self.adios2.ADIOS(self._data.comm)
self._io = self._adios.DeclareIO("Point cloud writer")
self._io.SetEngine("HDF5")
mode = self.adios2.Mode.Write if self.filemode == "w" else self.adios2.Mode.Read
self._outfile = self._io.Open(self.h5name.as_posix(), mode)

def _init_adios(self) -> None:
logger.debug("Initializing ADIOS2")
import adios2

def resolve_adios_scope(adios2):
return adios2 if not hasattr(adios2, "bindings") else adios2.bindings

adios2 = resolve_adios_scope(adios2)
self.adios2 = resolve_adios_scope(adios2)

# Create ADIOS2 reader
self._adios = adios2.ADIOS(self._data.comm)
self._io = self._adios.DeclareIO("Point cloud writer")
self._io.SetEngine("HDF5")
self._outfile = self._io.Open(self.h5name.as_posix(), adios2.Mode.Write)
if self.filemode == "r":
return None
self._open_adios()
pointvar = self._io.DefineVariable(
"Points",
self._data.points_out,
Expand Down Expand Up @@ -515,6 +688,38 @@ def _write_adios(self, index: int) -> None:
self._outfile.Put(valuevar, array)
self._outfile.PerformPuts()

def _read_adios(self, index: int) -> None:
logger.debug(f"Reading adios at time {index}")
self._open_adios()
hit = False
for data_name, data_array in zip(self.data_names, self.data_arrays):
variable_name = f"Values_{data_name}_{index}"
for i in range(self._outfile.Steps()):
self._outfile.BeginStep()
if variable_name in self._io.AvailableVariables().keys():
arr = self._io.InquireVariable(variable_name)
arr_shape = arr.Shape()
vals = np.empty(arr_shape, dtype=adios_to_numpy_dtype[arr.Type()])

self._outfile.Get(arr, vals, self.adios2.Mode.Sync)
start = self._data.local_range[0]
end = self._data.local_range[0] + self._data.num_dofs_local
data_array[: self._data.num_dofs_local * self._data.bs] = vals[
start:end, : self._data.bs
].flatten()
hit = True
self._outfile.EndStep()
break
else:
self._outfile.EndStep()
else:
self._outfile.EndStep()
break

self._close_adios()
if not hit:
raise ValueError(f"Variable {variable_name} not found in ADIOS2 file")

def _close_adios(self) -> None:
logger.debug("Closing ADIOS2 file")
try:
Expand Down Expand Up @@ -550,7 +755,12 @@ def _write_xdmf(self) -> None:
top_data.attrib["Format"] = "HDF"
top_data.text = f"{self.h5name.name}:/Step0/Cells"
geometry = ET.SubElement(grid, "Geometry")
geometry.attrib["GeometryType"] = "XYZ"
if self._data.points.shape[1] == 2:
geometry.attrib["GeometryType"] = "XY"
elif self._data.points.shape[1] == 3:
geometry.attrib["GeometryType"] = "XYZ"
else:
raise ValueError(f"Unsupported geometry type {self._data.points.shape[1]}")
it0 = ET.SubElement(geometry, "DataItem")
it0.attrib["Dimensions"] = f"{self._data.num_dofs_global} {self._data.points.shape[1]}"
it0.attrib["Format"] = "HDF"
Expand All @@ -569,10 +779,10 @@ def _write_xdmf(self) -> None:
xp.attrib["xpointer"] = (
"xpointer(/Xdmf/Domain/Grid[@GridType='Uniform'][1]/*[self::Topology or self::Geometry])" # noqa: E501
)
time = ET.SubElement(ugrid, "Time")
time.attrib["Value"] = str(time_value)
attrib = ET.SubElement(ugrid, "Attribute")
attrib.attrib["Name"] = name
time = ET.SubElement(ugrid, "Time")
time.attrib["Value"] = str(time_value)
out_bs = self._data.bs
if out_bs == 1:
attrib.attrib["AttributeType"] = "Scalar"
Expand Down Expand Up @@ -626,6 +836,27 @@ def write(self, time: float) -> None:
elif self.backend == "h5py":
self._write_h5py(index)

def read(self, time: float) -> None:
"""Read the point cloud at a given time.

Args:
time: The time value.

"""
logger.debug(f"Writing time {time}")
time = float(time)
if time not in self._time_values:
msg = f"Time {time} not found in file."
logger.warning(msg)
return

index = self._time_values[time]

if self.backend == "adios2":
self._read_adios(index)
elif self.backend == "h5py":
self._read_h5py(index)


class XDMFFile(BaseXDMFFile):
def __init__(
Expand Down Expand Up @@ -661,6 +892,18 @@ def data_names(self) -> list[str]:
def data_arrays(self) -> list[npt.NDArray[np.floating]]:
return [f.x.array for f in self.functions]

def _read_h5py(self, index: int) -> None:
super()._read_h5py(index)
for v, f in zip(self.vs, self.functions):
cell_map = f.function_space.mesh.topology.index_map(f.function_space.mesh.topology.dim)
num_cells = cell_map.size_local + cell_map.num_ghosts
cells = np.arange(num_cells, dtype=np.int32)
data = dolfinx.fem.create_interpolation_data(
f.function_space, v.function_space, cells, padding=1e-10
)

f.interpolate_nonmatching(v, cells, data)


class NumpyXDMFFile(BaseXDMFFile):
def __init__(
Expand Down
Loading
Loading