Skip to content

Commit

Permalink
Merge pull request #188 from SciML/ap/adtypes
Browse files Browse the repository at this point in the history
[WIP] Migrating to new ADTypes
  • Loading branch information
avik-pal authored May 26, 2024
2 parents e220325 + f1ab970 commit 0557612
Show file tree
Hide file tree
Showing 13 changed files with 73 additions and 85 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BoundaryValueDiffEq"
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
version = "5.7.1"
version = "5.8.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -34,7 +34,7 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"

[compat]
ADTypes = "0.2.6"
ADTypes = "1.2"
Adapt = "4"
Aqua = "0.8"
ArrayInterface = "7.7"
Expand Down
6 changes: 4 additions & 2 deletions benchmark/simple_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ function create_simple_pendulum_benchmark()
end
for alg in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
if @isdefined(alg)
iip_suite["$alg()"] = @benchmarkable solve($SimplePendulumBenchmark.prob_iip, $alg(), dt=0.05)
iip_suite["$alg()"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip, $alg(), dt = 0.05)
end
end

Expand All @@ -82,7 +83,8 @@ function create_simple_pendulum_benchmark()
end
for alg in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
if @isdefined(alg)
oop_suite["$alg()"] = @benchmarkable solve($SimplePendulumBenchmark.prob_oop, $alg(), dt=0.05)
oop_suite["$alg()"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop, $alg(), dt = 0.05)
end
end

Expand Down
23 changes: 11 additions & 12 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
import Logging
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
import SparseDiffTools: AbstractSparseADType
end

@reexport using ADTypes, DiffEqBase, NonlinearSolve, OrdinaryDiffEq, SparseDiffTools,
Expand Down Expand Up @@ -92,8 +91,8 @@ end
end

@compile_workload begin
for prob in probs, alg in algs
solve(prob, alg; dt = 0.2)
@sync for prob in probs, alg in algs
Threads.@spawn solve(prob, alg; dt = 0.2)
end
end

Expand Down Expand Up @@ -150,8 +149,8 @@ end
end

@compile_workload begin
for prob in probs, alg in algs
solve(prob, alg; dt = 0.2, abstol = 1e-2)
@sync for prob in probs, alg in algs
Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2)
end
end

Expand Down Expand Up @@ -192,12 +191,12 @@ end
nlsolve = NewtonRaphson(),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))))
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2)))))
end

@compile_workload begin
for prob in probs, alg in algs
solve(prob, alg)
@sync for prob in probs, alg in algs
Threads.@spawn solve(prob, alg)
end
end

Expand Down Expand Up @@ -257,18 +256,18 @@ end
nlsolve = LevenbergMarquardt(; disable_geodesic = Val(true)),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))),
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2)))),
MultipleShooting(10,
Tsit5();
nlsolve = GaussNewton(),
jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoForwardDiff(; chunksize = 2),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2)))])
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 2))))])
end

