Skip to content

Commit

Permalink
Merge pull request #177 from JuliaDiff/ox/typo
Browse files Browse the repository at this point in the history
fix mistake in non-tuple output message
  • Loading branch information
oxinabox authored Jun 16, 2021
2 parents e21bf34 + 68fc041 commit dffcfff
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 25 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/fix_doctests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: fix_doctests
on:
pull_request:
jobs:
doctests:
name: Fix doctests (Julia ${{ matrix.julia-version }} - ${{ github.event_name }})
runs-on: ubuntu-latest
strategy:
matrix:
julia-version: [1.6]
steps:
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- uses: actions/checkout@v1
- name: Fix doctests
shell: julia --project=docs/ {0}
run: |
using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()
using Documenter
using ChainRulesTestUtils
doctest(ChainRulesTestUtils, fix=true)
# don't push changes to Manifest in suggestions, as it removes `path=..`
run(`git restore docs/Manifest.toml`)
- uses: reviewdog/action-suggester@v1
if: github.event_name == 'pull_request'
with:
tool_name: Documenter (fix doctests)
fail_on_error: true
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.11"
version = "0.7.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
16 changes: 8 additions & 8 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba"
git-tree-sha1 = "dbc9aae1227cfddaa9d2552f3ecba5b641f6cce9"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.4"
version = "0.10.5"

[[ChainRulesTestUtils]]
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
path = ".."
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.9"
version = "0.7.12"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand Down Expand Up @@ -47,9 +47,9 @@ version = "0.8.5"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649"
git-tree-sha1 = "5acbebf1be22db43589bc5aa1bb5fcc378b17780"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.26.3"
version = "0.27.0"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
Expand All @@ -62,10 +62,10 @@ uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.13"

[[IOCapture]]
deps = ["Logging"]
git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59"
deps = ["Logging", "Random"]
git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a"
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
version = "0.1.1"
version = "0.2.2"

[[InteractiveUtils]]
deps = ["Markdown"]
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
Documenter = "0.26"
Documenter = "0.27"
julia = "1"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ makedocs(;
],
strict=true,
checkdocs=:exports,
)
)

const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git"
deploydocs(; repo=repo, push_preview=true)
26 changes: 15 additions & 11 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ For information about ChainRules, including how to write rules, refer to the gen
## Canonical example

Let's suppose a custom transformation has been defined
```jldoctest ex; output = false
```jldoctest ex
function two2three(x1::Float64, x2::Float64)
return 1.0, 2.0*x1, 3.0*x2
end
Expand All @@ -21,7 +21,7 @@ end
two2three (generic function with 1 method)
```
along with the `frule`
```jldoctest ex; output = false
```jldoctest ex
using ChainRulesCore
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
Expand All @@ -33,7 +33,7 @@ end
```
and `rrule`
```jldoctest ex; output = false
```jldoctest ex
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
Expand All @@ -55,29 +55,31 @@ They can be used for any type and number of inputs and outputs.
The call will test the `frule` for function `f` at the point `x` in the domain.
Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally be tested at both `x` being above and below zero.

```jldoctest ex; output = false
```jldoctest ex
julia> using ChainRulesTestUtils;
julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_frule: two2three on Float64,Float64 | 6 6
```

### Testing the `rrule`

[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.

```jldoctest ex; output = false
```jldoctest ex
julia> test_rrule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 7 7
test_rrule: two2three on Float64,Float64 | 8 8
```

## Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,
```jldoctest ex; output = false
```jldoctest ex
function relu(x::Real)
return max(0, x)
end
Expand All @@ -86,7 +88,7 @@ end
relu (generic function with 1 method)
```
with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
```jldoctest ex; output = false
```jldoctest ex
@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)
# output
Expand All @@ -95,14 +97,16 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro

`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
call.
```jldoctest ex; output = false
```jldoctest ex
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 9 9
test_scalar: relu at 0.5 | 10 10
julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 9 9
test_scalar: relu at -0.5 | 10 10
```

## Testing constructors and functors (callable objects)
Expand Down
11 changes: 8 additions & 3 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,14 @@ function test_rrule(

check_inferred && _test_inferred(pullback, ȳ)
ad_cotangents = pullback(ȳ)
ad_cotangents isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
msg = "The pullback should return 1 cotangent for the primal and each primal input."
@test_msg msg length(ad_cotangents) == 1 + length(args)
@test_msg(
"The pullback must return a Tuple (∂self, ∂args...)",
ad_cotangents isa Tuple
)
@test_msg(
"The pullback should return 1 cotangent for the primal and each primal input.",
length(ad_cotangents) == length(primals)
)

# Correctness testing via finite differencing.
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
Expand Down
9 changes: 9 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,15 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
@test fails(() -> test_frule(foo, 2.1, 2.1))
@test fails(() -> test_rrule(foo, 21.0, 32.0))
end

@testset "rrule not returning a tuple" begin
bar(x, y) = x + 3y
function ChainRulesCore.rrule(::typeof(bar), x, y)
bar_pullback(dy) = dy
return bar(x,y), bar_pullback
end
@test fails(() -> test_rrule(bar, 21.0, 32.0))
end
end

@testset "structs" begin
Expand Down

2 comments on commit dffcfff

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/38980

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.12 -m "<description of version>" dffcfffdc4be08447b607d3255094bbd858618e3
git push origin v0.7.12

Please sign in to comment.