Skip to content

Commit

Permalink
Adding code from Dimitris
Browse files Browse the repository at this point in the history
  • Loading branch information
mleprovost committed Feb 22, 2024
1 parent a8aca8e commit 6d1ba3b
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 11 deletions.
55 changes: 49 additions & 6 deletions src/hermitemap/expandedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export ExpandedFunction,
getbasis,
evaluate_basis!,
evaluate_basis,
repeated_evaluate_basis!,
repeated_evaluate_basis,
grad_xk_basis!,
grad_xk_basis,
Expand Down Expand Up @@ -266,21 +267,63 @@ evaluate_basis(f::ExpandedFunction, X) =
evaluate_basis!(zeros(size(X,2),size(f.idx,1)), f, X, f.dim, f.idx)
# evaluate_basis!(zeros(size(X,2),size(f.idx,1)), f, X, 1:f.Nx, f.idx)

# """
# $(TYPEDSIGNATURES)

# Evaluates the basis of `ExpandedFunction` `f` for the last component
# """
# function repeated_evaluate_basis(f::ExpandedFunction, x, idx::Array{Int64,2})
# # Compute the last component
# midxj = idx[:,f.Nx]
# maxj = maximum(midxj)
# ψj = vander(f.MB.B, maxj, 0, x)
# return ψj[:, midxj .+ 1]
# end

# repeated_evaluate_basis(f::ExpandedFunction, x) = repeated_evaluate_basis(f, x, f.idx)

"""
$(TYPEDSIGNATURES)
Evaluates the basis of `ExpandedFunction` `f` for the last component
Computes in-place the gradient with respect to the last state component of the basis of the last univariate function of each feature with multi-indices `idx` at `x`.
"""
function repeated_evaluate_basis(f::ExpandedFunction, x, idx::Array{Int64,2})
# Compute the last component
midxj = idx[:,f.Nx]
function repeated_evaluate_basis!(out, cache, f::ExpandedFunction, x, idx::Array{Int64,2})
# Compute the derivative of an expanded function along the last state component.
Ne = size(x, 1)
Nx = f.Nx

# @assert size(out,1) = (N, size(idx, 1)) "Wrong dimension of the output vector"
# ∂ᵏf/∂x_{grad_dim} = ψ
k = 0
grad_dim = Nx
dims = Nx

midxj = idx[:, Nx]
maxj = maximum(midxj)
ψj = vander(f.MB.B, maxj, 0, x)
return ψj[:, midxj .+ 1]
# dkψj = zeros(Ne, maxj+1)
vander!(cache, f.MB.B, maxj, k, x)
Nψreduced = size(idx, 1)
@avx for l = 1:Nψreduced
for k=1:Ne
out[k, l] = cache[k, midxj[l] + 1]
end
end

return out#dkψj[:, midxj .+ 1]
end

repeated_evaluate_basis!(out, cache, f::ExpandedFunction, x) = repeated_evaluate_basis!(out, cache, f, x, f.idx)

"""
$(TYPEDSIGNATURES)
Computes the gradient with respect to the last state component of the basis of the last univariate function of each feature with multi-indices `idx` at `x`.
"""
repeated_evaluate_basis(f::ExpandedFunction, x, idx::Array{Int64,2}) =
repeated_evaluate_basis!(zeros(size(x,1),size(idx,1)), zeros(size(x,1), maximum(idx[:,f.Nx])+1), f, x, idx)
repeated_evaluate_basis(f::ExpandedFunction, x) = repeated_evaluate_basis(f, x, f.idx)


