From 0bd7f55f2c3a2d757b10f63551cde4d738e74b8f Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 30 Dec 2023 23:20:39 -0500 Subject: [PATCH 1/4] Swap PyCall to PythonCall The former was giving us setup-related headaches on CI, and we can use any Python FFI for the one test which needs it. --- Project.toml | 4 ++-- test/features.jl | 10 +++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index b9a7f703b..d2357cb26 100644 --- a/Project.toml +++ b/Project.toml @@ -68,9 +68,9 @@ Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["ChainRulesTestUtils", "Conda", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PyCall", "Test"] +test = ["ChainRulesTestUtils", "Conda", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PythonCall", "Test"] diff --git a/test/features.jl b/test/features.jl index 908ae5815..84f875c8f 100644 --- a/test/features.jl +++ b/test/features.jl @@ -685,13 +685,9 @@ end end == ([8 112; 36 2004],) end -@testset "PyCall custom @adjoint" begin - # Trigger Python install if required. Required for Buildkite CI! - import Conda - Conda.list() - - import PyCall - math = PyCall.pyimport("math") +@testset "PythonCall custom @adjoint" begin + using PythonCall: pyimport + math = pyimport("math") pysin(x) = math.sin(x) Zygote.@adjoint pysin(x) = math.sin(x), (δ) -> (δ * math.cos(x), ) @test Zygote.gradient(pysin, 1.5) == Zygote.gradient(sin, 1.5) From cbcb7a9f8b4cf7c5a9b388da92847edeb03e741e Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 30 Dec 2023 23:20:59 -0500 Subject: [PATCH 2/4] Don't load CUDA.jl on GHA --- test/runtests.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 672960944..9fea7b2d8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,17 +1,17 @@ using Zygote, Test, LinearAlgebra using Zygote: gradient, ZygoteRuleConfig -using CUDA -using CUDA: has_cuda @testset "all" begin # Overall testset ensures it keeps running after failure - - if has_cuda() - @testset "CUDA tests" begin - include("cuda.jl") + if !haskey(ENV, "GITHUB_ACTION") + using CUDA + if CUDA.has_cuda() + @testset "CUDA tests" begin + include("cuda.jl") + end + @info "CUDA tests have run" + else + @warn "CUDA not found - Skipping CUDA Tests" end - @info "CUDA tests have run" - else - @warn "CUDA not found - Skipping CUDA Tests" end @testset "deprecated.jl" begin From 7252a93de3aa1b88777996fcb3ada5f540b3641e Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 30 Dec 2023 21:24:22 -0800 Subject: [PATCH 3/4] add needed type conversion This was implicit for PyCall but needs to be explicit for PythonCall. --- test/features.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/features.jl b/test/features.jl index 84f875c8f..3c6766e44 100644 --- a/test/features.jl +++ b/test/features.jl @@ -686,10 +686,10 @@ end end @testset "PythonCall custom @adjoint" begin - using PythonCall: pyimport + using PythonCall: pyimport, pyconvert math = pyimport("math") pysin(x) = math.sin(x) - Zygote.@adjoint pysin(x) = math.sin(x), (δ) -> (δ * math.cos(x), ) + Zygote.@adjoint pysin(x) = pyconvert(Float64, math.sin(x)), δ -> (δ * math.cos(x),) @test Zygote.gradient(pysin, 1.5) == Zygote.gradient(sin, 1.5) end From 0f5d958149b94daa687ac59c86b0091d29024144 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 30 Dec 2023 22:05:11 -0800 Subject: [PATCH 4/4] One more `pyconvert` --- test/features.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/features.jl b/test/features.jl index 3c6766e44..c8640b098 100644 --- a/test/features.jl +++ b/test/features.jl @@ -689,7 +689,7 @@ end using PythonCall: pyimport, pyconvert math = pyimport("math") pysin(x) = math.sin(x) - Zygote.@adjoint pysin(x) = pyconvert(Float64, math.sin(x)), δ -> (δ * math.cos(x),) + Zygote.@adjoint pysin(x) = pyconvert(Float64, math.sin(x)), δ -> (pyconvert(Float64, δ * math.cos(x)),) @test Zygote.gradient(pysin, 1.5) == Zygote.gradient(sin, 1.5) end