Skip to content

Commit

Permalink
Merge pull request #211 from ErikQQY/qqy/refactor_firk
Browse files Browse the repository at this point in the history
Refactor FIRK solvers
  • Loading branch information
ChrisRackauckas authored Sep 15, 2024
2 parents 4590c02 + ac46c48 commit 0a2b4ac
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 293 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BoundaryValueDiffEq"
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
version = "5.9.1"
version = "5.10.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -51,7 +51,7 @@ LinearSolve = "2.21"
Logging = "1.10"
NonlinearSolve = "3.8.1"
ODEInterface = "0.5"
OrdinaryDiffEq = "6.88.1"
OrdinaryDiffEq = "6.89.0"
PreallocationTools = "0.4.24"
PrecompileTools = "1.2"
Preferences = "1.4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Precompilation can be controlled via `Preferences.jl`
- `PrecompileMIRK` -- Precompile the MIRK2 - MIRK6 algorithms (default: `true`).
- `PrecompileShooting` -- Precompile the single shooting algorithms (default: `true`).
- `PrecompileMultipleShooting` -- Precompile the multiple shooting algorithms (default: `true`).
- `PrecompileMIRKNLLS` -- Precompile the MIRK2 - MIRK6 algorithms for under-determined and over-determined BVPs (default: `true`).
- `PrecompileMIRKNLLS` -- Precompile the MIRK2 - MIRK6 algorithms for under-determined and over-determined BVPs (default: `false`).
- `PrecompileShootingNLLS` -- Precompile the single shooting algorithms for under-determined and over-determined BVPs (default: `true`).
- `PrecompileMultipleShootingNLLS` -- Precompile the multiple shooting algorithms for under-determined and over-determined BVPs (default: `true` ).

Expand Down
173 changes: 0 additions & 173 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,179 +95,6 @@ end
Threads.@spawn solve(prob, alg; dt = 0.2)
end
end

f1_nlls! = (du, u, p, t) -> begin
du[1] = u[2]
du[2] = -u[1]
end

f1_nlls = (u, p, t) -> [u[2], -u[1]]

bc1_nlls! = (resid, sol, p, t) -> begin
solₜ₁ = sol[:, 1]
solₜ₂ = sol[:, end]
resid[1] = solₜ₁[1]
resid[2] = solₜ₂[1] - 1
resid[3] = solₜ₂[2] + 1.729109
return nothing
end
bc1_nlls = (sol, p, t) -> [sol[:, 1][1], sol[:, end][1] - 1, sol[:, end][2] + 1.729109]

bc1_nlls_a! = (resid, ua, p) -> (resid[1] = ua[1])
bc1_nlls_b! = (resid, ub, p) -> (resid[1] = ub[1] - 1;
resid[2] = ub[2] + 1.729109)

bc1_nlls_a = (ua, p) -> [ua[1]]
bc1_nlls_b = (ub, p) -> [ub[1] - 1, ub[2] + 1.729109]

tspan = (0.0, 100.0)
u0 = [0.0, 1.0]
bcresid_prototype1 = Array{Float64}(undef, 3)
bcresid_prototype2 = (Array{Float64}(undef, 1), Array{Float64}(undef, 2))

probs = [
BVProblem(BVPFunction(f1_nlls!, bc1_nlls!; bcresid_prototype = bcresid_prototype1),
u0, tspan, nlls = Val(true)),
BVProblem(BVPFunction(f1_nlls, bc1_nlls; bcresid_prototype = bcresid_prototype1),
u0, tspan, nlls = Val(true)),
TwoPointBVProblem(f1_nlls!, (bc1_nlls_a!, bc1_nlls_b!), u0, tspan;
bcresid_prototype = bcresid_prototype2, nlls = Val(true)),
TwoPointBVProblem(f1_nlls, (bc1_nlls_a, bc1_nlls_b), u0, tspan;
bcresid_prototype = bcresid_prototype2, nlls = Val(true))]

jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))

nlsolvers = [LevenbergMarquardt(; disable_geodesic = Val(true)), GaussNewton()]

algs = []

