Skip to content

Commit

Permalink
broken
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 30, 2024
1 parent 503b800 commit 4b7892e
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions GNNlib/test/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ end
dev = gpu_device(force=true)
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
f(g, x) = propagate(copy_xj, g, +, xj = x)
test_gradients(f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false)
@test test_gradients(
f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false
) broken=broken
end
end

Expand All @@ -180,7 +182,9 @@ end
dev = gpu_device(force=true)
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
f(g, x) = propagate(copy_xj, g, mean, xj = x)
test_gradients(f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false)
@test test_gradients(
f, g, g.x; test_gpu=true, test_grad_f=false, compare_finite_diff=false
) broken=broken
end
end

Expand All @@ -190,7 +194,9 @@ end
broken = get_graph_type(g) == :sparse && dev isa AMDGPUDevice
e = rand(Float32, size(g.x, 1), g.num_edges)
f(g, x, e) = propagate(e_mul_xj, g, +; xj = x, e)
test_gradients(f, g, g.x, e; test_gpu=true, test_grad_f=false, compare_finite_diff=false)
@test test_gradients(
f, g, g.x, e; test_gpu=true, test_grad_f=false, compare_finite_diff=false
) broken=broken
end
end

Expand Down

0 comments on commit 4b7892e

Please sign in to comment.