@compile_workload begin
for prob in probs, alg in algs
solve(prob, alg; nlsolve_kwargs = (; abstol = 1e-2))
@sync for prob in probs, alg in algs
Threads.@spawn solve(prob, alg; nlsolve_kwargs = (; abstol = 1e-2))
end
end
end
Expand Down
13 changes: 7 additions & 6 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ Significantly more stable than Single Shooting.
on the input types and problem type.
+ For `TwoPointBVProblem`, only `diffmode` is used (defaults to
`AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`).
`AutoSparse(AutoForwardDiff())` if possible else `AutoSparse(AutoFiniteDiff())`).
+ For `BVProblem`, `bc_diffmode` and `nonbc_diffmode` are used. For `nonbc_diffmode`
we default to `AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`. For
`bc_diffmode`, we default to `AutoForwardDiff` if possible else `AutoFiniteDiff`.
we default to `AutoSparse(AutoForwardDiff())` if possible else
`AutoSparse(AutoFiniteDiff())`. For `bc_diffmode`, we default to `AutoForwardDiff`
if possible else `AutoFiniteDiff`.
- `grid_coarsening`: Coarsening the multiple-shooting grid to generate a stable IVP
solution. Possible Choices:
Expand Down Expand Up @@ -160,10 +161,10 @@ for order in (2, 3, 4, 5, 6)
`BVPJacobianAlgorithm()`, which automatically decides the best algorithm to
use based on the input types and problem type.
- For `TwoPointBVProblem`, only `diffmode` is used (defaults to
`AutoSparseForwardDiff` if possible else `AutoSparseFiniteDiff`).
`AutoSparse(AutoForwardDiff())` if possible else `AutoSparse(AutoFiniteDiff())`).
- For `BVProblem`, `bc_diffmode` and `nonbc_diffmode` are used. For
`nonbc_diffmode` defaults to `AutoSparseForwardDiff` if possible else
`AutoSparseFiniteDiff`. For `bc_diffmode`, defaults to `AutoForwardDiff` if
`nonbc_diffmode` defaults to `AutoSparse(AutoForwardDiff())` if possible else
`AutoSparse(AutoFiniteDiff())`. For `bc_diffmode`, defaults to `AutoForwardDiff` if
possible else `AutoFiniteDiff`.
- `defect_threshold`: Threshold for defect control.
- `max_num_subintervals`: Number of maximal subintervals, default as 3000.
Expand Down
6 changes: 3 additions & 3 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,12 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p)
loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p)

sd_bc = jac_alg.bc_diffmode isa AbstractSparseADType ? SymbolicsSparsityDetection() :
sd_bc = jac_alg.bc_diffmode isa AutoSparse ? SymbolicsSparsityDetection() :
NoSparsityDetection()
cache_bc = __sparse_jacobian_cache(
Val(iip), jac_alg.bc_diffmode, sd_bc, loss_bcₚ, resid_bc, y)

sd_collocation = if jac_alg.nonbc_diffmode isa AbstractSparseADType
sd_collocation = if jac_alg.nonbc_diffmode isa AutoSparse
if L < cache.M
# For underdetermined problems we use sparse since we don't have banded qr
colored_matrix = __generate_sparse_jacobian_prototype(
Expand Down Expand Up @@ -416,7 +416,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
L = length(cache.bcresid_prototype)

sd = if jac_alg.diffmode isa AbstractSparseADType
sd = if jac_alg.diffmode isa AutoSparse
__sparsity_detection_alg(__generate_sparse_jacobian_prototype(
cache, cache.problem_type,
@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
Expand Down
18 changes: 9 additions & 9 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function __solve_nlproblem!(
du, u, p, cur_nshoot, nodes, prob, solve_internal_odes!,
resida_len, residb_len, N, bca, bcb, ode_cache_loss_fn)

sd_bvp = alg.jac_alg.diffmode isa AbstractSparseADType ?
sd_bvp = alg.jac_alg.diffmode isa AutoSparse ?
__sparsity_detection_alg(J_proto) : NoSparsityDetection()

resid_prototype_cached = similar(resid_prototype)
Expand All @@ -113,7 +113,7 @@ function __solve_nlproblem!(
jac_prototype = init_jacobian(jac_cache)

ode_cache_jac_fn = __multiple_shooting_init_jacobian_odecache(
ensemblealg, prob, jac_cache, alg.jac_alg.diffmode,
ensemblealg, prob, jac_cache, __cache_trait(alg.jac_alg.diffmode),
alg.ode_alg, cur_nshoot, u0; internal_ode_kwargs...)

loss_fnₚ = @closure (du, u) -> __multiple_shooting_2point_loss!(
Expand Down Expand Up @@ -153,21 +153,21 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
N, f, bc, u0_size, prob.tspan, alg.ode_alg, u0, ode_cache_loss_fn)

# ODE Part
sd_ode = alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ?
sd_ode = alg.jac_alg.nonbc_diffmode isa AutoSparse ?
__sparsity_detection_alg(J_proto) : NoSparsityDetection()
ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode, nothing,
similar(u_at_nodes, cur_nshoot * N), u_at_nodes)
ode_cache_ode_jac_fn = __multiple_shooting_init_jacobian_odecache(
ensemblealg, prob, ode_jac_cache, alg.jac_alg.nonbc_diffmode,
ensemblealg, prob, ode_jac_cache, __cache_trait(alg.jac_alg.nonbc_diffmode),
alg.ode_alg, cur_nshoot, u0; internal_ode_kwargs...)

# BC Part
sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ?
sd_bc = alg.jac_alg.bc_diffmode isa AutoSparse ?
SymbolicsSparsityDetection() : NoSparsityDetection()
bc_jac_cache = sparse_jacobian_cache(
alg.jac_alg.bc_diffmode, sd_bc, nothing, similar(bcresid_prototype), u_at_nodes)
ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache(
ensemblealg, prob, bc_jac_cache, alg.jac_alg.bc_diffmode,
ensemblealg, prob, bc_jac_cache, __cache_trait(alg.jac_alg.bc_diffmode),
alg.ode_alg, cur_nshoot, u0; internal_ode_kwargs...)

jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache))
Expand Down Expand Up @@ -208,12 +208,12 @@ function __multiple_shooting_init_odecache(
end

function __multiple_shooting_init_jacobian_odecache(
ensemblealg, prob, jac_cache, ad, alg, nshoots, u; kwargs...)
ensemblealg, prob, jac_cache, ::NoDiffCacheNeeded, alg, nshoots, u; kwargs...)
return __multiple_shooting_init_odecache(ensemblealg, prob, alg, u, nshoots; kwargs...)
end

function __multiple_shooting_init_jacobian_odecache(ensemblealg, prob, jac_cache,
::Union{AutoForwardDiff, AutoSparseForwardDiff}, alg, nshoots, u; kwargs...)
function __multiple_shooting_init_jacobian_odecache(
ensemblealg, prob, jac_cache, ::DiffCacheNeeded, alg, nshoots, u; kwargs...)
cache = jac_cache.cache
if cache isa ForwardDiff.JacobianConfig
xduals = reshape(cache.duals[2][1:length(u)], size(u))
Expand Down
11 changes: 6 additions & 5 deletions src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
u, p, ode_cache_loss_fn, bc, u0_size, prob.problem_type)
end

