Skip to content

Commit

Permalink
Merge pull request #203 from ErikQQY/qqy/interp_deriv
Browse files Browse the repository at this point in the history
Fix interpolation for solution derivative
  • Loading branch information
ChrisRackauckas authored Jul 31, 2024
2 parents 4a767e1 + 703fcfc commit 1268560
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
25 changes: 22 additions & 3 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end

for j in idx
z = similar(cache.fᵢ₂_cache)
interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt)
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
vals[j] = idxs !== nothing ? z[idxs] : z
end
return DiffEqArray(vals, tvals)
Expand All @@ -49,14 +49,33 @@ end

for j in idx
z = similar(cache.fᵢ₂_cache)
interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt)
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
vals[j] = z
end
end

@inline function interpolation(
tval::Number, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D}
z = similar(id.cache.fᵢ₂_cache)
interp_eval!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt)
interpolant!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv)
return idxs !== nothing ? z[idxs] : z
end

@inline function interpolant!(
z::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{0}})
i = interval(mesh, t)
dt = mesh_dt[i]
τ = (t - mesh[i]) / dt
w, w′ = interp_weights(τ, cache.alg)
sum_stages!(z, cache, w, i)
end

@inline function interpolant!(
dz::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{1}})
i = interval(mesh, t)
dt = mesh_dt[i]
τ = (t - mesh[i]) / dt
w, w′ = interp_weights(τ, cache.alg)
z = similar(dz)
sum_stages!(z, dz, cache, w, w′, i)
end
26 changes: 22 additions & 4 deletions test/mirk/mirk_basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ end
(-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a))]
end

function prob_bvp_linear_analytic_derivative(u, λ, t)
a = 1 / sqrt(λ)
return [(-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a)),
(exp(-a * t) - exp((t - 2) * a)) / (1 - exp(-2 * a))]
end

function prob_bvp_linear_f!(du, u, p, t)
du[1] = u[2]
du[2] = 1 / p * u[1]
Expand All @@ -177,19 +183,31 @@ end

@testset "Interpolation for adaptive MIRK$order" for order in (2, 3, 4, 5, 6)
sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001)
@test sol(0.001)[0.998687464, -1.312035941] atol=testTol
@test sol(0.001; idxs = [1, 2])[0.998687464, -1.312035941] atol=testTol
@test sol(0.001; idxs = 1)0.998687464 atol=testTol
@test sol(0.001; idxs = 2)-1.312035941 atol=testTol
sol_analytic = prob_bvp_linear_analytic(nothing, λ, 0.001)

@test sol(0.001)sol_analytic atol=testTol
@test sol(0.001; idxs = [1, 2])sol_analytic atol=testTol
@test sol(0.001; idxs = 1)sol_analytic[1] atol=testTol
@test sol(0.001; idxs = 2)sol_analytic[2] atol=testTol
end

@testset "Interpolation for non-adaptive MIRK$order" for order in (2, 3, 4, 5, 6)
sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001, adaptive = false)

@test_nowarn sol(0.01)
@test_nowarn sol(0.01; idxs = [1, 2])
@test_nowarn sol(0.01; idxs = 1)
@test_nowarn sol(0.01; idxs = 2)
end

@testset "Interpolation for solution derivative" for order in (2, 3, 4, 5, 6)
sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001)
sol_analytic = prob_bvp_linear_analytic(nothing, λ, 0.04)
dsol_analytic = prob_bvp_linear_analytic_derivative(nothing, λ, 0.04)

@test sol(0.04, Val{0})sol_analytic atol=testTol
@test sol(0.04, Val{1})dsol_analytic atol=testTol
end
end

@testitem "Swirling Flow III" begin
Expand Down

0 comments on commit 1268560

Please sign in to comment.