From a869b5d1daa3b7f398b2de79d1a20d9c514a2875 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 8 May 2024 13:01:09 -0400 Subject: [PATCH] Try fixing tests --- src/ITensorGPU.jl | 21 ++++++++++++--------- test/test_cudense.jl | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/ITensorGPU.jl b/src/ITensorGPU.jl index 5c693a9..d9063e5 100644 --- a/src/ITensorGPU.jl +++ b/src/ITensorGPU.jl @@ -1,14 +1,16 @@ module ITensorGPU -using CUDA: CUDA -using ITensors: cpu, cu -export cpu, cu +using Adapt: adapt +using CUDA: CUDA, cu +export cu +using ITensors: cpu +export cpu using ITensors: ITensor, cpu, cu, randomITensor function cuITensor(args...; kwargs...) - return cu(ITensor(args...; kwargs...)) + return adapt(CuArray, ITensor(args...; kwargs...)) end function randomCuITensor(args...; kwargs...) - return cu(randomITensor(args...; kwargs...)) + return adapt(CuArray, randomITensor(args...; kwargs...)) end export cuITensor, randomCuITensor @@ -16,16 +18,17 @@ export cuITensor, randomCuITensor # once it is registered. using ITensors.ITensorMPS: MPO, MPS, randomMPS function cuMPS(args...; kwargs...) - return cu(MPS(args...; kwargs...)) + return adapt(CuArray, MPS(args...; kwargs...)) end function productCuMPS(args...; kwargs...) - return cu(MPS(args...; kwargs...)) + return adapt(CuArray, MPS(args...; kwargs...)) end function randomCuMPS(args...; kwargs...) - return cu(randomMPS(args...; kwargs...)) + return adapt(CuArray, randomMPS(args...; kwargs...)) end function cuMPO(args...; kwargs...) - return cu(MPO(args...; kwargs...)) + return adapt(CuArray, MPO(args...; kwargs...)) end +cuMPO(tn::MPO) = cu(tn) export cuMPO, cuMPS, productCuMPS, randomCuMPO, randomCuMPS end diff --git a/test/test_cudense.jl b/test/test_cudense.jl index 5e88ab3..af6713c 100644 --- a/test/test_cudense.jl +++ b/test/test_cudense.jl @@ -1,4 +1,4 @@ -using CUDA: CUDA, CuArray, CuVector +using CUDA using Combinatorics: permutations using ITensors using ITensorGPU