diff --git a/src/solvers/common.jl b/src/solvers/common.jl index e3480bf..b4d6eac 100644 --- a/src/solvers/common.jl +++ b/src/solvers/common.jl @@ -1,4 +1,5 @@ # Define common functions for all solvers +# fixme: Dispatch different functions for Flaot64 and Float32 to improve performance, maybe exists a better way for Float32 # Copy matrices from host to device (Float64) function copy_to_device!(du::PtrArray{Float64}, u::PtrArray{Float64}) @@ -9,9 +10,10 @@ function copy_to_device!(du::PtrArray{Float64}, u::PtrArray{Float64}) end # Copy matrices from device to host (Float64) -function copy_to_host!(du::PtrArray{Float64}, u::PtrArray{Float64}) - du = Array(du) - u = Array(u) +function copy_to_host!(du::CuArray{Float64}, u::CuArray{Float64}) + # fixme: maybe direct PtrArray to CuArray conversion is possible + du = PtrArray(Array(du)) + u = PtrArray(Array(u)) return (du, u) end diff --git a/src/solvers/solvers.jl b/src/solvers/solvers.jl new file mode 100644 index 0000000..f7ad372 --- /dev/null +++ b/src/solvers/solvers.jl @@ -0,0 +1 @@ +include("common.jl")