Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distribute tests across workers #1559

Merged
merged 3 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }} # allow nightly to fail
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -41,12 +42,9 @@ jobs:
- uses: julia-actions/julia-buildpkg@v1
env:
JULIA_PKG_SERVER: ""
# `allow-failure` not available yet https://github.com/actions/toolkit/issues/399
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
- uses: julia-actions/julia-runtest@v1
env:
JULIA_PKG_SERVER: ""
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
- uses: julia-actions/julia-processcoverage@v1
if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
- uses: codecov/codecov-action@v3
Expand Down
16 changes: 0 additions & 16 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ ZygoteTrackerExt = "Tracker"
AbstractFFTs = "1.3.1"
ChainRules = "1.72.2"
ChainRulesCore = "1.25.1"
ChainRulesTestUtils = "1"
Colors = "0.12, 0.13"
DiffRules = "1.4"
Distances = "0.10"
Expand All @@ -59,18 +58,3 @@ Statistics = "1"
Tracker = "0.2"
ZygoteRules = "0.2.7"
julia = "1.10"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["ChainRulesTestUtils", "Conda", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PythonCall", "Test"]
22 changes: 22 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[deps]
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
22 changes: 14 additions & 8 deletions test/chainrules.jl → test/chainrules_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using ChainRulesCore, ChainRulesTestUtils, Zygote
@testitem "chainrules" begin

using ChainRulesCore
using ChainRulesTestUtils
using Zygote: ZygoteRuleConfig
using LinearAlgebra

@testset "ChainRules integration" begin
@testset "ChainRules basics" begin
Expand Down Expand Up @@ -64,7 +68,7 @@ using Zygote: ZygoteRuleConfig
end
return simo(x), simo_pullback
end

simo_outer(x) = sum(simo(x))

simo_rrule_hitcount[] = 0
Expand All @@ -86,7 +90,7 @@ using Zygote: ZygoteRuleConfig
end
return miso(a, b), miso_pullback
end


miso_outer(x) = miso(100x, 10x)

Expand Down Expand Up @@ -182,7 +186,7 @@ using Zygote: ZygoteRuleConfig
end
return kwfoo(x; k=k), kwfoo_pullback
end


kwfoo_outer_unused(x) = kwfoo(x)
kwfoo_outer_used(x) = kwfoo(x; k=-15)
Expand All @@ -207,7 +211,7 @@ using Zygote: ZygoteRuleConfig
end
return not_diff_kw_eg(x, i; kwargs...), not_diff_kw_eg_pullback
end


@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2), 10.4)
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2; kw=2.0), 10.4)
Expand All @@ -218,7 +222,7 @@ using Zygote: ZygoteRuleConfig
x::T
end
StructForTestingTypeOnlyRRules() = StructForTestingTypeOnlyRRules(1.0)

function ChainRulesCore.rrule(P::Type{<:StructForTestingTypeOnlyRRules})
# notice here we mess with the primal doing 2.0 rather than 1.0, this is for testing purposes
# and also because apparently people actually want to do this. Weird, but 🤷
Expand Down Expand Up @@ -253,7 +257,7 @@ using Zygote: ZygoteRuleConfig
@test ([1.0],) == Zygote.gradient(oout_id_outer, [π])
@test oout_id_rrule_hitcount[] == 0

# Now try opting out After we have already used it
# Now try opting out After we have already used it
@opt_out ChainRulesCore.rrule(::typeof(oout_id), x::Real)
oout_id_rrule_hitcount[] = 0
@test (1.0,) == Zygote.gradient(oout_id_outer, π)
Expand Down Expand Up @@ -399,7 +403,7 @@ end
@test @inferred(Zygote.z2d((nothing,), (1,))) === NoTangent()
@test @inferred(Zygote.z2d((nothing, nothing), (1,2))) === NoTangent()

# To test the generic case, we need a struct within a struct.
# To test the generic case, we need a struct within a struct.
nested = Tangent{Base.RefValue{ComplexF64}}(; x=Tangent{ComplexF64}(; re=1, im=NoTangent()),)
if VERSION > v"1.7-"
@test @inferred(Zygote.z2d((; x=(; re=1)), Ref(3.0+im))) == nested
Expand Down Expand Up @@ -456,3 +460,5 @@ end
@test g[1] isa NamedTuple
@test g[1].w isa Array
end

end
41 changes: 23 additions & 18 deletions test/compiler.jl → test/compiler_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Zygote, Test
@testitem "compiler" begin

using LinearAlgebra
using Zygote: pullback, @adjoint, Context

macro test_inferred(ex)
Expand All @@ -11,9 +13,10 @@ macro test_inferred(ex)
end) |> esc
end

trace_contains(st, func, file, line) = any(st) do fr
func in (nothing, fr.func) && endswith(String(fr.file), file) &&
fr.line == line
function trace_contains(st, func, file, line)
any(st) do fr
func in (nothing, fr.func) && endswith(String(fr.file), file) && fr.line == line
end
end

bad(x) = x
Expand All @@ -32,8 +35,8 @@ y, back = pullback(badly, 2)
@test_throws Exception back(1)
bt = try back(1) catch e stacktrace(catch_backtrace()) end

@test trace_contains(bt, nothing, "compiler.jl", bad_def_line)
@test trace_contains(bt, :badly, "compiler.jl", bad_call_line)
@test trace_contains(bt, nothing, "compiler_tests.jl", bad_def_line)
@test trace_contains(bt, :badly, "compiler_tests.jl", bad_call_line)

# Type inference checks

Expand Down Expand Up @@ -277,20 +280,20 @@ function try_catch_finally(cond, x)
x
end

