diff --git a/src/mps.jl b/src/mps.jl index 567319d..32a16f2 100644 --- a/src/mps.jl +++ b/src/mps.jl @@ -655,6 +655,8 @@ function sample(rng::AbstractRNG, m::MPS) error("sample: MPS is not normalized, norm=$(norm(m[1]))") end + ElT = promote_itensor_eltype(m) + result = zeros(Int, N) A = m[1] @@ -664,16 +666,16 @@ function sample(rng::AbstractRNG, m::MPS) # Compute the probability of each state # one-by-one and stop when the random # number r is below the total prob so far - pdisc = 0.0 + pdisc = zero(real(ElT)) r = rand(rng) # Will need n,An, and pn below n = 1 An = ITensor() - pn = 0.0 + pn = zero(real(ElT)) while n <= d projn = ITensor(s) - projn[s => n] = 1.0 - An = A * dag(projn) + projn[s => n] = one(ElT) + An = A * dag(adapt(datatype(A), projn)) pn = real(scalar(dag(An) * An)) pdisc += pn (r < pdisc) && break @@ -682,7 +684,7 @@ function sample(rng::AbstractRNG, m::MPS) result[j] = n if j < N A = m[j + 1] * An - A *= (1.0 / sqrt(pn)) + A *= (one(ElT) / sqrt(pn)) end end return result