Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsmierz committed Nov 12, 2024
1 parent 2ce3b13 commit a601932
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ const MatOrCentral{T,N} = Union{AbstractMatrix{T},CentralTensor{T,N}}
function dense_central(ten::CentralTensor)
# @cast V[(u1, u2), (d1, d2)] :=
# ten.e11[u1, d1] * ten.e21[u2, d1] * ten.e12[u1, d2] * ten.e22[u2, d2]
a11 = reshape(CuArray(ten.e11), size(ten.e11, 1), :, size(ten.e11, 2))
a21 = reshape(CuArray(ten.e21), :, size(ten.e21, 1), size(ten.e21, 2))
a12 = reshape(CuArray(ten.e12), size(ten.e12, 1), 1, 1, size(ten.e12, 2))
a22 = reshape(CuArray(ten.e22), 1, size(ten.e22, 1), 1, size(ten.e22, 2))
a11 = reshape(ten.e11, size(ten.e11, 1), :, size(ten.e11, 2))
a21 = reshape(ten.e21, :, size(ten.e21, 1), size(ten.e21, 2))
a12 = reshape(ten.e12, size(ten.e12, 1), 1, 1, size(ten.e12, 2))
a22 = reshape(ten.e22, 1, size(ten.e22, 1), 1, size(ten.e22, 2))
V = @__dot__(a11 * a21 * a12 * a22)
V = reshape(V, size(V, 1) * size(V, 2), size(V, 3) * size(V, 4))
V ./ maximum(V)
Expand Down
3 changes: 2 additions & 1 deletion src/contractions/central.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ function batched_mul!(
LE::Tensor{R,3},
M::AbstractArray{R,2},
) where {R<:Real}
onGPU = typeof(newLE) <: CuArray ? true : false
N1, N2 = size(M)
new_M = CUDA.CuArray(M) # TODO: this is a hack to solve problem with types;
new_M = ArrayorCuArray(M, onGPU) # TODO: this is a hack to solve problem with types;
new_M = reshape(new_M, (N1, N2, 1))
NNlib.batched_mul!(newLE, LE, new_M)
end
Expand Down

0 comments on commit a601932

Please sign in to comment.