diff --git a/benchmarks/2d/uniaxial_nodal_forces/mpm-nodal-forces.toml b/benchmarks/2d/uniaxial_nodal_forces/mpm-nodal-forces.toml index 1e7ef1a..ffa4a96 100644 --- a/benchmarks/2d/uniaxial_nodal_forces/mpm-nodal-forces.toml +++ b/benchmarks/2d/uniaxial_nodal_forces/mpm-nodal-forces.toml @@ -33,7 +33,7 @@ type = "generator" nelements = [3, 1] element_length = [0.1, 0.1] particle_element_ids = [0] -element = "Quadrilateral4Node" +element = "Quad4N" entity_sets = "entity_sets.json" [[mesh.constraints]] @@ -46,7 +46,7 @@ id = 0 density = 1000 poisson_ratio = 0 youngs_modulus = 1000000 -type = "LinearElastic" +type = "linear_elastic" [[particles]] file = "particles-2d-nodal-force.json" diff --git a/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py b/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py index ae72923..2ece328 100644 --- a/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py +++ b/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py @@ -1,9 +1,12 @@ import os from pathlib import Path +import jax + +jax.config.update("jax_platform_name", "cpu") import jax.numpy as jnp -from diffmpm import MPM +from diffmpm.mpm import MPM def test_benchmarks(): @@ -32,3 +35,7 @@ def test_benchmarks(): result = jnp.load("results/uniaxial-nodal-forces/particles_0990.npz") assert jnp.round(result["stress"][0, :, 0].min() - 0.9999990078443788, 5) == 0.0 assert jnp.round(result["stress"][0, :, 0].max() - 0.9999990292713694, 5) == 0.0 + + +if __name__ == "__main__": + test_benchmarks() diff --git a/benchmarks/2d/uniaxial_particle_traction/mpm-particle-traction.toml b/benchmarks/2d/uniaxial_particle_traction/mpm-particle-traction.toml index 480ec4e..4bfec27 100644 --- a/benchmarks/2d/uniaxial_particle_traction/mpm-particle-traction.toml +++ b/benchmarks/2d/uniaxial_particle_traction/mpm-particle-traction.toml @@ -33,7 +33,7 @@ type = "generator" nelements = [3, 1] element_length = [0.1, 0.1] particle_element_ids = [0] -element = "Quadrilateral4Node" +element = "Quad4N" entity_sets = "entity_sets.json" [[mesh.constraints]] @@ -46,7 +46,7 @@ id = 0 density = 1000 poisson_ratio = 0 youngs_modulus = 1000000 -type = "LinearElastic" +type = "linear_elastic" [[particles]] file = "particles-2d-particle-traction.json" diff --git a/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py b/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py index 356d0a3..cec2a34 100644 --- a/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py +++ b/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py @@ -1,9 +1,12 @@ import os from pathlib import Path +import jax + +jax.config.update("jax_platform_name", "cpu") import jax.numpy as jnp -from diffmpm import MPM +from diffmpm.mpm import MPM def test_benchmarks(): @@ -31,3 +34,7 @@ def test_benchmarks(): result = jnp.load("results/uniaxial-particle-traction/particles_0990.npz") assert jnp.round(result["stress"][0, :, 0].min() - 0.750002924022295, 5) == 0.0 assert jnp.round(result["stress"][0, :, 0].max() - 0.9999997782938734, 5) == 0.0 + + +if __name__ == "__main__": + test_benchmarks() diff --git a/benchmarks/2d/uniaxial_stress/mpm-uniaxial-stress.toml b/benchmarks/2d/uniaxial_stress/mpm-uniaxial-stress.toml index 3e074cd..39c141c 100644 --- a/benchmarks/2d/uniaxial_stress/mpm-uniaxial-stress.toml +++ b/benchmarks/2d/uniaxial_stress/mpm-uniaxial-stress.toml @@ -29,7 +29,7 @@ type = "generator" nelements = [1, 1] element_length = [1, 1] particle_element_ids = [0] -element = "Quadrilateral4Node" +element = "Quad4N" entity_sets = "entity_sets.json" [[mesh.constraints]] @@ -47,7 +47,7 @@ id = 0 density = 1 poisson_ratio = 0 youngs_modulus = 1000 -type = "LinearElastic" +type = "linear_elastic" [[particles]] file = "particles-2d-uniaxial-stress.json" diff --git a/benchmarks/2d/uniaxial_stress/test_benchmark.py b/benchmarks/2d/uniaxial_stress/test_benchmark.py index f04e820..f17e4fe 100644 --- a/benchmarks/2d/uniaxial_stress/test_benchmark.py +++ b/benchmarks/2d/uniaxial_stress/test_benchmark.py @@ -1,9 +1,13 @@ import os from pathlib import Path +import jax + +jax.config.update("jax_platform_name", "cpu") + import jax.numpy as jnp -from diffmpm import MPM +from diffmpm.mpm import MPM def test_benchmarks(): @@ -19,3 +23,7 @@ def test_benchmarks(): assert jnp.round(result["stress"][0, :, 1].max() - true_stress_yy, 8) == 0.0 assert jnp.round(result["stress"][0, :, 0].max() - true_stress_xx, 8) == 0.0 + + +if __name__ == "__main__": + test_benchmarks() diff --git a/diffmpm/__init__.py b/diffmpm/__init__.py index faa8316..a138300 100644 --- a/diffmpm/__init__.py +++ b/diffmpm/__init__.py @@ -1,47 +1,5 @@ from importlib.metadata import version -from pathlib import Path -import diffmpm.writers as writers -from diffmpm.io import Config -from diffmpm.solver import MPMExplicit - -__all__ = ["MPM", "__version__"] +__all__ = ["__version__"] __version__ = version("diffmpm") - - -class MPM: - def __init__(self, filepath): - self._config = Config(filepath) - mesh = self._config.parse() - out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath( - self._config.parsed_config["meta"]["title"], - ) - - write_format = self._config.parsed_config["output"].get("format", None) - if write_format is None or write_format.lower() == "none": - writer_func = None - elif write_format == "npz": - writer_func = writers.NPZWriter().write - else: - raise ValueError(f"Specified output format not supported: {write_format}") - - if self._config.parsed_config["meta"]["type"] == "MPMExplicit": - self.solver = MPMExplicit( - mesh, - self._config.parsed_config["meta"]["dt"], - velocity_update=self._config.parsed_config["meta"]["velocity_update"], - sim_steps=self._config.parsed_config["meta"]["nsteps"], - out_steps=self._config.parsed_config["output"]["step_frequency"], - out_dir=out_dir, - writer_func=writer_func, - ) - else: - raise ValueError("Wrong type of solver specified.") - - def solve(self): - """Solve the MPM simulation using JIT solver.""" - arrays = self.solver.solve_jit( - self._config.parsed_config["external_loading"]["gravity"], - ) - return arrays diff --git a/diffmpm/cli/mpm.py b/diffmpm/cli/mpm.py index aebc4ba..0b4b9d7 100644 --- a/diffmpm/cli/mpm.py +++ b/diffmpm/cli/mpm.py @@ -1,6 +1,6 @@ import click -from diffmpm import MPM +from diffmpm.mpm import MPM @click.command() # type: ignore diff --git a/diffmpm/constraint.py b/diffmpm/constraint.py index 93f75bd..a5310cb 100644 --- a/diffmpm/constraint.py +++ b/diffmpm/constraint.py @@ -26,19 +26,44 @@ def tree_unflatten(cls, aux_data, children): del children return cls(*aux_data) - def apply(self, obj, ids): + def apply_vel(self, vel, ids): """Apply constraint values to the passed object. Parameters ---------- - obj : diffmpm.node.Nodes, diffmpm.particle.Particles + obj : diffmpm.node.Nodes, diffmpm.particle._ParticlesState Object on which the constraint is applied ids : array_like The indices of the container `obj` on which the constraint will be applied. """ - obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity) - obj.momentum = obj.momentum.at[ids, :, self.dir].set( - obj.mass[ids, :, 0] * self.velocity - ) - obj.acceleration = obj.acceleration.at[ids, :, self.dir].set(0) + velocity = vel.at[ids, :, self.dir].set(self.velocity) + return velocity + + def apply_mom(self, mom, mass, ids): + """Apply constraint values to the passed object. + + Parameters + ---------- + obj : diffmpm.node.Nodes, diffmpm.particle._ParticlesState + Object on which the constraint is applied + ids : array_like + The indices of the container `obj` on which the constraint + will be applied. + """ + momentum = mom.at[ids, :, self.dir].set(mass[ids, :, 0] * self.velocity) + return momentum + + def apply_acc(self, acc, ids): + """Apply constraint values to the passed object. + + Parameters + ---------- + obj : diffmpm.node.Nodes, diffmpm.particle._ParticlesState + Object on which the constraint is applied + ids : array_like + The indices of the container `obj` on which the constraint + will be applied. + """ + acceleration = acc.at[ids, :, self.dir].set(0) + return acceleration diff --git a/diffmpm/element.py b/diffmpm/element.py index 3eeff67..6108a31 100644 --- a/diffmpm/element.py +++ b/diffmpm/element.py @@ -2,49 +2,108 @@ import abc import itertools +from functools import partial from typing import TYPE_CHECKING, Optional, Sequence, Tuple if TYPE_CHECKING: - from diffmpm.particle import Particles + from diffmpm.particle import _ParticlesState import jax.numpy as jnp -from jax import Array, jacobian, jit, lax, vmap -from jax.tree_util import register_pytree_node_class +from jax import Array, jacobian, jit, lax, tree_util, vmap +from jax.tree_util import register_pytree_node_class, tree_map, tree_reduce, Partial from jax.typing import ArrayLike from diffmpm.constraint import Constraint -from diffmpm.node import Nodes +from diffmpm.forces import NodalForce +from diffmpm.node import _NodesState, init_node_state +import chex -__all__ = ["_Element", "Linear1D", "Quadrilateral4Node"] +@chex.dataclass() +class _ElementsState: + nodes: _NodesState + total_elements: int + volume: chex.ArrayDevice + constraints: Sequence[Tuple[ArrayLike, Constraint]] + concentrated_nodal_forces: Sequence[NodalForce] -class _Element(abc.ABC): - """Base element class that is inherited by all types of Elements.""" - nodes: Nodes - total_elements: int - concentrated_nodal_forces: Sequence - volume: Array +@chex.dataclass() +class Quad4NState(_ElementsState): + nelements: chex.ArrayDevice + el_len: chex.ArrayDevice - @abc.abstractmethod - def id_to_node_ids(self, id: ArrayLike) -> Array: - """Node IDs corresponding to element `id`. - This method is implemented by each of the subclass. +@chex.dataclass() +class Quad4N: + total_elements: int + + def init_state( + self, + nelements: int, + total_elements: int, + el_len: float, + constraints: Sequence[Tuple[ArrayLike, Constraint]], + nodes: Optional[_NodesState] = None, + concentrated_nodal_forces: Sequence = [], + initialized: Optional[bool] = None, + volume: Optional[ArrayLike] = None, + ) -> Quad4NState: + """Initialize Linear1D. Parameters ---------- - id : int - Element ID. - - Returns - ------- - ArrayLike - Nodal IDs of the element. + nelements : int + Number of elements. + total_elements : int + Total number of elements (product of all elements of `nelements`) + el_len : float + Length of each element. + constraints: list + A list of constraints where each element is a tuple of + type `(node_ids, diffmpm.Constraint)`. Here, `node_ids` + correspond to the node IDs where `diffmpm.Constraint` + should be applied. + nodes : Nodes, Optional + Nodes in the element object. + concentrated_nodal_forces: list + A list of `diffmpm.forces.NodalForce`s that are to be + applied. + initialized: bool, None + `True` if the class has been initialized, `None` if not. + This is required like this for using JAX flattening. + volume: ArrayLike + Volume of the elements. """ - ... + nelements = jnp.asarray(nelements) + el_len = jnp.asarray(el_len) + + total_nodes = jnp.prod(nelements + 1) + coords = jnp.asarray( + list( + itertools.product( + jnp.arange(nelements[1] + 1), + jnp.arange(nelements[0] + 1), + ) + ) + ) + node_locations = (jnp.asarray([coords[:, 1], coords[:, 0]]).T * el_len).reshape( + -1, 1, 2 + ) + nodes = init_node_state(int(total_nodes), node_locations) + + volume = jnp.ones((total_elements, 1, 1)) + return Quad4NState( + nodes=nodes, + total_elements=total_elements, + concentrated_nodal_forces=concentrated_nodal_forces, + volume=volume, + constraints=constraints, + nelements=nelements, + el_len=el_len, + ) - def id_to_node_loc(self, id: ArrayLike) -> Array: + def id_to_node_loc(self, elements: _ElementState, id: ArrayLike) -> Array: """Node locations corresponding to element `id`. Parameters @@ -58,10 +117,10 @@ def id_to_node_loc(self, id: ArrayLike) -> Array: Nodal locations for the element. Shape of returned array is `(nodes_in_element, 1, ndim)` """ - node_ids = self.id_to_node_ids(id).squeeze() - return self.nodes.loc[node_ids] + node_ids = self.id_to_node_ids(elements.nelements[0], id).squeeze() + return elements.nodes.loc[node_ids] - def id_to_node_vel(self, id: ArrayLike) -> Array: + def id_to_node_vel(self, elements: _ElementState, id: ArrayLike) -> Array: """Node velocities corresponding to element `id`. Parameters @@ -75,417 +134,205 @@ def id_to_node_vel(self, id: ArrayLike) -> Array: Nodal velocities for the element. Shape of returned array is `(nodes_in_element, 1, ndim)` """ - node_ids = self.id_to_node_ids(id).squeeze() - return self.nodes.velocity[node_ids] - - def tree_flatten(self): - children = (self.nodes, self.volume) - aux_data = ( - self.nelements, - self.total_elements, - self.el_len, - self.constraints, - self.concentrated_nodal_forces, - self.initialized, - ) - return children, aux_data + node_ids = self.id_to_node_ids(elements.nelements[0], id).squeeze() + return elements.nodes.velocity[node_ids] - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls( - aux_data[0], - aux_data[1], - aux_data[2], - aux_data[3], - nodes=children[0], - concentrated_nodal_forces=aux_data[4], - initialized=aux_data[5], - volume=children[1], - ) - - @abc.abstractmethod - def shapefn(self, xi: ArrayLike): - """Evaluate Shape function for element type.""" - ... - - @abc.abstractmethod - def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): - """Evaluate gradient of shape function for element type.""" - ... - - @abc.abstractmethod - def set_particle_element_ids(self, particles: Particles): - """Set the element IDs that particles are present in.""" - ... - - # Mapping from particles to nodes (P2G) - def compute_nodal_mass(self, particles: Particles): - r"""Compute the nodal mass based on particle mass. + def id_to_node_ids(self, nelements_x, id: ArrayLike): + """Node IDs corresponding to element `id`. - The nodal mass is updated as a sum of particle mass for - all particles mapped to the node. + 3----2 + | | + 0----1 - \[ - (m)_i = \sum_p N_i(x_p) m_p - \] + Node ids are returned in the order as shown in the figure. Parameters ---------- - particles: diffmpm.particle.Particles - Particles to map to the nodal values. - """ - - def _step(pid, args): - pmass, mass, mapped_pos, el_nodes = args - mass = mass.at[el_nodes[pid]].add(pmass[pid] * mapped_pos[pid]) - return pmass, mass, mapped_pos, el_nodes - - self.nodes.mass = self.nodes.mass.at[:].set(0) - mapped_positions = self.shapefn(particles.reference_loc) - mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1) - args = ( - particles.mass, - self.nodes.mass, - mapped_positions, - mapped_nodes, - ) - _, self.nodes.mass, _, _ = lax.fori_loop(0, len(particles), _step, args) - - def compute_nodal_momentum(self, particles: Particles): - r"""Compute the nodal mass based on particle mass. - - The nodal mass is updated as a sum of particle mass for - all particles mapped to the node. - - \[ - (mv)_i = \sum_p N_i(x_p) (mv)_p - \] + id : int + Element ID. - Parameters - ---------- - particles: diffmpm.particle.Particles - Particles to map to the nodal values. + Returns + ------- + ArrayLike + Nodal IDs of the element. Shape of returned + array is (4, 1) """ - - def _step(pid, args): - pmom, mom, mapped_pos, el_nodes = args - mom = mom.at[el_nodes[pid]].add(mapped_pos[pid] @ pmom[pid]) - return pmom, mom, mapped_pos, el_nodes - - self.nodes.momentum = self.nodes.momentum.at[:].set(0) - mapped_positions = self.shapefn(particles.reference_loc) - mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1) - args = ( - particles.mass * particles.velocity, - self.nodes.momentum, - mapped_positions, - mapped_nodes, - ) - _, self.nodes.momentum, _, _ = lax.fori_loop(0, len(particles), _step, args) - self.nodes.momentum = jnp.where( - jnp.abs(self.nodes.momentum) < 1e-12, - jnp.zeros_like(self.nodes.momentum), - self.nodes.momentum, - ) - - def compute_velocity(self, particles: Particles): - """Compute velocity using momentum.""" - self.nodes.velocity = jnp.where( - self.nodes.mass == 0, - self.nodes.velocity, - self.nodes.momentum / self.nodes.mass, - ) - self.nodes.velocity = jnp.where( - jnp.abs(self.nodes.velocity) < 1e-12, - jnp.zeros_like(self.nodes.velocity), - self.nodes.velocity, + lower_left = (id // nelements_x) * (nelements_x + 1) + id % nelements_x + result = jnp.asarray( + [ + lower_left, + lower_left + 1, + lower_left + nelements_x + 2, + lower_left + nelements_x + 1, + ] ) + return result.reshape(4, 1) - def compute_external_force(self, particles: Particles): - r"""Update the nodal external force based on particle f_ext. + @classmethod + def _get_mapped_nodes(cls, id, nelements_x): + """Node IDs corresponding to element `id`. - The nodal force is updated as a sum of particle external - force for all particles mapped to the node. + 3----2 + | | + 0----1 - \[ - f_{ext})_i = \sum_p N_i(x_p) f_{ext} - \] + Node ids are returned in the order as shown in the figure. Parameters ---------- - particles: diffmpm.particle.Particles - Particles to map to the nodal values. - """ - - def _step(pid, args): - f_ext, pf_ext, mapped_pos, el_nodes = args - f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ pf_ext[pid]) - return f_ext, pf_ext, mapped_pos, el_nodes + id : int + Element ID. - self.nodes.f_ext = self.nodes.f_ext.at[:].set(0) - mapped_positions = self.shapefn(particles.reference_loc) - mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1) - args = ( - self.nodes.f_ext, - particles.f_ext, - mapped_positions, - mapped_nodes, + Returns + ------- + ArrayLike + Nodal IDs of the element. Shape of returned + array is (4, 1) + """ + lower_left = (id // nelements_x) * (nelements_x + 1) + id % nelements_x + result = jnp.asarray( + [ + lower_left, + lower_left + 1, + lower_left + nelements_x + 2, + lower_left + nelements_x + 1, + ] ) - self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - - def compute_body_force(self, particles: Particles, gravity: ArrayLike): - r"""Update the nodal external force based on particle mass. - - The nodal force is updated as a sum of particle body - force for all particles mapped to th + return result.reshape(4, 1) - \[ - (f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g - \] + def shapefn(self, xi: ArrayLike): + """Evaluate linear shape function. Parameters ---------- - particles: diffmpm.particle.Particles - Particles to map to the nodal values. - """ + xi : float, array_like + Locations of particles in natural coordinates to evaluate + the function at. Expected shape is (npoints, 1, ndim) - def _step(pid, args): - f_ext, pmass, mapped_pos, el_nodes, gravity = args - f_ext = f_ext.at[el_nodes[pid]].add( - mapped_pos[pid] @ (pmass[pid] * gravity) + Returns + ------- + array_like + Evaluated shape function values. The shape of the returned + array will depend on the input shape. For example, in the linear + case, if the input is a scalar, the returned array will be of + the shape `(1, 4, 1)` but if the input is a vector then the output will + be of the shape `(len(x), 4, 1)`. + """ + xi = jnp.asarray(xi) + if xi.ndim != 3: + raise ValueError( + f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}" ) - return f_ext, pmass, mapped_pos, el_nodes, gravity - - mapped_positions = self.shapefn(particles.reference_loc) - mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1) - args = ( - self.nodes.f_ext, - particles.mass, - mapped_positions, - mapped_nodes, - gravity, + result = jnp.array( + [ + 0.25 * (1 - xi[:, :, 0]) * (1 - xi[:, :, 1]), + 0.25 * (1 + xi[:, :, 0]) * (1 - xi[:, :, 1]), + 0.25 * (1 + xi[:, :, 0]) * (1 + xi[:, :, 1]), + 0.25 * (1 - xi[:, :, 0]) * (1 + xi[:, :, 1]), + ] ) - self.nodes.f_ext, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args) + result = result.transpose(1, 0, 2)[..., jnp.newaxis] + return result - def apply_concentrated_nodal_forces(self, particles: Particles, curr_time: float): - """Apply concentrated nodal forces. + @classmethod + def _shapefn(cls, xi: ArrayLike): + """Evaluate linear shape function. Parameters ---------- - particles: Particles - Particles in the simulation. - curr_time: float - Current time in the simulation. + xi : float, array_like + Locations of particles in natural coordinates to evaluate + the function at. Expected shape is (npoints, 1, ndim) + + Returns + ------- + array_like + Evaluated shape function values. The shape of the returned + array will depend on the input shape. For example, in the linear + case, if the input is a scalar, the returned array will be of + the shape `(1, 4, 1)` but if the input is a vector then the output will + be of the shape `(len(x), 4, 1)`. """ - for cnf in self.concentrated_nodal_forces: - factor = cnf.function.value(curr_time) - self.nodes.f_ext = self.nodes.f_ext.at[cnf.node_ids, 0, cnf.dir].add( - factor * cnf.force + xi = jnp.asarray(xi) + if xi.ndim != 3: + raise ValueError( + f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}" ) + result = jnp.array( + [ + 0.25 * (1 - xi[:, :, 0]) * (1 - xi[:, :, 1]), + 0.25 * (1 + xi[:, :, 0]) * (1 - xi[:, :, 1]), + 0.25 * (1 + xi[:, :, 0]) * (1 + xi[:, :, 1]), + 0.25 * (1 - xi[:, :, 0]) * (1 + xi[:, :, 1]), + ] + ) + result = result.transpose(1, 0, 2)[..., jnp.newaxis] + return result - def apply_particle_traction_forces(self, particles: Particles): - """Apply concentrated nodal forces. + @classmethod + def _shapefn_natural_grad(cls, xi: ArrayLike): + """Calculate the gradient of shape function. + + This calculation is done in the natural coordinates. Parameters ---------- - particles: Particles - Particles in the simulation. - """ - - def _step(pid, args): - f_ext, ptraction, mapped_pos, el_nodes = args - f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid]) - return f_ext, ptraction, mapped_pos, el_nodes - - mapped_positions = self.shapefn(particles.reference_loc) - mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1) - args = (self.nodes.f_ext, particles.traction, mapped_positions, mapped_nodes) - self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - - def update_nodal_acceleration_velocity( - self, particles: Particles, dt: float, *args - ): - """Update the nodal momentum based on total force on nodes.""" - total_force = self.nodes.get_total_force() - self.nodes.acceleration = self.nodes.acceleration.at[:].set( - jnp.nan_to_num(jnp.divide(total_force, self.nodes.mass)) - ) - self.nodes.velocity = self.nodes.velocity.at[:].add( - self.nodes.acceleration * dt - ) - self.apply_boundary_constraints() - self.nodes.momentum = self.nodes.momentum.at[:].set( - self.nodes.mass * self.nodes.velocity - ) - self.nodes.velocity = jnp.where( - jnp.abs(self.nodes.velocity) < 1e-12, - jnp.zeros_like(self.nodes.velocity), - self.nodes.velocity, - ) - self.nodes.acceleration = jnp.where( - jnp.abs(self.nodes.acceleration) < 1e-12, - jnp.zeros_like(self.nodes.acceleration), - self.nodes.acceleration, - ) - - def apply_boundary_constraints(self, *args): - """Apply boundary conditions for nodal velocity.""" - for ids, constraint in self.constraints: - constraint.apply(self.nodes, ids) - - def apply_force_boundary_constraints(self, *args): - """Apply boundary conditions for nodal forces.""" - self.nodes.f_int = self.nodes.f_int.at[self.constraints[0][0]].set(0) - self.nodes.f_ext = self.nodes.f_ext.at[self.constraints[0][0]].set(0) - self.nodes.f_damp = self.nodes.f_damp.at[self.constraints[0][0]].set(0) - - -@register_pytree_node_class -class Linear1D(_Element): - """Container for 1D line elements (and nodes). - - Element ID: 0 1 2 3 - Mesh: +-----+-----+-----+-----+ - Node IDs: 0 1 2 3 4 - - where - - + : Nodes - +-----+ : An element - - """ - - def __init__( - self, - nelements: int, - total_elements: int, - el_len: float, - constraints: Sequence[Tuple[ArrayLike, Constraint]], - nodes: Optional[Nodes] = None, - concentrated_nodal_forces: Sequence = [], - initialized: Optional[bool] = None, - volume: Optional[ArrayLike] = None, - ): - """Initialize Linear1D. - - Parameters - ---------- - nelements : int - Number of elements. - total_elements : int - Total number of elements (same as `nelements` for 1D) - el_len : float - Length of each element. - constraints: list - A list of constraints where each element is a tuple of type - `(node_ids, diffmpm.Constraint)`. Here, `node_ids` correspond to - the node IDs where `diffmpm.Constraint` should be applied. - nodes : Nodes, Optional - Nodes in the element object. - concentrated_nodal_forces: list - A list of `diffmpm.forces.NodalForce`s that are to be - applied. - initialized: bool, None - `True` if the class has been initialized, `None` if not. - This is required like this for using JAX flattening. - volume: ArrayLike - Volume of the elements. - """ - self.nelements = nelements - self.total_elements = nelements - self.el_len = el_len - if nodes is None: - self.nodes = Nodes( - nelements + 1, - jnp.arange(nelements + 1).reshape(-1, 1, 1) * el_len, - ) - else: - self.nodes = nodes - - # self.boundary_nodes = boundary_nodes - self.constraints = constraints - self.concentrated_nodal_forces = concentrated_nodal_forces - if initialized is None: - self.volume = jnp.ones((self.total_elements, 1, 1)) - else: - self.volume = jnp.asarray(volume) - self.initialized = True - - def id_to_node_ids(self, id: ArrayLike): - """Node IDs corresponding to element `id`. - - Parameters - ---------- - id : int - Element ID. + x : float, array_like + Locations of particles in natural coordinates to evaluate + the function at. Returns ------- - ArrayLike - Nodal IDs of the element. Shape of returned - array is `(2, 1)` + array_like + Evaluated gradient values of the shape function. The shape of + the returned array will depend on the input shape. For example, + in the linear case, if the input is a scalar, the returned array + will be of the shape `(4, 2)`. """ - return jnp.array([id, id + 1]).reshape(2, 1) + # result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze() + xi = jnp.asarray(xi) + xi = xi.squeeze() + result = jnp.array( + [ + [-0.25 * (1 - xi[1]), -0.25 * (1 - xi[0])], + [0.25 * (1 - xi[1]), -0.25 * (1 + xi[0])], + [0.25 * (1 + xi[1]), 0.25 * (1 + xi[0])], + [-0.25 * (1 + xi[1]), 0.25 * (1 - xi[0])], + ], + ) + return result - def shapefn(self, xi: ArrayLike): - """Evaluate linear shape function. + def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): + """Gradient of shape function in physical coordinates. Parameters ---------- xi : float, array_like - Locations of particles in natural coordinates to evaluate - the function at. Expected shape is `(npoints, 1, ndim)` + Locations of particles to evaluate in natural coordinates. + Expected shape `(npoints, 1, ndim)`. + coords : array_like + Nodal coordinates to transform by. Expected shape + `(npoints, 1, ndim)` Returns ------- array_like - Evaluated shape function values. The shape of the returned - array will depend on the input shape. For example, in the linear - case, if the input is a scalar, the returned array will be of - the shape `(1, 2, 1)` but if the input is a vector then the output will - be of the shape `(len(x), 2, 1)`. + Gradient of the shape function in physical coordinates at `xi` """ xi = jnp.asarray(xi) + coords = jnp.asarray(coords) if xi.ndim != 3: raise ValueError( - f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}" + f"`x` should be of size (npoints, 1, ndim); found {xi.shape}" ) - result = jnp.array([0.5 * (1 - xi), 0.5 * (1 + xi)]).transpose(1, 0, 2, 3) - return result - - def _shapefn_natural_grad(self, xi: ArrayLike): - """Calculate the gradient of shape function. - - This calculation is done in the natural coordinates. - - Parameters - ---------- - x : float, array_like - Locations of particles in natural coordinates to evaluate - the function at. + grad_sf = self._shapefn_natural_grad(xi) + _jacobian = grad_sf.T @ coords.squeeze() - Returns - ------- - array_like - Evaluated gradient values of the shape function. The shape of - the returned array will depend on the input shape. For example, - in the linear case, if the input is a scalar, the returned array - will be of the shape `(2, 1)`. - """ - xi = jnp.asarray(xi) - result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze() - - # TODO: The following code tries to evaluate vmap even if - # the predicate condition is true, not sure why. - # result = lax.cond( - # jnp.isscalar(x), - # jacobian(self.shapefn), - # vmap(jacobian(self.shapefn)), - # xi - # ) - return result.reshape(2, 1) + result = grad_sf @ jnp.linalg.inv(_jacobian).T + return result - def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): + @classmethod + def _shapefn_grad(cls, xi: ArrayLike, coords: ArrayLike): """Gradient of shape function in physical coordinates. Parameters @@ -508,44 +355,63 @@ def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): raise ValueError( f"`x` should be of size (npoints, 1, ndim); found {xi.shape}" ) - grad_sf = self._shapefn_natural_grad(xi) - _jacobian = grad_sf.T @ coords + grad_sf = cls._shapefn_natural_grad(xi) + _jacobian = grad_sf.T @ coords.squeeze() result = grad_sf @ jnp.linalg.inv(_jacobian).T return result - def set_particle_element_ids(self, particles): + @classmethod + def _get_particles_element_ids(cls, particles, elements): """Set the element IDs for the particles. If the particle doesn't lie between the boundaries of any element, it sets the element index to -1. """ - @jit - def f(x): - idl = ( - len(self.nodes.loc) - - 1 - - jnp.asarray(self.nodes.loc[::-1] <= x).nonzero(size=1, fill_value=-1)[ - 0 - ][-1] - ) - idg = ( - jnp.asarray(self.nodes.loc > x).nonzero(size=1, fill_value=-1)[0][0] - 1 - ) - return (idl, idg) + def f(x, *, loc, nelements): + xidl = (loc[:, :, 0] <= x[0, 0]).nonzero(size=loc.shape[0], fill_value=-1)[ + 0 + ] + yidl = (loc[:, :, 1] <= x[0, 1]).nonzero(size=loc.shape[0], fill_value=-1)[ + 0 + ] + lower_left = jnp.where(jnp.isin(xidl, yidl), xidl, -1).max() + element_id = lower_left - lower_left // (nelements + 1) + return element_id - ids = vmap(f)(particles.loc) - particles.element_ids = jnp.where( - ids[0] == ids[1], ids[0], jnp.ones_like(ids[0]) * -1 - ) + pf = partial(f, loc=elements.nodes.loc, nelements=elements.nelements[0]) + ids = vmap(pf)(particles.loc) + return ids - def compute_volume(self, *args): - """Compute volume of all elements.""" - vol = jnp.ediff1d(self.nodes.loc) - self.volume = jnp.ones((self.total_elements, 1, 1)) * vol + def set_particle_element_ids( + self, elements: _ElementsState, particles: _ParticlesState + ): + """Set the element IDs for the particles. + + If the particle doesn't lie between the boundaries of any + element, it sets the element index to -1. + """ + + @jit + def f(x, *, loc, nelements): + xidl = (loc[:, :, 0] <= x[0, 0]).nonzero(size=loc.shape[0], fill_value=-1)[ + 0 + ] + yidl = (loc[:, :, 1] <= x[0, 1]).nonzero(size=loc.shape[0], fill_value=-1)[ + 0 + ] + lower_left = jnp.where(jnp.isin(xidl, yidl), xidl, -1).max() + element_id = lower_left - lower_left // (nelements + 1) + return element_id - def compute_internal_force(self, particles): + pf = partial(f, loc=elements.nodes.loc, nelements=elements.nelements[0]) + ids = vmap(pf)(particles.loc) + return particles.replace(element_ids=ids) + + def compute_internal_force( + self, elements: _ElementState, particles: _ParticlesState + ): r"""Update the nodal internal force based on particle mass. The nodal force is updated as a sum of internal forces for @@ -559,10 +425,11 @@ def compute_internal_force(self, particles): Parameters ---------- - particles: diffmpm.particle.Particles + particles: diffmpm.particle._ParticlesState Particles to map to the nodal values. """ + @jit def _step(pid, args): ( f_int, @@ -571,10 +438,17 @@ def _step(pid, args): el_nodes, pstress, ) = args - # TODO: correct matrix multiplication for n-d - # update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid] - update = -pvol[pid] * pstress[pid][0] * mapped_grads[pid] - f_int = f_int.at[el_nodes[pid]].add(update[..., jnp.newaxis]) + force = jnp.zeros((mapped_grads.shape[1], 1, 2)) + force = force.at[:, 0, 0].set( + mapped_grads[pid][:, 0] * pstress[pid][0] + + mapped_grads[pid][:, 1] * pstress[pid][3] + ) + force = force.at[:, 0, 1].set( + mapped_grads[pid][:, 1] * pstress[pid][1] + + mapped_grads[pid][:, 0] * pstress[pid][3] + ) + update = -pvol[pid] * force + f_int = f_int.at[el_nodes[pid]].add(update) return ( f_int, pvol, @@ -583,327 +457,839 @@ def _step(pid, args): pstress, ) - self.nodes.f_int = self.nodes.f_int.at[:].set(0) - mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1) - mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2) - mapped_grads = vmap(self.shapefn_grad)( + # f_int = self.nodes.f_int.at[:].set(0) + f_int = elements.nodes.f_int + mapped_nodes = vmap(Partial(self.id_to_node_ids, elements.nelements[0]))( + particles.element_ids + ).squeeze(-1) + mapped_coords = vmap(Partial(self.id_to_node_loc, elements))( + particles.element_ids + ).squeeze(2) + mapped_grads = vmap(jit(self.shapefn_grad))( particles.reference_loc[:, jnp.newaxis, ...], mapped_coords, ) args = ( - self.nodes.f_int, + f_int, particles.volume, mapped_grads, mapped_nodes, particles.stress, ) - self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - - -@register_pytree_node_class -class Quadrilateral4Node(_Element): - r"""Container for 2D quadrilateral elements with 4 nodes. - - Nodes and elements are numbered as + f_int, _, _, _, _ = lax.fori_loop(0, particles.nparticles, _step, args) + return f_int, "f_int" - 15 +---+---+---+---+ 19 - | 8 | 9 | 10| 11| - 10 +---+---+---+---+ 14 - | 4 | 5 | 6 | 7 | - 5 +---+---+---+---+ 9 - | 0 | 1 | 2 | 3 | - +---+---+---+---+ - 0 1 2 3 4 + @classmethod + def _compute_internal_force( + cls, nf_int, nloc, mapped_node_ids, pxi, pvol, pstress, pids + ): + r"""Update the nodal internal force based on particle mass. - where + The nodal force is updated as a sum of internal forces for + all particles mapped to the node. - + : Nodes - +---+ - | | : An element - +---+ - """ + \[ + (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p) + \] - def __init__( - self, - nelements: int, - total_elements: int, - el_len: float, - constraints: Sequence[Tuple[ArrayLike, Constraint]], - nodes: Optional[Nodes] = None, - concentrated_nodal_forces: Sequence = [], - initialized: Optional[bool] = None, - volume: Optional[ArrayLike] = None, - ) -> None: - """Initialize Linear1D. + where \(\sigma_p\) is the stress at particle \(p\). Parameters ---------- - nelements : int - Number of elements. - total_elements : int - Total number of elements (product of all elements of `nelements`) - el_len : float - Length of each element. - constraints: list - A list of constraints where each element is a tuple of - type `(node_ids, diffmpm.Constraint)`. Here, `node_ids` - correspond to the node IDs where `diffmpm.Constraint` - should be applied. - nodes : Nodes, Optional - Nodes in the element object. - concentrated_nodal_forces: list - A list of `diffmpm.forces.NodalForce`s that are to be - applied. - initialized: bool, None - `True` if the class has been initialized, `None` if not. - This is required like this for using JAX flattening. - volume: ArrayLike - Volume of the elements. + particles: diffmpm.particle._ParticlesState + Particles to map to the nodal values. """ - self.nelements = jnp.asarray(nelements) - self.el_len = jnp.asarray(el_len) - self.total_elements = total_elements - - if nodes is None: - total_nodes = jnp.prod(self.nelements + 1) - coords = jnp.asarray( - list( - itertools.product( - jnp.arange(self.nelements[1] + 1), - jnp.arange(self.nelements[0] + 1), - ) - ) - ) - node_locations = ( - jnp.asarray([coords[:, 1], coords[:, 0]]).T * self.el_len - ).reshape(-1, 1, 2) - self.nodes = Nodes(int(total_nodes), node_locations) - else: - self.nodes = nodes - - self.constraints = constraints - self.concentrated_nodal_forces = concentrated_nodal_forces - if initialized is None: - self.volume = jnp.ones((self.total_elements, 1, 1)) - else: - self.volume = jnp.asarray(volume) - self.initialized = True - - def id_to_node_ids(self, id: ArrayLike): - """Node IDs corresponding to element `id`. - - 3----2 - | | - 0----1 - - Node ids are returned in the order as shown in the figure. - Parameters - ---------- - id : int - Element ID. + @jit + def _step(pid, args): + ( + f_int, + pvol, + mapped_grads, + el_nodes, + pstress, + ) = args + force = jnp.zeros((mapped_grads.shape[1], 1, 2)) + force = force.at[:, 0, 0].set( + mapped_grads[pid][:, 0] * pstress[pid][0] + + mapped_grads[pid][:, 1] * pstress[pid][3] + ) + force = force.at[:, 0, 1].set( + mapped_grads[pid][:, 1] * pstress[pid][1] + + mapped_grads[pid][:, 0] * pstress[pid][3] + ) + update = -pvol[pid] * force + f_int = f_int.at[el_nodes[pid]].add(update) + return ( + f_int, + pvol, + mapped_grads, + el_nodes, + pstress, + ) - Returns - ------- - ArrayLike - Nodal IDs of the element. Shape of returned - array is (4, 1) - """ - lower_left = (id // self.nelements[0]) * ( - self.nelements[0] + 1 - ) + id % self.nelements[0] - result = jnp.asarray( - [ - lower_left, - lower_left + 1, - lower_left + self.nelements[0] + 2, - lower_left + self.nelements[0] + 1, - ] - ) - return result.reshape(4, 1) + def _scan_step(carry, pid): + ( + f_int, + pvol, + mapped_grads, + el_nodes, + pstress, + ) = carry + force = jnp.zeros((mapped_grads.shape[1], 1, 2)) + force = force.at[:, 0, 0].set( + mapped_grads[pid][:, 0] * pstress[pid][0] + + mapped_grads[pid][:, 1] * pstress[pid][3] + ) + force = force.at[:, 0, 1].set( + mapped_grads[pid][:, 1] * pstress[pid][1] + + mapped_grads[pid][:, 0] * pstress[pid][3] + ) + update = -pvol[pid] * force + f_int = f_int.at[el_nodes[pid]].add(update) + return ( + f_int, + pvol, + mapped_grads, + el_nodes, + pstress, + ), pid + + # f_int = self.nodes.f_int.at[:].set(0) + # f_int = elements.nodes.f_int + mapped_nodes = mapped_node_ids.squeeze(-1) + mapped_coords = nloc[mapped_nodes].squeeze(2) + mapped_grads = vmap(jit(cls._shapefn_grad))( + pxi[:, jnp.newaxis, ...], + mapped_coords, + ) + args = ( + nf_int, + pvol, + mapped_grads, + mapped_nodes, + pstress, + ) + # f_int, _, _, _, _ = lax.fori_loop(0, nparticles, _step, args) + final_carry, _ = lax.scan(_scan_step, args, pids) + f_int, _, _, _, _ = final_carry + return f_int - def shapefn(self, xi: ArrayLike): - """Evaluate linear shape function. + # Mapping from particles to nodes (P2G) + @classmethod + def _compute_nodal_mass(cls, mass, pmass, pxi, peids, mapped_node_ids, pids): + r"""Compute the nodal mass based on particle mass. + + The nodal mass is updated as a sum of particle mass for + all particles mapped to the node. + + \[ + (m)_i = \sum_p N_i(x_p) m_p + \] Parameters ---------- - xi : float, array_like - Locations of particles in natural coordinates to evaluate - the function at. Expected shape is (npoints, 1, ndim) - - Returns - ------- - array_like - Evaluated shape function values. The shape of the returned - array will depend on the input shape. For example, in the linear - case, if the input is a scalar, the returned array will be of - the shape `(1, 4, 1)` but if the input is a vector then the output will - be of the shape `(len(x), 4, 1)`. + particles: diffmpm.particle.Particles + Particles to map to the nodal values. """ - xi = jnp.asarray(xi) - if xi.ndim != 3: - raise ValueError( - f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}" - ) - result = jnp.array( - [ - 0.25 * (1 - xi[:, :, 0]) * (1 - xi[:, :, 1]), - 0.25 * (1 + xi[:, :, 0]) * (1 - xi[:, :, 1]), - 0.25 * (1 + xi[:, :, 0]) * (1 + xi[:, :, 1]), - 0.25 * (1 - xi[:, :, 0]) * (1 + xi[:, :, 1]), - ] + + @jit + def _step(pid, args): + pmass, mass, mapped_pos, el_nodes = args + mass = mass.at[el_nodes[pid]].add(pmass[pid] * mapped_pos[pid]) + return pmass, mass, mapped_pos, el_nodes + + def _scan_step(carry, pid): + pmass, mass, mapped_pos, el_nodes = carry + mass = mass.at[el_nodes[pid]].add(pmass[pid] * mapped_pos[pid]) + return (pmass, mass, mapped_pos, el_nodes), pid + + mapped_positions = cls._shapefn(pxi) + mapped_nodes = mapped_node_ids.squeeze(-1) + args = ( + pmass, + mass, + mapped_positions, + mapped_nodes, ) - result = result.transpose(1, 0, 2)[..., jnp.newaxis] - return result + # _, mass, _, _ = lax.fori_loop(0, len(pids), _step, args) + final_carry, _ = lax.scan(_scan_step, args, pids) + _, mass, _, _ = final_carry + return mass - def _shapefn_natural_grad(self, xi: ArrayLike): - """Calculate the gradient of shape function. + def compute_nodal_mass(self, elements, particles: _ParticlesState): + r"""Compute the nodal mass based on particle mass. - This calculation is done in the natural coordinates. + The nodal mass is updated as a sum of particle mass for + all particles mapped to the node. + + \[ + (m)_i = \sum_p N_i(x_p) m_p + \] Parameters ---------- - x : float, array_like - Locations of particles in natural coordinates to evaluate - the function at. + particles: diffmpm.particle.Particles + Particles to map to the nodal values. + """ - Returns - ------- - array_like - Evaluated gradient values of the shape function. The shape of - the returned array will depend on the input shape. For example, - in the linear case, if the input is a scalar, the returned array - will be of the shape `(4, 2)`. + @jit + def _step(pid, args): + pmass, mass, mapped_pos, el_nodes = args + mass = mass.at[el_nodes[pid]].add(pmass[pid] * mapped_pos[pid]) + return pmass, mass, mapped_pos, el_nodes + + # mass = self.nodes.mass.at[:].set(0) + mass = elements.nodes.mass + mapped_positions = self.shapefn(particles.reference_loc) + mapped_nodes = vmap(Partial(self.id_to_node_ids, elements.nelements[0]))( + particles.element_ids + ).squeeze(-1) + args = ( + particles.mass, + mass, + mapped_positions, + mapped_nodes, + ) + _, mass, _, _ = lax.fori_loop(0, particles.nparticles, _step, args) + # TODO: Return state instead of setting + return mass, "mass" + + @classmethod + def _compute_nodal_momentum( + cls, nmom, pmass, pvel, pxi, peids, mapped_node_ids, pids + ): + r"""Compute the nodal mass based on particle mass. + + The nodal mass is updated as a sum of particle mass for + all particles mapped to the node. + + \[ + (mv)_i = \sum_p N_i(x_p) (mv)_p + \] + + Parameters + ---------- + particles: diffmpm.particle.Particles + Particles to map to the nodal values. """ - # result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze() - xi = jnp.asarray(xi) - xi = xi.squeeze() - result = jnp.array( - [ - [-0.25 * (1 - xi[1]), -0.25 * (1 - xi[0])], - [0.25 * (1 - xi[1]), -0.25 * (1 + xi[0])], - [0.25 * (1 + xi[1]), 0.25 * (1 + xi[0])], - [-0.25 * (1 + xi[1]), 0.25 * (1 - xi[0])], - ], + + @jit + def _step(pid, args): + pmom, mom, mapped_pos, el_nodes = args + new_mom = mom.at[el_nodes[pid]].add(mapped_pos[pid] @ pmom[pid]) + return pmom, new_mom, mapped_pos, el_nodes + + def _scan_step(carry, pid): + pmom, mom, mapped_pos, el_nodes = carry + new_mom = mom.at[el_nodes[pid]].add(mapped_pos[pid] @ pmom[pid]) + return (pmom, new_mom, mapped_pos, el_nodes), pid + + mapped_nodes = mapped_node_ids.squeeze(-1) + mapped_positions = cls._shapefn(pxi) + args = ( + pmass * pvel, + nmom, + mapped_positions, + mapped_nodes, ) - return result + _, new_momentum, _, _ = lax.fori_loop(0, len(pids), _step, args) + final_carry, _ = lax.scan(_scan_step, args, pids) + _, new_nmom, _, _ = final_carry + new_nmom = jnp.where(jnp.abs(new_nmom) < 1e-12, 0, new_nmom) + return new_nmom - def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): - """Gradient of shape function in physical coordinates. + def compute_nodal_momentum(self, elements, particles: _ParticlesState): + r"""Compute the nodal mass based on particle mass. + + The nodal mass is updated as a sum of particle mass for + all particles mapped to the node. + + \[ + (mv)_i = \sum_p N_i(x_p) (mv)_p + \] Parameters ---------- - xi : float, array_like - Locations of particles to evaluate in natural coordinates. - Expected shape `(npoints, 1, ndim)`. - coords : array_like - Nodal coordinates to transform by. Expected shape - `(npoints, 1, ndim)` + particles: diffmpm.particle.Particles + Particles to map to the nodal values. + """ - Returns - ------- - array_like - Gradient of the shape function in physical coordinates at `xi` + @jit + def _step(pid, args): + pmom, mom, mapped_pos, el_nodes = args + new_mom = mom.at[el_nodes[pid]].add(mapped_pos[pid] @ pmom[pid]) + return pmom, new_mom, mapped_pos, el_nodes + + # curr_mom = elements.nodes.momentum.at[:].set(0) + curr_mom = elements.nodes.momentum + mapped_positions = self.shapefn(particles.reference_loc) + mapped_nodes = vmap(Partial(self.id_to_node_ids, elements.nelements[0]))( + particles.element_ids + ).squeeze(-1) + args = ( + particles.mass * particles.velocity, + curr_mom, + mapped_positions, + mapped_nodes, + ) + _, new_momentum, _, _ = lax.fori_loop(0, particles.nparticles, _step, args) + new_momentum = jnp.where(jnp.abs(new_momentum) < 1e-12, 0, new_momentum) + # TODO: Return state instead of setting + return new_momentum, "momentum" + + @classmethod + def _compute_nodal_velocity(cls, nmass, nmom, nvel): + """Compute velocity using momentum.""" + velocity = jnp.where( + nmass == 0, + nvel, + nmom / nmass, + ) + velocity = jnp.where( + jnp.abs(velocity) < 1e-12, + 0, + velocity, + ) + # TODO: Return state instead of setting + return velocity + + def compute_velocity(self, elements, particles: _ParticlesState): + """Compute velocity using momentum.""" + velocity = jnp.where( + elements.nodes.mass == 0, + elements.nodes.velocity, + elements.nodes.momentum / elements.nodes.mass, + ) + velocity = jnp.where( + jnp.abs(velocity) < 1e-12, + 0, + velocity, + ) + # TODO: Return state instead of setting + return velocity, "velocity" + + def compute_external_force(self, elements, particles: _ParticlesState): + r"""Update the nodal external force based on particle f_ext. + + The nodal force is updated as a sum of particle external + force for all particles mapped to the node. + + \[ + f_{ext})_i = \sum_p N_i(x_p) f_{ext} + \] + + Parameters + ---------- + particles: diffmpm.particle.Particles + Particles to map to the nodal values. """ - xi = jnp.asarray(xi) - coords = jnp.asarray(coords) - if xi.ndim != 3: - raise ValueError( - f"`x` should be of size (npoints, 1, ndim); found {xi.shape}" - ) - grad_sf = self._shapefn_natural_grad(xi) - _jacobian = grad_sf.T @ coords.squeeze() - result = grad_sf @ jnp.linalg.inv(_jacobian).T - return result + @jit + def _step(pid, args): + f_ext, pf_ext, mapped_pos, el_nodes = args + f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ pf_ext[pid]) + return f_ext, pf_ext, mapped_pos, el_nodes - def set_particle_element_ids(self, particles: Particles): - """Set the element IDs for the particles. + # f_ext = elements.nodes.f_ext.at[:].set(0) + f_ext = elements.nodes.f_ext + mapped_positions = self.shapefn(particles.reference_loc) + mapped_nodes = vmap(Partial(self.id_to_node_ids, elements.nelements[0]))( + particles.element_ids + ).squeeze(-1) + args = ( + f_ext, + particles.f_ext, + mapped_positions, + mapped_nodes, + ) + f_ext, _, _, _ = lax.fori_loop(0, particles.nparticles, _step, args) + # TODO: Return state instead of setting + return f_ext, "f_ext" - If the particle doesn't lie between the boundaries of any - element, it sets the element index to -1. + @classmethod + def _compute_external_force(cls, f_ext, pf_ext, pxi, pids, mapped_node_ids): + r"""Update the nodal external force based on particle f_ext. + + The nodal force is updated as a sum of particle external + force for all particles mapped to the node. + + \[ + f_{ext})_i = \sum_p N_i(x_p) f_{ext} + \] + + Parameters + ---------- + particles: diffmpm.particle.Particles + Particles to map to the nodal values. """ @jit - def f(x): - xidl = (self.nodes.loc[:, :, 0] <= x[0, 0]).nonzero( - size=len(self.nodes.loc), fill_value=-1 - )[0] - yidl = (self.nodes.loc[:, :, 1] <= x[0, 1]).nonzero( - size=len(self.nodes.loc), fill_value=-1 - )[0] - lower_left = jnp.where(jnp.isin(xidl, yidl), xidl, -1).max() - element_id = lower_left - lower_left // (self.nelements[0] + 1) - return element_id + def _step(pid, args): + f_ext, pf_ext, mapped_pos, el_nodes = args + f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ pf_ext[pid]) + return f_ext, pf_ext, mapped_pos, el_nodes - ids = vmap(f)(particles.loc) - particles.element_ids = ids + def _scan_step(carry, pid): + f_ext, pf_ext, mapped_pos, el_nodes = carry + f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ pf_ext[pid]) + return (f_ext, pf_ext, mapped_pos, el_nodes), pid - def compute_internal_force(self, particles: Particles): - r"""Update the nodal internal force based on particle mass. + mapped_positions = cls._shapefn(pxi) + mapped_nodes = mapped_node_ids.squeeze(-1) + args = ( + f_ext, + pf_ext, + mapped_positions, + mapped_nodes, + ) + # f_ext, _, _, _ = lax.fori_loop(0, len(pids), _step, args) + final_carry, _ = lax.scan(_scan_step, args, pids) + f_ext, _, _, _ = final_carry + return f_ext - The nodal force is updated as a sum of internal forces for - all particles mapped to the node. + def compute_body_force( + self, elements, particles: _ParticlesState, gravity: ArrayLike + ): + r"""Update the nodal external force based on particle mass. + + The nodal force is updated as a sum of particle body + force for all particles mapped to th \[ - (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p) + (f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g \] - where \(\sigma_p\) is the stress at particle \(p\). + Parameters + ---------- + particles: diffmpm.particle._ParticlesState + Particles to map to the nodal values. + """ + + @jit + def _step(pid, args): + f_ext, pmass, mapped_pos, el_nodes, gravity = args + f_ext = f_ext.at[el_nodes[pid]].add( + mapped_pos[pid] @ (pmass[pid] * gravity) + ) + return f_ext, pmass, mapped_pos, el_nodes, gravity + + mapped_positions = self.shapefn(particles.reference_loc) + mapped_nodes = vmap(Partial(self.id_to_node_ids, elements.nelements[0]))( + particles.element_ids + ).squeeze(-1) + args = ( + elements.nodes.f_ext, + particles.mass, + mapped_positions, + mapped_nodes, + gravity, + ) + f_ext, _, _, _, _ = lax.fori_loop(0, particles.nparticles, _step, args) + # TODO: Return state instead of setting + return f_ext, "f_ext" + + @classmethod + def _compute_body_force( + cls, nf_ext, pmass, pxi, mapped_node_ids, pids, gravity: ArrayLike + ): + r"""Update the nodal external force based on particle mass. + + The nodal force is updated as a sum of particle body + force for all particles mapped to th + + \[ + (f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g + \] Parameters ---------- - particles: diffmpm.particle.Particles + particles: diffmpm.particle._ParticlesState Particles to map to the nodal values. """ + @jit def _step(pid, args): - ( - f_int, - pvol, - mapped_grads, - el_nodes, - pstress, - ) = args - force = jnp.zeros((mapped_grads.shape[1], 1, 2)) - force = force.at[:, 0, 0].set( - mapped_grads[pid][:, 0] * pstress[pid][0] - + mapped_grads[pid][:, 1] * pstress[pid][3] + f_ext, pmass, mapped_pos, el_nodes, gravity = args + f_ext = f_ext.at[el_nodes[pid]].add( + mapped_pos[pid] @ (pmass[pid] * gravity) ) - force = force.at[:, 0, 1].set( - mapped_grads[pid][:, 1] * pstress[pid][1] - + mapped_grads[pid][:, 0] * pstress[pid][3] + return f_ext, pmass, mapped_pos, el_nodes, gravity + + def _scan_step(carry, pid): + f_ext, pmass, mapped_pos, el_nodes, gravity = args + f_ext = f_ext.at[el_nodes[pid]].add( + mapped_pos[pid] @ (pmass[pid] * gravity) ) - update = -pvol[pid] * force - f_int = f_int.at[el_nodes[pid]].add(update) - return ( - f_int, - pvol, - mapped_grads, - el_nodes, - pstress, + return (f_ext, pmass, mapped_pos, el_nodes, gravity), pid + + mapped_positions = cls._shapefn(pxi) + mapped_nodes = mapped_node_ids.squeeze(-1) + args = ( + nf_ext, + pmass, + mapped_positions, + mapped_nodes, + gravity, + ) + # f_ext, _, _, _, _ = lax.fori_loop(0, nparticles, _step, args) + final_carry, _ = lax.scan(_scan_step, args, pids) + f_ext, _, _, _, _ = final_carry + return f_ext + + def apply_concentrated_nodal_forces( + self, elements, particles: _ParticlesState, curr_time: float + ): + """Apply concentrated nodal forces. + + Parameters + ---------- + particles: _ParticlesState + Particles in the simulation. + curr_time: float + Current time in the simulation. + """ + + def _func(cnf, *, f_ext): + factor = cnf.function.value(curr_time) + f_ext = f_ext.at[cnf.node_ids, 0, cnf.dir].add(factor * cnf.force) + return f_ext + + if elements.concentrated_nodal_forces: + partial_func = partial(_func, f_ext=elements.nodes.f_ext) + _out = tree_map( + partial_func, + elements.concentrated_nodal_forces, + is_leaf=lambda x: isinstance(x, NodalForce), ) - self.nodes.f_int = self.nodes.f_int.at[:].set(0) - mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1) - mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2) - mapped_grads = vmap(self.shapefn_grad)( - particles.reference_loc[:, jnp.newaxis, ...], - mapped_coords, + def _f(x, *, orig): + return jnp.where(x == orig, 0, x) + + # This assumes that the nodal forces are not overlapping, i.e. + # no node will be acted by 2 forces in the same direction. + _step_1 = tree_map(partial(_f, orig=elements.nodes.f_ext), _out) + _step_2 = tree_reduce(lambda x, y: x + y, _step_1) + f_ext = jnp.where(_step_2 == 0, elements.nodes.f_ext, _step_2) + # TODO: Return state instead of setting + return f_ext, "f_ext" + + @classmethod + def _apply_concentrated_nodal_forces( + self, nf_ext, concentrated_forces, curr_time: float + ): + """Apply concentrated nodal forces. + + Parameters + ---------- + particles: _ParticlesState + Particles in the simulation. + curr_time: float + Current time in the simulation. + """ + + def _func(cnf, f_ext): + factor = cnf.function.value(curr_time) + f_ext = f_ext.at[cnf.node_ids, 0, cnf.dir].add(factor * cnf.force) + return f_ext + + _out = tree_map( + _func, + concentrated_forces, + [nf_ext] * len(concentrated_forces), + is_leaf=lambda x: isinstance(x, NodalForce) or isinstance(x, Array), ) + + def _f(x, *, orig): + return jnp.where(x == orig, 0, x) + + # This assumes that the nodal forces are not overlapping, i.e. + # no node will be acted by 2 forces in the same direction. + _step_1 = tree_map(partial(_f, orig=nf_ext), _out) + _step_2 = tree_reduce(lambda x, y: x + y, _step_1) + f_ext = jnp.where(_step_2 == 0, nf_ext, _step_2) + return f_ext + + @classmethod + def _apply_particle_traction_forces( + cls, pxi, mapped_node_ids, nf_ext, ptraction, pids + ): + """Apply concentrated nodal forces. + + Parameters + ---------- + particles: Particles + Particles in the simulation. + """ + + @jit + def _step(pid, args): + f_ext, ptraction, mapped_pos, el_nodes = args + f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid]) + return f_ext, ptraction, mapped_pos, el_nodes + + def _scan_step(carry, pid): + f_ext, ptraction, mapped_pos, el_nodes = carry + f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid]) + return (f_ext, ptraction, mapped_pos, el_nodes), pid + + mapped_positions = cls._shapefn(pxi) + mapped_nodes = mapped_node_ids.squeeze(-1) args = ( - self.nodes.f_int, - particles.volume, - mapped_grads, + nf_ext, + ptraction, + mapped_positions, mapped_nodes, - particles.stress, ) - self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args) + # f_ext, _, _, _ = lax.fori_loop(0, nparticles, _step, args) + final_carry, _ = lax.scan(_scan_step, args, pids) + f_ext, _, _, _ = final_carry + return f_ext + + def apply_particle_traction_forces(self, elements, particles: _ParticlesState): + """Apply concentrated nodal forces. + + Parameters + ---------- + particles: Particles + Particles in the simulation. + """ + + @jit + def _step(pid, args): + f_ext, ptraction, mapped_pos, el_nodes = args + f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid]) + return f_ext, ptraction, mapped_pos, el_nodes + + mapped_positions = self.shapefn(particles.reference_loc) + mapped_nodes = vmap(Partial(self.id_to_node_ids, elements.nelements[0]))( + particles.element_ids + ).squeeze(-1) + args = ( + elements.nodes.f_ext, + particles.traction, + mapped_positions, + mapped_nodes, + ) + f_ext, _, _, _ = lax.fori_loop(0, particles.nparticles, _step, args) + # TODO: Return state instead of setting + return f_ext, "f_ext" + + def update_nodal_acceleration( + self, elements, particles: _ParticlesState, dt: float, *args + ): + """Update the nodal momentum based on total force on nodes.""" + total_force = ( + elements.nodes.f_int + elements.nodes.f_ext + elements.nodes.f_damp + ) + acceleration = elements.nodes.acceleration.at[:].set( + jnp.nan_to_num(jnp.divide(total_force, elements.nodes.mass)) + ) + if elements.constraints: + acceleration = self._apply_boundary_constraints_acc(elements, acceleration) + acceleration = jnp.where( + jnp.abs(acceleration) < 1e-12, + 0, + acceleration, + ) + return acceleration, "acceleration" + + @classmethod + def _update_nodal_acceleration( + cls, + total_force, + nacc, + nmass, + constraints, + tol, + ): + """Update the nodal momentum based on total force on nodes.""" + acceleration = jnp.nan_to_num(jnp.divide(total_force, nmass)) + if constraints: + acceleration = cls._apply_boundary_constraints_acc( + constraints, acceleration + ) + acceleration = jnp.where( + jnp.abs(acceleration) < tol, + 0, + acceleration, + ) + return acceleration + + def update_nodal_velocity( + self, elements, particles: _ParticlesState, dt: float, *args + ): + """Update the nodal momentum based on total force on nodes.""" + total_force = ( + elements.nodes.f_int + elements.nodes.f_ext + elements.nodes.f_damp + ) + acceleration = jnp.nan_to_num(jnp.divide(total_force, elements.nodes.mass)) + + velocity = elements.nodes.velocity + acceleration * dt + if elements.constraints: + velocity = self._apply_boundary_constraints_vel(elements, velocity) + velocity = jnp.where( + jnp.abs(velocity) < 1e-12, + 0, + velocity, + ) + return velocity, "velocity" + + @classmethod + def _update_nodal_velocity(cls, total_force, nvel, nmass, constraints, dt, tol): + """Update the nodal momentum based on total force on nodes.""" + acceleration = jnp.nan_to_num(jnp.divide(total_force, nmass)) + + velocity = nvel + acceleration * dt + if constraints: + velocity = cls._apply_boundary_constraints_vel(constraints, velocity) + velocity = jnp.where( + jnp.abs(velocity) < tol, + 0, + velocity, + ) + return velocity + + def update_nodal_momentum( + self, elements, particles: _ParticlesState, dt: float, *args + ): + """Update the nodal momentum based on total force on nodes.""" + momentum = elements.nodes.momentum.at[:].set( + elements.nodes.mass * elements.nodes.velocity + ) + momentum = jnp.where( + jnp.abs(momentum) < 1e-12, + 0, + momentum, + ) + return momentum, "momentum" + + @classmethod + def _update_nodal_momentum(cls, nmass, nvel, constraints, tol): + """Update the nodal momentum based on total force on nodes.""" + momentum = nmass * nvel + momentum = jnp.where( + jnp.abs(momentum) < tol, + 0, + momentum, + ) + return momentum + + @classmethod + def _apply_boundary_constraints_vel(cls, constraints, vel, *args): + """Apply boundary conditions for nodal velocity.""" + + # This assumes that the constraints don't have overlapping + # conditions. In case it does, only the first constraint will + # be applied. + def _func2(constraint, *, orig): + return constraint[1].apply_vel(orig, constraint[0]) + + partial_func = partial(_func2, orig=vel) + _out = tree_map( + partial_func, constraints, is_leaf=lambda x: isinstance(x, tuple) + ) + + def _f(x, *, orig): + return jnp.where(x == orig, jnp.nan, x) + + _pf = partial(_f, orig=vel) + _step_1 = tree_map(_pf, _out) + vel = tree_reduce( + lambda x, y: jnp.where(jnp.isnan(y), x, y), + [vel, _step_1], + ) + return vel + + @classmethod + def _apply_boundary_constraints_mom(cls, constraints, mom, mass, *args): + """Apply boundary conditions for nodal momentum.""" + + # This assumes that the constraints don't have overlapping + # conditions. In case it does, only the first constraint will + # be applied. + def _func2(constraint, *, mom, mass): + return constraint[1].apply_mom(mom, mass, constraint[0]) + + partial_func = partial(_func2, mom=mom, mass=mass) + _out = tree_map( + partial_func, constraints, is_leaf=lambda x: isinstance(x, tuple) + ) + + def _f(x, *, orig): + return jnp.where(x == orig, jnp.nan, x) + + _pf = partial(_f, orig=mom) + _step_1 = tree_map(_pf, _out) + mom = tree_reduce( + lambda x, y: jnp.where(jnp.isnan(y), x, y), + [mom, _step_1], + ) + return mom + + @classmethod + def _apply_boundary_constraints_acc(cls, constraints, orig, *args): + """Apply boundary conditions for nodal acceleration.""" + + # This assumes that the constraints don't have overlapping + # conditions. In case it does, only the first constraint will + # be applied. + def _func2(constraint, *, orig): + return constraint[1].apply_acc(orig, constraint[0]) + + partial_func = partial(_func2, orig=orig) + _out = tree_map( + partial_func, constraints, is_leaf=lambda x: isinstance(x, tuple) + ) + + def _f(x, *, orig): + return jnp.where(x == orig, jnp.nan, x) + + _pf = partial(_f, orig=orig) + _step_1 = tree_map(_pf, _out) + acc = tree_reduce( + lambda x, y: jnp.where(jnp.isnan(y), x, y), + [orig, _step_1], + ) + return acc + + @classmethod + def _apply_boundary_constraints(cls, nvel, nmom, nacc, nmass, constraints, *args): + if constraints: + vel = cls._apply_boundary_constraints_vel(constraints, nvel, *args) + mom = cls._apply_boundary_constraints_mom(constraints, nmom, nmass, *args) + acc = cls._apply_boundary_constraints_acc(constraints, nacc, *args) + return vel, mom, acc + + def apply_boundary_constraints(self, elements, *args): + if elements.constraints: + vel = self._apply_boundary_constraints_vel( + elements, elements.nodes.velocity, *args + ) + mom = self._apply_boundary_constraints_mom( + elements, elements.nodes.momentum, elements.nodes.mass, *args + ) + acc = self._apply_boundary_constraints_acc( + elements, elements.nodes.acceleration, *args + ) + + return elements.nodes.replace(velocity=vel, momentum=mom, acceleration=acc) + + @classmethod + def _compute_volume(cls, el_len, evol): + """Compute volume of all elements.""" + a = c = el_len[1] + b = d = el_len[0] + p = q = jnp.sqrt(a**2 + b**2) + vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2) + volume = jnp.ones_like(evol) * vol + return volume - def compute_volume(self, *args): + def compute_volume(self, elements, *args): """Compute volume of all elements.""" - a = c = self.el_len[1] - b = d = self.el_len[0] + a = c = elements.el_len[1] + b = d = elements.el_len[0] p = q = jnp.sqrt(a**2 + b**2) vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2) - self.volume = self.volume.at[:].set(vol) + volume = jnp.ones_like(elements.volume) * vol + return elements.replace(volume=volume) diff --git a/diffmpm/explicit.py b/diffmpm/explicit.py new file mode 100644 index 0000000..800c8ef --- /dev/null +++ b/diffmpm/explicit.py @@ -0,0 +1,527 @@ +import abc +from dataclasses import dataclass +from functools import partial +from typing import Callable, NamedTuple, Optional, Sequence + +from jax import Array, vmap +from jax.tree_util import tree_map, tree_reduce, tree_structure, tree_transpose + +from diffmpm.element import Quad4N, _ElementsState +from diffmpm.forces import ParticleTraction +from diffmpm.node import _reset_node_props +from diffmpm.particle import ( + _assign_traction, + _compute_particle_volume, + _compute_strain, + _compute_stress, + _get_natural_coords, + _ParticlesState, + _update_particle_position_velocity, + _update_particle_volume, + _zero_traction, +) + + +class MeshState(NamedTuple): + elements: _ElementsState + particles: _ParticlesState + particle_tractions: Sequence[ParticleTraction] + + @classmethod + def _apply_traction_on_particles( + cls, particles, particle_tractions, curr_time: float + ): + """Apply tractions on particles. + + Parameters + ---------- + curr_time: float + Current time in the simulation. + """ + pass + + +class Solver(abc.ABC): + @abc.abstractmethod + def init_state(*args, **kwargs): + pass + + @abc.abstractmethod + def update(*args, **kwargs): + pass + + def run(*args, **kwargs): + pass + + +def _reduce_attr(state_1, state_2, *, orig): + new_val = state_1 + state_2 - orig + return new_val + + +def _tree_transpose(pytree): + _out = tree_transpose( + tree_structure([0 for e in pytree]), tree_structure(pytree[0]), pytree + ) + return _out + + +@dataclass(eq=False) +class ExplicitSolver(Solver): + el_type: Quad4N + tol: float + dt: float + sim_steps: int + out_steps: int + out_dir: str + gravity: Array + scheme: str = "usf" + velocity_update: Optional[bool] = False + writer_func: Optional[Callable] = None + + def init_state(self, config): + elements = config["elements"] + particles = config["particles"] + new_peids: list = tree_map( + self.el_type._get_particles_element_ids, + particles, + [elements] * len(particles), + is_leaf=lambda x: isinstance(x, _ParticlesState) + or isinstance(x, _ElementsState), + ) + new_evol = self.el_type._compute_volume(elements.el_len, elements.volume) + + temp_pprops = tree_map( + _compute_particle_volume, + new_peids, + [self.el_type.total_elements] * len(particles), + [new_evol] * len(particles), + [p.volume for p in particles], + [p.size for p in particles], + [p.mass for p in particles], + [p.density for p in particles], + ) + new_pprops = _tree_transpose(temp_pprops) + elements = elements.replace(volume=new_evol) + particles = [ + p.replace( + element_ids=new_peids, + mass=new_pprops["mass"][i], + size=new_pprops["size"][i], + volume=new_pprops["volume"][i], + ) + for i, p in enumerate(particles) + ] + return MeshState( + elements=elements, + particles=particles, + particle_tractions=config["particle_surface_traction"], + ) + + def update(self, state: MeshState, step, *args, **kwargs): + _elements, _particles = state.elements, state.particles + # Nodal properties that are to be reset at the beginning of the + # update step. + new_nmass, new_nmom, new_nfint, new_nfext, new_nfdamp = _reset_node_props( + _elements.nodes + ) + + # New Element IDs for particles in each particle set. + # This is a `tree_map` function so that each particle set gets + # new EIDs. + new_peids: list = tree_map( + self.el_type._get_particles_element_ids, + _particles, + [_elements] * len(_particles), + is_leaf=lambda x: isinstance(x, _ParticlesState) + or isinstance(x, _ElementsState), + ) + map_fn = vmap(self.el_type._get_mapped_nodes, (0, None)) + new_pmapped_node_ids = tree_map( + map_fn, + new_peids, + [_elements.nelements[0]] * len(_particles), + is_leaf=lambda x: isinstance(x, _ParticlesState) + or isinstance(x, _ElementsState), + ) + + # New natural coordinates of the particles. + # This is again a `tree_map`-ed function for each particle set. + # The signature of the function is + # `get_natural_coords(particles.loc, elements)` + # Attributes required: + # - Element IDs of the particles + # - Nodal coords of the elements corresponding to the above element ids. + def _leaf_fn(x): + return isinstance(x, _ParticlesState) or isinstance(x, Array) + + new_pxi = tree_map( + _get_natural_coords, + _particles, + new_pmapped_node_ids, + [_elements.nodes.loc] * len(_particles), + is_leaf=_leaf_fn, + ) + + # New nodal mass based on particle mass + # Required: + # - Nodal mass (new_nmass) + # - Particle natural coords (new_pxi) + # - Mapped nodes + # - Particle element IDs (new_peids) (list) + # new_nmass = self.el_type._compute_nodal_mass(new_nmass, new_pxi, new_peids) + temp_nmass = tree_map( + self.el_type._compute_nodal_mass, + [new_nmass] * len(_particles), + [p.mass for p in _particles], + new_pxi, + new_peids, + new_pmapped_node_ids, + [p.ids for p in _particles], + is_leaf=lambda x: isinstance(x, _ParticlesState) + or isinstance(x, Array) + or isinstance(x, int), + ) + partial_reduce_attr = partial(_reduce_attr, orig=new_nmass) + new_nmass = tree_reduce(partial_reduce_attr, temp_nmass) + + # New nodal momentum based on particle momentum + # Required: + # - Nodal momentum (new_nmom) + # - Particle natural coords (new_pxi) + # - Mapped nodes + # - Particle element IDs (new_peids) (list) + # new_nmom = _compute_nodal_momentum(new_nmom, new_xi, new_peids) + temp_nmom = tree_map( + self.el_type._compute_nodal_momentum, + [new_nmom] * len(_particles), + [p.mass for p in _particles], + [p.velocity for p in _particles], + new_pxi, + new_peids, + new_pmapped_node_ids, + [p.ids for p in _particles], + is_leaf=lambda x: isinstance(x, _ParticlesState) + or isinstance(x, Array) + or isinstance(x, int), + ) + partial_reduce_attr = partial(_reduce_attr, orig=new_nmom) + new_nmom = tree_reduce(partial_reduce_attr, temp_nmom) + + # New nodal velocity based on nodal momentum + # Required: + # - Nodal mass (new_nmass) + # - Current nodal velocity (_elements.nodes.velocity) + # - Nodal momentum (new_nmom) + # - Tolerance (tol) + # new_nvel = _compute_nodal_velocity( + # new_nmass, new_nmom, _elements.nodes.velocity, self.tol + # ) + temp_nvel = tree_map( + self.el_type._compute_nodal_velocity, + new_nmass, + new_nmom, + _elements.nodes.velocity, + ) + partial_reduce_attr = partial(_reduce_attr, orig=_elements.nodes.velocity) + new_nvel = tree_reduce(partial_reduce_attr, temp_nvel) + + # Apply boundary constraints on newly calculated props. + # Since nodal acceleration hasn't been updated yet, we + # use the current states nodal acceleration. + # Required: + # - Constraints (_elements.constraints) + # - Nodal velocity (new_nvel) + # - Nodal momentum (new_nmom) + # - Nodal acceleration (new_nacc) + new_nvel, new_nmom, new_nacc = self.el_type._apply_boundary_constraints( + new_nvel, + new_nmom, + _elements.nodes.acceleration, + new_nmass, + _elements.constraints, + ) + + if self.scheme == "usf": + # Compute particle strain + # Required: + # - Mapped node ids + # - Mapped node locs + # - Mapped node vels + # - Particle natural coords (new_pxi) + # - Current particle strains (_particles.strain) + # - Particles locs + # - Particle volumetric strains (_particles.volumetric_strain_centroid) + _temp = tree_map( + _compute_strain, + [p.strain for p in _particles], + new_pxi, + [p.loc for p in _particles], + [p.volumetric_strain_centroid for p in _particles], + [p.ids for p in _particles], + new_pmapped_node_ids, + [_elements.nodes.loc] * len(_particles), + [new_nvel] * len(_particles), + [self.el_type] * len(_particles), + [self.dt] * len(_particles), + is_leaf=lambda x: isinstance(x, _ParticlesState) + or isinstance(x, Quad4N) + or isinstance(x, Array) + or isinstance(x, float), + ) + + _strains = _tree_transpose(_temp) + new_pstrain_rate = _strains["strain_rate"] + new_pdstrain = _strains["dstrain"] + new_pstrain = _strains["strain"] + new_pdvolumetric_strain = _strains["dvolumetric_strain"] + new_pvolumetric_strain_centroid = _strains["volumetric_strain_centroid"] + + # Compute new particle volumes based on updated strain + # Required: + # - Particle volumetric dstrain (new_pdvolumetric_strain) + # new_pvol, new_pdensity = _update_particle_volume(new_pdvolumetric_strain) + _temp = tree_map( + _update_particle_volume, + [p.volume for p in _particles], + [p.density for p in _particles], + new_pdvolumetric_strain, + ) + + new_pvol, new_pdensity = _tree_transpose(_temp) + # Compute particle stress + # Required: + # - Particle state since different materials need different + # particle properties to calculate stress. + # new_pstress = _compute_stress(_particles) + new_pstress = tree_map( + _compute_stress, + [p.stress for p in _particles], + new_pstrain, + new_pdstrain, + [p.material for p in _particles], + is_leaf=lambda x: isinstance(x, _ParticlesState), + ) + + # Compute external forces on nodes + # Required: + # - Nodal external forces (new_nfext) + # - Particle natural coords (new_pxi) + # - Mapped Node ids + # new_nfext = self.el_type._compute_external_force(new_nfext, new_pxi, *args) + temp_nfext = tree_map( + self.el_type._compute_external_force, + [new_nfext] * len(_particles), + [p.f_ext for p in _particles], + new_pxi, + [p.ids for p in _particles], + new_pmapped_node_ids, + ) + partial_reduce_attr = partial(_reduce_attr, orig=new_nfext) + new_nfext = tree_reduce(partial_reduce_attr, temp_nfext) + + # Compute body forces on nodes + # Required: + # - Nodal external forces (new_nfext) + # - Particle natural coords (new_pxi) + # - Mapped Node ids + # - gravity + # new_nfext = _compute_body_force(new_nfext, new_pxi, gravity, *args) + temp_nfext = tree_map( + self.el_type._compute_body_force, + [new_nfext] * len(_particles), + [p.mass for p in _particles], + new_pxi, + new_pmapped_node_ids, + [p.ids for p in _particles], + [self.gravity] * len(_particles), + ) + partial_reduce_attr = partial(_reduce_attr, orig=new_nfext) + new_nfext = tree_reduce(partial_reduce_attr, temp_nfext) + + # TODO: Apply traction on particles + new_ptraction = tree_map( + _zero_traction, + [p.traction for p in _particles], + is_leaf=lambda x: isinstance(x, _ParticlesState), + ) + + def func(ptract_, ptraction, pvol, psize, *, curr_time): + def f(ptraction, pvol, psize, *, ptract_, traction_val): + return _assign_traction( + ptraction, pvol, psize, ptract_.pids, ptract_.dir, traction_val + ) + + factor = ptract_.function.value(curr_time) + traction_val = factor * ptract_.traction + partial_f = partial(f, ptract_=ptract_, traction_val=traction_val) + traction_sets = tree_map( + partial_f, + ptraction, + pvol, + psize, + is_leaf=lambda x: isinstance(x, _ParticlesState), + ) + return tuple(traction_sets) + + partial_func = partial( + func, ptract_=state.particle_tractions, curr_time=step * self.dt + ) + if state.particle_tractions: + _out = tree_map( + partial_func, + state.particle_tractions, + new_ptraction, + new_pvol, + [p.size for p in _particles], + is_leaf=lambda x: isinstance(x, ParticleTraction) + or isinstance(x, Array), + ) + breakpoint() + _temp = _tree_transpose(_out) + new_ptraction = tree_reduce( + lambda x, y: x + y, _temp, is_leaf=lambda x: isinstance(x, list) + ) + + # breakpoint() + temp_nfext = tree_map( + self.el_type._apply_particle_traction_forces, + new_pxi, + new_pmapped_node_ids, + [new_nfext] * len(_particles), + new_ptraction, + [p.ids for p in _particles], + ) + partial_reduce_attr = partial(_reduce_attr, orig=new_nfext) + new_nfext = tree_reduce(partial_reduce_attr, temp_nfext) + + # Apply nodal concentrated forces + # Required: + # - Concentrated forces on nodes (_elements.concentrated_nodal_forces) + # - Nodal external forces (new_nfext) + # - current time + if _elements.concentrated_nodal_forces: + new_nfext = self.el_type._apply_concentrated_nodal_forces( + new_nfext, _elements.concentrated_nodal_forces, self.dt * step + ) + # Compute internal forces on nodes + # Required: + # - Mapped node ids + # - Mapped node locs + # - Nodal internal forces (new_nfint) + # - Particle natural coords (new_pxi) + # - Particle volume (new_pvol) + # - Particle stress (new_pstress) + temp_nfint = tree_map( + self.el_type._compute_internal_force, + [new_nfint] * len(_particles), + [_elements.nodes.loc] * len(_particles), + new_pmapped_node_ids, + new_pxi, + new_pvol, + new_pstress, + [p.ids for p in _particles], + ) + partial_reduce_attr = partial(_reduce_attr, orig=new_nfint) + new_nfint = tree_reduce(partial_reduce_attr, temp_nfint) + + if self.scheme == "usl": + # TODO: Calculate strains and stresses + pass + + # Update nodal acceleration based on nodal forces + # Required: + # - Nodal forces (new_nfint, new_nfext, new_nfdamp) + # - Nodal mass + # - Constraints (_elements.constraints) + # - Tolerance (self.tol) + total_force = new_nfint + new_nfext + new_nfdamp + new_nacc = self.el_type._update_nodal_acceleration( + total_force, new_nacc, new_nmass, _elements.constraints, self.tol + ) + # Update nodal acceleration based on nodal forces + # Required: + # - Nodal forces (new_nfint, new_nfext, new_nfdamp) + # - Nodal mass + # - Constraints (_elements.constraints) + # - Tolerance (self.tol) + new_nvel = self.el_type._update_nodal_velocity( + total_force, new_nvel, new_nmass, _elements.constraints, self.dt, self.tol + ) + + # Update nodal momentum based on nodal forces + # Required: + # - Nodal mass (new_nmass) + # - Nodal velocity (new_nvel) + # - Tolerance (self.tol) + new_nmom = self.el_type._update_nodal_momentum( + new_nmass, new_nvel, _elements.constraints, self.tol + ) + + # Update particle position and velocity + # Required: + # - Particle natural coords (new_pxi) + # - Timestep (self.dt) + # - self.velocity_update + # - Mapped node ids + # - Mapped node vels + # - Mapped node accelerations + # - Particle locs + _temp_new_vals = tree_map( + _update_particle_position_velocity, + [self.el_type] * len(_particles), + [p.loc for p in _particles], + [p.velocity for p in _particles], + [p.momentum for p in _particles], + [p.mass for p in _particles], + new_pxi, + new_pmapped_node_ids, + [new_nvel] * len(_particles), + [new_nacc] * len(_particles), + [self.velocity_update] * len(_particles), + [self.dt] * len(_particles), + is_leaf=lambda x: isinstance(x, Array) + or isinstance(x, Quad4N) + or isinstance(x, float) + or isinstance(x, bool), + ) + _new_vals = _tree_transpose(_temp_new_vals) + new_pvel = _new_vals["velocity"] + new_ploc = _new_vals["loc"] + new_pmom = _new_vals["momentum"] + + new_node_state = _elements.nodes.replace( + velocity=new_nvel, + acceleration=new_nacc, + mass=new_nmass, + momentum=new_nmom, + f_int=new_nfint, + f_ext=new_nfext, + f_damp=new_nfdamp, + ) + new_element_state = _elements.replace(nodes=new_node_state) + new_particle_states = [ + _p.replace( + loc=new_ploc[i], + element_ids=new_peids[i], + density=new_pdensity[i], + volume=new_pvol[i], + velocity=new_pvel[i], + momentum=new_pmom[i], + strain=new_pstrain[i], + stress=new_pstress[i], + strain_rate=new_pstrain_rate[i], + dstrain=new_pdstrain[i], + reference_loc=new_pxi[i], + dvolumetric_strain=new_pdvolumetric_strain[i], + volumetric_strain_centroid=new_pvolumetric_strain_centroid[i], + ) + for i, _p in enumerate(_particles) + ] + + new_mesh_state = MeshState( + elements=new_element_state, + particles=new_particle_states, + particle_tractions=state.particle_tractions, + ) + return new_mesh_state diff --git a/diffmpm/functions.py b/diffmpm/functions.py index 90b55c4..ef6b3e1 100644 --- a/diffmpm/functions.py +++ b/diffmpm/functions.py @@ -22,7 +22,7 @@ def value(self, x): return 1.0 def tree_flatten(self): - return ((), (self.id)) + return ((), (self.id,)) @classmethod def tree_unflatten(cls, aux_data, children): diff --git a/diffmpm/io.py b/diffmpm/io.py index d6e4573..92330f8 100644 --- a/diffmpm/io.py +++ b/diffmpm/io.py @@ -1,21 +1,23 @@ import json import tomllib as tl -from collections import namedtuple import jax.numpy as jnp from diffmpm import element as mpel -from diffmpm import material as mpmat -from diffmpm import mesh as mpmesh +from diffmpm import materials as mpmat + +# from diffmpm import mesh as mpmesh from diffmpm.constraint import Constraint from diffmpm.forces import NodalForce, ParticleTraction from diffmpm.functions import Linear, Unit -from diffmpm.particle import Particles +from diffmpm.particle import _ParticlesState, init_particle_state +from pathlib import Path class Config: def __init__(self, filepath): - self._filepath = filepath + self._filepath = Path(filepath).absolute() + self._basedir = self._filepath.parent self.parsed_config = {} self.parse() @@ -23,7 +25,9 @@ def parse(self): with open(self._filepath, "rb") as f: self._fileconfig = tl.load(f) - self.entity_sets = json.load(open(self._fileconfig["mesh"]["entity_sets"])) + self.entity_sets = json.load( + open(self._basedir.joinpath(self._fileconfig["mesh"]["entity_sets"])) + ) self._parse_meta(self._fileconfig) self._parse_output(self._fileconfig) self._parse_materials(self._fileconfig) @@ -32,7 +36,8 @@ def parse(self): self._parse_math_functions(self._fileconfig) self._parse_external_loading(self._fileconfig) mesh = self._parse_mesh(self._fileconfig) - return mesh + # return mesh + return self.parsed_config def _get_node_set_ids(self, set_ids): all_ids = [] @@ -59,20 +64,22 @@ def _parse_materials(self, config): materials = [] for mat_config in config["materials"]: mat_type = mat_config.pop("type") - mat_cls = getattr(mpmat, mat_type) - mat = mat_cls(mat_config) - materials.append(mat) + # mat_cls = getattr(mpmat, mat_type) + # mat = mat_cls(mat_config) + mat_fun = getattr(mpmat, f"init_{mat_type}") + materials.append(mat_fun(mat_config)) self.parsed_config["materials"] = materials def _parse_particles(self, config): particle_sets = [] for pset_config in config["particles"]: pmat = self.parsed_config["materials"][pset_config["material_id"]] - with open(pset_config["file"], "r") as f: + with open(self._basedir.joinpath(pset_config["file"]), "r") as f: ploc = jnp.asarray(json.load(f)) peids = jnp.zeros(ploc.shape[0], dtype=jnp.int32) - pset = Particles(ploc, pmat, peids) - pset.velocity = pset.velocity.at[:].set(pset_config["init_velocity"]) + pset = init_particle_state( + ploc, pmat, peids, init_vel=jnp.asarray(pset_config["init_velocity"]) + ) particle_sets.append(pset) self.parsed_config["particles"] = particle_sets @@ -140,20 +147,24 @@ def _parse_external_loading(self, config): def _parse_mesh(self, config): element_cls = getattr(mpel, config["mesh"]["element"]) - mesh_cls = getattr(mpmesh, f"Mesh{config['meta']['dimension']}D") - - constraints = [ - ( - self._get_node_set_ids(c["nset_ids"]), - Constraint(c["dir"], c["velocity"]), - ) - for c in config["mesh"]["constraints"] - ] + # mesh_cls = getattr(mpmesh, f"Mesh{config['meta']['dimension']}D") + + constraints = [] + if "constraints" in config["mesh"]: + constraints = [ + ( + self._get_node_set_ids(c["nset_ids"]), + Constraint(c["dir"], c["velocity"]), + ) + for c in config["mesh"]["constraints"] + ] if config["mesh"]["type"] == "generator": - elements = element_cls( + total_elements = jnp.prod(jnp.array(config["mesh"]["nelements"])) + elementor = element_cls(total_elements=total_elements) + elements = elementor.init_state( config["mesh"]["nelements"], - jnp.prod(jnp.array(config["mesh"]["nelements"])), + total_elements, config["mesh"]["element_length"], constraints, concentrated_nodal_forces=self.parsed_config["external_loading"][ @@ -166,5 +177,6 @@ def _parse_mesh(self, config): ) self.parsed_config["elements"] = elements - mesh = mesh_cls(self.parsed_config) - return mesh + self.parsed_config["elementor"] = elementor + # mesh = mesh_cls(self.parsed_config) + # return mesh diff --git a/diffmpm/material.py b/diffmpm/material.py deleted file mode 100644 index 09230d4..0000000 --- a/diffmpm/material.py +++ /dev/null @@ -1,131 +0,0 @@ -import abc -from typing import Tuple - -import jax.numpy as jnp -from jax.tree_util import register_pytree_node_class - - -class Material(abc.ABC): - """Base material class.""" - - _props: Tuple[str, ...] - - def __init__(self, material_properties): - """Initialize material properties. - - Parameters - ---------- - material_properties: dict - A key-value map for various material properties. - """ - self.properties = material_properties - - # @abc.abstractmethod - def tree_flatten(self): - """Flatten this class as PyTree Node.""" - return (tuple(), self.properties) - - # @abc.abstractmethod - @classmethod - def tree_unflatten(cls, aux_data, children): - """Unflatten this class as PyTree Node.""" - del children - return cls(aux_data) - - @abc.abstractmethod - def __repr__(self): - """Repr for Material class.""" - ... - - @abc.abstractmethod - def compute_stress(self): - """Compute stress for the material.""" - ... - - def validate_props(self, material_properties): - for key in self._props: - if key not in material_properties: - raise KeyError( - f"'{key}' should be present in `material_properties` " - f"for {self.__class__.__name__} materials." - ) - - -@register_pytree_node_class -class LinearElastic(Material): - """Linear Elastic Material.""" - - _props = ("density", "youngs_modulus", "poisson_ratio") - - def __init__(self, material_properties): - """Create a Linear Elastic material. - - Parameters - ---------- - material_properties: dict - Dictionary with material properties. For linear elastic - materials, 'density' and 'youngs_modulus' are required keys. - """ - self.validate_props(material_properties) - youngs_modulus = material_properties["youngs_modulus"] - poisson_ratio = material_properties["poisson_ratio"] - density = material_properties["density"] - bulk_modulus = youngs_modulus / (3 * (1 - 2 * poisson_ratio)) - constrained_modulus = ( - youngs_modulus - * (1 - poisson_ratio) - / ((1 + poisson_ratio) * (1 - 2 * poisson_ratio)) - ) - shear_modulus = youngs_modulus / (2 * (1 + poisson_ratio)) - # Wave velocities - vp = jnp.sqrt(constrained_modulus / density) - vs = jnp.sqrt(shear_modulus / density) - self.properties = { - **material_properties, - "bulk_modulus": bulk_modulus, - "pwave_velocity": vp, - "swave_velocity": vs, - } - self._compute_elastic_tensor() - - def __repr__(self): - return f"LinearElastic(props={self.properties})" - - def _compute_elastic_tensor(self): - G = self.properties["youngs_modulus"] / ( - 2 * (1 + self.properties["poisson_ratio"]) - ) - - a1 = self.properties["bulk_modulus"] + (4 * G / 3) - a2 = self.properties["bulk_modulus"] - (2 * G / 3) - - self.de = jnp.array( - [ - [a1, a2, a2, 0, 0, 0], - [a2, a1, a2, 0, 0, 0], - [a2, a2, a1, 0, 0, 0], - [0, 0, 0, G, 0, 0], - [0, 0, 0, 0, G, 0], - [0, 0, 0, 0, 0, G], - ] - ) - - def compute_stress(self, dstrain): - """Compute material stress.""" - dstress = self.de @ dstrain - return dstress - - -@register_pytree_node_class -class SimpleMaterial(Material): - _props = ("E", "density") - - def __init__(self, material_properties): - self.validate_props(material_properties) - self.properties = material_properties - - def __repr__(self): - return f"SimpleMaterial(props={self.properties})" - - def compute_stress(self, dstrain): - return dstrain * self.properties["E"] diff --git a/diffmpm/materials/__init__.py b/diffmpm/materials/__init__.py new file mode 100644 index 0000000..a7bc501 --- /dev/null +++ b/diffmpm/materials/__init__.py @@ -0,0 +1,8 @@ +from diffmpm.materials._base import _Material + +# from diffmpm.materials.linear_elastic import LinearElastic +from diffmpm.materials.linear_elastic import _LinearElasticState, init_linear_elastic +from diffmpm.materials.newtonian import Newtonian + +# from diffmpm.materials.simple import SimpleMaterial +from diffmpm.materials.simple import _SimpleMaterialState, init_simple diff --git a/diffmpm/materials/_base.py b/diffmpm/materials/_base.py new file mode 100644 index 0000000..896b206 --- /dev/null +++ b/diffmpm/materials/_base.py @@ -0,0 +1,49 @@ +import abc +from typing import Tuple + + +class _Material(abc.ABC): + """Base material class.""" + + _props: Tuple[str, ...] + properties: dict + + def __init__(self, material_properties): + """Initialize material properties. + + Parameters + ---------- + material_properties: dict + A key-value map for various material properties. + """ + self.properties = material_properties + + # @abc.abstractmethod + def tree_flatten(self): + """Flatten this class as PyTree Node.""" + return (tuple(), self.properties) + + # @abc.abstractmethod + @classmethod + def tree_unflatten(cls, aux_data, children): + """Unflatten this class as PyTree Node.""" + del children + return cls(aux_data) + + @abc.abstractmethod + def __repr__(self): + """Repr for Material class.""" + ... + + @abc.abstractmethod + def compute_stress(self, particles): + """Compute stress for the material.""" + ... + + def validate_props(self, material_properties): + for key in self._props: + if key not in material_properties: + raise KeyError( + f"'{key}' should be present in `material_properties` " + f"for {self.__class__.__name__} materials." + ) diff --git a/diffmpm/materials/linear_elastic.py b/diffmpm/materials/linear_elastic.py new file mode 100644 index 0000000..6cca66f --- /dev/null +++ b/diffmpm/materials/linear_elastic.py @@ -0,0 +1,137 @@ +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class + +from ._base import _Material + +import chex + + +@chex.dataclass() +class _LinearElasticState: + id: int + state_vars: tuple + density: float + youngs_modulus: float + poisson_ratio: float + bulk_modulus: float + pwave_velocity: float + swave_velocity: float + de: chex.ArrayDevice + + def compute_stress(self, strain, dstrain): + """Compute material stress.""" + dstress = self.de @ dstrain + return dstress + + +def init_linear_elastic(material_properties): + """Create a Linear Elastic material. + + Parameters + ---------- + material_properties: dict + Dictionary with material properties. For linear elastic + materials, 'density' and 'youngs_modulus' are required keys. + """ + state_vars = () + youngs_modulus = material_properties["youngs_modulus"] + poisson_ratio = material_properties["poisson_ratio"] + density = material_properties["density"] + bulk_modulus = youngs_modulus / (3 * (1 - 2 * poisson_ratio)) + constrained_modulus = ( + youngs_modulus + * (1 - poisson_ratio) + / ((1 + poisson_ratio) * (1 - 2 * poisson_ratio)) + ) + shear_modulus = youngs_modulus / (2 * (1 + poisson_ratio)) + # Wave velocities + vp = jnp.sqrt(constrained_modulus / density) + vs = jnp.sqrt(shear_modulus / density) + properties = { + **material_properties, + "bulk_modulus": bulk_modulus, + "pwave_velocity": vp, + "swave_velocity": vs, + } + G = youngs_modulus / (2 * (1 + poisson_ratio)) + + a1 = bulk_modulus + (4 * G / 3) + a2 = bulk_modulus - (2 * G / 3) + + de = jnp.array( + [ + [a1, a2, a2, 0, 0, 0], + [a2, a1, a2, 0, 0, 0], + [a2, a2, a1, 0, 0, 0], + [0, 0, 0, G, 0, 0], + [0, 0, 0, 0, G, 0], + [0, 0, 0, 0, 0, G], + ] + ) + return _LinearElasticState(**properties, de=de, state_vars=state_vars) + + +# @register_pytree_node_class +# class LinearElastic(_Material): +# """Linear Elastic Material.""" + +# _props = ("density", "youngs_modulus", "poisson_ratio") +# state_vars = () + +# def __init__(self, material_properties): +# """Create a Linear Elastic material. + +# Parameters +# ---------- +# material_properties: dict +# Dictionary with material properties. For linear elastic +# materials, 'density' and 'youngs_modulus' are required keys. +# """ +# self.validate_props(material_properties) +# youngs_modulus = material_properties["youngs_modulus"] +# poisson_ratio = material_properties["poisson_ratio"] +# density = material_properties["density"] +# bulk_modulus = youngs_modulus / (3 * (1 - 2 * poisson_ratio)) +# constrained_modulus = ( +# youngs_modulus +# * (1 - poisson_ratio) +# / ((1 + poisson_ratio) * (1 - 2 * poisson_ratio)) +# ) +# shear_modulus = youngs_modulus / (2 * (1 + poisson_ratio)) +# # Wave velocities +# vp = jnp.sqrt(constrained_modulus / density) +# vs = jnp.sqrt(shear_modulus / density) +# self.properties = { +# **material_properties, +# "bulk_modulus": bulk_modulus, +# "pwave_velocity": vp, +# "swave_velocity": vs, +# } +# self._compute_elastic_tensor() + +# def __repr__(self): +# return f"LinearElastic(props={self.properties})" + +# def _compute_elastic_tensor(self): +# G = self.properties["youngs_modulus"] / ( +# 2 * (1 + self.properties["poisson_ratio"]) +# ) + +# a1 = self.properties["bulk_modulus"] + (4 * G / 3) +# a2 = self.properties["bulk_modulus"] - (2 * G / 3) + +# self.de = jnp.array( +# [ +# [a1, a2, a2, 0, 0, 0], +# [a2, a1, a2, 0, 0, 0], +# [a2, a2, a1, 0, 0, 0], +# [0, 0, 0, G, 0, 0], +# [0, 0, 0, 0, G, 0], +# [0, 0, 0, 0, 0, G], +# ] +# ) + +# def compute_stress(self, particles): +# """Compute material stress.""" +# dstress = self.de @ particles.dstrain +# return dstress diff --git a/diffmpm/materials/newtonian.py b/diffmpm/materials/newtonian.py new file mode 100644 index 0000000..558832f --- /dev/null +++ b/diffmpm/materials/newtonian.py @@ -0,0 +1,114 @@ +import jax.numpy as jnp +from jax import Array, lax +from jax.typing import ArrayLike + +from ._base import _Material + + +class Newtonian(_Material): + """Newtonian fluid material model.""" + + _props = ("density", "bulk_modulus", "dynamic_viscosity") + state_vars = ("pressure",) + + def __init__(self, material_properties: dict): + """Create a Newtonian material. + + Parameters + ---------- + material_properties: dict + Dictionary with material properties. For newtonian + materials, `density`, `bulk_modulus` and `dynamic_viscosity` + are required keys. + """ + self.validate_props(material_properties) + compressibility = 1 + + if material_properties.get("incompressible", False): + compressibility = 0 + + self.properties = { + **material_properties, + "compressibility": compressibility, + } + + def __repr__(self): + return f"Newtonian(props={self.properties})" + + def initialize_state_variables(self, nparticles: int) -> dict: + """Return initial state variables dictionary. + + Parameters + ---------- + nparticles : int + Number of particles being simulated with this material. + + Returns + ------- + dict + Dictionary of state variables initialized with values + decided by material type. + """ + state_vars_dict = {var: jnp.zeros((nparticles, 1)) for var in self.state_vars} + return state_vars_dict + + def _thermodynamic_pressure(self, volumetric_strain: ArrayLike) -> Array: + return -self.properties["bulk_modulus"] * volumetric_strain + + def compute_stress(self, particles): + """Compute material stress.""" + ndim = particles.loc.shape[-1] + if ndim not in {2, 3}: + raise ValueError(f"Cannot compute stress for {ndim}-d Newotonian material.") + volumetric_strain_rate = ( + particles.strain_rate[:, 0] + particles.strain_rate[:, 1] + ) + particles.state_vars["pressure"] = ( + particles.state_vars["pressure"] + .at[:] + .add( + self.properties["compressibility"] + * self._thermodynamic_pressure(particles.dvolumetric_strain) + ) + ) + + volumetric_stress_component = self.properties["compressibility"] * ( + -particles.state_vars["pressure"] + - (2 * self.properties["dynamic_viscosity"] * volumetric_strain_rate / 3) + ) + + stress = jnp.zeros_like(particles.stress) + stress = stress.at[:, 0].set( + volumetric_stress_component + + 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 0] + ) + stress = stress.at[:, 1].set( + volumetric_stress_component + + 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 1] + ) + + extra_component_2 = lax.select( + ndim == 3, + 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 2], + jnp.zeros_like(particles.strain_rate[:, 2]), + ) + stress = stress.at[:, 2].set(volumetric_stress_component + extra_component_2) + + stress = stress.at[:, 3].set( + self.properties["dynamic_viscosity"] * particles.strain_rate[:, 3] + ) + + component_4 = lax.select( + ndim == 3, + self.properties["dynamic_viscosity"] * particles.strain_rate[:, 4], + jnp.zeros_like(particles.strain_rate[:, 4]), + ) + stress = stress.at[:, 4].set(component_4) + component_5 = lax.select( + ndim == 3, + self.properties["dynamic_viscosity"] * particles.strain_rate[:, 5], + jnp.zeros_like(particles.strain_rate[:, 5]), + ) + stress = stress.at[:, 5].set(component_5) + + return stress diff --git a/diffmpm/materials/simple.py b/diffmpm/materials/simple.py new file mode 100644 index 0000000..142ead6 --- /dev/null +++ b/diffmpm/materials/simple.py @@ -0,0 +1,36 @@ +from jax.tree_util import register_pytree_node_class + +from ._base import _Material + +import chex + + +@chex.dataclass() +class _SimpleMaterialState: + id: int + E: float + density: float + state_vars: () + + def compute_stress(self, strain, dstrain): + return dstrain * self.E + + +def init_simple(material_properties): + return _SimpleMaterialState(**material_properties, state_vars=()) + + +@register_pytree_node_class +class SimpleMaterial(_Material): + _props = ("E", "density") + state_vars = () + + def __init__(self, material_properties): + self.validate_props(material_properties) + self.properties = material_properties + + def __repr__(self): + return f"SimpleMaterial(props={self.properties})" + + def compute_stress(self, strain, dstrain): + return dstrain * self.properties["E"] diff --git a/diffmpm/mesh.py b/diffmpm/mesh.py index 23bc6de..dde457b 100644 --- a/diffmpm/mesh.py +++ b/diffmpm/mesh.py @@ -1,11 +1,17 @@ import abc +from functools import partial from typing import Callable, Sequence, Tuple import jax.numpy as jnp -from jax.tree_util import register_pytree_node_class +from jax import lax, jit, tree_util +from jax.tree_util import register_pytree_node_class, tree_map -from diffmpm.element import _Element -from diffmpm.particle import Particles +from diffmpm.element import _ElementsState +import diffmpm.element as dfel +from diffmpm.node import _NodesState +from diffmpm.particle import _ParticlesState +import diffmpm.particle as dpart +from diffmpm.forces import ParticleTraction __all__ = ["_MeshBase", "Mesh1D", "Mesh2D"] @@ -23,11 +29,12 @@ class _MeshBase(abc.ABC): def __init__(self, config: dict): """Initialize mesh using configuration.""" - self.particles: Sequence[Particles] = config["particles"] - self.elements: _Element = config["elements"] + self.particles: Sequence[_ParticlesState] = config["particles"] + self.elements: _ElementState = config["elements"] self.particle_tractions = config["particle_surface_traction"] + self.elementor = config["elementor"] - # TODO: Convert to using jax directives for loop + # TODO: Change to allow called functions to return outputs def apply_on_elements(self, function: str, args: Tuple = ()): """Apply a given function to elements. @@ -38,24 +45,69 @@ def apply_on_elements(self, function: str, args: Tuple = ()): args: tuple Parameters to be passed to the function. """ - f = getattr(self.elements, function) - for particle_set in self.particles: - f(particle_set, *args) + f = getattr(self.elementor, function) + + def _func(particles, *, func, elements, fargs): + return func(elements, particles, *fargs) + + partial_func = partial(_func, func=f, elements=self.elements, fargs=args) + _out = tree_map( + partial_func, + self.particles, + is_leaf=lambda x: isinstance(x, _ParticlesState), + ) + if function == "set_particle_element_ids": + self.particles = _out + elif function == "apply_boundary_constraints": + self.elements.nodes = _out[0] + elif function == "compute_volume": + self.elements = _out[0] + elif _out[0] is not None: + _temp = tree_util.tree_transpose( + tree_util.tree_structure([0 for e in _out]), + tree_util.tree_structure(_out[0]), + _out, + ) + + def reduce_attr(state_1, state_2, *, orig): + new_val = state_1 + state_2 - orig + return new_val + + attr = _temp[1][0] + p_reduce_attr = partial( + reduce_attr, + attr=attr, + orig=getattr(self.elements.nodes, attr), + ) + new_val = tree_util.tree_reduce( + p_reduce_attr, _temp[0], is_leaf=lambda x: isinstance(x, _NodesState) + ) + self.elements.nodes = self.elements.nodes.replace(**{attr: new_val}) - # TODO: Convert to using jax directives for loop def apply_on_particles(self, function: str, args: Tuple = ()): """Apply a given function to particles. Parameters ---------- function: str - A string corresponding to a function name in `Particles`. + A string corresponding to a function name in `_ParticlesState`. args: tuple Parameters to be passed to the function. """ - for particle_set in self.particles: - f = getattr(particle_set, function) - f(self.elements, *args) + + def _func(particles, *, elements, fname, fargs): + f = getattr(dpart, fname) + return f(particles, elements, *fargs) + + partial_func = partial( + _func, elements=self.elements, fname=function, fargs=(self.elementor, *args) + ) + new_states = tree_map( + partial_func, + self.particles, + is_leaf=lambda x: isinstance(x, _ParticlesState), + ) + self.particles = new_states def apply_traction_on_particles(self, curr_time: float): """Apply tractions on particles. @@ -66,19 +118,48 @@ def apply_traction_on_particles(self, curr_time: float): Current time in the simulation. """ self.apply_on_particles("zero_traction") - for ptraction in self.particle_tractions: - factor = ptraction.function.value(curr_time) - traction_val = factor * ptraction.traction - for i, pset_id in enumerate(ptraction.pset): - self.particles[pset_id].assign_traction( - ptraction.pids, ptraction.dir, traction_val + + def func(ptraction, *, particle_sets): + def f(particles, *, ptraction, traction_val): + return dpart.assign_traction( + particles, ptraction.pids, ptraction.dir, traction_val ) - self.apply_on_elements("apply_particle_traction_forces") + factor = ptraction.function.value(curr_time) + traction_val = factor * ptraction.traction + partial_f = partial(f, ptraction=ptraction, traction_val=traction_val) + traction_sets = tree_map( + partial_f, + particle_sets, + is_leaf=lambda x: isinstance(x, _ParticlesState), + ) + return tuple(traction_sets) + + partial_func = partial(func, particle_sets=self.particles) + if self.particle_tractions: + _out = tree_map( + partial_func, + self.particle_tractions, + is_leaf=lambda x: isinstance(x, ParticleTraction), + ) + _temp = tree_util.tree_transpose( + tree_util.tree_structure([0 for e in _out]), + tree_util.tree_structure(_out[0]), + _out, + ) + tractions_ = tree_util.tree_reduce( + lambda x, y: x + y, _temp, is_leaf=lambda x: isinstance(x, list) + ) + self.particles = [ + pset.replace(traction=traction) + for pset, traction in zip(self.particles, tractions_) + ] + + self.apply_on_elements("apply_particle_traction_forces") def tree_flatten(self): children = (self.particles, self.elements) - aux_data = self.particle_tractions + aux_data = (self.elementor, self.particle_tractions) return (children, aux_data) @classmethod @@ -87,7 +168,8 @@ def tree_unflatten(cls, aux_data, children): { "particles": children[0], "elements": children[1], - "particle_surface_traction": aux_data, + "elementor": aux_data[0], + "particle_surface_traction": aux_data[1], } ) diff --git a/diffmpm/mpm.py b/diffmpm/mpm.py new file mode 100644 index 0000000..b06eaee --- /dev/null +++ b/diffmpm/mpm.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import diffmpm.writers as writers +from diffmpm.io import Config +from diffmpm.solver import MPMExplicit + + +class MPM: + def __init__(self, filepath): + self._config = Config(filepath) + mesh = self._config.parse() + out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath( + self._config.parsed_config["meta"]["title"], + ) + + write_format = self._config.parsed_config["output"].get("format", None) + if write_format is None or write_format.lower() == "none": + writer_func = None + elif write_format == "npz": + writer_func = writers.NPZWriter().write + else: + raise ValueError(f"Specified output format not supported: {write_format}") + + if self._config.parsed_config["meta"]["type"] == "MPMExplicit": + self.solver = MPMExplicit( + mesh, + self._config.parsed_config["meta"]["dt"], + velocity_update=self._config.parsed_config["meta"]["velocity_update"], + sim_steps=self._config.parsed_config["meta"]["nsteps"], + out_steps=self._config.parsed_config["output"]["step_frequency"], + out_dir=out_dir, + writer_func=writer_func, + ) + else: + raise ValueError("Wrong type of solver specified.") + + def solve(self): + """Solve the MPM simulation using JIT solver.""" + arrays = self.solver.solve_jit( + self._config.parsed_config["external_loading"]["gravity"], + ) + return arrays diff --git a/diffmpm/node.py b/diffmpm/node.py index 46e2a60..815851b 100644 --- a/diffmpm/node.py +++ b/diffmpm/node.py @@ -1,125 +1,70 @@ -from typing import Optional, Sized, Tuple - import jax.numpy as jnp -from jax.tree_util import register_pytree_node_class -from jax.typing import ArrayLike +import chex + + +@chex.dataclass(frozen=True) +class _NodesState: + nnodes: int + loc: chex.ArrayDevice + velocity: chex.ArrayDevice + acceleration: chex.ArrayDevice + mass: chex.ArrayDevice + momentum: chex.ArrayDevice + f_int: chex.ArrayDevice + f_ext: chex.ArrayDevice + f_damp: chex.ArrayDevice -@register_pytree_node_class -class Nodes(Sized): - """Nodes container class. - Keeps track of all values required for nodal points. +def init_node_state( + nnodes: int, + loc: chex.ArrayDevice, +): + """Initialize container for Nodes. - Attributes + Parameters ---------- nnodes : int Number of nodes stored. loc : ArrayLike - Location of all the nodes. - velocity : array_like - Velocity of all the nodes. - mass : ArrayLike - Mass of all the nodes. - momentum : array_like - Momentum of all the nodes. - f_int : ArrayLike - Internal forces on all the nodes. - f_ext : ArrayLike - External forces present on all the nodes. - f_damp : ArrayLike - Damping forces on the nodes. + Locations of all the nodes. Expected shape (nnodes, 1, ndim) + initialized: bool + `False` if node property arrays like mass need to be initialized. + If `True`, they are set to values from `data`. + data: tuple + Tuple of length 7 that sets arrays for mass, density, volume, + and forces. Mainly used by JAX while unflattening. """ - - def __init__( - self, - nnodes: int, - loc: ArrayLike, - initialized: Optional[bool] = None, - data: Tuple[ArrayLike, ...] = tuple(), - ): - """Initialize container for Nodes. - - Parameters - ---------- - nnodes : int - Number of nodes stored. - loc : ArrayLike - Locations of all the nodes. Expected shape (nnodes, 1, ndim) - initialized: bool - `False` if node property arrays like mass need to be initialized. - If `True`, they are set to values from `data`. - data: tuple - Tuple of length 7 that sets arrays for mass, density, volume, - and forces. Mainly used by JAX while unflattening. - """ - self.nnodes = nnodes - loc = jnp.asarray(loc, dtype=jnp.float32) - if loc.ndim != 3: - raise ValueError( - f"`loc` should be of size (nnodes, 1, ndim); found {loc.shape}" - ) - self.loc = loc - - if initialized is None: - self.velocity = jnp.zeros_like(self.loc, dtype=jnp.float32) - self.acceleration = jnp.zeros_like(self.loc, dtype=jnp.float32) - self.mass = jnp.ones((self.loc.shape[0], 1, 1), dtype=jnp.float32) - self.momentum = jnp.zeros_like(self.loc, dtype=jnp.float32) - self.f_int = jnp.zeros_like(self.loc, dtype=jnp.float32) - self.f_ext = jnp.zeros_like(self.loc, dtype=jnp.float32) - self.f_damp = jnp.zeros_like(self.loc, dtype=jnp.float32) - else: - ( - self.velocity, - self.acceleration, - self.mass, - self.momentum, - self.f_int, - self.f_ext, - self.f_damp, - ) = data # type: ignore - self.initialized = True - - def tree_flatten(self): - """Flatten class as Pytree type.""" - children = ( - self.loc, - self.initialized, - self.velocity, - self.acceleration, - self.mass, - self.momentum, - self.f_int, - self.f_ext, - self.f_damp, + loc = jnp.asarray(loc, dtype=jnp.float32) + if loc.ndim != 3 or nnodes != loc.shape[0]: + raise ValueError( + f"`loc` should be of size (nnodes, 1, ndim); found {loc.shape}" ) - aux_data = (self.nnodes,) - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """Unflatten class from Pytree type.""" - return cls(aux_data[0], children[0], initialized=children[1], data=children[2:]) - - def reset_values(self): - """Reset nodal parameter values except location.""" - self.velocity = self.velocity.at[:].set(0) - self.acceleration = self.velocity.at[:].set(0) - self.mass = self.mass.at[:].set(0) - self.momentum = self.momentum.at[:].set(0) - self.f_int = self.f_int.at[:].set(0) - self.f_ext = self.f_ext.at[:].set(0) - self.f_damp = self.f_damp.at[:].set(0) - def __len__(self): - """Set length of class as number of nodes.""" - return self.nnodes + velocity = jnp.zeros_like(loc, dtype=jnp.float32) + acceleration = jnp.zeros_like(loc, dtype=jnp.float32) + mass = jnp.zeros((loc.shape[0], 1, 1), dtype=jnp.float32) + momentum = jnp.zeros_like(loc, dtype=jnp.float32) + f_int = jnp.zeros_like(loc, dtype=jnp.float32) + f_ext = jnp.zeros_like(loc, dtype=jnp.float32) + f_damp = jnp.zeros_like(loc, dtype=jnp.float32) + return _NodesState( + nnodes=nnodes, + loc=loc, + velocity=velocity, + acceleration=acceleration, + mass=mass, + momentum=momentum, + f_int=f_int, + f_ext=f_ext, + f_damp=f_damp, + ) - def __repr__(self): - """Repr containing number of nodes.""" - return f"Nodes(n={self.nnodes})" - def get_total_force(self): - """Calculate total force on the nodes.""" - return self.f_int + self.f_ext + self.f_damp +def _reset_node_props(state: _NodesState): + mass = jnp.zeros_like(state.mass) + momentum = jnp.zeros_like(state.momentum) + f_int = jnp.zeros_like(state.f_int) + f_ext = jnp.zeros_like(state.f_ext) + f_damp = jnp.zeros_like(state.f_damp) + return mass, momentum, f_int, f_ext, f_damp diff --git a/diffmpm/particle.py b/diffmpm/particle.py index 1bb3d70..b8cac78 100644 --- a/diffmpm/particle.py +++ b/diffmpm/particle.py @@ -1,347 +1,581 @@ from typing import Optional, Sized, Tuple import jax.numpy as jnp -from jax import lax, vmap -from jax.tree_util import register_pytree_node_class +from jax import jit, lax, vmap +from jax.tree_util import register_pytree_node_class, Partial from jax.typing import ArrayLike -from diffmpm.element import _Element -from diffmpm.material import Material - - -@register_pytree_node_class -class Particles(Sized): - """Container class for a set of particles.""" - - def __init__( - self, - loc: ArrayLike, - material: Material, - element_ids: ArrayLike, - initialized: Optional[bool] = None, - data: Optional[Tuple[ArrayLike, ...]] = None, - ): - """Initialize a container of particles. - - Parameters - ---------- - loc: ArrayLike - Location of the particles. Expected shape (nparticles, 1, ndim) - material: diffmpm.material.Material - Type of material for the set of particles. - element_ids: ArrayLike - The element ids that the particles belong to. This contains - information that will make sense only with the information of - the mesh that is being considered. - initialized: bool - `False` if particle property arrays like mass need to be initialized. - If `True`, they are set to values from `data`. - data: tuple - Tuple of length 13 that sets arrays for mass, density, volume, - velocity, acceleration, momentum, strain, stress, strain_rate, - dstrain, f_ext, reference_loc and volumetric_strain_centroid. - """ - self.material = material - self.element_ids = element_ids - loc = jnp.asarray(loc, dtype=jnp.float32) - if loc.ndim != 3: - raise ValueError( - f"`loc` should be of size (nparticles, 1, ndim); " f"found {loc.shape}" - ) - self.loc = loc - - if initialized is None: - self.mass = jnp.ones((self.loc.shape[0], 1, 1)) - self.density = ( - jnp.ones_like(self.mass) * self.material.properties["density"] - ) - self.volume = jnp.ones_like(self.mass) - self.size = jnp.zeros_like(self.loc) - self.velocity = jnp.zeros_like(self.loc) - self.acceleration = jnp.zeros_like(self.loc) - self.momentum = jnp.zeros_like(self.loc) - self.strain = jnp.zeros((self.loc.shape[0], 6, 1)) - self.stress = jnp.zeros((self.loc.shape[0], 6, 1)) - self.strain_rate = jnp.zeros((self.loc.shape[0], 6, 1)) - self.dstrain = jnp.zeros((self.loc.shape[0], 6, 1)) - self.f_ext = jnp.zeros_like(self.loc) - self.traction = jnp.zeros_like(self.loc) - self.reference_loc = jnp.zeros_like(self.loc) - self.dvolumetric_strain = jnp.zeros((self.loc.shape[0], 1)) - self.volumetric_strain_centroid = jnp.zeros((self.loc.shape[0], 1)) - else: - ( - self.mass, - self.density, - self.volume, - self.size, - self.velocity, - self.acceleration, - self.momentum, - self.strain, - self.stress, - self.strain_rate, - self.dstrain, - self.f_ext, - self.traction, - self.reference_loc, - self.dvolumetric_strain, - self.volumetric_strain_centroid, - ) = data # type: ignore - self.initialized = True - - def tree_flatten(self): - """Flatten class as Pytree type.""" - children = ( - self.loc, - self.element_ids, - self.initialized, - self.mass, - self.density, - self.volume, - self.size, - self.velocity, - self.acceleration, - self.momentum, - self.strain, - self.stress, - self.strain_rate, - self.dstrain, - self.f_ext, - self.traction, - self.reference_loc, - self.dvolumetric_strain, - self.volumetric_strain_centroid, - ) - aux_data = (self.material,) - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """Unflatten class from Pytree type.""" - return cls( - children[0], - aux_data[0], - children[1], - initialized=children[2], - data=children[3:], +from diffmpm.element import _ElementsState +from diffmpm.materials import _Material + +import chex + + +@chex.dataclass(frozen=True) +class _ParticlesState: + ids: chex.ArrayDevice + nparticles: int + loc: chex.ArrayDevice + material: _Material + element_ids: chex.ArrayDevice + mass: chex.ArrayDevice + density: chex.ArrayDevice + volume: chex.ArrayDevice + size: chex.ArrayDevice + velocity: chex.ArrayDevice + acceleration: chex.ArrayDevice + momentum: chex.ArrayDevice + strain: chex.ArrayDevice + stress: chex.ArrayDevice + strain_rate: chex.ArrayDevice + dstrain: chex.ArrayDevice + f_ext: chex.ArrayDevice + traction: chex.ArrayDevice + reference_loc: chex.ArrayDevice + dvolumetric_strain: chex.ArrayDevice + volumetric_strain_centroid: chex.ArrayDevice + state_vars: dict + + +def init_particle_state( + loc: chex.ArrayDevice, + material: _Material, + element_ids: chex.ArrayDevice, + init_vel: chex.ArrayDevice = 0, +): + """Initialize a container of particles. + + Parameters + ---------- + loc: ArrayLike + Location of the particles. Expected shape (nparticles, 1, ndim) + material: diffmpm.materials._Material + Type of material for the set of particles. + element_ids: ArrayLike + The element ids that the particles belong to. This contains + information that will make sense only with the information of + the mesh that is being considered. + initialized: bool + `False` if particle property arrays like mass need to be initialized. + If `True`, they are set to values from `data`. + data: tuple + Tuple of length 13 that sets arrays for mass, density, volume, + velocity, acceleration, momentum, strain, stress, strain_rate, + dstrain, f_ext, reference_loc and volumetric_strain_centroid. + """ + loc = jnp.asarray(loc, dtype=jnp.float32) + if loc.ndim != 3: + raise ValueError( + f"`loc` should be of size (nparticles, 1, ndim); " f"found {loc.shape}" ) - def __len__(self) -> int: - """Set length of the class as number of particles.""" - return self.loc.shape[0] - - def __repr__(self) -> str: - """Informative repr showing number of particles.""" - return f"Particles(nparticles={len(self)})" - - def set_mass_volume(self, m: ArrayLike): - """Set particle mass. - - Parameters - ---------- - m: float, array_like - Mass to be set for particles. If scalar, mass for all - particles is set to this value. - """ - m = jnp.asarray(m) - if jnp.isscalar(m): - self.mass = jnp.ones_like(self.loc) * m - elif m.shape == self.mass.shape: - self.mass = m - else: - raise ValueError( - f"Incompatible shapes. Expected {self.mass.shape}, " f"found {m.shape}." - ) - self.volume = jnp.divide(self.mass, self.material.properties["density"]) - - def compute_volume(self, elements: _Element, total_elements: int): - """Compute volume of all particles. - - Parameters - ---------- - elements: diffmpm._Element - Elements that the particles are present in, and are used to - compute the particles' volumes. - total_elements: int - Total elements present in `elements`. - """ - particles_per_element = jnp.bincount( - self.element_ids, length=elements.total_elements - ) - vol = ( - elements.volume.squeeze((1, 2))[self.element_ids] # type: ignore - / particles_per_element[self.element_ids] - ) - self.volume = self.volume.at[:, 0, 0].set(vol) - self.size = self.size.at[:].set(self.volume ** (1 / self.size.shape[-1])) - self.mass = self.mass.at[:, 0, 0].set(vol * self.density.squeeze()) - - def update_natural_coords(self, elements: _Element): - r"""Update natural coordinates for the particles. - - Whenever the particles' physical coordinates change, their - natural coordinates need to be updated. This function updates - the natural coordinates of the particles based on the element - a particle is a part of. The update formula is - - \[ - \xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e) - \] - - where \(x_i^e\) are the nodal coordinates of the element the - particle is in. If a particle is not in any element - (element_id = -1), its natural coordinate is set to 0. - - Parameters - ---------- - elements: diffmpm.element._Element - Elements based on which to update the natural coordinates - of the particles. - """ - t = vmap(elements.id_to_node_loc)(self.element_ids) - xi_coords = (self.loc - (t[:, 0, ...] + t[:, 2, ...]) / 2) * ( - 2 / (t[:, 2, ...] - t[:, 0, ...]) - ) - self.reference_loc = xi_coords - - def update_position_velocity( - self, elements: _Element, dt: float, velocity_update: bool - ): - """Transfer nodal velocity to particles and update particle position. - - The velocity is calculated based on the total force at nodes. - - Parameters - ---------- - elements: diffmpm.element._Element - Elements whose nodes are used to transfer the velocity. - dt: float - Timestep. - velocity_update: bool - If True, velocity is directly used as nodal velocity, else - velocity is calculated is interpolated nodal acceleration - multiplied by dt. Default is False. - """ - mapped_positions = elements.shapefn(self.reference_loc) - mapped_ids = vmap(elements.id_to_node_ids)(self.element_ids).squeeze(-1) - nodal_velocity = jnp.sum( - mapped_positions * elements.nodes.velocity[mapped_ids], axis=1 + ids = jnp.arange(loc.shape[0]) + mass = jnp.ones((loc.shape[0], 1, 1)) + density = jnp.ones_like(mass) * material.density + volume = jnp.ones_like(mass) + size = jnp.zeros_like(loc) + velocity = jnp.ones_like(loc) * init_vel + acceleration = jnp.zeros_like(loc) + momentum = jnp.zeros_like(loc) + strain = jnp.zeros((loc.shape[0], 6, 1)) + stress = jnp.zeros((loc.shape[0], 6, 1)) + strain_rate = jnp.zeros((loc.shape[0], 6, 1)) + dstrain = jnp.zeros((loc.shape[0], 6, 1)) + f_ext = jnp.zeros_like(loc) + traction = jnp.zeros_like(loc) + reference_loc = jnp.zeros_like(loc) + dvolumetric_strain = jnp.zeros((loc.shape[0], 1)) + volumetric_strain_centroid = jnp.zeros((loc.shape[0], 1)) + state_vars = {} + if material.state_vars: + state_vars = material.initialize_state_variables(loc.shape[0]) + return _ParticlesState( + ids=ids, + nparticles=loc.shape[0], + loc=loc, + material=material, + element_ids=element_ids, + mass=mass, + density=density, + volume=volume, + size=size, + velocity=velocity, + acceleration=acceleration, + momentum=momentum, + strain=strain, + stress=stress, + strain_rate=strain_rate, + dstrain=dstrain, + f_ext=f_ext, + traction=traction, + reference_loc=reference_loc, + dvolumetric_strain=dvolumetric_strain, + volumetric_strain_centroid=volumetric_strain_centroid, + state_vars=state_vars, + ) + + +# TODO: Can these methods just return the updated arrays to +# a single function which then generates the new state? + + +def set_mass_volume(state, m: ArrayLike) -> _ParticlesState: + """Set particle mass. + + Parameters + ---------- + m: float, array_like + Mass to be set for particles. If scalar, mass for all + particles is set to this value. + """ + m = jnp.asarray(m) + if jnp.isscalar(m): + mass = jnp.ones_like(state.loc) * m + elif m.shape == state.mass.shape: + mass = m + else: + raise ValueError( + f"Incompatible shapes. Expected {state.mass.shape}, " f"found {m.shape}." ) - nodal_acceleration = jnp.sum( - mapped_positions * elements.nodes.acceleration[mapped_ids], - axis=1, + volume = jnp.divide(mass, state.material.properties["density"]) + return state.replace(mass=mass, volume=volume) + + +def _compute_particle_volume( + element_ids, total_elements, evol, pvol, psize, pmass, pdensity +): + """Compute volume of all particles. + + Parameters + ---------- + state: + Current state + elements: diffmpm._ElementState + Elements that the particles are present in, and are used to + compute the particles' volumes. + total_elements: int + Total elements present in `elements`. + """ + particles_per_element = jnp.bincount(element_ids, length=total_elements) + vol = ( + evol.squeeze((1, 2))[element_ids] # type: ignore + / particles_per_element[element_ids] + ) + volume = pvol.at[:, 0, 0].set(vol) + size = psize.at[:].set(volume ** (1 / psize.shape[-1])) + mass = pmass.at[:, 0, 0].set(vol * pdensity.squeeze()) + return {"mass": mass, "size": size, "volume": volume} + + +def compute_volume(state, elements: _ElementsState, elementor, total_elements: int): + """Compute volume of all particles. + + Parameters + ---------- + state: + Current state + elements: diffmpm._ElementState + Elements that the particles are present in, and are used to + compute the particles' volumes. + total_elements: int + Total elements present in `elements`. + """ + particles_per_element = jnp.bincount( + state.element_ids, length=elementor.total_elements + ) + vol = ( + elements.volume.squeeze((1, 2))[state.element_ids] # type: ignore + / particles_per_element[state.element_ids] + ) + volume = state.volume.at[:, 0, 0].set(vol) + size = state.size.at[:].set(volume ** (1 / state.size.shape[-1])) + mass = state.mass.at[:, 0, 0].set(vol * state.density.squeeze()) + return state.replace(mass=mass, size=size, volume=volume) + + +def _get_natural_coords(particles, p_mapped_ids, eloc): + r"""Update natural coordinates for the particles. + + Whenever the particles' physical coordinates change, their + natural coordinates need to be updated. This function updates + the natural coordinates of the particles based on the element + a particle is a part of. The update formula is + + \[ + \xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e) + \] + + where \(x_i^e\) are the nodal coordinates of the element the + particle is in. If a particle is not in any element + (element_id = -1), its natural coordinate is set to 0. + + Parameters + ---------- + elements: diffmpm.element._ElementState + Elements based on which to update the natural coordinates + of the particles. + """ + t = eloc[p_mapped_ids.squeeze(-1)] + xi_coords = (particles.loc - (t[:, 0, ...] + t[:, 2, ...]) / 2) * ( + 2 / (t[:, 2, ...] - t[:, 0, ...]) + ) + return xi_coords + + +def update_natural_coords(state, elements: _ElementsState, elementor, *args): + r"""Update natural coordinates for the particles. + + Whenever the particles' physical coordinates change, their + natural coordinates need to be updated. This function updates + the natural coordinates of the particles based on the element + a particle is a part of. The update formula is + + \[ + \xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e) + \] + + where \(x_i^e\) are the nodal coordinates of the element the + particle is in. If a particle is not in any element + (element_id = -1), its natural coordinate is set to 0. + + Parameters + ---------- + elements: diffmpm.element._ElementState + Elements based on which to update the natural coordinates + of the particles. + """ + t = vmap(Partial(elementor.id_to_node_loc, elements))(state.element_ids) + xi_coords = (state.loc - (t[:, 0, ...] + t[:, 2, ...]) / 2) * ( + 2 / (t[:, 2, ...] - t[:, 0, ...]) + ) + return state.replace(reference_loc=xi_coords) + + +def update_position_velocity( + state, elements: _ElementsState, elementor, dt: float, velocity_update: bool, *args +): + """Transfer nodal velocity to particles and update particle position. + + The velocity is calculated based on the total force at nodes. + + Parameters + ---------- + elements: diffmpm.element._ElementState + Elements whose nodes are used to transfer the velocity. + dt: float + Timestep. + velocity_update: bool + If True, velocity is directly used as nodal velocity, else + velocity is calculated is interpolated nodal acceleration + multiplied by dt. Default is False. + """ + mapped_positions = elementor.shapefn(state.reference_loc) + mapped_ids = vmap(Partial(elementor.id_to_node_ids, elements.nelements[0]))( + state.element_ids + ).squeeze(-1) + nodal_velocity = jnp.sum( + mapped_positions * elements.nodes.velocity[mapped_ids], axis=1 + ) + nodal_acceleration = jnp.sum( + mapped_positions * elements.nodes.acceleration[mapped_ids], + axis=1, + ) + velocity = state.velocity.at[:].set( + lax.cond( + velocity_update, + lambda sv, nv, na, t: nv, + lambda sv, nv, na, t: sv + na * t, + state.velocity, + nodal_velocity, + nodal_acceleration, + dt, ) - self.velocity = self.velocity.at[:].set( - lax.cond( - velocity_update, - lambda sv, nv, na, t: nv, - lambda sv, nv, na, t: sv + na * t, - self.velocity, - nodal_velocity, - nodal_acceleration, - dt, - ) - ) - self.loc = self.loc.at[:].add(nodal_velocity * dt) - self.momentum = self.momentum.at[:].set(self.mass * self.velocity) - - def compute_strain(self, elements: _Element, dt: float): - """Compute the strain on all particles. - - This is done by first calculating the strain rate for the particles - and then calculating strain as `strain += strain rate * dt`. - - Parameters - ---------- - elements: diffmpm.element._Element - Elements whose nodes are used to calculate the strain. - dt : float - Timestep. - """ - mapped_coords = vmap(elements.id_to_node_loc)(self.element_ids).squeeze(2) - dn_dx_ = vmap(elements.shapefn_grad)( - self.reference_loc[:, jnp.newaxis, ...], mapped_coords - ) - self.strain_rate = self._compute_strain_rate(dn_dx_, elements) - self.dstrain = self.dstrain.at[:].set(self.strain_rate * dt) - - self.strain = self.strain.at[:].add(self.dstrain) - centroids = jnp.zeros_like(self.loc) - dn_dx_centroid_ = vmap(elements.shapefn_grad)( - centroids[:, jnp.newaxis, ...], mapped_coords - ) - strain_rate_centroid = self._compute_strain_rate(dn_dx_centroid_, elements) - ndim = self.loc.shape[-1] - self.dvolumetric_strain = dt * strain_rate_centroid[:, :ndim].sum(axis=1) - self.volumetric_strain_centroid = self.volumetric_strain_centroid.at[:].add( - self.dvolumetric_strain - ) - - def _compute_strain_rate(self, dn_dx: ArrayLike, elements: _Element): - """Compute the strain rate for particles. - - Parameters - ---------- - dn_dx: ArrayLike - The gradient of the shape function. Expected shape - `(nparticles, 1, ndim)` - elements: diffmpm.element._Element - Elements whose nodes are used to calculate the strain rate. - """ - dn_dx = jnp.asarray(dn_dx) - strain_rate = jnp.zeros((dn_dx.shape[0], 6, 1)) # (nparticles, 6, 1) - mapped_vel = vmap(elements.id_to_node_vel)( - self.element_ids - ) # (nparticles, 2, 1) - - temp = mapped_vel.squeeze(2) - - def _step(pid, args): - dndx, nvel, strain_rate = args - matmul = dndx[pid].T @ nvel[pid] - strain_rate = strain_rate.at[pid, 0].add(matmul[0, 0]) - strain_rate = strain_rate.at[pid, 1].add(matmul[1, 1]) - strain_rate = strain_rate.at[pid, 3].add(matmul[0, 1] + matmul[1, 0]) - return dndx, nvel, strain_rate - - args = (dn_dx, temp, strain_rate) - _, _, strain_rate = lax.fori_loop(0, self.loc.shape[0], _step, args) - strain_rate = jnp.where( - jnp.abs(strain_rate) < 1e-12, jnp.zeros_like(strain_rate), strain_rate - ) - return strain_rate - - def compute_stress(self, *args): - """Compute the strain on all particles. - - This calculation is governed by the material of the - particles. The stress calculated by the material is then - added to the particles current stress values. - """ - self.stress = self.stress.at[:].add(self.material.compute_stress(self.dstrain)) - - def update_volume(self, *args): - """Update volume based on central strain rate.""" - self.volume = self.volume.at[:, 0, :].multiply(1 + self.dvolumetric_strain) - self.density = self.density.at[:, 0, :].divide(1 + self.dvolumetric_strain) - - def assign_traction(self, pids: ArrayLike, dir: int, traction_: float): - """Assign traction to particles. - - Parameters - ---------- - pids: ArrayLike - IDs of the particles to which traction should be applied. - dir: int - The direction in which traction should be applied. - traction_: float - Traction value to be applied in the direction. - """ - self.traction = self.traction.at[pids, 0, dir].add( - traction_ * self.volume[pids, 0, 0] / self.size[pids, 0, dir] - ) - - def zero_traction(self, *args): - """Set all traction values to 0.""" - self.traction = self.traction.at[:].set(0) + ) + loc = state.loc.at[:].add(nodal_velocity * dt) + momentum = state.momentum.at[:].set(state.mass * state.velocity) + return state.replace(velocity=velocity, loc=loc, momentum=momentum) + + +def _update_particle_position_velocity( + el_type, + ploc, + pvel, + pmom, + pmass, + pxi, + mapped_node_ids, + nvel, + nacc, + velocity_update, + dt, +): + """Transfer nodal velocity to particles and update particle position. + + The velocity is calculated based on the total force at nodes. + + Parameters + ---------- + elements: diffmpm.element._ElementState + Elements whose nodes are used to transfer the velocity. + dt: float + Timestep. + velocity_update: bool + If True, velocity is directly used as nodal velocity, else + velocity is calculated is interpolated nodal acceleration + multiplied by dt. Default is False. + """ + mapped_positions = el_type._shapefn(pxi) + mapped_ids = mapped_node_ids.squeeze(-1) + nodal_velocity = jnp.sum(mapped_positions * nvel[mapped_ids], axis=1) + nodal_acceleration = jnp.sum( + mapped_positions * nacc[mapped_ids], + axis=1, + ) + velocity = lax.cond( + velocity_update, + lambda sv, nv, na, t: nv, + lambda sv, nv, na, t: sv + na * t, + pvel, + nodal_velocity, + nodal_acceleration, + dt, + ) + loc = ploc.at[:].add(nodal_velocity * dt) + momentum = pmass * pvel + return {"velocity": velocity, "loc": loc, "momentum": momentum} + + +def _compute_strain_rate(mapped_vel, pids, dn_dx: ArrayLike): + """Compute the strain rate for particles. + + Parameters + ---------- + dn_dx: ArrayLike + The gradient of the shape function. Expected shape + `(nparticles, 1, ndim)` + elements: diffmpm.element._ElementState + Elements whose nodes are used to calculate the strain rate. + """ + dn_dx = jnp.asarray(dn_dx) + strain_rate = jnp.zeros((dn_dx.shape[0], 6, 1)) # (nparticles, 6, 1) + temp = mapped_vel.squeeze(2) + + @jit + def _step(pid, args): + dndx, nvel, strain_rate = args + matmul = dndx[pid].T @ nvel[pid] + strain_rate = strain_rate.at[pid, 0].add(matmul[0, 0]) + strain_rate = strain_rate.at[pid, 1].add(matmul[1, 1]) + strain_rate = strain_rate.at[pid, 3].add(matmul[0, 1] + matmul[1, 0]) + return dndx, nvel, strain_rate + + def _scan_step(carry, pid): + dndx, nvel, strain_rate = carry + matmul = dndx[pid].T @ nvel[pid] + strain_rate = strain_rate.at[pid, 0].add(matmul[0, 0]) + strain_rate = strain_rate.at[pid, 1].add(matmul[1, 1]) + strain_rate = strain_rate.at[pid, 3].add(matmul[0, 1] + matmul[1, 0]) + return (dndx, nvel, strain_rate), pid + + args = (dn_dx, temp, strain_rate) + # _, _, strain_rate = lax.fori_loop(0, nparticles, _step, args) + final_carry, _ = lax.scan(_scan_step, args, pids) + _, _, strain_rate = final_carry + strain_rate = jnp.where(jnp.abs(strain_rate) < 1e-12, 0, strain_rate) + return strain_rate + + +def _compute_strain( + pstrain, + pxi, + ploc, + pvolumetric_strain_centroid, + pids, + mapped_node_ids, + nloc, + nvel, + el_type, + dt, +): + """Compute the strain on all particles. + + This is done by first calculating the strain rate for the particles + and then calculating strain as `strain += strain rate * dt`. + + Parameters + ---------- + elements: diffmpm.element._ElementState + Elements whose nodes are used to calculate the strain. + dt : float + Timestep. + """ + mapped_nodes = mapped_node_ids.squeeze(-1) + mapped_coords = nloc[mapped_nodes] + mapped_vel = nvel[mapped_nodes] + dn_dx_ = vmap(el_type._shapefn_grad)(pxi[:, jnp.newaxis, ...], mapped_coords) + new_strain_rate = _compute_strain_rate(mapped_vel, pids, dn_dx_) + new_dstrain = new_strain_rate * dt + + new_strain = pstrain + new_dstrain + centroids = jnp.zeros_like(ploc) + dn_dx_centroid_ = vmap(jit(el_type._shapefn_grad))( + centroids[:, jnp.newaxis, ...], mapped_coords + ) + strain_rate_centroid = _compute_strain_rate( + mapped_vel, + pids, + dn_dx_centroid_, + ) + ndim = ploc.shape[-1] + new_dvolumetric_strain = dt * strain_rate_centroid[:, :ndim].sum(axis=1) + new_volumetric_strain_centroid = ( + pvolumetric_strain_centroid + new_dvolumetric_strain + ) + return { + "strain_rate": new_strain_rate, + "dstrain": new_dstrain, + "strain": new_strain, + "dvolumetric_strain": new_dvolumetric_strain, + "volumetric_strain_centroid": new_volumetric_strain_centroid, + } + + +def compute_strain(state, elements: _ElementsState, elementor, dt: float, *args): + """Compute the strain on all particles. + + This is done by first calculating the strain rate for the particles + and then calculating strain as `strain += strain rate * dt`. + + Parameters + ---------- + elements: diffmpm.element._ElementState + Elements whose nodes are used to calculate the strain. + dt : float + Timestep. + """ + # breakpoint() + mapped_coords = vmap(Partial(elementor.id_to_node_loc, elements))( + state.element_ids + ).squeeze(2) + mapped_vel = vmap(Partial(elementor.id_to_node_vel, elements))(state.element_ids) + dn_dx_ = vmap(jit(elementor.shapefn_grad))( + state.reference_loc[:, jnp.newaxis, ...], mapped_coords + ) + # strain_rate = _compute_strain_rate(state, dn_dx_, elements, elementor) + strain_rate = _compute_strain_rate(mapped_vel, state.nparticles, dn_dx_) + dstrain = state.dstrain.at[:].set(strain_rate * dt) + + strain = state.strain.at[:].add(dstrain) + centroids = jnp.zeros_like(state.loc) + dn_dx_centroid_ = vmap(jit(elementor.shapefn_grad))( + centroids[:, jnp.newaxis, ...], mapped_coords + ) + # strain_rate_centroid = _compute_strain_rate( + # state, dn_dx_centroid_, elements, elementor + # ) + strain_rate_centroid = _compute_strain_rate( + mapped_vel, state.nparticles, dn_dx_centroid_ + ) + ndim = state.loc.shape[-1] + dvolumetric_strain = dt * strain_rate_centroid[:, :ndim].sum(axis=1) + volumetric_strain_centroid = state.volumetric_strain_centroid.at[:].add( + dvolumetric_strain + ) + return state.replace( + strain_rate=strain_rate, + dstrain=dstrain, + strain=strain, + dvolumetric_strain=dvolumetric_strain, + volumetric_strain_centroid=volumetric_strain_centroid, + ) + + +def compute_stress(state, *args): + """Compute the strain on all particles. + + This calculation is governed by the material of the + particles. The stress calculated by the material is then + added to the particles current stress values. + """ + stress = state.stress.at[:].add(state.material.compute_stress(state)) + return state.replace(stress=stress) + + +def _compute_stress(stress, strain, dstrain, material, *args): + """Compute the strain on all particles. + + This calculation is governed by the material of the + particles. The stress calculated by the material is then + added to the particles current stress values. + """ + new_stress = stress + material.compute_stress(strain, dstrain) + return new_stress + + +def update_volume(state, *args): + """Update volume based on central strain rate.""" + volume = state.volume.at[:, 0, :].multiply(1 + state.dvolumetric_strain) + density = state.density.at[:, 0, :].divide(1 + state.dvolumetric_strain) + return state.replace(volume=volume, density=density) + + +def _update_particle_volume(pvol, pdensity, pdvolumetric_strain): + """Update volume based on central strain rate.""" + new_volume = pvol.at[:, 0, :].multiply(1 + pdvolumetric_strain) + new_density = pdensity.at[:, 0, :].divide(1 + pdvolumetric_strain) + return new_volume, new_density + + +def assign_traction(state, pids: ArrayLike, dir: int, traction_: float): + """Assign traction to particles. + + Parameters + ---------- + pids: ArrayLike + IDs of the particles to which traction should be applied. + dir: int + The direction in which traction should be applied. + traction_: float + Traction value to be applied in the direction. + """ + traction = state.traction.at[pids, 0, dir].add( + traction_ * state.volume[pids, 0, 0] / state.size[pids, 0, dir] + ) + return traction + + +def _assign_traction( + ptraction, + pvol, + psize, + pids: ArrayLike, + dir: int, + traction_val_: float, +): + """Assign traction to particles. + + Parameters + ---------- + pids: ArrayLike + IDs of the particles to which traction should be applied. + dir: int + The direction in which traction should be applied. + traction_: float + Traction value to be applied in the direction. + """ + traction = ptraction.at[pids, 0, dir].add( + traction_val_ * pvol[pids, 0, 0] / psize[pids, 0, dir] + ) + return traction + + +def _zero_traction(traction): + """Set all traction values to 0.""" + traction = jnp.zeros_like(traction) + return traction + + +def zero_traction(state, *args): + """Set all traction values to 0.""" + traction = state.traction.at[:].set(0) + return state.replace(traction=traction) diff --git a/diffmpm/scheme.py b/diffmpm/scheme.py index 61a062e..385dd7b 100644 --- a/diffmpm/scheme.py +++ b/diffmpm/scheme.py @@ -12,6 +12,8 @@ _schemes = ("usf", "usl") +from diffmpm.node import reset_node_state + class _MPMScheme(abc.ABC): def __init__(self, mesh, dt, velocity_update): @@ -21,6 +23,7 @@ def __init__(self, mesh, dt, velocity_update): def compute_nodal_kinematics(self): """Compute nodal kinematics - map mass and momentum to mesh nodes.""" + self.mesh.elements.nodes = reset_node_state(self.mesh.elements.nodes) self.mesh.apply_on_elements("set_particle_element_ids") self.mesh.apply_on_particles("update_natural_coords") self.mesh.apply_on_elements("compute_nodal_mass") @@ -57,9 +60,9 @@ def compute_forces(self, gravity: ArrayLike, step: int): def compute_particle_kinematics(self): """Compute particle location, acceleration and velocity.""" - self.mesh.apply_on_elements( - "update_nodal_acceleration_velocity", args=(self.dt,) - ) + self.mesh.apply_on_elements("update_nodal_acceleration", args=(self.dt,)) + self.mesh.apply_on_elements("update_nodal_velocity", args=(self.dt,)) + self.mesh.apply_on_elements("update_nodal_momentum", args=(self.dt,)) self.mesh.apply_on_particles( "update_position_velocity", args=(self.dt, self.velocity_update), diff --git a/diffmpm/solver.py b/diffmpm/solver.py index 3b1ae01..7245dad 100644 --- a/diffmpm/solver.py +++ b/diffmpm/solver.py @@ -4,11 +4,12 @@ from typing import TYPE_CHECKING, Callable, Optional import jax.numpy as jnp -from jax import lax +from jax import lax, profiler from jax.experimental.host_callback import id_tap from jax.tree_util import register_pytree_node_class from jax.typing import ArrayLike +from diffmpm.pbar import loop_tqdm from diffmpm.scheme import USF, USL, _MPMScheme, _schemes if TYPE_CHECKING: @@ -176,6 +177,7 @@ def solve_jit(self, gravity: ArrayLike) -> dict: final state of the simulation after completing all steps. """ + @loop_tqdm(self.sim_steps, print_rate=1) def _step(i, data): self = data self.mpm_scheme.compute_nodal_kinematics() @@ -210,6 +212,8 @@ def _write(self, i): ) return self + # with profiler.trace("/tmp/jax-trace", create_perfetto_link=True): + # self = lax.fori_loop(0, self.sim_steps, _step, self) self = lax.fori_loop(0, self.sim_steps, _step, self) arrays = {} for name in self.__particle_props: diff --git a/diffmpm/writers.py b/diffmpm/writers.py index fdc5cd2..b594fd7 100644 --- a/diffmpm/writers.py +++ b/diffmpm/writers.py @@ -1,10 +1,10 @@ import abc import logging from pathlib import Path +from typing import Annotated, Any, Tuple -from typing import Tuple, Annotated, Any -from jax.typing import ArrayLike import numpy as np +from jax.typing import ArrayLike logger = logging.getLogger(__file__) diff --git a/examples/mpm-nodal-forces.toml b/examples/mpm-nodal-forces.toml deleted file mode 100644 index cf01f1e..0000000 --- a/examples/mpm-nodal-forces.toml +++ /dev/null @@ -1,73 +0,0 @@ -# The `meta` group contains top level attributes that govern the -# behaviour of the MPM Solver. -# -# Attributes: -# title: The title of the experiment. This is just for the user's -# reference. -# type: The type of simulation to be used. Allowed values are -# {"MPMExplicit"} -# scheme: The MPM Scheme used for simulation. Allowed values are -# {"usl", "usf"} -# dt: Timestep used in the simulation. -# nsteps: Number of steps to run the simulation for. -[meta] -title = "uniaxial-nodal-traction" -type = "MPMExplicit" -dimension = 2 -scheme = "usf" -dt = 0.001 -nsteps = 301 -velocity_update = true - -[output] -type = "hdf5" -file = "results/example_2d_out.hdf5" -step_frequency = 5 - -[mesh] -# type = "file" -# file = "mesh-1d.txt" -# boundary_nodes = "boundary-1d.txt" -# particle_element_ids = "particles-elements.txt" -type = "generator" -nelements = [3, 1] -element_length = [0.1, 0.1] -particle_element_ids = [0] -element = "Quadrilateral4Node" - -[[mesh.constraints]] -node_ids = [0, 4] -dir = 0 -velocity = 0.0 - -[[materials]] -id = 0 -density = 1000 -poisson_ratio = 0 -youngs_modulus = 1000000 -type = "LinearElastic" - -[[particles]] -file = "examples/particles-2d-nodal-force.json" -material_id = 0 -init_velocity = 0.0 - -[external_loading] -gravity = [0, 0] - -[[external_loading.concentrated_nodal_forces]] -node_ids = [3, 7] -math_function_id = 0 -dir = 0 -force = 0.05 - -[[external_loading.particle_surface_traction]] -pset = [1] -dir = 1 -math_function_id = 0 -traction = 10.5 - -[[math_functions]] -type = "Linear" -xvalues = [0.0, 0.5, 1.0] -fxvalues = [0.0, 1.0, 1.0] diff --git a/examples/mpm-uniaxial-stress.toml b/examples/mpm-uniaxial-stress.toml deleted file mode 100644 index 4f8065e..0000000 --- a/examples/mpm-uniaxial-stress.toml +++ /dev/null @@ -1,61 +0,0 @@ -# The `meta` group contains top level attributes that govern the -# behaviour of the MPM Solver. -# -# Attributes: -# title: The title of the experiment. This is just for the user's -# reference. -# type: The type of simulation to be used. Allowed values are -# {"MPMExplicit"} -# scheme: The MPM Scheme used for simulation. Allowed values are -# {"usl", "usf"} -# dt: Timestep used in the simulation. -# nsteps: Number of steps to run the simulation for. -[meta] -title = "uniaxial-stress" -type = "MPMExplicit" -dimension = 2 -scheme = "usf" -dt = 0.01 -nsteps = 10 -velocity_update = false - -[output] -format = "npz" -folder = "results/" -step_frequency = 5 - -[mesh] -# type = "file" -# file = "mesh-1d.txt" -# boundary_nodes = "boundary-1d.txt" -# particle_element_ids = "particles-elements.txt" -type = "generator" -nelements = [1, 1] -element_length = [1, 1] -particle_element_ids = [0] -element = "Quadrilateral4Node" - -[[mesh.constraints]] -node_ids = [0, 1] -dir = 1 -velocity = 0.0 - -[[mesh.constraints]] -node_ids = [2, 3] -dir = 1 -velocity = -0.01 - -[[materials]] -id = 0 -density = 1 -poisson_ratio = 0 -youngs_modulus = 1000 -type = "LinearElastic" - -[[particles]] -file = "examples/particles-2d-uniaxial-stress.json" -material_id = 0 -init_velocity = [1.0, 0.0] - -[external_loading] -gravity = [0, 0] diff --git a/examples/optim_benchmark.py b/examples/optim_benchmark.py new file mode 100644 index 0000000..48f2655 --- /dev/null +++ b/examples/optim_benchmark.py @@ -0,0 +1,165 @@ +from typing import NamedTuple +from functools import partial +import matplotlib.pyplot as plt + +import jax +import jax.numpy as jnp +import optax +from tqdm import tqdm + +from diffmpm.constraint import Constraint +from diffmpm.element import Quad4N, Quad4NState +from diffmpm.explicit import ExplicitSolver +from diffmpm.forces import NodalForce +from diffmpm.functions import Unit +from diffmpm.io import Config +from diffmpm.materials import init_simple, init_linear_elastic +from diffmpm.particle import _ParticlesState, init_particle_state + +jax.config.update("jax_platform_name", "cpu") + +config = Config("./benchmarks/2d/uniaxial_stress/mpm-uniaxial-stress.toml") +# config = Config("./benchmarks/2d/uniaxial_particle_traction/mpm-particle-traction.toml") +# config = Config("./benchmarks/2d/uniaxial_nodal_forces/mpm-nodal-forces.toml") +# config = Config("./benchmarks/2d/hydrostatic_column/mpm.toml") +# parsed_config = config.parse() +# cnf = [NodalForce(node_ids=jnp.array([0, 1]), function=Unit(-1), dir=1, force=1.5)] +# material = NamedTuple("Simple", density=1, E=1, state_vars={}) +# ploc = jnp.array([[0.5, 0.5], [0.5, 0.5]]).reshape(2, 1, 2) +# pmat = material(density=1.0, E=1.0, state_vars={}) +# pmat = init_simple({"density": 1, "E": 100, "id": 1}) +# peids = jnp.array([1]) +# particles = [init_particle_state(ploc, pmat, peids)] + +# cls = Quad4N(total_elements=1) +# elements = cls.init_state( +# (1, 1), +# 1, +# (1, 1), +# [(jnp.array([0]), Constraint(0, 2))], +# concentrated_nodal_forces=cnf, +# ) + +solver = ExplicitSolver( + el_type=config.parsed_config["elementor"], + tol=1e-12, + scheme=config.parsed_config["meta"]["scheme"], + dt=config.parsed_config["meta"]["dt"], + velocity_update=config.parsed_config["meta"]["velocity_update"], + sim_steps=config.parsed_config["meta"]["nsteps"], + out_steps=config.parsed_config["output"]["step_frequency"], + out_dir=config.parsed_config["output"]["format"], + gravity=config.parsed_config["external_loading"]["gravity"], +) + +init_vals = solver.init_state( + { + "elements": config.parsed_config["elements"], + "particles": config.parsed_config["particles"], + "particle_surface_traction": config.parsed_config["particle_surface_traction"], + } +) + +jit_updated = init_vals +jitted_update = jax.jit(solver.update) +for step in tqdm(range(20)): + jit_updated = jitted_update(jit_updated, step + 1) + +true_vel = jit_updated.particles[0].stress + + +def compute_loss(params, *, solver, target_vel, config): + material = init_linear_elastic( + { + "youngs_modulus": params["ym"], + "density": 1, + "poisson_ratio": 0, + "id": -1, + } + ) + particles_ = [ + init_particle_state( + config.parsed_config["particles"][0].loc, + material, + config.parsed_config["particles"][0].element_ids, + init_vel=jnp.asarray([1.0, 0.0]), + ) + ] + init_vals = solver.init_state( + { + "elements": config.parsed_config["elements"], + "particles": particles_, + "particle_surface_traction": config.parsed_config[ + "particle_surface_traction" + ], + } + ) + result = init_vals + for step in tqdm(range(20), leave=False): + result = jitted_update(result, step + 1) + vel = result.particles[0].stress + loss = jnp.linalg.norm(vel - target_vel) + return loss + + +def optax_adam(params, niter, mpm, target_vel, config): + # Initialize parameters of the model + optimizer. + start_learning_rate = 4 + optimizer = optax.adam(start_learning_rate) + opt_state = optimizer.init(params) + + param_list = {"ym": [], "pr": []} + loss_list = [] + # A simple update loop. + t = tqdm(range(niter), desc=f"E: {params}") + partial_f = partial(compute_loss, solver=mpm, target_vel=target_vel, config=config) + for _ in t: + lo, grads = jax.value_and_grad(partial_f, argnums=0)(params) + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + t.set_description(f"YM: {params['ym']:.2f}") + param_list["ym"].append(params["ym"]) + # param_list["pr"].append(params["pr"]) + loss_list.append(lo) + return param_list, loss_list + + +# params = {"pr": 0.4} +params = {"ym": 1101.0} +# material = init_simple({"E": params, "density": 1, "id": -1}) +material = init_linear_elastic( + { + "youngs_modulus": params["ym"], + "density": 1, + "poisson_ratio": 0, + "id": -1, + } +) +particles = [ + init_particle_state( + config.parsed_config["particles"][0].loc, + material, + config.parsed_config["particles"][0].element_ids, + ) +] + +init_vals = solver.init_state( + { + "elements": config.parsed_config["elements"], + "particles": particles, + "particle_surface_traction": config.parsed_config["particle_surface_traction"], + } +) +param_list, loss_list = optax_adam( + params, 200, solver, true_vel, config +) # ADAM optimizer + +fig, ax = plt.subplots(1, 2, figsize=(16, 6)) +ax[0].plot(param_list["ym"], "ko", markersize=2, label="Youngs Modulus") +ax[0].grid() +ax[0].legend() +ax[1].plot(loss_list, "ko", markersize=2, label="Loss") +ax[1].grid() +ax[1].legend() +# plt.show() +fig.savefig("./examples/optim_uniaxial_stress.png") diff --git a/examples/optim_uniaxial_stress.png b/examples/optim_uniaxial_stress.png new file mode 100644 index 0000000..2a6fccd Binary files /dev/null and b/examples/optim_uniaxial_stress.png differ diff --git a/examples/particles-2d-nodal-force.json b/examples/particles-2d-nodal-force.json deleted file mode 100644 index f0143b8..0000000 --- a/examples/particles-2d-nodal-force.json +++ /dev/null @@ -1,74 +0,0 @@ -[ - [ - [ - 0.025, - 0.025 - ] - ], - [ - [ - 0.075, - 0.025 - ] - ], - [ - [ - 0.125, - 0.025 - ] - ], - [ - [ - 0.175, - 0.025 - ] - ], - [ - [ - 0.225, - 0.025 - ] - ], - [ - [ - 0.275, - 0.025 - ] - ], - [ - [ - 0.025, - 0.075 - ] - ], - [ - [ - 0.075, - 0.075 - ] - ], - [ - [ - 0.125, - 0.075 - ] - ], - [ - [ - 0.175, - 0.075 - ] - ], - [ - [ - 0.225, - 0.075 - ] - ], - [ - [ - 0.275, - 0.075 - ] - ] -] \ No newline at end of file diff --git a/examples/particles-2d-uniaxial-stress.json b/examples/particles-2d-uniaxial-stress.json deleted file mode 100644 index 3b22d51..0000000 --- a/examples/particles-2d-uniaxial-stress.json +++ /dev/null @@ -1,26 +0,0 @@ -[ - [ - [ - 0.25, - 0.25 - ] - ], - [ - [ - 0.75, - 0.25 - ] - ], - [ - [ - 0.75, - 0.75 - ] - ], - [ - [ - 0.25, - 0.75 - ] - ] -] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ad356c2..6016ce1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,8 @@ authors = [ readme = "README.md" version = "0.0.1" dependencies = [ - "jax[cpu]", - "click" + # "jax[cuda11_pip]", + # "click" ] classifiers = [ "Programming Language :: Python :: 3", @@ -26,3 +26,6 @@ mpm = "diffmpm.cli.mpm:mpm" [tool.black] line-length = 88 + +[tool.ruff] +line-length = 88 diff --git a/tests/newtonian.py b/tests/newtonian.py new file mode 100644 index 0000000..518a246 --- /dev/null +++ b/tests/newtonian.py @@ -0,0 +1,90 @@ +import jax.numpy as jnp +import pytest +from diffmpm.constraint import Constraint +from diffmpm.element import Quadrilateral4Node +from diffmpm.materials import Newtonian +from diffmpm.node import Nodes +from diffmpm.particle import Particles + +particles_element_targets = [ + ( + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + Newtonian( + { + "density": 1000, + "bulk_modulus": 8333333.333333333, + "dynamic_viscosity": 8.9e-4, + } + ), + jnp.array([0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [(0, Constraint(0, 0.02)), (0, Constraint(1, 0.03))], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array( + [ + -52083.3333338896, + -52083.3333355583, + -52083.3333305521, + -0.0000041719, + 0, + 0, + ] + ).reshape(1, 6, 1), + ), + ( + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + Newtonian( + { + "density": 1000, + "bulk_modulus": 8333333.333333333, + "dynamic_viscosity": 8.9e-4, + "incompressible": True, + } + ), + jnp.array([0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [(0, Constraint(0, 0.02)), (0, Constraint(1, 0.03))], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array( + [ + -0.0000033375, + -0.00000500625, + 0, + -0.0000041719, + 0, + 0, + ] + ).reshape(1, 6, 1), + ), +] + + +@pytest.mark.parametrize( + "particles, element, target", + particles_element_targets, +) +def test_compute_stress(particles, element, target): + dt = 1 + particles.update_natural_coords(element) + if element.constraints: + element.apply_boundary_constraints() + particles.compute_strain(element, dt) + stress = particles.material.compute_stress(particles) + assert jnp.allclose(stress, target) + + +def test_init(): + with pytest.raises(KeyError): + Newtonian({"dynamic_viscosity": 1, "density": 1}) diff --git a/tests/test_element.py b/tests/test_element.py index 50881d9..72adc6c 100644 --- a/tests/test_element.py +++ b/tests/test_element.py @@ -5,8 +5,9 @@ from diffmpm.element import Quadrilateral4Node from diffmpm.forces import NodalForce from diffmpm.functions import Unit -from diffmpm.material import SimpleMaterial -from diffmpm.particle import Particles +from diffmpm.materials import init_simple +from diffmpm.particle import init_particle_state +import diffmpm.particle as dpar class TestLinear1D: @@ -21,8 +22,8 @@ def elements(self): @pytest.fixture def particles(self): loc = jnp.array([[0.5, 0.5], [0.5, 0.5]]).reshape(2, 1, 2) - material = SimpleMaterial({"E": 1, "density": 1}) - return Particles(loc, material, jnp.array([0, 0])) + material = init_simple({"id": 0, "E": 1, "density": 1}) + return init_particle_state(loc, material, jnp.array([0, 0])) @pytest.mark.parametrize( "particle_coords, expected", @@ -99,28 +100,30 @@ def test_element_node_loc(self, elements): assert jnp.all(node_loc == true_loc) def test_element_node_vel(self, elements): - elements.nodes.velocity += jnp.array([1, 1]) + elements.nodes = elements.nodes.replace( + velocity=elements.nodes.velocity + jnp.array([1, 1]) + ) node_vel = elements.id_to_node_vel(0) true_vel = jnp.array([[1.0, 1.0], [1, 1], [1, 1], [1, 1]]).reshape(4, 1, 2) assert jnp.all(node_vel == true_vel) def test_compute_nodal_mass(self, elements, particles): - particles.mass += 1 - elements.compute_nodal_mass(particles) + particles = particles.replace(mass=particles.mass + 1) + nodal_mass, _ = elements.compute_nodal_mass(particles) true_mass = jnp.ones((4, 1, 1)) - assert jnp.all(elements.nodes.mass == true_mass) + assert jnp.all(nodal_mass == true_mass) def test_compute_nodal_momentum(self, elements, particles): - particles.velocity += 1 - elements.compute_nodal_momentum(particles) + particles = particles.replace(velocity=particles.velocity + 1) + nodal_momentum, _ = elements.compute_nodal_momentum(particles) true_momentum = jnp.ones((4, 1, 1)) * 0.5 - assert jnp.all(elements.nodes.momentum == true_momentum) + assert jnp.all(nodal_momentum == true_momentum) def test_compute_external_force(self, elements, particles): - particles.f_ext += 1 - elements.compute_external_force(particles) + particles = particles.replace(f_ext=particles.f_ext + 1) + nodal_f_ext, _ = elements.compute_external_force(particles) true_fext = jnp.ones((4, 1, 1)) * 0.5 - assert jnp.all(elements.nodes.f_ext == true_fext) + assert jnp.all(nodal_f_ext == true_fext) @pytest.mark.parametrize( "gravity, expected", @@ -133,9 +136,9 @@ def test_compute_external_force(self, elements, particles): ], ) def test_compute_body_force(self, elements, particles, gravity, expected): - particles.mass += 1 - elements.compute_body_force(particles, gravity) - assert jnp.all(elements.nodes.f_ext == expected) + particles = particles.replace(mass=particles.mass + 1) + nodal_f_ext, _ = elements.compute_body_force(particles, gravity) + assert jnp.all(nodal_f_ext == expected) def test_apply_concentrated_nodal_force(self, particles): cnf_1 = NodalForce( @@ -153,50 +156,57 @@ def test_apply_concentrated_nodal_force(self, particles): elements = Quadrilateral4Node( (1, 1), 1, 1, [], concentrated_nodal_forces=[cnf_1, cnf_2] ) - elements.apply_concentrated_nodal_forces(particles, 1) + elements.nodes = elements.nodes.replace(f_ext=elements.nodes.f_ext + 2) + nodal_f_ext, _ = elements.apply_concentrated_nodal_forces(particles, 1) assert jnp.all( - elements.nodes.f_ext - == jnp.array([[1, 0], [0, 0], [1, 1], [0, 0]]).reshape(4, 1, 2) + nodal_f_ext == jnp.array([[3, 2], [2, 2], [3, 3], [2, 2]]).reshape(4, 1, 2) ) def test_apply_boundary_constraints(self): - cons = [(jnp.array([0]), Constraint(0, 0))] + cons = [ + (jnp.array([0, 1]), Constraint(0, 0)), + (jnp.array([0]), Constraint(1, 2)), + ] elements = Quadrilateral4Node((1, 1), 1, (1.0, 1.0), cons) - elements.nodes.velocity += 1 - elements.apply_boundary_constraints() + elements.nodes = elements.nodes.replace(velocity=elements.nodes.velocity + 1) + node_state = elements.apply_boundary_constraints() assert jnp.all( - elements.nodes.velocity - == jnp.array([[0, 1], [1, 1], [1, 1], [1, 1]]).reshape(4, 1, 2) + node_state.velocity + == jnp.array([[0, 2], [0, 1], [1, 1], [1, 1]]).reshape(4, 1, 2) ) def test_update_nodal_acceleration_velocity(self, elements, particles): - elements.nodes.f_ext += jnp.array([1, 0]) - elements.nodes.mass = elements.nodes.mass.at[:].set(2) - elements.update_nodal_acceleration_velocity(particles, 0.1) + f_ext = elements.nodes.f_ext + jnp.array([1, 0]) + mass = elements.nodes.mass.at[:].set(2) + elements.nodes = elements.nodes.replace(mass=mass, f_ext=f_ext) + nodal_acc, _ = elements.update_nodal_acceleration(particles, 0.1) assert jnp.allclose( - elements.nodes.acceleration, + nodal_acc, jnp.array([[0.5, 0.0], [0.5, 0], [0.5, 0], [0.5, 0]]), ) + nodal_vel, _ = elements.update_nodal_velocity(particles, 0.1) assert jnp.allclose( - elements.nodes.velocity, + nodal_vel, jnp.array([[0.05, 0.0], [0.05, 0], [0.05, 0], [0.05, 0]]), ) + elements.nodes = elements.nodes.replace(velocity=nodal_vel) + nodal_mom, _ = elements.update_nodal_momentum(particles, 0.1) assert jnp.allclose( - elements.nodes.momentum, + nodal_mom, jnp.array([[0.1, 0.0], [0.1, 0], [0.1, 0], [0.1, 0]]), ) def test_set_particle_element_ids(self, elements, particles): - particles.element_ids = jnp.array([-1, -1]) - elements.set_particle_element_ids(particles) + particles = particles.replace(element_ids=jnp.array([-1, -1])) + particles = elements.set_particle_element_ids(particles) assert jnp.all(particles.element_ids == jnp.array([0, 0])) def test_compute_internal_force(self, elements, particles): - particles.compute_volume(elements, 1) - particles.stress += 1 - elements.compute_internal_force(particles) + particles = dpar.compute_volume(particles, elements, 1) + particles = particles.replace(stress=particles.stress + 1) + nodal_f_int, _ = elements.compute_internal_force(particles) assert jnp.allclose( - elements.nodes.f_int, + nodal_f_int, jnp.array([[1, 1], [0, 0], [0, 0], [-1, -1]]).reshape(4, 1, 2), ) @@ -205,9 +215,9 @@ def test_compute_volume(self, elements): assert jnp.allclose(elements.volume, jnp.array([1]).reshape(1, 1, 1)) def test_apply_particle_traction_forces(self, elements, particles): - particles.traction += jnp.array([1, 0]) - elements.apply_particle_traction_forces(particles) + particles = particles.replace(traction=particles.traction + jnp.array([1, 0])) + nodal_f_ext, _ = elements.apply_particle_traction_forces(particles) assert jnp.allclose( - elements.nodes.f_ext, + nodal_f_ext, jnp.array([[0.5, 0], [0.5, 0], [0.5, 0], [0.5, 0]]).reshape(4, 1, 2), ) diff --git a/tests/test_material.py b/tests/test_material.py index 2e041d7..075b9f7 100644 --- a/tests/test_material.py +++ b/tests/test_material.py @@ -1,28 +1,50 @@ import jax.numpy as jnp import pytest +from diffmpm.materials import init_linear_elastic, init_simple +from diffmpm.particle import init_particle_state -from diffmpm.material import LinearElastic, SimpleMaterial - -material_dstrain_stress_targets = [ +particles_dstrain_stress_targets = [ ( - SimpleMaterial({"E": 10, "density": 1}), + init_particle_state( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + init_simple({"id": 0, "E": 10, "density": 1}), + jnp.array([0]), + ), jnp.ones((1, 6, 1)), jnp.ones((1, 6, 1)) * 10, ), ( - LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}), + init_particle_state( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + init_linear_elastic( + {"id": 0, "density": 1, "youngs_modulus": 10, "poisson_ratio": 1} + ), + jnp.array([0]), + ), jnp.ones((1, 6, 1)), jnp.array([-10, -10, -10, 2.5, 2.5, 2.5]).reshape(1, 6, 1), ), ( - LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}), + init_particle_state( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + init_linear_elastic( + {"id": 0, "density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3} + ), + jnp.array([0]), + ), jnp.array([0.001, 0.0005, 0, 0, 0, 0]).reshape(1, 6, 1), jnp.array([1.63461538461538e4, 12500, 0.86538461538462e4, 0, 0, 0]).reshape( 1, 6, 1 ), ), ( - LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}), + init_particle_state( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + init_linear_elastic( + {"id": 0, "density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3} + ), + jnp.array([0]), + ), jnp.array([0.001, 0.0005, 0, 0.00001, 0, 0]).reshape(1, 6, 1), jnp.array( [1.63461538461538e4, 12500, 0.86538461538462e4, 3.84615384615385e01, 0, 0] @@ -31,7 +53,8 @@ ] -@pytest.mark.parametrize("material, dstrain, target", material_dstrain_stress_targets) -def test_compute_stress(material, dstrain, target): - stress = material.compute_stress(dstrain) +@pytest.mark.parametrize("particles, dstrain, target", particles_dstrain_stress_targets) +def test_compute_stress(particles, dstrain, target): + particles = particles.replace(dstrain=dstrain) + stress = particles.material.compute_stress(particles) assert jnp.allclose(stress, target) diff --git a/tests/test_particle.py b/tests/test_particle.py index d7dedaa..6dd7b1d 100644 --- a/tests/test_particle.py +++ b/tests/test_particle.py @@ -1,21 +1,26 @@ import jax.numpy as jnp +from functools import partial import pytest +from jax import vmap -from diffmpm.element import Quadrilateral4Node -from diffmpm.material import SimpleMaterial -from diffmpm.particle import Particles +from diffmpm.element import Quad4N +from diffmpm.materials import init_simple +from diffmpm.particle import init_particle_state +import diffmpm.particle as dpar class TestParticles: + elementor = Quad4N(total_elements=1) + @pytest.fixture def elements(self): - return Quadrilateral4Node((1, 1), 1, (1.0, 1.0), []) + return self.elementor.init_state((1, 1), 1, (1.0, 1.0), []) @pytest.fixture def particles(self): loc = jnp.array([[0.5, 0.5], [0.5, 0.5]]).reshape(2, 1, 2) - material = SimpleMaterial({"E": 1, "density": 1}) - return Particles(loc, material, jnp.array([0, 0])) + material = init_simple({"id": 0, "E": 1, "density": 1}) + return init_particle_state(loc, material, jnp.array([0, 0])) @pytest.mark.parametrize( "velocity_update, expected", @@ -25,38 +30,72 @@ def particles(self): ], ) def test_update_velocity(self, elements, particles, velocity_update, expected): - particles.update_natural_coords(elements) - elements.nodes.acceleration += 1 - elements.nodes.velocity += 1 - particles.update_position_velocity(elements, 0.1, velocity_update) - assert jnp.allclose(particles.velocity, expected) + dpar.update_natural_coords(particles, elements, self.elementor) + elements.nodes = elements.nodes.replace( + acceleration=elements.nodes.acceleration + 1 + ) + elements.nodes = elements.nodes.replace(velocity=elements.nodes.velocity + 1) + updated = dpar._update_particle_position_velocity( + Quad4N, + particles.loc, + particles.velocity, + particles.momentum, + particles.mass, + particles.reference_loc, + vmap(partial(self.elementor.id_to_node_ids, 1))(particles.element_ids), + elements.nodes.velocity, + elements.nodes.acceleration, + velocity_update, + 0.1, + ) + assert jnp.allclose(updated["velocity"], expected) def test_compute_strain(self, elements, particles): - elements.nodes.velocity = jnp.array([[0, 1], [0, 2], [0, 3], [0, 4]]).reshape( - 4, 1, 2 + elements.nodes = elements.nodes.replace( + velocity=jnp.array([[0, 1], [0, 2], [0, 3], [0, 4]]).reshape(4, 1, 2) ) - particles.update_natural_coords(elements) - particles.compute_strain(elements, 0.1) - assert jnp.allclose( + particles = dpar.update_natural_coords(particles, elements, self.elementor) + updated = dpar._compute_strain( particles.strain, + particles.reference_loc, + particles.loc, + particles.volumetric_strain_centroid, + particles.nparticles, + vmap(partial(self.elementor.id_to_node_ids, 1))(particles.element_ids), + elements.nodes.loc, + elements.nodes.velocity, + Quad4N, + 0.1, + ) + assert jnp.allclose( + updated["strain"], jnp.array([[0, 0.2, 0, 0.1, 0, 0], [0, 0.2, 0, 0.1, 0, 0]]).reshape( 2, 6, 1 ), ) - assert jnp.allclose(particles.volumetric_strain_centroid, jnp.array([0.2])) + assert jnp.allclose(updated["volumetric_strain_centroid"], jnp.array([0.2])) def test_compute_volume(self, elements, particles): - particles.compute_volume(elements, elements.total_elements) - assert jnp.allclose(particles.volume, jnp.array([0.5, 0.5]).reshape(2, 1, 1)) + props = dpar._compute_particle_volume( + particles.element_ids, + self.elementor.total_elements, + elements.volume, + particles.volume, + particles.size, + particles.mass, + particles.density, + ) + assert jnp.allclose(props["volume"], jnp.array([0.5, 0.5]).reshape(2, 1, 1)) + @pytest.mark.skip() def test_assign_traction(self, elements, particles): - particles.compute_volume(elements, elements.total_elements) - particles.assign_traction(jnp.array([0]), 1, 10) + particles = dpar.compute_volume(particles, elements, elements.total_elements) + traction = dpar.assign_traction(particles, jnp.array([0]), 1, 10) assert jnp.allclose( - particles.traction, jnp.array([[0, 7.071068], [0, 0]]).reshape(2, 1, 2) + traction, jnp.array([[0, 7.071068], [0, 0]]).reshape(2, 1, 2) ) def test_zero_traction(self, particles): - particles.traction += 1 - particles.zero_traction() - assert jnp.all(particles.traction == 0) + particles = particles.replace(traction=particles.traction + 1) + traction = dpar._zero_traction(particles.traction) + assert jnp.all(traction == 0)