Skip to content

Commit

Permalink
Made sample type agnostic and CUDA compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
tipfom authored Nov 13, 2024
1 parent 8896bb1 commit 58cfe83
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,8 @@ function sample(rng::AbstractRNG, m::MPS)
error("sample: MPS is not normalized, norm=$(norm(m[1]))")
end

ElT = promote_itensor_eltype(m)

result = zeros(Int, N)
A = m[1]

Expand All @@ -664,16 +666,16 @@ function sample(rng::AbstractRNG, m::MPS)
# Compute the probability of each state
# one-by-one and stop when the random
# number r is below the total prob so far
pdisc = 0.0
pdisc = zero(real(ElT))
r = rand(rng)
# Will need n,An, and pn below
n = 1
An = ITensor()
pn = 0.0
pn = zero(real(ElT))
while n <= d
projn = ITensor(s)
projn[s => n] = 1.0
An = A * dag(projn)
projn[s => n] = one(ElT)
An = A * dag(adapt(datatype(A), projn))
pn = real(scalar(dag(An) * An))
pdisc += pn
(r < pdisc) && break
Expand All @@ -682,7 +684,7 @@ function sample(rng::AbstractRNG, m::MPS)
result[j] = n
if j < N
A = m[j + 1] * An
A *= (1.0 / sqrt(pn))
A *= (one(ElT) / sqrt(pn))
end
end
return result
Expand Down

0 comments on commit 58cfe83

Please sign in to comment.