function try_catch_else(cond, x)
x = 2x
function try_catch_else(cond, x)
x = 2x

try
x = 2x
cond && throw(nothing)
catch
x = 3x
else
x = 2x
end
try
x = 2x
cond && throw(nothing)
catch
x = 3x
else
x = 2x
end

x
end
x
end

@testset "try/catch" begin
@testset "happy path (nothrow)" begin
Expand Down Expand Up @@ -337,3 +340,5 @@ end
err = try pull(1.) catch ex; ex end
@test occursin("Can't differentiate function execution in catch block", string(err))
end

end
48 changes: 25 additions & 23 deletions test/complex.jl → test/complex_tests.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
using Zygote, Test, LinearAlgebra
@testitem "complex" begin

@testset "basic" begin

@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0
@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] ≈ -1im
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ 0 # projected to zero
@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] ≈ 1im
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] ≈ 0
using LinearAlgebra

@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
@test gradient(a -> real(([a].*conj([a])))[], 0.3im)[1] == 0.6im
@test gradient(a -> real(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
@test gradient(a -> real.(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
@testset "basic" begin
@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0
@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] ≈ -1im
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ 0 # projected to zero
@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] ≈ 1im
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] ≈ 0

@test gradient(x -> norm((im*x) ./ (im)), 2)[1] == 1
@test gradient(x -> norm((im) ./ (im*x)), 2)[1] == -1/4
@test gradient(x -> real(det(x)), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]
@test gradient(x -> real(logdet(x)), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]/10
@test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]/10
@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
@test gradient(a -> real(([a].*conj([a])))[], 0.3im)[1] == 0.6im
@test gradient(a -> real(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
@test gradient(a -> real.(([a].*conj.([a])))[], 0.3im)[1] == 0.6im

# https://github.com/FluxML/Zygote.jl/issues/705
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3))
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3)
@test gradient(x -> norm((im*x) ./ (im)), 2)[1] == 1
@test gradient(x -> norm((im) ./ (im*x)), 2)[1] == -1/4
@test gradient(x -> real(det(x)), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]
@test gradient(x -> real(logdet(x)), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]/10
@test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]/10

end # @testset
# https://github.com/FluxML/Zygote.jl/issues/705
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3))
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3)
end

fs_C_to_R = (real,
imag,
Expand Down Expand Up @@ -120,3 +120,5 @@ end
end
@test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
end

end
7 changes: 6 additions & 1 deletion test/cuda.jl → test/cuda_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
@testitem "cuda" begin

using LinearAlgebra
using CUDA
using Zygote: Grads
using Random: randn!
import FiniteDifferences

CUDA.allowscalar(false)

function gradcheck_gpu(f, xs...)
Expand All @@ -11,7 +15,6 @@ function gradcheck_gpu(f, xs...)
return all(isapprox.(collect.(grad_zygote), grad_finite_difference))
end


# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
r = rand(Float32, 3,3)
Expand Down Expand Up @@ -195,3 +198,5 @@ end


end

end
4 changes: 4 additions & 0 deletions test/deprecated.jl → test/deprecated_tests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
@testitem "deprecated" begin

@test_deprecated dropgrad(1)
@test_deprecated ignore(1)
@test_deprecated Zygote.@ignore x=1
Expand All @@ -8,3 +10,5 @@
y = Zygote.@ignore x
x * y
end == (1,)

end
20 changes: 7 additions & 13 deletions test/features.jl → test/features_tests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
@testitem "features" begin

using Zygote, Test, LinearAlgebra
using Zygote: Params, gradient, forwarddiff
using FillArrays: Fill
Expand Down Expand Up @@ -397,7 +399,7 @@ global_r = 1
@test back(1) == (nothing, 3)
ref = first(keys(Zygote.cache(cx)))
@test ref isa GlobalRef
@test ref.mod == Main
@test_broken ref.mod == Main # TODO module is now "features###"
@test ref.name == :global_param
@test Zygote.cache(cx)[ref] == 2

Expand All @@ -409,7 +411,7 @@ global_r = 1
end
return global_r
end

@test gradient(pow_global, 2, 3) == (12, nothing)
end

Expand Down Expand Up @@ -694,14 +696,6 @@ end
end == ([8 112; 36 2004],)
end

@testset "PythonCall custom @adjoint" begin
using PythonCall: pyimport, pyconvert
math = pyimport("math")
pysin(x) = math.sin(x)
Zygote.@adjoint pysin(x) = pyconvert(Float64, math.sin(x)), δ -> (pyconvert(Float64, δ * math.cos(x)),)
@test Zygote.gradient(pysin, 1.5) == Zygote.gradient(sin, 1.5)
end

# https://github.com/JuliaDiff/ChainRules.jl/issues/257
@testset "Keyword Argument Passing" begin
struct Type1{VJP}
Expand All @@ -717,9 +711,8 @@ end
end

i = 1
global x = Any[nothing,nothing]

Zygote.@nograd g(x,i,sensealg) = Main.x[i] = sensealg
x = Any[nothing,nothing]
Zygote.@nograd g(x,i,sensealg) = x[i] = sensealg
function f(;sensealg=nothing)
g(x,i,sensealg)
return rand(100)
Expand Down Expand Up @@ -889,3 +882,4 @@ end
@test g1[1] ≈ g2[1] ≈ g3[1]
end

end
6 changes: 5 additions & 1 deletion test/forward/forward.jl → test/forward_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Zygote, Test
@testitem "forward" begin

using LinearAlgebra

D(f, x) = pushforward(f, x)(1)

Expand Down Expand Up @@ -46,3 +48,5 @@ end == 0
mul!(B, A, A)
sum(B)
end == 6

end
Loading
Loading