# function grad_xk_basis!(dkψ, f::ExpandedFunction, X::Array{Float64,2}, k::Int64, grad_dim::Union{Int64, Array{Int64,1}}, dims::Union{Int64, UnitRange{Int64}, Array{Int64,1}}, idx::Array{Int64,2})
"""
$(TYPEDSIGNATURES)
Expand Down
57 changes: 52 additions & 5 deletions src/hermitemap/rectifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,43 @@ d2softplus(x) = log(2.0)/(2.0*(1.0 + cosh(log(2.0)*x)))
invsoftplus(x) = min(log(exp(log(2.0)*x) - 1.0)/log(2.0), x)

# Logistic tools
# Sigmoid implementation from NNlib.jl to avoid underflow errors
# Sigmoid implementation from NNlib.jl to avoid underflow errors.

function sigmoid(x)
t = exp(-abs(x))
ifelse(x 0, inv(1 + t), t / (1 + t))
end

function dsigmoid(x)
σ = sigmoid(x)
return σ*(1-σ)
end

function d2sigmoid(x)
σ = sigmoid(x)
# from dσ*(1-σ) - σ*dσ
return σ*(1-σ)*(1-2*σ)
end
invsigmoid(x) = ifelse(x > 0, log(x) - log(1-x), "Not defined for x ≤ 0 ")

function sigmoid_(x, K_min, K_max)
return K_min + (K_max-K_min) * sigmoid(x)
end

function dsigmoid_(x, K_min, K_max)
σ = sigmoid(x)
return (K_max-K_min)*σ*(1-σ)
end

function d2sigmoid_(x, K_min, K_max)
σ = sigmoid(x)
return (K_max-K_min) * σ*(1-σ)*(1-2*σ)
end

invsigmoid_(x, K_min, K_max) = ifelse(x > 0, log(x) - log(1-x), "Not defined for x ≤ 0 ")
if x > K_min && x < K_max
return log(x-K_min) - log(K_max-x)
else
return "Not defined for x outside [K_min, K_max]"

explinearunit(x) = x < 0.0 ? exp(x) : x + 1.0
dexplinearunit(x) = x < 0.0 ? exp(x) : 1.0
Expand All @@ -76,6 +94,8 @@ function (g::Rectifier)(x)
return exp(x)
elseif g.T=="sigmoid"
return sigmoid(x)
elseif g.T=="sigmoid_"
return sigmoid_(x, -1.0, 1.0)
elseif g.T=="softplus"
return softplus(x)
elseif g.T=="explinearunit"
Expand All @@ -94,6 +114,9 @@ function evaluate!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(sigmoid, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(y->sigmoid_(y, -1.0, 1.0), result, x)
return result
elseif g.T=="softplus"
vmap!(softplus, result, x)
return result
Expand All @@ -103,7 +126,7 @@ function evaluate!(result, g::Rectifier, x)
end
end

vevaluate(g::Rectifier, x) = evaluate!(zero(x), g, x)
evaluate(g::Rectifier, x) = evaluate!(zero(x), g, x)

function inverse(g::Rectifier, x)
@assert x>=0 "Input to rectifier is negative"
Expand All @@ -113,6 +136,8 @@ function inverse(g::Rectifier, x)
return log(x)
elseif g.T=="sigmoid"
return invsigmoid(x)
elseif g.T=="sigmoid_"
return invsigmoid_(x, -1.0, 1.0)
elseif g.T=="softplus"
return invsoftplus(x)
elseif g.T=="explinearunit"
Expand All @@ -131,6 +156,8 @@ function inverse!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(invsigmoid, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(y->invsigmoid(y, -1.0, 1.0), result, x)
elseif g.T=="softplus"
vmap!(invsoftplus, result, x)
return result
Expand All @@ -150,6 +177,8 @@ function grad_x(g::Rectifier, x)
return exp(x)
elseif g.T=="sigmoid"
return dsigmoid(x)
elseif g.T=="sigmoid_"
return dsigmoid_(x, -1.0, 1.0)
elseif g.T=="softplus"
return dsoftplus(x)
elseif g.T=="explinearunit"
Expand All @@ -159,7 +188,7 @@ end


function grad_x!(result, g::Rectifier, x)
@assert size(result,1) == size(x,1) "Dimension of result and x don't match"
@assert size(result,1) == size(x, 1) "Dimension of result and x don't match"
if g.T=="squared"
vmap!(dsquare, result, x)
return result
Expand All @@ -169,6 +198,9 @@ function grad_x!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(dsigmoid, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(y->dsigmoid_(y, -1.0, 1.0), result, x)
return result
elseif g.T=="softplus"
vmap!(dsoftplus, result, x)
return result
Expand All @@ -187,7 +219,9 @@ function grad_x_logeval(g::Rectifier, x::T) where {T <: Real}
elseif g.T=="exponential"
return 1.0
elseif g.T=="sigmoid"
return dsigmoid(x)/sigmoid(x)
return dsigmoid(x)/sigmoid(x)
elseif g.T=="sigmoid_"
return d2sigmoid_(x, -1.0, 1.0) / sigmoid_(x, -1.0, 1.0)
elseif g.T=="softplus"
return dsoftplus(x)/softplus(x)
elseif g.T=="explinearunit"
Expand All @@ -206,6 +240,9 @@ function grad_x_logeval!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(xi->dsigmoid(xi)/sigmoid(xi), result, x)
return result
elseif g.T=="sigmoid_"
vmap!(xi->dsigmoid_(xi, -1,0, 1.0)/sigmoid_(xi, -1.0, 1.0), result, x)
return result
elseif g.T=="softplus"
vmap!(xi->dsoftplus(xi)/softplus(xi), result, x)
return result
Expand All @@ -226,6 +263,8 @@ function hess_x_logeval(g::Rectifier, x::T) where {T <: Real}
return 0.0
elseif g.T=="sigmoid"
return (d2sigmoid(x)*sigmoid(x) - dsigmoid(x)^2)/sigmoid(x)^2
elseif g.T=="sigmoid_"
return (d2sigmoid_(x, -1.0, 1.0)*sigmoid_(x, -1.0, 1.0) - dsigmoid_(x, -1.0, 1.0)^2) / (sigmoid_x, -1.0, 1.0)^2
elseif g.T=="softplus"
return (d2softplus(x)*softplus(x) - dsoftplus(x)^2)/softplus(x)^2
elseif g.T=="explinearunit"
Expand All @@ -244,6 +283,9 @@ function hess_x_logeval!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(xi->(d2sigmoid(xi)*sigmoid(xi) - dsigmoid(xi)^2)/sigmoid(xi)^2, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(xi->(d2sigmoid_(xi, -1.0, 1.0)*sigmoid_(xi, -1.0, 1.0) - dsigmoid_(xi, -1.0, 1.0)^2)/sigmoid_(xi, -1.0, 1.0)^2, result, x)
return result
elseif g.T=="softplus"
vmap!(xi->(d2softplus(xi)*softplus(xi) - dsoftplus(xi)^2)/softplus(xi)^2, result, x)
return result
Expand All @@ -262,6 +304,8 @@ function hess_x(g::Rectifier, x::T) where {T <: Real}
return exp(x)
elseif g.T=="sigmoid"
return d2sigmoid(x)
elseif g.T=="sigmoid_"
return d2sigmoid_(x, -1.0, 1.0)
elseif g.T=="softplus"
return d2softplus(x)
elseif g.T=="explinearunit"
Expand All @@ -280,6 +324,9 @@ function hess_x!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(d2softplus, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(y->d2sigmoid_(y, -1.0, 1.0), result, x)
return result
elseif g.T=="softplus"
vmap!(d2softplus, result, x)
return result
Expand Down

0 comments on commit 6d1ba3b

Please sign in to comment.