Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output truncation error from truncate! #99

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

NuclearPowerNerd
Copy link

The purpose of this commit is to allow the user to access the truncation error that is internally calculated in a call to truncate!. This PR implements #96.

I have implemented this by allowing the user to pass a Ref object to the call to truncate!. The result of each SVD performed during a call to truncate! is then accumulated in that Ref.

I also added a new test for this functionality. I then re-ran all tests. There were 75165 passing and 33 broken for a total of 75198 tests. I also updated the docstring for truncate! to reflect this new keyword argument.

Here is an example of the new functionality.

A few differences with what was suggested in #96:

  • I accumulate the error rather than storing the truncation error for each bond individually.
  • I did not use the ! in my keyword argument. Was that just for stylistic reasons or is there some other convention for doing that (I'm aware of the convention in function names but not in variable names)?
truncation_error = Ref{Float64}()
truncation_error[] = 0.0
truncate!(someMPS, maxdim=4, cutoff=1E-5, truncation_error=truncation_error)
truncation_error[]

@NuclearPowerNerd
Copy link
Author

One thing I noticed while playing around with this new feature is that sometimes the truncation error exceeds the cutoff. But I thought it should be that truncation_error[] <= cutoff.

The following MWE illustrates this. I'd be curious what you think. If you want I can open a separate issue about this if/once this has been merged.

d = 2
N = 7
n = d^N
sinds = siteinds(d, N)
x = collect(LinRange(0, 1, n))
y = @. 2 * x + 3 + sin(2π * x)
χ = 8

truncation_error = Ref{Float64}()

cutoff_ = 1E-4
Y = MPS(y, sinds, maxdim=χ, cutoff=0.0)
truncation_error[] = 0.0
truncate!(Y, maxdim=χ, cutoff=cutoff_, truncation_error=truncation_error)

truncation_error[]
truncation_error[] > cutoff_  # returns true!

cutoff_ = 1E-5
Y = MPS(y, sinds, maxdim=χ, cutoff=0.0)
truncation_error[] = 0.0
truncate!(Y, maxdim=χ, cutoff=cutoff_, truncation_error=truncation_error)

truncation_error[]
truncation_error[] > cutoff_ # returns false, as expected

@mtfishman
Copy link
Member

The cutoff refers to the truncation of each SVD performed, if multiple SVDs are performed the total error could add up to a value larger than the cutoff. It could make sense to store a truncation error for each bond of the MPS. That's one reason why I'm a bit hesitant about this PR, since I'd prefer to think about this more generally in terms of what kinds of other information we might want to output and how it should get output.

@NuclearPowerNerd
Copy link
Author

The cutoff refers to the truncation of each SVD performed, if multiple SVDs are performed the total error could add up to a value larger than the cutoff.

Ah, duh. Yeah so I will push an update here shortly and make it store per bond truncation error which is the more sensible thing to do. Thanks for the feedback.

This commit refactors the original implemenation. `truncate!` now
expects the user to pass a pointer to a vector of floats with as
many elements as there are bonds in the MPS. It will then store the
truncation error of each bond in the vector.

The corresponding test was updated. The package was re-tested
with the same results described in the PR (75165 passing and
33 broken for total of 75198 tests).

The docstring of `truncate!` was updated to reflect the new behavior.
@NuclearPowerNerd
Copy link
Author

Any other comments or changes you'd like me to make? Or to discuss? I don't mind if a different system is opted for but I do think it is useful to be able to inspect the truncation error.

NuclearPowerNerd and others added 2 commits January 13, 2025 14:08
Change new kwarg name to `truncation_errors!` from `truncation_errors`
Remove usage of `enumerate`
Update docstring

Co-authored-by: Matt Fishman <[email protected]>
@NuclearPowerNerd
Copy link
Author

I updated the PR. Could you please take a look and let me know what you think?

@NuclearPowerNerd
Copy link
Author

Any update here? If I need to rewrite I don't mind. Just looking for some feedback. Or if there are no plans to expose the truncation information then I can just close this PR. I still think it would be useful but I can also get it via the local branch I've modified for this PR.

@mtfishman
Copy link
Member

mtfishman commented Feb 17, 2025

Sorry for the slow response. I want to make sure we use a design that is as "future proof" as possible, by which I mean it:

  1. uses a code pattern we want to use elsewhere for similar purposes in code throughout the package, so there is some consistency and we don't have to keep reinventing ways of doing the same thing,
  2. is easily extendable, so that if someone asks us to output something new from a function, it is easy to do so,
  3. if the function design changes, what we decided to output previously isn't affected too much, and
  4. it is customizable, so that users can have the flexibility to save data however they want to.

I would say getting all of these properties is almost a "code research" question, from the perspective that we are still evolving in the way we are thinking about how to do this, and also I haven't seen many approaches I am happy with that ticks all of these boxes. So, given that broader context and design considerations, here is my latest proposal for how to design this feature:

function truncate!(
  ::Algorithm"frobenius",
  M::AbstractMPS;
  site_range=1:length(M),
  (callback!)=Returns(nothing),
  kwargs...,
)
  N = length(M)
  # Left-orthogonalize all tensors to make
  # truncations controlled
  orthogonalize!(M, last(site_range))

  # Perform truncations in a right-to-left sweep
  for j in reverse((first(site_range) + 1):last(site_range))
    rinds = uniqueinds(M[j], M[j - 1])
    ltags = tags(commonind(M[j], M[j - 1]))
    U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...)
    M[j] = U
    M[j - 1] *= (S * V)
    setrightlim!(M, j)
    callback!(; link=(j => j - 1), truncation_error=spec.truncerr)
  end
  return M
