Skip to content

Commit

Permalink
fallback to vanilla rejection method
Browse files Browse the repository at this point in the history
  • Loading branch information
foldfelis committed Jul 29, 2021
1 parent 8d937b3 commit 86c988e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 95 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "QuantumStateBase"
uuid = "73ce9c4f-35d1-4161-b9e6-26915895bfed"
authors = ["JingYu Ning <[email protected]> and contributors"]
version = "1.0.0"
version = "1.0.1"

[deps]
ClassicalOrthogonalPolynomials = "b30e2e7b-c4ee-47da-9d5f-2c5c27239acd"
Expand All @@ -12,8 +12,8 @@ Mmap = "a63ad114-7e13-5084-954f-fe012c677804"

[compat]
ClassicalOrthogonalPolynomials = "0.4"
KernelDensity = "0.6"
DataDeps = "0.7"
KernelDensity = "0.6"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion docs/src/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ julia> heatmap(θs, xs, ps')
## Quantum state sampler

Here, we can sample points from quadrature probability density function of the quantum state.
The sampler is implemented by special adaptive rejection method.
The sampler is implemented by rejection method.

```julia
julia> points = rand(state, 4096);
Expand Down
90 changes: 19 additions & 71 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,101 +64,49 @@ function ranged_rand(range::Tuple{T, T}) where {T <: Number}
return range[1] + (range[2]-range[1]) * rand(T)
end

function reject!(new_points, θ_range, gen_point, p, g, c)
Threads.@threads for i in 1:size(new_points, 2)
new_points[:, i] .= gen_point()
while (
!(θ_range[1] new_points[1, i] θ_range[2]) ||
p(Threads.threadid(), new_points[:, i]...) / g(new_points[:, i]...) < c
)
new_points[:, i] .= gen_point()
end
end

return new_points
end

"""
state_sampler(
state::AbstractState, n::Integer;
warm_up_n=128, batch_size=64, c=0.9, θ_range=(0., 2π), x_range=(-10., 10.),
show_log=false
c=1, θ_range=(0., 2π), x_range=(-10., 10.),
)
Random points sampled from quadrature probability density function of Gaussian `state`.
* `state`: Quantum state.
* `n`: N points.
* `warm_up_n`: N points sampled from uniform random and accepted by rejection method.
* `batch_size`: Adapt `g` for every `batch_size` points.
* `θ_range`: Sampling range of θ.
* `x_range`: Sampling range of x.
"""
function state_sampler(
state::AbstractState, n::Integer;
warm_up_n=128, batch_size=64, c=0.9, θ_range=(0., 2π), x_range=(-10., 10.),
show_log=false
c=1, θ_range=(0., 2π), x_range=(-10., 10.),
)
sampled_points = Matrix{Float64}(undef, 2, n)
𝛑̂_res_vec = [Matrix{complex(Float64)}(undef, state.dim, state.dim) for _ in 1:Threads.nthreads()]
𝛑̂_res_vec = [
Matrix{complex(Float64)}(undef, state.dim, state.dim)
for _ in 1:Threads.nthreads()
]

return state_sampler!(
sampled_points, 𝛑̂_res_vec,
state,
warm_up_n, batch_size, c, θ_range, x_range,
show_log
)
return state_sampler!(sampled_points, 𝛑̂_res_vec, state, c, θ_range, x_range)
end

# rejection method: \frac{f(x)}{c * g(x)} ≥ u
# here, say g(x) is a uniform distribution, and c = 1
function state_sampler!(
sampled_points::AbstractMatrix{T}, 𝛑̂_res_vec::Vector{Matrix{Complex{T}}},
state::StateMatrix,
warm_up_n::Integer, batch_size::Integer, c::Real, θ_range, x_range,
show_log::Bool
state::StateMatrix, c::Real, θ_range, x_range
) where {T}
n = size(sampled_points, 2)
p = (thread_id, θ, x) -> q_pdf!(𝛑̂_res_vec[thread_id], state, θ, x)

show_log && @info "Warm up"
warm_up_n = n < warm_up_n ? n : warm_up_n
warm_up_points = view(sampled_points, :, 1:warm_up_n)
gen_rand_point = () -> [ranged_rand(θ_range), ranged_rand(x_range)]
kde_result = kde((ranged_rand(n, θ_range), ranged_rand(n, x_range)))
g = (θ, x) -> pdf(kde_result, θ, x)
reject!(warm_up_points, θ_range, gen_rand_point, p, g, c)

show_log && @info "Start to generate data"
batch = div(n-warm_up_n, batch_size)
for b in 1:batch
ref_range = 1:(warm_up_n+(b-1)*batch_size)
ref_points = view(sampled_points, :, ref_range)
new_range = (warm_up_n+(b-1)*batch_size+1):(warm_up_n+b*batch_size)
new_points = view(sampled_points, :, new_range)

h = KernelDensity.default_bandwidth((ref_points[1, :], ref_points[2, :]))
gen_point_from_g = () -> ref_points[:, rand(ref_range)] + randn(2)./h
kde_result = kde((ref_points[1, :], ref_points[2, :]), bandwidth=h)
g = (θ, x) -> pdf(kde_result, θ, x)
reject!(new_points, θ_range, gen_point_from_g, p, g, c)

show_log && @info "progress: $b/$(batch+1)"
end
rem_n = rem(n-warm_up_n, batch_size)
if rem_n > 0
ref_range = 1:(n-rem_n)
ref_points = view(sampled_points, :, ref_range)
new_range = (n-rem_n+1):n
new_points = view(sampled_points, :, new_range)

h = KernelDensity.default_bandwidth((ref_points[1, :], ref_points[2, :]))
gen_point_from_g = () -> ref_points[:, rand(ref_range)] + randn(2)./h
kde_result = kde((ref_points[1, :], ref_points[2, :]), bandwidth=h)
g = (θ, x) -> pdf(kde_result, θ, x)
reject!(new_points, θ_range, gen_point_from_g, p, g, c)
end
show_log && @info "progress: $(batch+1)/$(batch+1)"
sampled_points[1, :] .= sort!(ranged_rand(n, θ_range))

Threads.@threads for i in 1:n
p = (thread_id, x) -> q_pdf!(𝛑̂_res_vec[thread_id], state, sampled_points[1, i], x)

sampled_points .= sampled_points[:, sortperm(sampled_points[1, :])]
sampled_points[2, i] = ranged_rand(x_range)
while (p(Threads.threadid(), sampled_points[2, i]) < c*rand())
sampled_points[2, i] = ranged_rand(x_range)
end
end

return sampled_points
end
Expand Down
26 changes: 5 additions & 21 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,10 @@ end
single_point_pdf = (θ, x) -> q_pdf(state, θ, x)
@test single_point_pdf.(θs, xs') ground_truth_pdf

n = 4096
data = state_sampler(state, n)
n = 10000
@time data = state_sampler(state, n)
sampled_pdf = pdf(kde((data[1, :], data[2, :])), θs, xs)
@test sum(abs.(sampled_pdf .- ground_truth_pdf)) / n < 5e-2

n = 4096-1
data = state_sampler(state, n)
sampled_pdf = pdf(kde((data[1, :], data[2, :])), θs, xs)
@test sum(abs.(sampled_pdf .- ground_truth_pdf)) / n < 5e-2

n = 4100
data = state_sampler(state, n, warm_up_n=100, batch_size=97)
sampled_pdf = pdf(kde((data[1, :], data[2, :])), θs, xs)
@test sum(abs.(sampled_pdf .- ground_truth_pdf)) / n < 5e-2
@test sum(abs.(sampled_pdf .- ground_truth_pdf)) / n < 5e-4
end

@testset "wrapping" begin
Expand All @@ -76,14 +66,8 @@ end

state = SinglePhotonState()
@test size(rand(state)) == (2, 1)
@test size(rand(state, 4096, show_log=false)) == (2, 4096)
@test size(rand(state, 4100, warm_up_n=97, show_log=false)) == (2, 4100)
@test size(rand(state, 4100, warm_up_n=97, batch_size=100, show_log=false)) == (2, 4100)
@test size(rand(state, 100, warm_up_n=200, batch_size=100, show_log=false)) == (2, 100)
@test size(rand(state, 4096)) == (2, 4096)
state = SinglePhotonState(rep=StateMatrix)
@test size(rand(state)) == (2, 1)
@test size(rand(state, 4096, show_log=false)) == (2, 4096)
@test size(rand(state, 4100, warm_up_n=97, show_log=false)) == (2, 4100)
@test size(rand(state, 4100, warm_up_n=97, batch_size=100, show_log=false)) == (2, 4100)
@test size(rand(state, 100, warm_up_n=200, batch_size=100, show_log=false)) == (2, 100)
@test size(rand(state, 4096)) == (2, 4096)
end

2 comments on commit 86c988e

@foldfelis
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/41743

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.0.1 -m "<description of version>" 86c988e87ad30e37f62a535301f72955e1185d78
git push origin v1.0.1

Please sign in to comment.