Skip to content

Commit

Permalink
Changing B to MB for multibasis
Browse files Browse the repository at this point in the history
  • Loading branch information
mleprovost committed Nov 21, 2023
1 parent 9f5a032 commit dd5cc62
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 40 deletions.
36 changes: 18 additions & 18 deletions src/hermitemap/expandedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ struct ExpandedFunction
m::Int64
::Int64
Nx::Int64
B::MultiBasis
MB::MultiBasis
idx::Array{Int64,2}
dim::Array{Int64, 1} # contains the active dimensions, i.e. columns of idx not equal to zeros
coeff::Array{Float64,1}
function ExpandedFunction(B::MultiBasis, idx::Array{Int64,2}, coeff::Array{Float64,1})
function ExpandedFunction(MB::MultiBasis, idx::Array{Int64,2}, coeff::Array{Float64,1})
= size(idx,1)
Nx = B.Nx
Nx = MB.Nx
@assert== size(coeff, 1) "The dimension of the basis functions don't
match the number of coefficients"


@assert size(idx,2) == Nx "Size of the array of multi-indices idx is wrong"
return new(B.B.m, Nψ, Nx, B, idx, active_dim(idx, B), coeff)
return new(MB.B.m, Nψ, Nx, MB, idx, active_dim(idx, MB), coeff)
end
end