end

Then, a user can save the truncation error from each link/bond like this:

using ITensorMPS: maxlinkdim, random_mps, siteinds, truncate!
s = siteinds("S=1/2", 10)
ψ = random_mps(s; linkdims=6)
@show maxlinkdim(ψ)

truncation_errors = Dict{Pair{Int,Int},Float64}()
function callback!(; link, truncation_error, kwargs...)
  truncation_errors[link] = truncation_error
end
ψ′ = truncate!(copy(ψ); maxdim=3, callback!)
@show maxlinkdim(ψ′)

which when run outputs something like this:

julia> truncation_errors
Dict{Pair{Int64, Int64}, Float64} with 9 entries:
  5=>4  => 0.104644
  8=>7  => 0.122511
  2=>1  => 0.0
  9=>8  => 0.153664
  10=>9 => 0.0
  7=>6  => 0.096218
  6=>5  => 0.0617276
  4=>3  => 0.0620809
  3=>2  => 0.0378573

You can see this design would allow for a lot of customizability, since in the future we could decide to pass more data from within truncate! into the callback! function (for example, we could pass the singular values, if someone wanted those), and users have the freedom to save what they want in whatever format they want. Also this same code pattern is easily applicable to other functions in the library where we want to allow users to access data from within the function, or even do other things like print something from within the function.

@NuclearPowerNerd
Copy link
Author

That makes good sense and I agree those points are all important and balancing them is basically a research question as you say. I like this pattern though and I will update the PR later today to reflect this (including updated tests). Thank you!

@NuclearPowerNerd
Copy link
Author

I've updated the PR. Please let me know if you have any comments or questions.

@mtfishman mtfishman changed the title implement truncation_error keyword arg for truncate! Output truncation error from truncate! Feb 21, 2025
Copy link

codecov bot commented Feb 21, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (main@d7f8ca3). Learn more about missing BASE report.

Additional details and impacted files
@@           Coverage Diff           @@
##             main      #99   +/-   ##
=======================================
  Coverage        ?   90.26%           
=======================================
  Files           ?       54           
  Lines           ?     3575           
  Branches        ?        0           
=======================================
  Hits            ?     3227           
  Misses          ?      348           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants