diff --git a/src/base.jl b/src/base.jl index 27c2629..1cd402c 100644 --- a/src/base.jl +++ b/src/base.jl @@ -68,10 +68,10 @@ const MatOrCentral{T,N} = Union{AbstractMatrix{T},CentralTensor{T,N}} function dense_central(ten::CentralTensor) # @cast V[(u1, u2), (d1, d2)] := # ten.e11[u1, d1] * ten.e21[u2, d1] * ten.e12[u1, d2] * ten.e22[u2, d2] - a11 = reshape(CuArray(ten.e11), size(ten.e11, 1), :, size(ten.e11, 2)) - a21 = reshape(CuArray(ten.e21), :, size(ten.e21, 1), size(ten.e21, 2)) - a12 = reshape(CuArray(ten.e12), size(ten.e12, 1), 1, 1, size(ten.e12, 2)) - a22 = reshape(CuArray(ten.e22), 1, size(ten.e22, 1), 1, size(ten.e22, 2)) + a11 = reshape(ten.e11, size(ten.e11, 1), :, size(ten.e11, 2)) + a21 = reshape(ten.e21, :, size(ten.e21, 1), size(ten.e21, 2)) + a12 = reshape(ten.e12, size(ten.e12, 1), 1, 1, size(ten.e12, 2)) + a22 = reshape(ten.e22, 1, size(ten.e22, 1), 1, size(ten.e22, 2)) V = @__dot__(a11 * a21 * a12 * a22) V = reshape(V, size(V, 1) * size(V, 2), size(V, 3) * size(V, 4)) V ./ maximum(V) diff --git a/src/contractions/central.jl b/src/contractions/central.jl index a8e87df..a3ed82d 100644 --- a/src/contractions/central.jl +++ b/src/contractions/central.jl @@ -69,8 +69,9 @@ function batched_mul!( LE::Tensor{R,3}, M::AbstractArray{R,2}, ) where {R<:Real} + onGPU = typeof(newLE) <: CuArray ? true : false N1, N2 = size(M) - new_M = CUDA.CuArray(M) # TODO: this is a hack to solve problem with types; + new_M = ArrayorCuArray(M, onGPU) # TODO: this is a hack to solve problem with types; new_M = reshape(new_M, (N1, N2, 1)) NNlib.batched_mul!(newLE, LE, new_M) end