if Preferences.@load_preference("PrecompileMIRKNLLS", false)
for nlsolve in nlsolvers
append!(algs, [MIRK2(; jac_alg, nlsolve), MIRK6(; jac_alg, nlsolve)])
end
end

@compile_workload begin
@sync for prob in probs, alg in algs
Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2)
end
end

function bc2!(residual, u, p, t)
residual[1] = u(0.0)[1] - 5
residual[2] = u(5.0)[1]
end
bc2 = (u, p, t) -> [u(0.0)[1] - 5, u(5.0)[1]]

tspan = (0.0, 5.0)
u0 = [5.0, -3.5]
bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1))

probs = [BVProblem(BVPFunction{true}(f1!, bc2!), u0, tspan; nlls = Val(false)),
BVProblem(BVPFunction{false}(f1, bc2), u0, tspan; nlls = Val(false)),
BVProblem(
BVPFunction{true}(
f1!, (bc1_a!, bc1_b!); bcresid_prototype, twopoint = Val(true)),
u0,
tspan;
nlls = Val(false)),
BVProblem(
BVPFunction{false}(f1, (bc1_a, bc1_b); bcresid_prototype, twopoint = Val(true)),
u0, tspan; nlls = Val(false))]

algs = []

if @load_preference("PrecompileShooting", true)
push!(algs,
Shooting(Tsit5(); nlsolve = NewtonRaphson(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))))
end

if @load_preference("PrecompileMultipleShooting", true)
push!(algs,
MultipleShooting(10,
Tsit5();
nlsolve = NewtonRaphson(),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2)))))
end

@compile_workload begin
@sync for prob in probs, alg in algs
Threads.@spawn solve(prob, alg)
end
end

bc1_nlls! = (resid, sol, p, t) -> begin
solₜ₁ = sol(0.0)
solₜ₂ = sol(100.0)
resid[1] = solₜ₁[1]
resid[2] = solₜ₂[1] - 1
resid[3] = solₜ₂[2] + 1.729109
return nothing
end
bc1_nlls = (sol, p, t) -> [sol(0.0)[1], sol(100.0)[1] - 1, sol(1.0)[2] + 1.729109]

tspan = (0.0, 100.0)
u0 = [0.0, 1.0]
bcresid_prototype1 = Array{Float64}(undef, 3)
bcresid_prototype2 = (Array{Float64}(undef, 1), Array{Float64}(undef, 2))

probs = [
BVProblem(
BVPFunction{true}(f1_nlls!, bc1_nlls!; bcresid_prototype = bcresid_prototype1),
u0, tspan; nlls = Val(true)),
BVProblem(
BVPFunction{false}(f1_nlls, bc1_nlls; bcresid_prototype = bcresid_prototype1),
u0, tspan; nlls = Val(true)),
BVProblem(
BVPFunction{true}(f1_nlls!, (bc1_nlls_a!, bc1_nlls_b!);
bcresid_prototype = bcresid_prototype2, twopoint = Val(true)),
u0,
tspan;
nlls = Val(true)),
BVProblem(
BVPFunction{false}(f1_nlls, (bc1_nlls_a, bc1_nlls_b);
bcresid_prototype = bcresid_prototype2, twopoint = Val(true)),
u0,
tspan;
nlls = Val(true))]

algs = []

if @load_preference("PrecompileShootingNLLS", true)
append!(algs,
[
Shooting(
Tsit5(); nlsolve = LevenbergMarquardt(; disable_geodesic = Val(true)),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))),
Shooting(Tsit5(); nlsolve = GaussNewton(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)))])
end

if @load_preference("PrecompileMultipleShootingNLLS", true)
append!(algs,
[
MultipleShooting(10,
Tsit5();
nlsolve = LevenbergMarquardt(; disable_geodesic = Val(true)),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2)))),
MultipleShooting(10,
Tsit5();
nlsolve = GaussNewton(),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2))))])
end

@compile_workload begin
@sync for prob in probs, alg in algs
Threads.@spawn solve(prob, alg; nlsolve_kwargs = (; abstol = 1e-2))
end
end
end