sd = alg.jac_alg.diffmode isa AbstractSparseADType ? SymbolicsSparsityDetection() :
sd = alg.jac_alg.diffmode isa AutoSparse ? SymbolicsSparsityDetection() :
NoSparsityDetection()
y_ = similar(resid_prototype)

Expand All @@ -49,7 +49,8 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
end

ode_cache_jac_fn = __single_shooting_jacobian_ode_cache(
internal_prob, jac_cache, alg.jac_alg.diffmode, u0, alg.ode_alg; ode_kwargs...)
internal_prob, jac_cache, __cache_trait(alg.jac_alg.diffmode),
u0, alg.ode_alg; ode_kwargs...)

jac_prototype = init_jacobian(jac_cache)

Expand Down Expand Up @@ -126,13 +127,13 @@ function __single_shooting_jacobian(J, u, jac_cache, diffmode, loss_fn::L) where
return J
end

function __single_shooting_jacobian_ode_cache(prob, jac_cache, alg, u0, ode_alg; kwargs...)
function __single_shooting_jacobian_ode_cache(
prob, jac_cache, ::NoDiffCacheNeeded, u0, ode_alg; kwargs...)
return SciMLBase.__init(remake(prob; u0), ode_alg; kwargs...)
end

function __single_shooting_jacobian_ode_cache(
prob, jac_cache, ::Union{AutoForwardDiff, AutoSparseForwardDiff},
u0, ode_alg; kwargs...)
prob, jac_cache, ::DiffCacheNeeded, u0, ode_alg; kwargs...)
cache = jac_cache.cache
if cache isa ForwardDiff.JacobianConfig
xduals = cache.duals isa Tuple ? cache.duals[2] : cache.duals
Expand Down
39 changes: 20 additions & 19 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ function Base.show(io::IO, alg::BVPJacobianAlgorithm)
print(io, ")")
end

