Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
Use JuliaFormatter
  • Loading branch information
max-de-rooij committed Jul 17, 2024
1 parent dfca2bf commit b45b446
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
20 changes: 12 additions & 8 deletions src/ks_rank_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ struct KSRank <: GSAMethod
acceptance_threshold::Union{Function, Real}
end

KSRank(;n_dummy_parameters = 50, acceptance_threshold=mean) = KSRank(n_dummy_parameters, acceptance_threshold)
function KSRank(; n_dummy_parameters = 50, acceptance_threshold = mean)
KSRank(n_dummy_parameters, acceptance_threshold)
end

struct KSRankResult{T}
S::AbstractVector{T}
Sd::Tuple{T, T}
Sd::Tuple{T, T}
end

function ks_rank_sensitivity(Xi, flag)
Expand All @@ -79,26 +81,28 @@ function _compute_ksrank(X::AbstractArray, Y::AbstractArray, method::KSRank)
K = size(X, 1)
#samples = size(X, 2)
sensitivities = zeros(K)

if method.acceptance_threshold isa Function
acceptance_threshold = method.acceptance_threshold(Y)
else
acceptance_threshold = method.acceptance_threshold
end
end
flag = Int.(Y .> acceptance_threshold)

# Cumulative distributions (for model parameters and dummies)
@inbounds for i = 1:K
@inbounds for i in 1:K
Xi = @view X[i, :]

# calculate KS score
sensitivities[i] = ks_rank_sensitivity(Xi, flag)
end

# collect dummy sensitivities (mean and std)
dummy_sensitivities = (mean(sensitivities[K-method.n_dummy_parameters+1:end]), std(sensitivities[K-method.n_dummy_parameters+1+1:end]))
dummy_sensitivities = (mean(sensitivities[(K - method.n_dummy_parameters + 1):end]),
std(sensitivities[(K - method.n_dummy_parameters + 1 + 1):end]))

return KSRankResult(sensitivities[1:K-method.n_dummy_parameters], dummy_sensitivities)
return KSRankResult(
sensitivities[1:(K - method.n_dummy_parameters)], dummy_sensitivities)
end

function gsa(f, method::KSRank, p_range; samples, batch = false)
Expand Down
16 changes: 8 additions & 8 deletions test/ks_rank_method.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using GlobalSensitivity, Test, QuasiMonteCarlo


function ishi_batch(X)
A = 7
B = 0.1
Expand All @@ -26,16 +25,17 @@ end
lb = -ones(4) * π
ub = ones(4) * π

res1 = gsa(ishi, KSRank(n_dummy_parameters=50), [[lb[i], ub[i]] for i in 1:4], samples = 100_000)
res2 = gsa(ishi_batch, KSRank(), [[lb[i], ub[i]] for i in 1:4], samples = 100_000, batch = true)
res1 = gsa(
ishi, KSRank(n_dummy_parameters = 50), [[lb[i], ub[i]] for i in 1:4], samples = 100_000)
res2 = gsa(
ishi_batch, KSRank(), [[lb[i], ub[i]] for i in 1:4], samples = 100_000, batch = true)

@test (4*res1.Sd[1] .> res1.S) == [0,0,1,1]
@test (4*res2.Sd[1] .> res2.S) == [0,0,1,1]
@test (4 * res1.Sd[1] .> res1.S) == [0, 0, 1, 1]
@test (4 * res2.Sd[1] .> res2.S) == [0, 0, 1, 1]

res1 = gsa(linear, KSRank(), [[lb[i], ub[i]] for i in 1:4], samples = 100_000)
res2 = gsa(linear_batch, KSRank(), [[lb[i], ub[i]] for i in 1:4], batch = true,
samples = 100_000)

@test (4*res1.Sd[1] .> res1.S) == [0,1,1,1]
@test (4*res2.Sd[1] .> res2.S) == [0,1,1,1]

@test (4 * res1.Sd[1] .> res1.S) == [0, 1, 1, 1]
@test (4 * res2.Sd[1] .> res2.S) == [0, 1, 1, 1]

0 comments on commit b45b446

Please sign in to comment.