diff --git a/examples/elastoplasticity/mechanical_2d_elastoplasticity.py b/examples/elastoplasticity/mechanical_2d_elastoplasticity.py new file mode 100644 index 0000000..c8d2607 --- /dev/null +++ b/examples/elastoplasticity/mechanical_2d_elastoplasticity.py @@ -0,0 +1,845 @@ +import sys +import os +import shutil +import jax +import numpy as np + +from fol.loss_functions.mechanical_elastoplasticity import ElastoplasticityLoss2DQuad +from fol.mesh_input_output.mesh import Mesh +from fol.controls.voronoi_control2D import VoronoiControl2D +from fol.solvers.fe_nonlinear_residual_based_solver_with_history_update import FiniteElementNonLinearResidualBasedSolverWithStateUpdate +from fol.tools.usefull_functions import * +from fol.tools.logging_functions import Logger +import matplotlib.pyplot as plt +import pickle, time + + + +def main(solve_FE=True, clean_dir=False): + # directory & save handling + working_directory_name = 'mechanical_2d_elastoplasticity' + case_dir = os.path.join('.', working_directory_name) + create_clean_directory(working_directory_name) + sys.stdout = Logger(os.path.join(case_dir, working_directory_name + ".log")) + + # problem setup + model_settings = { + "L": 1, "N": 40, + "Ux_left": 0.0, "Ux_right": 0.08 + , + "Uy_left": 0.0, "Uy_right": 0.08 + } + + # creation of the model + fe_mesh = create_2D_square_mesh(L=model_settings["L"], N=model_settings["N"]) + + # create FE-based loss function (Dirichlet BCs and material params for stress viz) + bc_dict = { + "Ux": {"left": model_settings["Ux_left"], "right": model_settings["Ux_right"]}, + "Uy": {"left": model_settings["Uy_left"], "right": model_settings["Uy_right"]} + } + material_dict = {"young_modulus": 3.0, "poisson_ratio": 0.3, "iso_hardening_parameter_1": 0.4, "iso_hardening_param_2" :10.0, "yield_limit" :0.2} + + mechanical_loss_2d = ElastoplasticityLoss2DQuad( + "mechanical_loss_2d", + loss_settings={ + "dirichlet_bc_dict": bc_dict, + "num_gp": 2, + "material_dict": material_dict + }, + fe_mesh=fe_mesh + ) + + with open(f'voroni_control_dict.pkl', 'rb') as f: + voronoi_control_settings = pickle.load(f) + voronoi_control = VoronoiControl2D("first_voronoi_control",voronoi_control_settings,fe_mesh) + + fe_mesh.Initialize() + mechanical_loss_2d.Initialize() + voronoi_control.Initialize() + + coeffs_matrix = voronoi_control_settings["coeffs_matrix"] + K_matrix = voronoi_control.ComputeBatchControlledVariables(coeffs_matrix) + + # specify id of the K of interest + eval_id = 25 + + # classical FE solve (no ML) + if solve_FE: + fe_setting = { + "linear_solver_settings": { + "solver": "JAX-direct", + "tol": 1e-6, + "atol": 1e-6, + "maxiter": 1000, + "pre-conditioner": "ilu" + }, + "nonlinear_solver_settings": { + "rel_tol": 1e-5, + "abs_tol": 1e-5, + "maxiter": 100, + "load_incr": 100 + } + } + + nonlinear_fe_solver = FiniteElementNonLinearResidualBasedSolverWithStateUpdate( + "nonlinear_fe_solver", + mechanical_loss_2d, + fe_setting, + history_plot_settings={"plot":True,"save_directory":case_dir} + ) + nonlinear_fe_solver.Initialize() + + + # Solve for the chosen K-field and zero initial guess + load_steps_solutions, load_steps_states, solution_history_dict = nonlinear_fe_solver.Solve( + K_matrix[eval_id], np.zeros(2 * fe_mesh.GetNumberOfNodes()),return_all_steps=True) + + # @Rishabh Please adjust the rest accordingly from here + + exit() + + # 3) Scatter-add to global nodes and average at shared nodes + nelem = fe_mesh.GetNumberOfElements(mechanical_loss_2d.element_type) + nnod = fe_mesh.GetNumberOfNodes() + conn = fe_mesh.GetElementsNodes(mechanical_loss_2d.element_type) + gp_points, gp_weights = mechanical_loss_2d.fe_element.GaussIntegration2() + H = jnp.stack([mechanical_loss_2d.fe_element.ShapeFunctionsValues(p) for p in gp_points], axis=0) # (4,4) + + def extrapolate_gp_to_nodes_vectorized(gp_data, H_matrix, conn_jax, nnod): + """ + Vectorized extrapolation from Gauss points to nodes + + Args: + gp_data: (nelem, ngp) array - values at Gauss points + H_matrix: (ngp, nnodes_per_elem) array - shape functions at Gauss points + conn_jax: (nelem, nnodes_per_elem) array - connectivity + nnod: int - total number of nodes + + Returns: + (nnod,) array - nodal values averaged at shared nodes + """ + + # Solve H.T @ nodal = gp for each element (avoids computing inverse) + # Vectorized over all elements using vmap + def solve_element(gp_elem): + return jnp.linalg.solve(H_matrix.T, gp_elem) + + elem_nodal = jax.vmap(solve_element)(gp_data) # (nelem, nnodes_per_elem) + + # Vectorized scatter-add using segment_sum (more efficient than .at[].add()) + node_indices = conn_jax.flatten() # (nelem * nnodes_per_elem,) + values = elem_nodal.flatten() # (nelem * nnodes_per_elem,) + + # Accumulate contributions to each node + nodal_sum = jax.ops.segment_sum(values, node_indices, num_segments=nnod) + nodal_count = jnp.bincount(node_indices, length=nnod) + + return nodal_sum / jnp.maximum(nodal_count, 1.0) + + def extrapolate_gp_to_nodes_vectorized_tensor(gp_data, H_matrix, conn_jax, nnod): + """ + Vectorized extrapolation for tensor quantities (e.g., strain, stress). + + Args: + gp_data: (nelem, ngp, ...) array - tensor values at Gauss points + H_matrix: (ngp, nnodes_per_elem) array + conn_jax: (nelem, nnodes_per_elem) array + nnod: int + + Returns: + (nnod, ...) array - nodal tensor values + """ + nelem, ngp = gp_data.shape[:2] + extra_dims = gp_data.shape[2:] # e.g., (2, 2) for strain + nnodes_per_elem = H_matrix.shape[1] + + def solve_element(gp_elem): + # gp_elem: (ngp, *extra_dims) -> (nnodes_per_elem, *extra_dims) + flat_shape = (ngp, -1) + gp_flat = gp_elem.reshape(flat_shape) # (ngp, prod(extra_dims)) + nodal_flat = jnp.linalg.solve(H_matrix.T, gp_flat) # (nnodes_per_elem, prod(extra_dims)) + return nodal_flat.reshape(nnodes_per_elem, *extra_dims) + + elem_nodal = jax.vmap(solve_element)(gp_data) # (nelem, nnodes_per_elem, *extra_dims) + + # Flatten for scatter-add + elem_nodal_flat = elem_nodal.reshape(-1, *extra_dims) # (nelem*nnodes_per_elem, *extra_dims) + node_indices = conn_jax.flatten() + + # Accumulate using segment_sum with extra dimensions + nodal_sum = jax.ops.segment_sum(elem_nodal_flat, node_indices, num_segments=nnod) + nodal_count = jnp.bincount(node_indices, length=nnod) + + # Broadcast count for division + count_shape = (nnod,) + (1,) * len(extra_dims) + nodal_count = nodal_count.reshape(count_shape) + + return nodal_sum / jnp.maximum(nodal_count, 1.0) + + def plot_mesh_res_1(vectors_list:list, file_name:str="plot",loss_settings:dict={}): + fontsize = 16 + fig, axs = plt.subplots(3, 3, figsize=(20, 12)) # Adjusted to 4 columns + + # Plot the first entity in the first row + data = vectors_list[0] + N = int((data.reshape(-1, 1).shape[0]) ** 0.5) + im = axs[0, 0].imshow(data.reshape(N, N), cmap='viridis', aspect='equal') + axs[0, 0].set_xticks([]) + axs[0, 0].set_yticks([]) + axs[0, 0].set_title('Elasticity Morph.', fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[0, 0], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot the same entity with mesh grid in the first row, second column + im = axs[0, 1].imshow(data.reshape(N, N), cmap='bone', aspect='equal',extent=[0, N, 0, N]) + axs[0, 1].set_xticks([]) + axs[0, 1].set_yticks([]) + axs[0, 1].set_xticklabels([]) # Remove text on x-axis + axs[0, 1].set_yticklabels([]) # Remove text on y-axis + axs[0, 1].set_title(f'Mesh Grid: {N} x {N}', fontsize=fontsize) + axs[0, 1].grid(True, color='red', linestyle='-', linewidth=1) # Adding solid grid lines with red color + axs[0, 1].xaxis.grid(True) + axs[0, 1].yaxis.grid(True) + + x_ticks = np.linspace(0, N, N) + y_ticks = np.linspace(0, N, N) + axs[0, 1].set_xticks(x_ticks) + axs[0, 1].set_yticks(y_ticks) + + cbar = fig.colorbar(im, ax=axs[0, 1], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot the fourth entity in the second row + data = vectors_list[1][::2] + im = axs[0, 2].imshow(data.reshape(N, N), cmap='coolwarm', aspect='equal',origin='lower') + axs[0, 2].set_xticks([]) + axs[0, 2].set_yticks([]) + axs[0, 2].set_title(f'U FEM', fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[0, 2], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot the fourth entity in the second row + data = vectors_list[1][1::2] + im = axs[1,0].imshow( + data.reshape(N, N), # or see note below re: transpose + cmap='coolwarm', + aspect='equal', + origin='lower', # <-- key line + ) + axs[1,0].set_xticks([]) + axs[1,0].set_yticks([]) + axs[1,0].set_title(f'V FEM', fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[1, 0], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + data = vectors_list[0] + L = 1 + N = int((data.reshape(-1, 1).shape[0])**0.5) + nu = loss_settings["poisson_ratio"] + e = loss_settings["young_modulus"] + + # PLANE STRAIN constitutive constants (not plane stress!) + factor = e / ((1 + nu) * (1 - 2*nu)) + c11 = factor * (1 - nu) + c12 = factor * nu + + dx = L / (N - 1) + # Get the data + data = vectors_list[1] + U_fem = data[::2] + V_fem = data[1::2] + coords = vectors_list[3] + plastic_strain_nodal = vectors_list[2] + + # 1) Build a consistent node ordering by (y, x) + x = coords[:, 0] + y = coords[:, 1] + idx = np.lexsort((x, y)) + + # 2) Apply the SAME idx to all nodal fields + U_sorted = U_fem[idx] + V_sorted = V_fem[idx] + exx_plastic_sorted = plastic_strain_nodal[idx, 0, 0] + eyy_plastic_sorted = plastic_strain_nodal[idx, 1, 1] + domain_map_sorted = vectors_list[0][idx] + + # 3) Build grid sizes from unique coordinates + x_unique = np.unique(x) + y_unique = np.unique(y) + Nx_grid = len(x_unique) + Ny_grid = len(y_unique) + + # 4) Reshape EVERYTHING consistently as (Ny, Nx) + U_grid = U_sorted.reshape(Ny_grid, Nx_grid) + V_grid = V_sorted.reshape(Ny_grid, Nx_grid) + exx_plastic = exx_plastic_sorted.reshape(Ny_grid, Nx_grid) + eyy_plastic = eyy_plastic_sorted.reshape(Ny_grid, Nx_grid) + domain_map_matrix = domain_map_sorted.reshape(Ny_grid, Nx_grid) + + # 5) Compute grid spacings from actual coordinates + dx = np.diff(x_unique).mean() if Nx_grid > 1 else 1.0 + dy = np.diff(y_unique).mean() if Ny_grid > 1 else 1.0 + + + # 6) Total SMALL strains from displacement gradients + dU_dx = np.gradient(U_grid, dx, axis=1) + dV_dy = np.gradient(V_grid, dy, axis=0) + + exx_total = dU_dx + eyy_total = dV_dy + + + # 7) Elastic = Total − Plastic + exx_elastic = exx_total - exx_plastic + eyy_elastic = eyy_total - eyy_plastic + + + # 8) Plane strain stresses + factor = e / ((1 + nu) * (1 - 2*nu)) + c11 = factor * (1 - nu) + c12 = factor * nu + + stress_xx_fem = domain_map_matrix * (c11 * exx_elastic + c12 * eyy_elastic) + stress_yy_fem = domain_map_matrix * (c12 * exx_elastic + c11 * eyy_elastic) + + # 9) Plot + im = axs[1, 1].imshow(stress_xx_fem, cmap='plasma', origin='lower') + axs[1, 1].set_xticks([]) + axs[1, 1].set_yticks([]) + axs[1, 1].set_title('$\sigma_{xx}$, FEM', fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[1, 1], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + im = axs[1, 2].imshow(stress_yy_fem, cmap='plasma', origin='lower') + axs[1, 2].set_xticks([]) + axs[1, 2].set_yticks([]) + axs[1, 2].set_title('$\sigma_{yy}$, FEM', fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[1, 2], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # ---- compute equivalent strain per node ---- + strain_nodes=vectors_list[2] + #eq_nodes=von_mises_equivalent_strain(strain_nodes) + eq_nodes=strain_nodes[:,0,0] + # --- sort nodes onto a structured (Ny, Nx) grid --- + # sort by y first, then x, so within each row (fixed y) x increases + coords=vectors_list[3] + U_full=vectors_list[4] + idx = np.lexsort((coords[:, 0], coords[:, 1])) + coords_s = coords[idx] + U_s = U_full[idx] + vals_s = np.asarray(eq_nodes, dtype=float)[idx] + + # infer grid shape from unique coordinates + x_unique = np.unique(coords_s[:, 0]) + y_unique = np.unique(coords_s[:, 1]) + Nx, Ny = len(x_unique), len(y_unique) + + + # reshape to (Ny, Nx): y is slow index (rows), x is fast index (cols) + Z = vals_s.reshape(Ny, Nx) + Ux = U_s[:, 0].reshape(Ny, Nx) + Uy = U_s[:, 1].reshape(Ny, Nx) + + # build *undeformed* grid from unique x/y + X, Y = np.meshgrid(x_unique, y_unique, indexing='xy') + + # deformed grid + def_scale = 2.0 + X_def = X + def_scale * Ux + Y_def = Y + def_scale * Uy + + contour = axs[2,0].contourf(X_def, Y_def, Z, levels=20, cmap='viridis') + cbar = fig.colorbar(contour, ax=axs[2,0],pad=0.02, shrink=0.7) + cbar.set_label(r'$\varepsilon_{\mathrm{xx}}$') + + axs[2,0].set_aspect('equal', adjustable='box') + axs[2,0].set_xlabel('x'); axs[1,2].set_ylabel('y') + axs[2,0].set_title('Plastic strain in XX direction') + + eq_nodes=strain_nodes[:,1,1] + # --- sort nodes onto a structured (Ny, Nx) grid --- + # sort by y first, then x, so within each row (fixed y) x increases + coords=vectors_list[3] + U_full=vectors_list[4] + idx = np.lexsort((coords[:, 0], coords[:, 1])) + coords_s = coords[idx] + U_s = U_full[idx] + vals_s = np.asarray(eq_nodes, dtype=float)[idx] + + # infer grid shape from unique coordinates + x_unique = np.unique(coords_s[:, 0]) + y_unique = np.unique(coords_s[:, 1]) + Nx, Ny = len(x_unique), len(y_unique) + + + # reshape to (Ny, Nx): y is slow index (rows), x is fast index (cols) + Z = vals_s.reshape(Ny, Nx) + Ux = U_s[:, 0].reshape(Ny, Nx) + Uy = U_s[:, 1].reshape(Ny, Nx) + + # build *undeformed* grid from unique x/y + X, Y = np.meshgrid(x_unique, y_unique, indexing='xy') + + # deformed grid + def_scale = 2.0 + X_def = X + def_scale * Ux + Y_def = Y + def_scale * Uy + + + contour = axs[2,1].contourf(X_def, Y_def, Z, levels=20, cmap='viridis') + cbar = fig.colorbar(contour, ax=axs[2,1],pad=0.02, shrink=0.7) + cbar.set_label(r'$\varepsilon_{\mathrm{yy}}$') + + axs[2,1].set_aspect('equal', adjustable='box') + axs[2,1].set_xlabel('x'); axs[2,1].set_ylabel('y') + axs[2,1].set_title('Plastic strain in YY direction') + + + + # Simple scatter plot with interpolation + scatter = axs[2,2].tricontourf(coords[:, 0], coords[:, 1], np.array(vectors_list[5]), + levels=20, cmap='plasma') + cbar = plt.colorbar(scatter, ax=axs[2,2]) + cbar.set_label('Accumulated Plastic Strain (ξ)', fontsize=12) + + # Add contour lines + axs[2,2].tricontour(coords[:, 0], coords[:, 1], np.array(vectors_list[5]), + levels=20, colors='k', alpha=0.3, linewidths=0.5) + + # Plot mesh edges + for e in range(nelem): + nodes = conn[e] + for i in range(4): + n1 = nodes[i] + n2 = nodes[(i + 1) % 4] + axs[2,2].plot([coords[n1, 0], coords[n2, 0]], + [coords[n1, 1], coords[n2, 1]], + 'k-', linewidth=0.3, alpha=0.5) + + axs[2,2].set_xlabel('X', fontsize=12) + axs[2,2].set_ylabel('Y', fontsize=12) + axs[2,2].set_title('Accumulated Plastic Strain (ξ) - Final Increment', + fontsize=14, fontweight='bold') + axs[2,2].set_aspect('equal') + axs[2,2].grid(True, alpha=0.3) + + plt.tight_layout() + + plt.savefig(file_name, dpi=300) + + def plot_all_convergence_metrics(convergence_history, file_name, abs_tol=None, rel_tol=None): + """ + Convergence plot with 2 subplots: Iterations per Load Step and Final Residual vs Load Factor + """ + fig = plt.figure(figsize=(14, 6)) + gs = fig.add_gridspec(1, 2, hspace=0.3, wspace=0.3) + + load_steps = [h['load_step'] for h in convergence_history] + iterations = [h['iterations'] for h in convergence_history] + final_residuals = [h['final_residual'] for h in convergence_history] + load_factors = [h['load_factor'] for h in convergence_history] + + # 1. Iterations per load step (bar chart) + ax1 = fig.add_subplot(gs[0, 0]) + bars = ax1.bar(load_steps, iterations, color='steelblue', alpha=0.7, edgecolor='black') + ax1.set_xlabel('Load Step', fontsize=11, fontweight='bold') + ax1.set_ylabel('Iterations', fontsize=11, fontweight='bold') + ax1.set_title('Iterations per Load Step', fontsize=12, fontweight='bold') + ax1.grid(True, alpha=0.3, linestyle='--') + for bar, itr in zip(bars, iterations): + height = bar.get_height() + ax1.text(bar.get_x() + bar.get_width()/2., height, + f'{int(itr)}', ha='center', va='bottom', fontsize=9) + + # 2. Final residual vs load factor + ax2 = fig.add_subplot(gs[0, 1]) + ax2.plot(load_factors, final_residuals, 'o-', color='darkgreen', + linewidth=2, markersize=8, markerfacecolor='lightgreen') + if abs_tol: + ax2.axhline(y=abs_tol, color='r', linestyle='--', linewidth=2, label='Abs Tol') + ax2.set_xlabel('Load Factor (λ)', fontsize=11, fontweight='bold') + ax2.set_ylabel('Final Residual Norm', fontsize=11, fontweight='bold') + ax2.set_title('Final Residual vs Load Factor', fontsize=12, fontweight='bold') + ax2.grid(True, alpha=0.3) + ax2.set_yscale('log') + if abs_tol: + ax2.legend() + + # Add summary statistics + total_itr = sum(iterations) + avg_itr = np.mean(iterations) + stats_text = f'Total: {total_itr} | Avg: {avg_itr:.1f} | Max: {max(iterations)}' + fig.text(0.5, 0.02, stats_text, ha='center', fontsize=11, + bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.7)) + + plt.savefig(file_name, dpi=300, bbox_inches='tight') + plt.close() + + + def plot_residual_convergence(convergence_history, file_name, rel_tol=None): + """ + Separate plot for residual norm vs iteration (cumulative x-axis across all load steps) + """ + fig, ax = plt.subplots(figsize=(12, 7)) + + # Color map for different load steps + colors = plt.cm.viridis(np.linspace(0, 1, len(convergence_history))) + + # Plot each step one after another on the x-axis + x_start = 1 + for i, history in enumerate(convergence_history): + residuals = np.asarray(history['residuals']) + n = len(residuals) + x = np.arange(x_start, x_start + n) # cumulative x + ax.semilogy( + x, residuals, 'o-', color=colors[i], + linewidth=2, markersize=5, alpha=0.8, + label=f"Step {history['load_step']}" + ) + x_start += n # next step continues where this one ended + + # Tolerance line + if rel_tol is not None: + ax.axhline(y=rel_tol, color='r', linestyle='--', linewidth=2, label='Rel Tol') + + # Labels & title + ax.set_xlabel('Cumulative Iteration', fontsize=12, fontweight='bold') + ax.set_ylabel('Residual Norm', fontsize=12, fontweight='bold') + ax.set_title('Residual Convergence Curves (Log Scale)', fontsize=14, fontweight='bold') + + + # Grid + ax.grid(True, which='both', linewidth=0.5, alpha=0.3) + + plt.tight_layout() + plt.savefig(file_name, dpi=300, bbox_inches='tight') + plt.close() + + + + + def plot_reac_disp(vectors_list: list, state_snapshots: list, conn, H_matrix, + displacement_history, file_name: str = "plot", + loss_settings: dict = {}): + """ + OPTIMIZED: Uses vectorized JAX operations instead of Python loops + Computes reaction force on LEFT boundary (fixed side) + and plots vs equivalent displacement sqrt(Ux^2 + Uy^2) on RIGHT side + """ + + # Extract material properties + nu = loss_settings["poisson_ratio"] + e = loss_settings["young_modulus"] + + # Material stiffness (plane strain) + factor = e / ((1 + nu) * (1 - 2*nu)) + c11 = factor * (1 - nu) + c12 = factor * nu + c33 = factor * (1 - 2*nu) / 2 + + # Extract from vectors_list + coords = vectors_list[3] + Ux_right = vectors_list[6] + Uy_right = vectors_list[7] + + # Get mesh info + nelem = conn.shape[0] + nnod = coords.shape[0] + + # Convert conn to JAX array for vectorized operations + conn_jax = jnp.array(conn, dtype=jnp.int32) + + # Boundary nodes + y_coords = coords[:, 1] + x_coords = coords[:, 0] + x_min, x_max = x_coords.min(), x_coords.max() + + tol = 1e-6 + left_nodes = np.where(np.abs(x_coords - x_min) < tol)[0] + right_nodes = np.where(np.abs(x_coords - x_max) < tol)[0] + + + # ================================================================== + # PRE-COMPUTE: Strain computation function (vectorized) + # ================================================================== + def compute_total_strain_at_nodes_vectorized(U_k, coords, conn_jax): + """Vectorized strain computation - much faster than element loops""" + nelem = conn_jax.shape[0] + nnod = coords.shape[0] + + # Reshape displacement vector + U_nodes = U_k.reshape((nnod, 2)) + + # Get element nodal coordinates and displacements + elem_coords = coords[conn_jax] # (nelem, 4, 2) + elem_U = U_nodes[conn_jax] # (nelem, 4, 2) + + # Compute strain at element centroids using simple averaging + # This is approximate but much faster than full FE strain computation + # For more accuracy, you'd need to implement full shape function derivatives + + # Average nodal strains (simple finite difference approach) + nodal_exx = jnp.zeros(nnod) + nodal_eyy = jnp.zeros(nnod) + nodal_exy = jnp.zeros(nnod) + nodal_count = jnp.zeros(nnod) + + # For each element, compute approximate strain and distribute to nodes + for e in range(nelem): + nodes = conn_jax[e] + x = elem_coords[e, :, 0] + y = elem_coords[e, :, 1] + ux = elem_U[e, :, 0] + uy = elem_U[e, :, 1] + + # Simple finite difference (this is a rough approximation) + dx = x.max() - x.min() + dy = y.max() - y.min() + + if dx > 1e-10: + exx = (ux.max() - ux.min()) / dx + else: + exx = 0.0 + + if dy > 1e-10: + eyy = (uy.max() - uy.min()) / dy + else: + eyy = 0.0 + + exy = 0.0 + if dx > 1e-10 and dy > 1e-10: + exy = 0.5 * ((ux.max() - ux.min()) / dy + (uy.max() - uy.min()) / dx) + + # Distribute to element nodes + for a in range(4): + n = nodes[a] + nodal_exx = nodal_exx.at[n].add(exx) + nodal_eyy = nodal_eyy.at[n].add(eyy) + nodal_exy = nodal_exy.at[n].add(exy) + nodal_count = nodal_count.at[n].add(1.0) + + nodal_exx = nodal_exx / jnp.maximum(nodal_count, 1.0) + nodal_eyy = nodal_eyy / jnp.maximum(nodal_count, 1.0) + nodal_exy = nodal_exy / jnp.maximum(nodal_count, 1.0) + + return nodal_exx, nodal_eyy, nodal_exy + + # ================================================================== + # OPTIMIZED: Vectorized scatter-add for plastic strain + # ================================================================== + equiv_disp_history = [] + reaction_x_history = [] + reaction_y_history = [] + reaction_total_history = [] + + # Process each load increment + for k, U_k in enumerate(displacement_history): + + # ================================================================== + # STEP 1: Extrapolate PLASTIC strain from GP to nodes (VECTORIZED) + # ================================================================== + state_k = state_snapshots[k] # (nelem, ngp, nstate) + exx_plastic_gp = state_k[..., 0] + eyy_plastic_gp = state_k[..., 1] + exy_plastic_gp = state_k[..., 2] + + # Vectorized extrapolation + nodal_exx_plastic = extrapolate_gp_to_nodes_vectorized( + exx_plastic_gp, H_matrix, conn_jax, nnod) + nodal_eyy_plastic = extrapolate_gp_to_nodes_vectorized( + eyy_plastic_gp, H_matrix, conn_jax, nnod) + nodal_exy_plastic = extrapolate_gp_to_nodes_vectorized( + exy_plastic_gp, H_matrix, conn_jax, nnod) + + # ================================================================== + # STEP 2: Compute TOTAL strain at nodes from displacements + # ================================================================== + nodal_exx_total, nodal_eyy_total, nodal_exy_total = \ + compute_total_strain_at_nodes_vectorized(U_k, coords, conn_jax) + + # ================================================================== + # STEP 3: Compute ELASTIC strain at nodes + # ================================================================== + nodal_exx_elastic = nodal_exx_total - nodal_exx_plastic + nodal_eyy_elastic = nodal_eyy_total - nodal_eyy_plastic + nodal_exy_elastic = nodal_exy_total - nodal_exy_plastic + + # ================================================================== + # STEP 4: Compute STRESS at nodes from elastic strain + # ================================================================== + nodal_sxx = c11 * nodal_exx_elastic + c12 * nodal_eyy_elastic + nodal_syy = c12 * nodal_exx_elastic + c11 * nodal_eyy_elastic + nodal_sxy = 2 * c33 * nodal_exy_elastic + + # Convert to numpy for integration + nodal_sxx = np.array(nodal_sxx) + nodal_syy = np.array(nodal_syy) + nodal_sxy = np.array(nodal_sxy) + + # ================================================================== + # STEP 5: Integrate reaction force on LEFT boundary + # ================================================================== + reaction_x = 0.0 + reaction_y = 0.0 + + if len(left_nodes) > 0: + y_left = coords[left_nodes, 1] + sxx_left = nodal_sxx[left_nodes] + sxy_left = nodal_sxy[left_nodes] + + # Sort by y-coordinate + sort_idx = np.argsort(y_left) + y_left = y_left[sort_idx] + sxx_left = sxx_left[sort_idx] + sxy_left = sxy_left[sort_idx] + + # Integrate + reaction_x = -np.trapezoid(sxx_left, y_left) + reaction_y = -np.trapezoid(sxy_left, y_left) + + reaction_total = np.sqrt(reaction_x**2 + reaction_y**2) + + # ================================================================== + # STEP 6: Compute equivalent displacement on RIGHT boundary + # ================================================================== + if len(right_nodes) > 0: + Ux_right_nodes = U_k[2*right_nodes] + Uy_right_nodes = U_k[2*right_nodes + 1] + Ux_mean = float(Ux_right_nodes.mean()) + Uy_mean = float(Uy_right_nodes.mean()) + U_equiv = np.sqrt(Ux_mean**2 + Uy_mean**2) + else: + scale_factor = (k + 1) / len(displacement_history) + Ux_mean = scale_factor * Ux_right + Uy_mean = scale_factor * Uy_right + U_equiv = np.sqrt(Ux_mean**2 + Uy_mean**2) + + equiv_disp_history.append(float(U_equiv)) + reaction_x_history.append(float(reaction_x)) + reaction_y_history.append(float(reaction_y)) + reaction_total_history.append(float(reaction_total)) + + + # ================================================================== + # STEP 7: Create single plot - Total Reaction vs Equivalent Displacement + # ================================================================== + plt.figure(figsize=(10, 7)) + + # Plot with dots at data points and connecting lines (no hollow circles) + plt.plot(equiv_disp_history, reaction_total_history, '-', linewidth=2, + color='#2E86AB', label='Total Reaction Force') + plt.plot(equiv_disp_history, reaction_total_history, '.', markersize=8, + color='#2E86AB') + + plt.xlabel('Equivalent Displacement √(Ux² + Uy²)', fontsize=13, fontweight='bold') + plt.ylabel('Total Reaction Force (Left Boundary)', fontsize=13, fontweight='bold') + plt.title('Total Reaction Force vs Equivalent Displacement', fontsize=15, fontweight='bold') + plt.grid(True, alpha=0.3, linestyle='--') + plt.axhline(y=0, color='k', linestyle='-', alpha=0.2, linewidth=0.8) + plt.axvline(x=0, color='k', linestyle='-', alpha=0.2, linewidth=0.8) + plt.legend(fontsize=11, loc='best') + plt.tight_layout() + plt.savefig(file_name, dpi=300, bbox_inches='tight') + plt.close() + + + snap = np.stack(state_snapshots, axis=0) + + # Use only the first three comps (0: exx, 1: eyy, 2: exy). Ignore comp 3+. + exx = snap[..., 0] # (n_inc, n_elem, n_gp) + eyy = snap[..., 1] + exy = snap[..., 2] + + # Allocate output + e_hist = np.zeros(snap.shape[:3] + (2, 2), dtype=snap.dtype) + + # Fill symmetric 2×2 tensor + e_hist[..., 0, 0] = exx + e_hist[..., 1, 1] = eyy + e_hist[..., 0, 1] = exy + e_hist[..., 1, 0] = exy + + + # 1) Build E from your own FE element APIs (robust to reordering) + + last_inc = e_hist[-1] + strain_nodes = extrapolate_gp_to_nodes_vectorized_tensor(last_inc, H, conn, nnod) + coords = np.asarray(fe_mesh.GetNodesCoordinates())[:, :2] # (nnodes, 2) + U_full = FE_UV.reshape((fe_mesh.GetNumberOfNodes(), 2)) # (nnodes, 2) + fe_mesh['U_FE'] = FE_UV.reshape((fe_mesh.GetNumberOfNodes(), 2)) + + + xi_history = [snapshot[..., 3] for snapshot in state_snapshots] + xi_history = np.array(xi_history) + xi_final = xi_history[-1] # (n_elem, n_gp) + xi_nodes = extrapolate_gp_to_nodes_vectorized(xi_final, H, conn, nnod) + vectors_list_U = [K_matrix[eval_id],FE_UV,strain_nodes,coords,U_full,xi_nodes] + + + plot_mesh_res_1( + vectors_list_U, + file_name=os.path.join(case_dir, 'plot_inputs_outputs.png'), + loss_settings=material_dict + ) + + n_incr = fe_setting["nonlinear_solver_settings"]["load_incr"] + Ux_right = float(model_settings["Ux_right"]) + Uy_right = float(model_settings["Uy_right"]) + + # MODIFIED: Pass both Ux_right and Uy_right + vectors_list_U_1 = [ + K_matrix[eval_id], FE_UV, strain_nodes, coords, U_full, + n_incr, Ux_right, Uy_right + ] + + plot_reac_disp( + vectors_list_U_1, state_snapshots, conn, H, displacement_history, + file_name=os.path.join(case_dir, 'plot_reaction_force.png'), + loss_settings=material_dict + ) + # NEW: Comprehensive convergence plots + abs_tol = nonlinear_fe_solver.nonlinear_solver_settings.get("abs_tol", None) + rel_tol = nonlinear_fe_solver.nonlinear_solver_settings.get("rel_tol", None) + + # First plot: 2 metrics + plot_all_convergence_metrics( + convergence_history, + file_name=os.path.join(case_dir, 'convergence_metrics.png'), + abs_tol=abs_tol, + rel_tol=rel_tol + ) + + # Second plot: Residual curves + plot_residual_convergence( + + convergence_history, + file_name=os.path.join(case_dir, 'residual_convergence.png'), + rel_tol=rel_tol + ) + + # finalize and export mesh data + fe_mesh.Finalize(export_dir=case_dir) + + if clean_dir: + shutil.rmtree(case_dir) + +if __name__ == "__main__": + # Defaults + solve_FE = True + clean_dir = False + + main(solve_FE, clean_dir) diff --git a/examples/elastoplasticity/voroni_control_dict.pkl b/examples/elastoplasticity/voroni_control_dict.pkl new file mode 100644 index 0000000..f623e9d Binary files /dev/null and b/examples/elastoplasticity/voroni_control_dict.pkl differ diff --git a/fol/loss_functions/mechanical_elastoplasticity.py b/fol/loss_functions/mechanical_elastoplasticity.py new file mode 100644 index 0000000..ac58a2b --- /dev/null +++ b/fol/loss_functions/mechanical_elastoplasticity.py @@ -0,0 +1,208 @@ +""" + Authors: Rishabh Arora, https://github.com/rishabharora236-cell + Date: Oct, 2025 + License: FOL/LICENSE +""" +from .mechanical import MechanicalLoss +import jax +import jax.numpy as jnp +from jax.experimental import sparse +from jax import jit +from functools import partial +from fol.tools.fem_utilities import * +from fol.tools.decoration_functions import * +from fol.mesh_input_output.mesh import Mesh + +class ElastoplasticityLoss(MechanicalLoss): + + def Initialize(self) -> None: + + super().Initialize() + + if self.dim == 2: + self.material_model = ElastoplasticityModel2D() + else: + fol_error("3D Elastoplasticity analysis is not supported yet !") + + @partial(jit, static_argnums=(0,)) + def ComputeElement(self,xyze,de,uvwe,element_state_gps): + @jit + def compute_at_gauss_point(gp_point,gp_weight,gp_state_vector): + N_vec = self.fe_element.ShapeFunctionsValues(gp_point) + N_mat = self.CalculateNMatrix(N_vec) + DN_DX = self.fe_element.ShapeFunctionsGlobalGradients(xyze,gp_point) + B_mat = self.CalculateBMatrix(DN_DX) + J = self.fe_element.Jacobian(xyze,gp_point) + detJ = jnp.linalg.det(J) + strain_gp = B_mat @ uvwe + strain_matrix = jnp.array([ + [strain_gp[0], strain_gp[2]], # [ε_xx, ε_xy] + [strain_gp[2], strain_gp[1]] # [ε_xy, ε_yy] + ]) + strain_matrix_2x2= strain_matrix.squeeze() + tgMM,stress_gp_v,gp_state_up = self.material_model.evaluate(self.loss_settings["material_dict"]["young_modulus"],self.loss_settings["material_dict"]["poisson_ratio"],self.loss_settings["material_dict"]["iso_hardening_parameter_1"],self.loss_settings["material_dict"]["iso_hardening_param_2"],self.loss_settings["material_dict"]["yield_limit"], strain_matrix_2x2, gp_state_vector) + stress_gp_v = stress_gp_v.reshape(3,1) + gp_stiffness = gp_weight * detJ * (B_mat.T @ (tgMM @ B_mat)) + gp_f_int = (gp_weight * detJ * (B_mat.T @ stress_gp_v)) + gp_f_body = (gp_weight * detJ * (N_mat.T @ self.body_force)) + + return gp_stiffness,gp_f_body,gp_f_int,gp_state_up + + gp_points,gp_weights = self.fe_element.GetIntegrationData() + + k_gps,f_gps,f_gps_int,gps_state = jax.vmap(compute_at_gauss_point,in_axes=(0,0,0))(gp_points,gp_weights,element_state_gps) + Se = jnp.sum(k_gps, axis=0, keepdims=False) + Fe = jnp.sum(f_gps, axis=0, keepdims=False) + Fe_int= jnp.sum(f_gps_int, axis=0) + residual = (Fe_int-Fe) + element_residuals = jax.lax.stop_gradient(residual) + return ((uvwe.T @ element_residuals)[0,0]),gps_state, residual, Se + + def ComputeElementResidualAndJacobian( + self, + elem_xyz: jnp.array, + elem_controls: jnp.array, + elem_dofs: jnp.array, + elem_BC: jnp.array, + elem_mask_BC: jnp.array, + transpose_jac: bool, + elem_state_gps: jnp.array + ): + """ + Compute element residual and jacobian, with optional state update. + """ + + _, elem_state_up_gps, re, ke = self.ComputeElement( + elem_xyz, + elem_controls, + elem_dofs, + elem_state_gps + ) + + index = jnp.asarray(transpose_jac, dtype=jnp.int32) + + # Define the two branches for switch + branches = [ + lambda _: ke, # Case 0: No transpose + lambda _: jnp.transpose(ke) # Case 1: Transpose ke + ] + + # Apply the switch operation + ke = jax.lax.switch(index, branches, None) + + # Apply Dirichlet boundary conditions + r_e, k_e = self.ApplyDirichletBCOnElementResidualAndJacobian(re, ke, elem_BC, elem_mask_BC) + + return r_e, k_e, elem_state_up_gps + + def ComputeElementResidualAndJacobianVmapCompatible(self,element_id:jnp.integer, + elements_nodes:jnp.array, + xyz:jnp.array, + full_control_vector:jnp.array, + full_dof_vector:jnp.array, + full_dirichlet_BC_vec:jnp.array, + full_mask_dirichlet_BC_vec:jnp.array, + transpose_jac:bool, + full_state_gps: jnp.array): + return self.ComputeElementResidualAndJacobian(xyz[elements_nodes[element_id],:], + full_control_vector[elements_nodes[element_id]], + full_dof_vector[((self.number_dofs_per_node*elements_nodes[element_id])[:, jnp.newaxis] + + jnp.arange(self.number_dofs_per_node))].reshape(-1,1), + full_dirichlet_BC_vec[((self.number_dofs_per_node*elements_nodes[element_id])[:, jnp.newaxis] + + jnp.arange(self.number_dofs_per_node))].reshape(-1,1), + full_mask_dirichlet_BC_vec[((self.number_dofs_per_node*elements_nodes[element_id])[:, jnp.newaxis] + + jnp.arange(self.number_dofs_per_node))].reshape(-1,1), + transpose_jac, + full_state_gps[element_id, :, :]) + + @print_with_timestamp_and_execution_time + @partial(jit, static_argnums=(0,)) + def ComputeJacobianMatrixAndResidualVector( + self, + total_control_vars: jnp.array, + total_primal_vars: jnp.array, + old_state_gps: jnp.array, + transpose_jacobian: bool = False + ): + BC_vector = jnp.ones((self.total_number_of_dofs)) + BC_vector = BC_vector.at[self.dirichlet_indices].set(0) + mask_BC_vector = jnp.zeros((self.total_number_of_dofs)) + mask_BC_vector = mask_BC_vector.at[self.dirichlet_indices].set(1) + + num_nodes_per_elem = len(self.fe_mesh.GetElementsNodes(self.element_type)[0]) + element_matrix_size = self.number_dofs_per_node * num_nodes_per_elem + elements_jacobian_flat = jnp.zeros( + self.fe_mesh.GetNumberOfElements(self.element_type) * element_matrix_size * element_matrix_size + ) + + template_element_indices = jnp.arange(0, self.adjusted_batch_size) + template_elem_res_indices = jnp.arange(0, element_matrix_size, self.number_dofs_per_node) + template_elem_jac_indices = jnp.arange(0, self.adjusted_batch_size * element_matrix_size * element_matrix_size) + + residuals_vector = jnp.zeros((self.total_number_of_dofs)) + + new_state_gps = jnp.zeros_like(old_state_gps) + + @jit + def fill_arrays(batch_index, batch_arrays): + + glob_res_vec, elem_jac_vec, new_state_buf = batch_arrays + + batch_element_indices = (batch_index * self.adjusted_batch_size) + template_element_indices + batch_elem_jac_indices = (batch_index * self.adjusted_batch_size * element_matrix_size**2) + template_elem_jac_indices + + element_nodes = self.fe_mesh.GetElementsNodes(self.element_type) + node_coords = self.fe_mesh.GetNodesCoordinates() + + batch_elements_residuals, batch_elements_stiffness, batch_state_up_gps = jax.vmap( + self.ComputeElementResidualAndJacobianVmapCompatible, (0, None, None, None, None, None, None, None, None) + )( + batch_element_indices, + element_nodes, + node_coords, + total_control_vars, + total_primal_vars, + BC_vector, + mask_BC_vector, + transpose_jacobian, + old_state_gps, + ) + + elem_jac_vec = elem_jac_vec.at[batch_elem_jac_indices].set(batch_elements_stiffness.ravel()) + + @jit + def fill_res_vec(dof_idx, glob_res_vec): + glob_res_vec = glob_res_vec.at[ + self.number_dofs_per_node * element_nodes[batch_element_indices] + dof_idx + ].add(jnp.squeeze(batch_elements_residuals[:, template_elem_res_indices + dof_idx])) + return glob_res_vec + + glob_res_vec = jax.lax.fori_loop(0, self.number_dofs_per_node, fill_res_vec, glob_res_vec) + + new_state_buf = new_state_buf.at[batch_element_indices, :, :].set(batch_state_up_gps) + return glob_res_vec, elem_jac_vec, new_state_buf + + # Run loop + residuals_vector, elements_jacobian_flat, new_state_gps = jax.lax.fori_loop( + 0, self.num_element_batches, fill_arrays, (residuals_vector, elements_jacobian_flat, new_state_gps) + ) + + # Assemble sparse Jacobian + jacobian_indices = jax.vmap(self.ComputeElementJacobianIndices)( + self.fe_mesh.GetElementsNodes(self.element_type) + ).reshape(-1, 2) + + sparse_jacobian = sparse.BCOO( + (elements_jacobian_flat, jacobian_indices), + shape=(self.total_number_of_dofs, self.total_number_of_dofs), + ) + + return new_state_gps, sparse_jacobian, residuals_vector + +class ElastoplasticityLoss2DQuad(ElastoplasticityLoss): + def __init__(self, name: str, loss_settings: dict, fe_mesh: Mesh): + if not "num_gp" in loss_settings.keys(): + loss_settings["num_gp"] = 2 + super().__init__(name,{**loss_settings,"compute_dims":2, + "ordered_dofs": ["Ux","Uy"], + "element_type":"quad"},fe_mesh) diff --git a/fol/solvers/fe_nonlinear_residual_based_solver.py b/fol/solvers/fe_nonlinear_residual_based_solver.py index 763242d..c033b33 100644 --- a/fol/solvers/fe_nonlinear_residual_based_solver.py +++ b/fol/solvers/fe_nonlinear_residual_based_solver.py @@ -13,17 +13,36 @@ from fol.loss_functions.fe_loss import FiniteElementLoss class FiniteElementNonLinearResidualBasedSolver(FiniteElementLinearResidualBasedSolver): - """Nonlinear solver class. + """ + Nonlinear finite element solver implementing an incremental, residual-based + Newton–Raphson method. + + This class extends `FiniteElementLinearResidualBasedSolver` to support nonlinear + problems by repeatedly assembling the tangent stiffness matrix (Jacobian) and + residual vector at each Newton iteration. Load is applied incrementally, and + convergence is monitored using absolute residual tolerance, relative DOF update + tolerance, and a maximum number of Newton iterations. + + The solver is intended for problems where the global residual depends + nonlinearly on the displacement field, such as: + • geometric nonlinearities + • material nonlinearities (hyperelasticity, plasticity, damage, etc.) + • boundary-condition-dependent nonlinearities + Features + -------- + • Incremental load stepping based on `nonlinear_solver_settings["load_incr"]` + • Newton–Raphson iterations with user-configurable tolerances + • Detection of NaN residuals with informative diagnostic output + • Optional plotting of convergence history (`PlotHistoryDict`) """ @print_with_timestamp_and_execution_time - def __init__(self, fe_solver_name: str, fe_loss_function: FiniteElementLoss, fe_solver_settings:dict={}) -> None: + def __init__(self, fe_solver_name: str, fe_loss_function: FiniteElementLoss, fe_solver_settings:dict={}, history_plot_settings:dict={}) -> None: super().__init__(fe_solver_name,fe_loss_function,fe_solver_settings) - self.nonlinear_solver_settings = {"rel_tol":1e-8, - "abs_tol":1e-8, - "maxiter":20, - "load_incr":5} + self.nonlinear_solver_settings = {"rel_tol":1e-8,"abs_tol":1e-8,"maxiter":20,"load_incr":5} + self.default_history_plot_settings = {"plot":False,"criteria":["res_norm","delta_dofs_norm"],"save_directory":"."} + self.history_plot_settings = history_plot_settings @print_with_timestamp_and_execution_time def Initialize(self) -> None: @@ -31,42 +50,121 @@ def Initialize(self) -> None: if "nonlinear_solver_settings" in self.fe_solver_settings.keys(): self.nonlinear_solver_settings = UpdateDefaultDict(self.nonlinear_solver_settings, self.fe_solver_settings["nonlinear_solver_settings"]) + self.history_plot_settings = UpdateDefaultDict(self.default_history_plot_settings,self.history_plot_settings) + + def PlotHistoryDict(self,history_dict:dict): + + if self.history_plot_settings["plot"]: + plot_dict = {key: [] for key in self.history_plot_settings["criteria"]} + for plot_criterion,plot_values in plot_dict.items(): + for step,step_dict in history_dict.items(): + plot_values.extend(step_dict[plot_criterion]) + + plt.figure(figsize=(10, 5)) + for key,value in plot_dict.items(): + plt.semilogy(value, marker='o', markersize=6, linewidth=1.5, label=f"{key}") + + plt.title("Newton-Raphson History with Load Stepping") + plt.xlabel("Cumulative Iteration") + plt.ylabel("Log Value") + plt.legend() + plt.grid(True) + plt.savefig(os.path.join(self.history_plot_settings["save_directory"],"newton_raphson_history.png"), bbox_inches='tight') + plt.close() + @print_with_timestamp_and_execution_time - def Solve(self,current_control_vars,current_dofs_np:np.array): - current_dofs = jnp.array(current_dofs_np) - load_increament = self.nonlinear_solver_settings["load_incr"] - for load_fac in range(load_increament): - fol_info(f"loadStep; increment:{load_fac+1}") - applied_BC_dofs = self.fe_loss_function.ApplyDirichletBCOnDofVector(current_dofs,(load_fac+1)/load_increament) - for i in range(self.nonlinear_solver_settings["maxiter"]): + def Solve(self,current_control_vars:jnp.array,current_dofs_np:jnp.array): + """ + Solve the nonlinear finite element system using an incremental, + residual-based Newton–Raphson method. + + The load is applied in a prescribed number of increments and, for each + load step, a Newton–Raphson loop is performed. At each iteration the + global residual vector and tangent (Jacobian) matrix are assembled via + `fe_loss_function.ComputeJacobianMatrixAndResidualVector`, the linear + system is solved, and the DOFs are updated until convergence criteria + are satisfied or the maximum number of iterations is reached. + + Parameters + ---------- + current_control_vars : jax.numpy.ndarray + Array of control variables (e.g., load parameters, material or + design variables) passed to the FE loss function. Shape is + problem-dependent. + + current_dofs_np : jax.numpy.ndarray + Initial displacement (DOF) vector at the beginning of the analysis. + This is converted to a JAX array internally. + Shape: (n_dofs,). + + Returns + ------- + current_dofs : jax.numpy.ndarray + Converged displacement vector at the end of all load increments. + Shape: (n_dofs,). + + """ + current_dofs = jnp.asarray(current_dofs_np) + current_control_vars = jnp.asarray(current_control_vars) + num_load_steps = self.nonlinear_solver_settings["load_incr"] + convergence_history = {} + for load_step in range(1,num_load_steps+1): + load_step_value = (load_step)/num_load_steps + # increment load + current_dofs = self.fe_loss_function.ApplyDirichletBCOnDofVector(current_dofs,load_step_value) + newton_converged = False + convergence_history[load_step] = {"res_norm":[],"delta_dofs_norm":[]} + for i in range(1,self.nonlinear_solver_settings["maxiter"]+1): BC_applied_jac,BC_applied_r = self.fe_loss_function.ComputeJacobianMatrixAndResidualVector( - current_control_vars,applied_BC_dofs) + current_control_vars,current_dofs) + + # check residuals norm res_norm = jnp.linalg.norm(BC_applied_r,ord=2) if jnp.isnan(res_norm): - fol_info("Residual norm is NaN, check inputs!") - raise(ValueError("res_norm contains nan values!")) - if res_norm 4th-order tensor in 2D (xx,yy,xy_tensor ordering) + def voigt33_to_C4_2D(Cv): + # Cv order: [xx,yy,xy]x[xx,yy,xy] with tensor-shear in xy + # Build 4th-order C_ijkl (i,j,k,l in {x,y}) consistent with our Voigt convention + C = jnp.zeros((2,2,2,2), dtype=Cv.dtype) + # helper: map (xx->0, yy->1, xy->2) + def idx_pair(a): + return (0,0) if a==0 else (1,1) if a==1 else (0,1) + def sym_set(C, i,j,k,l, val): + # enforce minor symmetries: ij and kl + C = C.at[i,j,k,l].set(val) + C = C.at[j,i,k,l].set(val) + C = C.at[i,j,l,k].set(val) + C = C.at[j,i,l,k].set(val) + return C + for a in range(3): + for b in range(3): + i,j = idx_pair(a) + k,l = idx_pair(b) + C = sym_set(C, i,j,k,l, Cv[a,b]) + return C + + # ---------- Elasticity in 3D ---------- + @staticmethod + def C_elastic_3D(E, nu): + I3 = jnp.eye(3) + lam = E*nu / ((1+nu)*(1-2*nu)) + G = E/(2*(1+nu)) + def C_dot(e3): + tr = jnp.trace(e3) + return lam*tr*I3 + 2.0*G*e3 + return lam, G, C_dot + + # ---------- Newton solver made JAX-differentiable ---------- + @staticmethod + def newton_solve(R, x0, maxit=60, tol=1e-3): + + # run fixed number of iterations but allow early stop with no-op steps + def while_cond(state): + x, k = state + r = R(x) + return jnp.logical_and(jnp.linalg.norm(r) > tol, k < maxit) + + def while_body(state): + x, k = state + r = R(x) + J = jax.jacfwd(R)(x) + dx = jnp.linalg.solve(J, -r) + return (x + dx, k + 1) + + x_final, _ = jax.lax.while_loop(while_cond, while_body, (x0, 0)) + return x_final + + # ---------- Plasticity helpers ---------- + @staticmethod + def tr3(A): return jnp.trace(A) + @staticmethod + def frob(A): return jnp.sqrt(jnp.tensordot(A, A)) + @staticmethod + def dev3(A): + I3 = jnp.eye(3, dtype=A.dtype) + return A - I3 * (ElastoplasticityModel2D.tr3(A) / 3.0) + @staticmethod + def flow_normal(sig3): + s = ElastoplasticityModel2D.dev3(sig3) + seq = jnp.sqrt(1.5) * ElastoplasticityModel2D.frob(s) + inv = jnp.where(seq > 0.0, 1.0/seq, 0.0) + return 1.5 * s * inv + + # ---------- Core: one local integration step as a pure function ---------- + @staticmethod + def local_update(ts2, ps2_n, xi_n, E, nu, h1, h2, y0): + """ + Inputs: + ts2 : total strain (2x2 tensor) at current step + ps2_n : plastic strain at previous step (2x2) + xi_n : cum. plastic strain at previous step (scalar) + Returns: + sig2_v : Cauchy stress (2D Voigt, length 3) + ps2_new : updated plastic strain (2x2) + xi_n1 : updated cum. plastic strain (scalar) + """ + _, _, C_dot = ElastoplasticityModel2D.C_elastic_3D(E, nu) + + # plane strain embeddings + ps_zz = -(ps2_n[0,0] + ps2_n[1,1]) + ps3_n = ElastoplasticityModel2D.to3D_from_2D_plane_strain(ps2_n, ps_zz) + ts3 = ElastoplasticityModel2D.to3D_from_2D_plane_strain(ts2, 0.0) + + # trial state + es_trial = ts3 - ps3_n + sig_trial = C_dot(es_trial) + s_trial = ElastoplasticityModel2D.dev3(sig_trial) + sig_eq = jnp.sqrt(1.5) * ElastoplasticityModel2D.frob(s_trial) + yl = y0 + h1*(1.0 - jnp.exp(-h2*xi_n)) + f_yield = sig_eq - yl + + # Elastic return: no plastic flow + def elastic_return(): + sig2 = sig_trial[0:2, 0:2] + sig2_v = ElastoplasticityModel2D.tensor2_to_voigt(sig2) + return sig2_v, ps2_n, xi_n + + # Plastic return: solve for plastic multiplier + def plastic_return(): + def make_residual(): + def R(dx): + depsp_v = dx[:-1] + dp = dx[-1] + + # Update plastic strain + depsp2 = ElastoplasticityModel2D.voigt2_to_tensor(depsp_v) + epsp2 = ps2_n + depsp2 + epsp_zz = -(epsp2[0,0] + epsp2[1,1]) + epsp3 = ElastoplasticityModel2D.to3D_from_2D_plane_strain(epsp2, epsp_zz) + + # Compute stress + eps3 = ts3 + ee3 = eps3 - epsp3 + sig3 = C_dot(ee3) + + # Deviatoric stress and equivalent stress + s3 = ElastoplasticityModel2D.dev3(sig3) + seq = jnp.sqrt(1.5) * ElastoplasticityModel2D.frob(s3) + + # Flow normal + n3 = s3 / (seq + 1e-12) # Avoid division by zero + n2 = n3[0:2, 0:2] + n_v = jnp.array([n2[0,0], n2[1,1], n2[0,1]], dtype=n2.dtype) + + # Residuals + r_flow = depsp_v - n_v * dp + r_yield = seq - (y0 + h1*(1.0 - jnp.exp(-h2*(xi_n + dp)))) + + return jnp.concatenate([r_flow, jnp.array([r_yield])], axis=0) + + return R + + R = make_residual() + x0 = jnp.zeros((4,)) + x = ElastoplasticityModel2D.newton_solve(R, x0, maxit=60, tol=1e-3) + + dps_v = x[:-1] + dp = x[-1] + + # Update plastic variables + dps2 = ElastoplasticityModel2D.voigt2_to_tensor(dps_v) + ps2_new = ps2_n + dps2 + xi_n1 = xi_n + dp + + # Compute final stress + ps_zz_new = -(ps2_new[0,0] + ps2_new[1,1]) + ps3_new = ElastoplasticityModel2D.to3D_from_2D_plane_strain(ps2_new, ps_zz_new) + es = ts3 - ps3_new + sig3 = C_dot(es) + sig2 = sig3[0:2, 0:2] + sig2_v = ElastoplasticityModel2D.tensor2_to_voigt(sig2) + + return sig2_v, ps2_new, xi_n1 + + # Choose elastic or plastic return + return jax.lax.cond(f_yield <= 0.0, elastic_return, plastic_return) + + @partial(jit, static_argnums=(0,)) + def evaluate(self, E, nu, h1, h2, y0, ts, state): + """ + Returns: + C_alg_4 : 4th-order tangent in 2D (i,j,k,l over x,y) OR you can return the 3x3 Voigt if you prefer + sig2_v : stress in 2D Voigt (length 3) + ps_new : updated plastic strain tensor (2x2) + xi_n1 : updated cum. plastic strain (scalar) + """ + ps_v = state[:3] + xi_n=state[3] + ps=ElastoplasticityModel2D.voigt2_to_tensor(ps_v) + + # 1) do the usual local update to get stress at the given ts + sig2_v, ps_new, xi_n1 = ElastoplasticityModel2D.local_update(ts, ps, xi_n, E, nu, h1, h2, y0) + + # 2) define a pure function "stress(ts_v)" so we can differentiate w.r.t. total strain + def stress_voigt_from_ts_voigt(ts_v_flat): + ts2 = ElastoplasticityModel2D.voigt2_to_tensor(ts_v_flat) + sig2_v_loc, _, _ = ElastoplasticityModel2D.local_update(ts2, ps, xi_n, E, nu, h1, h2, y0) + return sig2_v_loc # length-3 Voigt vector + + # 3) automatic differentiation: d(sig_voigt)/d(eps_voigt) -> (3x3) plane-strain tangent in Voigt + ts_v = ElastoplasticityModel2D.tensor2_to_voigt(ts) + C_voigt = jax.jacfwd(stress_voigt_from_ts_voigt)(ts_v) # shape (3,3) + + # 4) (optional) lift to a true 4th-order tensor in 2D + + ps_v = ElastoplasticityModel2D.tensor2_to_voigt(ps_new) # shape (3,) + state_vec = jnp.concatenate([ps_v, jnp.array([xi_n1])]) # shape (4,) + + return C_voigt, sig2_v, state_vec + +class JAXNewton: + def __init__(self, maxit=50, tol=1e-3): + self.maxit = maxit + self.tol = tol + + def solve(self, R, x0): + """ + R: residual function R(x) -> (m,), with m == len(x) + x0: initial guess + returns: x, info dict + """ + x = x0 + for k in range(self.maxit): + r = R(x) + nrm = jnp.linalg.norm(r) + + # Use lax.cond() to check the residual norm condition + def condition_met_fn(x): + return x + + def continue_iteration_fn(x): + J = jax.jacfwd(R)(x) # Compute Jacobian + dx = jnp.linalg.solve(J, -r) # Solve for the step + x_new = x + dx + return x_new + + # Apply the condition and either exit the loop or continue + x= lax.cond(nrm < self.tol, condition_met_fn, continue_iteration_fn, x) + + # Final residual after the loop + r = R(x) + nrm = jnp.linalg.norm(r) + return x class NeoHookianModelAD(MaterialModel): """ Material model. @@ -573,4 +824,4 @@ def tangent(C_voigt): xsie = strain_energy(C_voigt) Se_voigt = second_piola(C_voigt) C_tangent = tangent(C_voigt) - return xsie, Se_voigt, C_tangent.squeeze() \ No newline at end of file + return xsie, Se_voigt, C_tangent.squeeze()