export Shooting, MultipleShooting
Expand Down
52 changes: 41 additions & 11 deletions src/adaptivity.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
interp_eval!(y::AbstractArray, cache::MIRKCache, t)
interp_eval!(y::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt)
interp_eval!(y::AbstractArray, cache::FIRKCacheExpand, t, mesh, mesh_dt)
interp_eval!(y::AbstractArray, cache::FIRKCacheNested, t, mesh, mesh_dt)
After we construct an interpolant, we use interp_eval to evaluate it.
"""
Expand All @@ -14,13 +16,6 @@ end

@views function interp_eval!(
y::AbstractArray, cache::FIRKCacheExpand{iip}, t, mesh, mesh_dt) where {iip}
i = findfirst(x -> x == y, cache.y₀.u)
interp_eval!(cache.y₀.u, i, cache::FIRKCacheExpand{iip}, t, mesh, mesh_dt)
return y
end

@views function interp_eval!(
y::AbstractArray, i::Int, cache::FIRKCacheExpand{iip}, t, mesh, mesh_dt) where {iip}
j = interval(mesh, t)
h = mesh_dt[j]
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1) # Cache length factor. We use a h corresponding to cache.y. Note that this assumes equidistributed mesh
Expand All @@ -29,12 +24,11 @@ end
end
τ = (t - mesh[j])

(; f, M, p, ITU) = cache
(; q_coeff, stage) = ITU
(; f, M, stage, p, ITU) = cache
(; q_coeff) = ITU

K = __similar(cache.y[1].du, M, stage)

ctr_y0 = (i - 1) * (stage + 1) + 1
ctr_y = (j - 1) * (stage + 1) + 1

yᵢ = cache.y[ctr_y].du
Expand Down Expand Up @@ -130,6 +124,34 @@ function dS_interpolate!(dy::AbstractArray, t, S_coeffs)
dy .= S_coeffs * ts
end

"""
s_constraints(M, h)
Form the quartic interpolation constraint matrix, see bvp5c paper.
"""
function s_constraints(M, h)
t = vec(repeat([0.0, 1.0 * h, 0.5 * h, 0.0, 1.0 * h, 0.5 * h], 1, M))
A = zeros(6 * M, 6 * M)
for i in 1:6
row_start = (i - 1) * M + 1
for k in 0:(M - 1)
for j in 1:6
A[row_start + k, j + k * 6] = t[i + k * 6]^(j - 1)
end
end
end
for i in 4:6
row_start = (i - 1) * M + 1
for k in 0:(M - 1)
for j in 1:6
A[row_start + k, j + k * 6] = j == 1.0 ? 0.0 :
(j - 1) * t[i + k * 6]^(j - 2)
end
end
end
return A
end

"""
interval(mesh, t)
Expand All @@ -141,6 +163,8 @@ end

"""
mesh_selector!(cache::MIRKCache)
mesh_selector!(cache::FIRKCacheExpand)
mesh_selector!(cache::FIRKCacheNested)
Generate new mesh based on the defect.
"""
Expand Down Expand Up @@ -199,6 +223,8 @@ end

"""
redistribute!(cache::MIRKCache, Nsub_star, ŝ, mesh, mesh_dt)
redistribute!(cache::FIRKCacheExpand, Nsub_star, ŝ, mesh, mesh_dt)
redistribute!(cache::FIRKCacheNested, Nsub_star, ŝ, mesh, mesh_dt)
Generate a new mesh based on the `ŝ`.
"""
Expand Down Expand Up @@ -235,6 +261,8 @@ end
"""
half_mesh!(mesh, mesh_dt)
half_mesh!(cache::MIRKCache)
half_mesh!(cache::FIRKCacheExpand)
half_mesh!(cache::FIRKCacheNested)
The input mesh has length of `n + 1`. Divide the original subinterval into two equal length
subinterval. The `mesh` and `mesh_dt` are modified in place.
Expand All @@ -260,6 +288,8 @@ end

"""
defect_estimate!(cache::MIRKCache)
defect_estimate!(cache::FIRKCacheExpand)
defect_estimate!(cache::FIRKCacheNested)
defect_estimate use the discrete solution approximation Y, plus stages of
the RK method in 'k_discrete', plus some new stages in 'k_interp' to construct
Expand Down
Loading

0 comments on commit 0a2b4ac

Please sign in to comment.