Skip to content

Commit

Permalink
Add missing constructor, fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed May 17, 2024
1 parent 3406bc5 commit f9fe841
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
18 changes: 15 additions & 3 deletions src/specialitensornetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,26 @@ Build an ITensor network on a graph specified by the inds network s.
Bond_dim is given by link_space and entries are randomized.
The random distribution is based on the input argument `distribution`.
"""
function random_tensornetwork(distribution::Distribution, s::IndsNetwork; kwargs...)
function random_tensornetwork(
rng::AbstractRNG, distribution::Distribution, s::IndsNetwork; kwargs...
)
return ITensorNetwork(s; kwargs...) do v
return inds -> itensor(rand(distribution, dim.(inds)...), inds)
return inds -> itensor(rand(rng, distribution, dim.(inds)...), inds)
end
end

function random_tensornetwork(distribution::Distribution, s::IndsNetwork; kwargs...)
return random_tensornetwork(Random.default_rng(), distribution, s; kwargs...)
end

@traitfn function random_tensornetwork(
rng::AbstractRNG, distribution::Distribution, g::::IsUnderlyingGraph; kwargs...
)
return random_tensornetwork(rng, distribution, IndsNetwork(g); kwargs...)
end

@traitfn function random_tensornetwork(
distribution::Distribution, g::::IsUnderlyingGraph; kwargs...
)
return random_tensornetwork(distribution, IndsNetwork(g); kwargs...)
return random_tensornetwork(Random.default_rng(), distribution, g; kwargs...)
end
2 changes: 1 addition & 1 deletion test/test_itensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
dims = (2, 2)
g = named_grid(dims)
s = siteinds("S=1/2", g)
rng = StableRN(1234)
rng = StableRNG(1234)
ψ = random_tensornetwork(rng, s; link_space=2)
@test scalartype(ψ) == Float64
ϕ = NDTensors.convert_scalartype(new_eltype, ψ)
Expand Down

0 comments on commit f9fe841

Please sign in to comment.