Skip to content

Latest commit

 

History

History
 
 

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

README.md

Purpose:

This example illustrates solving differential equations numerically in Julia. Specifically, it solves an ODE after an example taken from this Python notebook, using DifferentialEquations.jl

Contents:

  • ode_test.jl: Julia source code
  • run.sbatch: Batch-job submission script
  • results.dat: Numeric results
  • figure.png: Figure of ODE's solution
  • figure.py: Python script for generating the figure

Julia code:

#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Program: ode_test.jl 
#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
using DifferentialEquations
using SimpleDiffEq
using Printf
using Plots
using Plots.PlotMeasures

# --- Define the problem ---
f(u, p, t) = t - u
u0 = 1.0             # Initial conditions
tspan = (0.0, 5.0)   # Time span
prob = ODEProblem(f, u0, tspan)

# --- Solve the problem ---
dt  = 0.2
alg = RK4()
sol = solve(prob, alg, saveat=dt)

# --- Exact solution ---
function f_ex(t)
   f = t - 1.0 + ( 2.0 * exp(-t) )
   return f
end

# --- Print out solution ---
fname = "results.dat"
fo = open(fname, "w")
println(fo, "   Time     Numeric     Exact   ")
println(fo, " ------------------------------ ")
for i = 1: size(sol.t)[1]
   x = sol.t[i]
   y = sol.u[i]
   y_exact = f_ex(x)
   @printf(fo, "%8.4f %10.6f %10.6f\n", x, y, y_exact)
end
close(fo)

Example Batch-Job Submission Script:

#!/bin/bash
#SBATCH -J ode_test
#SBATCH -o ode_test.out
#SBATCH -e ode_test.err
#SBATCH -p shared
#SBATCH -N 1
#SBATCH -c 1
#SBATCH -t 0-00:30
#SBATCH --mem=4G

# Load required software modules
module load julia/1.1.1-fasrc01
srun -n 1 -c 1 julia ode_test.jl

Example Usage:

sbatch run.sbatch

Example Output:

$ cat results.dat 
   Time     Numeric     Exact   
 ------------------------------ 
  0.0000   1.000000   1.000000
  0.2000   0.837462   0.837462
  0.4000   0.740639   0.740640
  0.6000   0.697624   0.697623
  0.8000   0.698660   0.698658
  1.0000   0.735763   0.735759
  1.2000   0.802393   0.802388
  1.4000   0.893198   0.893194
  1.6000   1.003802   1.003793
  1.8000   1.130613   1.130598
  2.0000   1.270681   1.270671
  2.2000   1.421625   1.421606
  2.4000   1.581455   1.581436
  2.6000   1.748569   1.748547
  2.8000   1.921645   1.921620
  3.0000   2.099603   2.099574
  3.2000   2.281551   2.281524
  3.4000   2.466785   2.466747
  3.6000   2.654677   2.654647
  3.8000   2.844781   2.844742
  4.0000   3.036672   3.036631
  4.2000   3.230027   3.229991
  4.4000   3.424601   3.424555
  4.6000   3.620150   3.620104
  4.8000   3.816498   3.816459
  5.0000   4.013510   4.013476

Figure of Solution:

solution

Python script for generating the figure:

"""
Program: figure.py
"""
import os
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
from matplotlib.ticker import MultipleLocator, FormatStrFormatter

def rc_params():
    """Set rcParams for this plot"""
    params = {
        'axes.linewidth':2.0,
        'xtick.major.size':6.5,
        'xtick.major.width':1.5,
        'xtick.minor.size':3.5,
        'xtick.minor.width':1.5,
        'ytick.major.size':6.5,
        'ytick.major.width':1.5,
        'ytick.minor.size':3.5,
        'ytick.minor.width':1.5,
        'xtick.labelsize':22,
        'ytick.labelsize':22,
        'xtick.direction':'in',
        'ytick.direction':'in',
        'xtick.top':True,
        'ytick.right':True
        }
    mpl.rcParams.update(params)

# Set rcParams    
rc_params()

# Data
cwd = os.getcwd()
data_path = os.path.join(cwd, 'results.dat')

# Figure
fig_path  = os.path.join(cwd, 'figure.png')

# Load data
darr = np.loadtxt(data_path, skiprows=2)
t = darr[:,0]
y = darr[:,1]
y_ex = darr[:,2]

# Plot results
fig, ax = plt.subplots(figsize=(8,6))

p1, = ax.plot(t, y_ex, linewidth = 3.0, color="red", alpha=0.5,
              linestyle='--', label='Exact Solution')

p2, = ax.plot(t, y_ex, linewidth = 3.0, color="blue", alpha=0.5,
              marker='o', markersize=7, linestyle='', label='Numeric Solution')


plt.xlim([0.0, 5.0])
plt.ylim([0.0, 4.3])
plt.xlabel('Time', fontsize=22)
plt.ylabel('u(t)', fontsize=22)
plt.legend(fontsize=15, loc="upper left", shadow=True, fancybox=True)

plt.savefig(fig_path, format='png', dpi=100, bbox_inches='tight')

References: