Skip to content

Commit

Permalink
add onGPU
Browse files Browse the repository at this point in the history
  • Loading branch information
annamariadziubyna committed Oct 8, 2024
1 parent 35a6b77 commit c481658
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/bp_2site.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ end
E = get_prop(new_potts_h1, src(e), dst(e), :en)
# @cast E[(l1, l2), (r1, r2)] :=
# E.e11[l1, r1] + E.e21[l2, r1] + E.e12[l1, r2] + E.e22[l2, r2]
a11 = reshape(CuArray(E.e11), size(E.e11, 1), :, size(E.e11, 2))
a21 = reshape(CuArray(E.e21), :, size(E.e21, 1), size(E.e21, 2))
a12 = reshape(CuArray(E.e12), size(E.e12, 1), 1, 1, size(E.e12, 2))
a22 = reshape(CuArray(E.e22), 1, size(E.e22, 1), 1, size(E.e22, 2))
a11 = reshape(onGPU ? CuArray(E.e11) : E.e11, size(E.e11, 1), :, size(E.e11, 2))
a21 = reshape(onGPU ? CuArray(E.e21) : E.e21, :, size(E.e21, 1), size(E.e21, 2))
a12 = reshape(onGPU ? CuArray(E.e12) : E.e12, size(E.e12, 1), 1, 1, size(E.e12, 2))
a22 = reshape(onGPU ? CuArray(E.e22) : E.e22, 1, size(E.e22, 1), 1, size(E.e22, 2))
E = @__dot__(a11 + a21 + a12 + a22)
E = reshape(E, size(E, 1) * size(E, 2), size(E, 3) * size(E, 4))
@test Array(E) == get_prop(potts_h1, src(e), dst(e), :en)
Expand Down

0 comments on commit c481658

Please sign in to comment.