Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output truncation error from truncate! #99

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
23 changes: 19 additions & 4 deletions src/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1672,16 +1672,30 @@ provided as keyword arguments.

Keyword arguments:
* `site_range`=1:N - only truncate the MPS bonds between these sites
* `(callback!)=Returns(nothing)` - callback function that allows the user to save the per-bond truncation error. The API of `callback!` expects to take two kwargs called `link` and `truncation_error` where `link` is of type `Pair{Int64, Int64}` and `truncation_error` is `Float64`. Consider the following example that illustrates one possible use case.

```julia
nbonds = 9
truncation_errors = zeros(nbonds)
function callback!(; link, truncation_error)
bond_no = last(link)
truncation_errors[bond_no] = truncation_error
return nothing
end
truncate!(ψ; maxdim=5, cutoff=1E-7, callback!)
```
"""
function truncate!(M::AbstractMPS; alg="frobenius", kwargs...)
return truncate!(Algorithm(alg), M; kwargs...)
end

function truncate!(
::Algorithm"frobenius", M::AbstractMPS; site_range=1:length(M), kwargs...
::Algorithm"frobenius",
M::AbstractMPS;
site_range=1:length(M),
(callback!)=Returns(nothing),
kwargs...,
)
N = length(M)

# Left-orthogonalize all tensors to make
# truncations controlled
orthogonalize!(M, last(site_range))
Expand All @@ -1690,10 +1704,11 @@ function truncate!(
for j in reverse((first(site_range) + 1):last(site_range))
rinds = uniqueinds(M[j], M[j - 1])
ltags = tags(commonind(M[j], M[j - 1]))
U, S, V = svd(M[j], rinds; lefttags=ltags, kwargs...)
U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...)
M[j] = U
M[j - 1] *= (S * V)
setrightlim!(M, j)
callback!(; link=(j => j - 1), truncation_error=spec.truncerr)
end
return M
end
Expand Down
15 changes: 15 additions & 0 deletions test/base/test_mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,21 @@ end
truncate!(M; site_range=3:7, maxdim=2)
@test linkdims(M) == [2, 4, 2, 2, 2, 2, 8, 4, 2]
end

@testset "truncate! with callback!" begin
nsites = 10
nbonds = nsites - 1
s = siteinds("S=1/2", nsites)
mps_ = random_mps(s; linkdims=10)
truncation_errors = ones(nbonds) * -1.0
function _callback!(; link, truncation_error)
bond_no = last(link)
truncation_errors[bond_no] = truncation_error
return nothing
end
truncate!(mps_; maxdim=3, cutoff=1E-3, (callback!)=_callback!)
@test all(truncation_errors .>= 0.0)
end
end

@testset "Other MPS methods" begin
Expand Down
Loading