From c481658fcf467799684cc436205b67deeb52e647 Mon Sep 17 00:00:00 2001 From: annamariadziubyna Date: Tue, 8 Oct 2024 09:36:48 +0200 Subject: [PATCH] add onGPU --- test/bp_2site.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/bp_2site.jl b/test/bp_2site.jl index d0111b1..4cb8786 100644 --- a/test/bp_2site.jl +++ b/test/bp_2site.jl @@ -170,10 +170,10 @@ end E = get_prop(new_potts_h1, src(e), dst(e), :en) # @cast E[(l1, l2), (r1, r2)] := # E.e11[l1, r1] + E.e21[l2, r1] + E.e12[l1, r2] + E.e22[l2, r2] - a11 = reshape(CuArray(E.e11), size(E.e11, 1), :, size(E.e11, 2)) - a21 = reshape(CuArray(E.e21), :, size(E.e21, 1), size(E.e21, 2)) - a12 = reshape(CuArray(E.e12), size(E.e12, 1), 1, 1, size(E.e12, 2)) - a22 = reshape(CuArray(E.e22), 1, size(E.e22, 1), 1, size(E.e22, 2)) + a11 = reshape(onGPU ? CuArray(E.e11) : E.e11, size(E.e11, 1), :, size(E.e11, 2)) + a21 = reshape(onGPU ? CuArray(E.e21) : E.e21, :, size(E.e21, 1), size(E.e21, 2)) + a12 = reshape(onGPU ? CuArray(E.e12) : E.e12, size(E.e12, 1), 1, 1, size(E.e12, 2)) + a22 = reshape(onGPU ? CuArray(E.e22) : E.e22, 1, size(E.e22, 1), 1, size(E.e22, 2)) E = @__dot__(a11 + a21 + a12 + a22) E = reshape(E, size(E, 1) * size(E, 2), size(E, 3) * size(E, 4)) @test Array(E) == get_prop(potts_h1, src(e), dst(e), :en)