From f9fe841d5119cd46ad0c6392f15d5cb28529cc6b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 17 May 2024 09:46:20 -0400 Subject: [PATCH] Add missing constructor, fix typo --- src/specialitensornetworks.jl | 18 +++++++++++++++--- test/test_itensornetwork.jl | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/specialitensornetworks.jl b/src/specialitensornetworks.jl index 92cd434b..f0c4f26f 100644 --- a/src/specialitensornetworks.jl +++ b/src/specialitensornetworks.jl @@ -70,14 +70,26 @@ Build an ITensor network on a graph specified by the inds network s. Bond_dim is given by link_space and entries are randomized. The random distribution is based on the input argument `distribution`. """ -function random_tensornetwork(distribution::Distribution, s::IndsNetwork; kwargs...) +function random_tensornetwork( + rng::AbstractRNG, distribution::Distribution, s::IndsNetwork; kwargs... +) return ITensorNetwork(s; kwargs...) do v - return inds -> itensor(rand(distribution, dim.(inds)...), inds) + return inds -> itensor(rand(rng, distribution, dim.(inds)...), inds) end end +function random_tensornetwork(distribution::Distribution, s::IndsNetwork; kwargs...) + return random_tensornetwork(Random.default_rng(), distribution, s; kwargs...) +end + +@traitfn function random_tensornetwork( + rng::AbstractRNG, distribution::Distribution, g::::IsUnderlyingGraph; kwargs... +) + return random_tensornetwork(rng, distribution, IndsNetwork(g); kwargs...) +end + @traitfn function random_tensornetwork( distribution::Distribution, g::::IsUnderlyingGraph; kwargs... ) - return random_tensornetwork(distribution, IndsNetwork(g); kwargs...) + return random_tensornetwork(Random.default_rng(), distribution, g; kwargs...) end diff --git a/test/test_itensornetwork.jl b/test/test_itensornetwork.jl index 6fda876d..7e97c6a3 100644 --- a/test/test_itensornetwork.jl +++ b/test/test_itensornetwork.jl @@ -349,7 +349,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) dims = (2, 2) g = named_grid(dims) s = siteinds("S=1/2", g) - rng = StableRN(1234) + rng = StableRNG(1234) ψ = random_tensornetwork(rng, s; link_space=2) @test scalartype(ψ) == Float64 ϕ = NDTensors.convert_scalartype(new_eltype, ψ)