CARDIAX-NNFE is a GPU-accelerated scientific machine learning framework specifically for the Neural Network Finite Element method based on JAX. The major dependencies along with JAX are Equinox for network creation, Optax for optimization procedures, and CARDIAX for finite element residual computation. This package is actively managed by the Willerson Center for Cardiovascular Modeling and Simulation (WCCMS) and is constantly adapting to accommodate the suite of problems we are intereseted in solving. We are only focused on GPU development.
Before installing nnfe be sure to install jax at JAX Install. Verify that the GPU is seen by running the following to see if CUDA devices are found
import jax
print(jax.devices())Once the jax installation is working, the easiest option is to build all the dependencies through a conda environment using environment.yaml which also installs JAX with CUDA. These files set up the pypi dependecies. CARDIAX isn't yet pypi, so you can install it through github at CARDIAX. Then to install nnfe, you must clone and go inside the directory ../NNFE to run
pip install -e .
In the documentation, there are examples that walk through how to use the code. These are under demos, but the files are markdown format to explain functionality. The corresponding *.py files live in the NNFE/demos directory. The main demo currently is the prolate spheroid, which is the illustrative example in SoftwareX submission.
While JAX supports CPU, NNFE is not being tested on CPU environments. We created this codebase to fully leverage GPUs, but the functionality should remain consistent. Also, multi-GPU functionality is also not available. The problems we are currently solving can fit on the memory of a single GPU, so we will not develop this parallelization until needed.
The scope of finite element limitations are inherited from CARDIAX. The rule is if the problem can be solved in traditional FE with CARDIAX then you have the ability to attempt to train a network to solve the parameterized problem.
This project is licensed under the GNU General Public License v3 - see the LICENSE for details.
If you're using this project, you can cite this work here.
We'll add a list of others papers built upon this framework below:
