From c68fd0f75def6c3ab64b00a5fb66d2e380c4a9d4 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 5 Jul 2024 01:24:24 +0530 Subject: [PATCH] Update transform.jl --- src/GNNGraphs/transform.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 0b66a29d3..6c26961a9 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -458,8 +458,8 @@ function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio; seed::Int = Random.d num_nodes = g.num_nodes @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" - snew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) - tnew = ceil.(Int, rand(Float32, num_edges_to_add) .* num_nodes) + snew = ceil.(eltype(s), rand_like(rng, s, Float32, n) .* num_nodes) + tnew = ceil.(eltype(s), rand_like(rng, s, Float32, n) .* num_nodes) mask_loops = snew .!= tnew snew = snew[mask_loops]