__any_sparse_ad(ad) = ad isa AbstractSparseADType
function __any_sparse_ad(jac_alg::BVPJacobianAlgorithm)
@inline __any_sparse_ad(::AutoSparse) = true
@inline function __any_sparse_ad(jac_alg::BVPJacobianAlgorithm)
__any_sparse_ad(jac_alg.bc_diffmode) ||
__any_sparse_ad(jac_alg.nonbc_diffmode) ||
__any_sparse_ad(jac_alg.diffmode)
end
@inline __any_sparse_ad(_) = false

function BVPJacobianAlgorithm(
diffmode = missing; nonbc_diffmode = missing, bc_diffmode = missing)
Expand All @@ -89,8 +90,8 @@ If user provided all the required fields, then return the user provided algorith
Otherwise, based on the problem type and the algorithm, decide the missing fields.
For example, for `TwoPointBVProblem`, the `bc_diffmode` is set to
`AutoSparseForwardDiff` while for `StandardBVProblem`, the `bc_diffmode` is set to
`AutoForwardDiff`.
`AutoSparse(AutoForwardDiff())` while for `StandardBVProblem`, the `bc_diffmode` is set to
`AutoForwardDiff()`.
"""
function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, prob::BVProblem, alg)
return concrete_jacobian_algorithm(jac_alg, prob.problem_type, prob, alg)
Expand All @@ -109,21 +110,13 @@ function concrete_jacobian_algorithm(
return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, diffmode)
end

struct BoundaryValueDiffEqTag end

function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:BoundaryValueDiffEqTag, <:T}},
f::F, x::AbstractArray{T}) where {T, F}
return true
end

@inline function __default_sparse_ad(x::AbstractArray{T}) where {T}
return isbitstype(T) ? __default_sparse_ad(T) : __default_sparse_ad(first(x))
end
@inline __default_sparse_ad(x::T) where {T} = __default_sparse_ad(T)
@inline __default_sparse_ad(::Type{<:Complex}) = AutoSparseFiniteDiff()
@inline __default_sparse_ad(::Type{<:Complex}) = AutoSparse(AutoFiniteDiff())
@inline function __default_sparse_ad(::Type{T}) where {T}
return ForwardDiff.can_dual(T) ?
AutoSparseForwardDiff(; tag = BoundaryValueDiffEqTag()) : AutoSparseFiniteDiff()
return AutoSparse(ifelse(ForwardDiff.can_dual(T), AutoForwardDiff(), AutoFiniteDiff()))
end

@inline function __default_nonsparse_ad(x::AbstractArray{T}) where {T}
Expand All @@ -132,8 +125,7 @@ end
@inline __default_nonsparse_ad(x::T) where {T} = __default_nonsparse_ad(T)
@inline __default_nonsparse_ad(::Type{<:Complex}) = AutoFiniteDiff()
@inline function __default_nonsparse_ad(::Type{T}) where {T}
return ForwardDiff.can_dual(T) ? AutoForwardDiff(; tag = BoundaryValueDiffEqTag()) :
AutoFiniteDiff()
return ifelse(ForwardDiff.can_dual(T), AutoForwardDiff(), AutoFiniteDiff())
end

# This can cause Type Instability
Expand All @@ -146,9 +138,10 @@ Base.@deprecate MIRKJacobianComputationAlgorithm(
diffmode = missing; collocation_diffmode = missing, bc_diffmode = missing) BVPJacobianAlgorithm(
diffmode; nonbc_diffmode = collocation_diffmode, bc_diffmode)

__needs_diffcache(::Union{AutoForwardDiff, AutoSparseForwardDiff}) = true
__needs_diffcache(_) = false
function __needs_diffcache(jac_alg::BVPJacobianAlgorithm)
@inline __needs_diffcache(::AutoForwardDiff) = true
@inline __needs_diffcache(ad::AutoSparse) = __needs_diffcache(ADTypes.dense_ad(ad))
@inline __needs_diffcache(_) = false
@inline function __needs_diffcache(jac_alg::BVPJacobianAlgorithm)
return __needs_diffcache(jac_alg.diffmode) ||
__needs_diffcache(jac_alg.bc_diffmode) ||
__needs_diffcache(jac_alg.nonbc_diffmode)
Expand Down Expand Up @@ -176,3 +169,11 @@ const MaybeDiffCache = Union{DiffCache, FakeDiffCache}
PreallocationTools.get_tmp(dc, u)
end
end

# DiffCache
struct DiffCacheNeeded end
struct NoDiffCacheNeeded end

@inline __cache_trait(::AutoForwardDiff) = DiffCacheNeeded()
@inline __cache_trait(ad::AutoSparse) = __cache_trait(ADTypes.dense_ad(ad))
@inline __cache_trait(_) = NoDiffCacheNeeded()
20 changes: 2 additions & 18 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,24 +220,8 @@ end
__vec_bc(sol, p, t, bc, u_size) = vec(bc(__restructure_sol(sol, u_size), p, t))
__vec_bc(sol, p, bc, u_size) = vec(bc(reshape(sol, u_size), p))

__get_non_sparse_ad(ad::AbstractADType) = ad
function __get_non_sparse_ad(ad::AbstractSparseADType)
if ad isa AutoSparseForwardDiff
return AutoForwardDiff{__get_chunksize(ad), typeof(ad.tag)}(ad.tag)
elseif ad isa AutoSparseEnzyme
return AutoEnzyme()
elseif ad isa AutoSparseFiniteDiff
return AutoFiniteDiff()
elseif ad isa AutoSparseReverseDiff
return AutoReverseDiff(ad.compile)
elseif ad isa AutoSparseZygote
return AutoZygote()
else
throw(ArgumentError("Unknown AD Type"))
end
end

__get_chunksize(::AutoSparseForwardDiff{CK}) where {CK} = CK
@inline __get_non_sparse_ad(ad::AbstractADType) = ad
@inline __get_non_sparse_ad(ad::AutoSparse) = ADTypes.dense_ad(ad)

# Restructure Solution
function __restructure_sol(sol::Vector{<:AbstractArray}, u_size)
Expand Down
4 changes: 2 additions & 2 deletions test/mirk/ensemble_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

@testset "$(solver)" for solver in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
jac_algs = [BVPJacobianAlgorithm(),
BVPJacobianAlgorithm(;
bc_diffmode = AutoFiniteDiff(), nonbc_diffmode = AutoSparseFiniteDiff())]
BVPJacobianAlgorithm(; bc_diffmode = AutoFiniteDiff(),
nonbc_diffmode = AutoSparse(AutoFiniteDiff()))]
for jac_alg in jac_algs
sol = solve(ensemble_prob, solver(; jac_alg); trajectories = 10, dt = 0.1)
@test sol.converged
Expand Down
2 changes: 1 addition & 1 deletion test/mirk/mirk_basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end
bvp1 = BVProblem(simplependulum!, bc_pendulum!, u0, tspan)

jac_alg = BVPJacobianAlgorithm(;
bc_diffmode = AutoFiniteDiff(), nonbc_diffmode = AutoSparseFiniteDiff())
bc_diffmode = AutoFiniteDiff(), nonbc_diffmode = AutoSparse(AutoFiniteDiff()))

# Using ForwardDiff might lead to Cache expansion warnings
@test_nowarn solve(bvp1, MIRK2(; jac_alg); dt = 0.005)
Expand Down
2 changes: 1 addition & 1 deletion test/shooting/basic_problems_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ end
grid_coarsening = true,
nlsolve = TrustRegion(),
jac_alg = BVPJacobianAlgorithm(; bc_diffmode = AutoForwardDiff(; chunksize = 8),
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 8)))
nonbc_diffmode = AutoSparse(AutoForwardDiff(; chunksize = 8))))
alg_dense = MultipleShooting(10,
AutoVern7(Rodas4P());
grid_coarsening = true,
Expand Down
Loading

0 comments on commit 0557612

Please sign in to comment.