Expand All @@ -68,7 +68,7 @@ $(TYPEDSIGNATURES)
Returns the kind of basis of the `ExpandedFunction` `f`.
"""
getbasis(f::ExpandedFunction) = getbasis(f.B)
getbasis(f::ExpandedFunction) = getbasis(f.MB)


"""
Expand Down Expand Up @@ -127,7 +127,7 @@ Evaluates the `ExpandedFunction` `f` at `x`.
function (f::ExpandedFunction)(x::Array{T,1}) where {T<:Real}
out = 0.0
@inbounds for i=1:f.
fi = MultiFunction(f.B, f.idx[i,:])
fi = MultiFunction(f.MB, f.idx[i,:])
out += f.coeff[i]*fi(x)
end
return out
Expand Down Expand Up @@ -161,7 +161,7 @@ end
$(TYPEDSIGNATURES)
Returns the active dimensions of the set of multi-indices `idx` for the `MultiBasis` `B`.
"""
active_dim(idx::Array{Int64,2}, B::MultiBasis) = active_dim(idx, B.B)
active_dim(idx::Array{Int64,2}, MB::MultiBasis) = active_dim(idx, MB.B)


# alleval computes the evaluation, gradient and hessian of the function
Expand All @@ -182,7 +182,7 @@ function alleval(f::ExpandedFunction, X)
result = DiffResults.HessianResult(zeros(Nx))

for i=1:
fi = MultiFunction(f.B, f.idx[i,:])
fi = MultiFunction(f.MB, f.idx[i,:])
for j=1:Ne
result = ForwardDiff.hessian!(result, fi, X[:,j])
ψ[j,i] = DiffResults.value(result)
Expand Down Expand Up @@ -221,7 +221,7 @@ function evaluate_basis!(ψ, f::ExpandedFunction, X, dims::Union{Array{Int64,1},
Xj = view(X,j,:)
ψj = ψtmp[:,1:maxj+1]

vander!(ψj, f.B.B, maxj, 0, Xj)
vander!(ψj, f.MB.B, maxj, 0, Xj)

@avx for l = 1:Nψreduced
for k=1:Ne
Expand Down Expand Up @@ -275,7 +275,7 @@ function repeated_evaluate_basis(f::ExpandedFunction, x, idx::Array{Int64,2})
# Compute the last component
midxj = idx[:,f.Nx]
maxj = maximum(midxj)
ψj = vander(f.B.B, maxj, 0, x)
ψj = vander(f.MB.B, maxj, 0, x)
return ψj[:, midxj .+ 1]
end

Expand Down Expand Up @@ -321,10 +321,10 @@ function grad_xk_basis!(dkψ, f::ExpandedFunction, X, k::Int64, grad_dim::Union{
maxj = maximum(midxj)
Xj = view(X,j,:)
if j in grad_dim # Compute the kth derivative along grad_dim
dkψj = vander(f.B.B, maxj, k, Xj)
dkψj = vander(f.MB.B, maxj, k, Xj)

else # Simple evaluation
dkψj = vander(f.B.B, maxj, 0, Xj)
dkψj = vander(f.MB.B, maxj, 0, Xj)
end
dkψ .*= dkψj[:, midxj .+ 1]
end
Expand Down Expand Up @@ -550,7 +550,7 @@ function repeated_grad_xk_basis!(out, cache, f::ExpandedFunction, x, idx::Array{
midxj = idx[:, Nx]
maxj = maximum(midxj)
# dkψj = zeros(Ne, maxj+1)
vander!(cache, f.B.B, maxj, k, x)
vander!(cache, f.MB.B, maxj, k, x)
Nψreduced = size(idx, 1)
@avx for l = 1:Nψreduced
for k=1:Ne
Expand Down Expand Up @@ -594,7 +594,7 @@ function repeated_hess_xk_basis!(out, cache, f::ExpandedFunction, x, idx::Array{
maxj = maximum(midxj)
# Compute the kth derivative along grad_dim
# dkψj = zeros(Ne, maxj+1)
vander!(cache, f.B.B, maxj, k, x)
vander!(cache, f.MB.B, maxj, k, x)
Nψreduced = size(idx, 1)
@avx for l = 1:Nψreduced
for k=1:Ne
Expand Down Expand Up @@ -650,7 +650,7 @@ function grad_x_grad_xd(f::ExpandedFunction, X, idx::Array{Int64,2})
@inbounds for i f.dim[f.dim .< f.Nx]
# Reduce further the computation, we have a non-zero output only if
# there is a feature such that idx[:,i]*idx[:,Nx]>0
if any([line[i]*line[f.Nx] for line in eachslice(idx; dims = 1)] .> 0) || iszerofeatureactive(f.B.B)
if any([line[i]*line[f.Nx] for line in eachslice(idx; dims = 1)] .> 0) || iszerofeatureactive(f.MB.B)
fill!(dxdxkψ_basis, 0.0)
grad_xk_basis!(dxdxkψ_basis, f, X, 1, [i;Nx], idx)
dxidxkψ = view(dxdxkψ,:,i)
Expand Down Expand Up @@ -705,7 +705,7 @@ function reduced_grad_x_grad_xd!(dxdxkψ, f::ExpandedFunction, X, idx::Array{Int
@inbounds for i=1:length(dimoff)
# Reduce further the computation, we have a non-zero output only if
# there is a feature such that idx[:,i]*idx[:,Nx]>0
if any([line[dim[i]]*line[f.Nx] for line in eachslice(idx; dims = 1)] .> 0) || iszerofeatureactive(f.B.B)
if any([line[dim[i]]*line[f.Nx] for line in eachslice(idx; dims = 1)] .> 0) || iszerofeatureactive(f.MB.B)
fill!(dxdxkψ_basis, 0.0)
grad_xk_basis!(dxdxkψ_basis, f, X, 1, [dim[i]; Nx], idx)
dxidxkψ = view(dxdxkψ,:,i)
Expand Down Expand Up @@ -761,7 +761,7 @@ function hess_x_grad_xd(f::ExpandedFunction, X, idx::Array{Int64,2})
for j f.dim[f.dim .>= i]
# Reduce further the computation, we have a non-zero output only if
# there is a feature such that idx[:,i]*idx[:,j]*idx[:,Nx]>0 or if the first feature is active
if any([line[i]*line[j]*line[f.Nx] for line in eachslice(f.idx; dims = 1)] .> 0) || iszerofeatureactive(f.B.B)
if any([line[i]*line[j]*line[f.Nx] for line in eachslice(f.idx; dims = 1)] .> 0) || iszerofeatureactive(f.MB.B)
fill!(d2xdxkψ_basis, 0.0)
dxidxjdxkψ = view(d2xdxkψ,:,i,j)
# Case i = j = k
Expand Down Expand Up @@ -834,7 +834,7 @@ function reduced_hess_x_grad_xd!(d2xdxkψ, f::ExpandedFunction, X, idx::Array{In
for j = i:length(dim)
# Reduce further the computation, we have a non-zero output only if
# there is a feature such that idx[:,i]*idx[:,j]*idx[:,Nx]>0
if any([line[dim[i]]*line[dim[j]]*line[f.Nx] for line in eachslice(f.idx; dims = 1)] .> 0) || iszerofeatureactive(f.B.B)
if any([line[dim[i]]*line[dim[j]]*line[f.Nx] for line in eachslice(f.idx; dims = 1)] .> 0) || iszerofeatureactive(f.MB.B)
fill!(d2xdxkψ_basis, 0.0)
dxidxjdxkψ = view(d2xdxkψ,:,i,j)
# Case i = j = k
Expand Down
10 changes: 5 additions & 5 deletions src/hermitemap/hermitemapcomponent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ function HermiteMapComponent(m::Int64, Nx::Int64, idx::Array{Int64,2}, coeff::Ar
if b ["ProHermiteBasis"; "PhyHermiteBasis";
"CstProHermiteBasis"; "CstPhyHermiteBasis";
"CstLinProHermiteBasis"; "CstLinPhyHermiteBasis"]
B = MultiBasis(eval(Symbol(b))(m), Nx)
MB = MultiBasis(eval(Symbol(b))(m), Nx)
else
error("The basis "*b*" is not defined.")
end

return HermiteMapComponent(m, Nψ, Nx, IntegratedFunction(ExpandedFunction(B, idx, coeff)), α)
return HermiteMapComponent(m, Nψ, Nx, IntegratedFunction(ExpandedFunction(MB, idx, coeff)), α)
end

function HermiteMapComponent(f::ExpandedFunction; α::Float64 = αreg)
Expand All @@ -74,22 +74,22 @@ function HermiteMapComponent(m::Int64, Nx::Int64; α::Float64 = αreg, b::String
if b ["ProHermiteBasis"; "PhyHermiteBasis";
"CstProHermiteBasis"; "CstPhyHermiteBasis";
"CstLinProHermiteBasis"; "CstLinPhyHermiteBasis"]
B = MultiBasis(eval(Symbol(b))(m), Nx)
MB = MultiBasis(eval(Symbol(b))(m), Nx)
else
error("The basis "*b*" is not defined.")
end

idx = zeros(Int64, Nψ, Nx)
coeff = zeros(Nψ)

f = ExpandedFunction(B, idx, coeff)
f = ExpandedFunction(MB, idx, coeff)
I = IntegratedFunction(f)
return HermiteMapComponent(I; α = α)
end

function Base.show(io::IO, C::HermiteMapComponent)
println(io,"Hermite map component of dimension "*string(C.Nx)*" with Nψ = "*string(C.Nψ)*" active features")
# for i=1:B.m
# for i=1:MB.m
# println(io, B[i])
# end
end
Expand Down
8 changes: 4 additions & 4 deletions src/hermitemap/multibasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ end
"""
$(TYPEDSIGNATURES)
Returns the size of the `MultiBasis` `B`.
Returns the size of the `MultiBasis` `MB`.
"""
size(B::MultiBasis) = (B.B.m, B.Nx)
size(MB::MultiBasis) = (MB.B.m, MB.Nx)

"""
$(TYPEDSIGNATURES)
Returns the kind of the underlying basis of the `MultiBasis` `B`.
Returns the kind of the underlying basis of the `MultiBasis` `MB`.
"""
getbasis(B::MultiBasis) = string(typeof(B.B))
getbasis(MB::MultiBasis) = string(typeof(MB.B))
20 changes: 10 additions & 10 deletions src/hermitemap/multifunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,26 @@ $(TYPEDFIELDS)
struct MultiFunction
m::Int64
Nx::Int64
B::MultiBasis
MB::MultiBasis
α::Array{Int64,1}
function MultiFunction(B::MultiBasis, α::Array{Int64,1})
m = B.B.m
Nx = B.Nx
function MultiFunction(MB::MultiBasis, α::Array{Int64,1})
m = MB.B.m
Nx = MB.Nx
@assert Nx == size(α,1) "Dimension of the space doesn't match the size of α"
for i=1:Nx
@assert α[i]<=m "multi index α can't be greater than the size of the univariate basis "
end
return new(m, Nx, B, α)
return new(m, Nx, MB, α)
end
end


function MultiFunction(B::MultiBasis)
return MultiFunction(B.B.m, B.Nx, B, ones(Int64, B.Nx))
function MultiFunction(MB::MultiBasis)
return MultiFunction(MB.B.m, MB.Nx, MB, ones(Int64, MB.Nx))
end

function MultiFunction(B::Basis, Nx::Int64; scaled::Bool = true)
return MultiFunction(B.B.m, Nx, MultiBasis(k, B), ones(Int64, Nx))
function MultiFunction(MB::Basis, Nx::Int64; scaled::Bool = true)
return MultiFunction(MB.B.m, Nx, MultiBasis(k, MB), ones(Int64, Nx))
end

"""
Expand All @@ -47,7 +47,7 @@ Evaluates the `MultiFunction` `F` at `x`
function (F::MultiFunction)(x::Array{T,1}) where {T <: Real}
out = 1.0
@inbounds for i=1:F.Nx
out *= F.B.B[F.α[i]+1](x[i])
out *= F.MB.B[F.α[i]+1](x[i])
end
return out
end
6 changes: 3 additions & 3 deletions src/hermitemap/totalordermap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ The features of the created maps are all the tensorial products of the basis ele
function totalordermapcomponent(Nx::Int64, order::Int64; withconstant::Bool = false, b::String = "CstProHermiteBasis")
@assert order >= 0 "Order should be positive"
if b ["CstProHermiteBasis"; "CstPhyHermiteBasis"]
B = MultiBasis(eval(Symbol(b))(order+2), Nx)
MB = MultiBasis(eval(Symbol(b))(order+2), Nx)
elseif b ["CstLinProHermiteBasis"; "CstLinPhyHermiteBasis"]
B = MultiBasis(eval(Symbol(b))(order+3), Nx)
MB = MultiBasis(eval(Symbol(b))(order+3), Nx)
else
error("Undefined basis")
end
Expand All @@ -24,7 +24,7 @@ function totalordermapcomponent(Nx::Int64, order::Int64; withconstant::Bool = fa

= size(idx, 1)

f = ExpandedFunction(B, idx, zeros(Nψ))
f = ExpandedFunction(MB, idx, zeros(Nψ))
return HermiteMapComponent(IntegratedFunction(f))
end

Expand Down

0 comments on commit dd5cc62

Please sign in to comment.