Skip to content

Commit

Permalink
Improve DiffRules integration and tests (#209)
Browse files Browse the repository at this point in the history
* Improve DiffRules integration and tests

* Bump version

* Try to remove suspicious line
  • Loading branch information
devmotion authored Oct 16, 2022
1 parent 8ac1f7d commit f06b776
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 59 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReverseDiff"
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
version = "1.14.3"
version = "1.14.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
11 changes: 11 additions & 0 deletions src/ReverseDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ const REAL_TYPES = (:Bool, :Integer, :(Irrational{:ℯ}), :(Irrational{:π}), :R
const SKIPPED_UNARY_SCALAR_FUNCS = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
const SKIPPED_BINARY_SCALAR_FUNCS = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]

# Some functions with derivatives in DiffRules are not supported
# For instance, ReverseDiff does not support functions with complex results and derivatives
const SKIPPED_DIFFRULES = Tuple{Symbol,Symbol}[
(:SpecialFunctions, :hankelh1),
(:SpecialFunctions, :hankelh1x),
(:SpecialFunctions, :hankelh2),
(:SpecialFunctions, :hankelh2x),
(:SpecialFunctions, :besselh),
(:SpecialFunctions, :besselhx),
]

include("tape.jl")
include("tracked.jl")
include("macros.jl")
Expand Down
128 changes: 106 additions & 22 deletions src/derivatives/elementwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ for g! in (:map!, :broadcast!), (M, f, arity) in DiffRules.diffrules(; filter_mo
@warn "$M.$f is not available and hence rule for it can not be defined"
continue # Skip rules for methods not defined in the current scope
end
(M, f) in SKIPPED_DIFFRULES && continue
if arity == 1
@eval @inline Base.$(g!)(f::typeof($M.$f), out::TrackedArray, t::TrackedArray) = $(g!)(ForwardOptimize(f), out, t)
elseif arity == 2
Expand Down Expand Up @@ -122,23 +123,53 @@ for (g!, g) in ((:map!, :map), (:broadcast!, :broadcast))
return out
end
end
for A in ARRAY_TYPES, T in (:TrackedArray, :TrackedReal)
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray{S}, x::$(T){X}, y::$A) where {F,S,X}
result = DiffResults.GradientResult(SVector(zero(S), zero(S)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
for A in ARRAY_TYPES
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedReal{X,D}, y::$A) where {F,X,D}
result = DiffResults.DiffResult(zero(X), zero(D))
df = let result=result
(vx, vy) -> let vy=vy
ForwardDiff.derivative!(result, s -> f.f(s, vy), vx)
end
end
results = $(g)(df, value(x), value(y))
map!(DiffResult.value, value(out), results)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tape(x, y), SpecialInstruction, $(g), (x, y), out, cache)
record!(tape(x), SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::$(T){Y}) where {F,Y}
result = DiffResults.GradientResult(SVector(zero(S), zero(S)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedReal{Y,D}) where {F,Y,D}
result = DiffResults.DiffResult(zero(Y), zero(D))
df = let result=result
(vx, vy) -> let vx=vx
ForwardDiff.derivative!(result, s -> f.f(vx, s), vy)
end
end
results = $(g)(df, value(x), value(y))
map!(DiffResult.value, value(out), results)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tape(x, y), SpecialInstruction, $(g), (x, y), out, cache)
record!(tape(y), SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedArray{X}, y::$A) where {F,X}
result = DiffResults.GradientResult(SVector(zero(X)))
df = (vx, vy) -> let vy=vy
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx))
end
results = $(g)(df, value(x), value(y))
map!(DiffResult.value, value(out), results)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tape(x), SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedArray{Y}) where {F,Y}
result = DiffResults.GradientResult(SVector(zero(Y)))
df = let vx=vx
(vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy))
end
results = $(g)(df, value(x), value(y))
map!(DiffResult.value, value(out), results)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tape(y), SpecialInstruction, $(g), (x, y), out, cache)
return out
end
end
Expand Down Expand Up @@ -166,6 +197,7 @@ for g in (:map, :broadcast), (M, f, arity) in DiffRules.diffrules(; filter_modul
if arity == 1
@eval @inline Base.$(g)(f::typeof($M.$f), t::TrackedArray) = $(g)(ForwardOptimize(f), t)
elseif arity == 2
(M, f) in SKIPPED_DIFFRULES && continue
# skip these definitions if `f` is one of the functions
# that will get a manually defined broadcast definition
# later (see "built-in infix operations" below)
Expand Down Expand Up @@ -207,20 +239,52 @@ for g in (:map, :broadcast)
record!(tp, SpecialInstruction, $(g), x, out, cache)
return out
end
for A in ARRAY_TYPES, T in (:TrackedArray, :TrackedReal)
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$(T){X,D}, y::$A) where {F,X,D}
result = DiffResults.GradientResult(SVector(zero(X), zero(D)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
for A in ARRAY_TYPES
@eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedReal{X,D}, y::$A) where {F,X,D}
result = DiffResults.DiffResult(zero(X), zero(D))
df = let result=result
(vx, vy) -> let vy=vy
ForwardDiff.derivative!(result, s -> f.f(s, vy), vx)
end
end
results = $(g)(df, value(x), value(y))
tp = tape(x)
out = track(DiffResults.value.(results), D, tp)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::$(T){Y,D}) where {F,Y,D}
result = DiffResults.GradientResult(SVector(zero(Y), zero(D)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedReal{Y,D}) where {F,Y,D}
result = DiffResults.DiffResult(zero(Y), zero(D))
df = let result=result
(vx, vy) -> let vx=vx
ForwardDiff.derivative!(result, s -> f.f(vx, s), vy)
end
end
results = $(g)(df, value(x), value(y))
tp = tape(y)
out = track(DiffResults.value.(results), D, tp)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedArray{X,D}, y::$A) where {F,X,D}
result = DiffResults.GradientResult(SVector(zero(X)))
df = (vx, vy) -> let vy=vy
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx))
end
results = $(g)(df, value(x), value(y))
tp = tape(x)
out = track(DiffResults.value.(results), D, tp)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedArray{Y,D}) where {F,Y,D}
result = DiffResults.GradientResult(SVector(zero(Y)))
df = (vx, vy) -> let vx=vx
ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy))
end
results = $(g)(df, value(x), value(y))
tp = tape(y)
out = track(DiffResults.value.(results), D, tp)
Expand Down Expand Up @@ -291,8 +355,15 @@ end
diffresult_increment_deriv!(input, output_deriv, results, 1)
else
a, b = input
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1)
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2)
p = 0
if istracked(a)
p += 1
diffresult_increment_deriv!(a, output_deriv, results, p)
end
if istracked(b)
p += 1
diffresult_increment_deriv!(b, output_deriv, results, p)
end
end
unseed!(output)
return nothing
Expand All @@ -311,12 +382,25 @@ end
end
else
a, b = input
p = 0
if size(a) == size(b)
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1)
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2)
if istracked(a)
p += 1
diffresult_increment_deriv!(a, output_deriv, results, p)
end
if istracked(b)
p += 1
diffresult_increment_deriv!(b, output_deriv, results, p)
end
else
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1, a_bound)
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2, b_bound)
if istracked(a)
p += 1
diffresult_increment_deriv!(a, output_deriv, results, p, a_bound)
end
if istracked(b)
p += 1
diffresult_increment_deriv!(b, output_deriv, results, p, b_bound)
end
end
end
unseed!(output)
Expand Down
1 change: 1 addition & 0 deletions src/derivatives/scalars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
@warn "$M.$f is not available and hence rule for it can not be defined"
continue # Skip rules for methods not defined in the current scope
end
(M, f) in SKIPPED_DIFFRULES && continue
if arity == 1
@eval @inline $M.$(f)(t::TrackedReal) = ForwardOptimize($M.$(f))(t)
elseif arity == 2
Expand Down
63 changes: 39 additions & 24 deletions test/derivatives/ElementwiseTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function test_elementwise(f, fopt, x, tp)
# reverse
out = similar(y, (length(x), length(x)))
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
test_approx(out, ForwardDiff.jacobian(z -> map(f, z), x))
test_approx(out, ForwardDiff.jacobian(z -> map(f, z), x); nans=true)

# forward
x2 = x .- offset
Expand All @@ -57,7 +57,7 @@ function test_elementwise(f, fopt, x, tp)
# reverse
out = similar(y, (length(x), length(x)))
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
test_approx(out, ForwardDiff.jacobian(z -> broadcast(f, z), x))
test_approx(out, ForwardDiff.jacobian(z -> broadcast(f, z), x); nans=true)

# forward
x2 = x .- offset
Expand All @@ -81,9 +81,9 @@ function test_map(f, fopt, a, b, tp)
@test length(tp) == 1

# reverse
out = similar(c, (length(a), length(a)))
out = similar(c, (length(c), length(a)))
ReverseDiff.seeded_reverse_pass!(out, ct, at, tp)
test_approx(out, ForwardDiff.jacobian(x -> map(f, x, b), a))
test_approx(out, ForwardDiff.jacobian(x -> map(f, x, b), a); nans=true)

# forward
a2 = a .- offset
Expand All @@ -102,9 +102,9 @@ function test_map(f, fopt, a, b, tp)
@test length(tp) == 1

# reverse
out = similar(c, (length(a), length(a)))
out = similar(c, (length(c), length(b)))
ReverseDiff.seeded_reverse_pass!(out, ct, bt, tp)
test_approx(out, ForwardDiff.jacobian(x -> map(f, a, x), b))
test_approx(out, ForwardDiff.jacobian(x -> map(f, a, x), b); nans=true)

# forward
b2 = b .- offset
Expand All @@ -123,13 +123,17 @@ function test_map(f, fopt, a, b, tp)
@test length(tp) == 1

# reverse
out_a = similar(c, (length(a), length(a)))
out_b = similar(c, (length(a), length(a)))
out_a = similar(c, (length(c), length(a)))
out_b = similar(c, (length(c), length(b)))
ReverseDiff.seeded_reverse_pass!(out_a, ct, at, tp)
ReverseDiff.seeded_reverse_pass!(out_b, ct, bt, tp)
test_approx(out_a, ForwardDiff.jacobian(x -> map(f, x, b), a))
test_approx(out_b, ForwardDiff.jacobian(x -> map(f, a, x), b))

jac = let a=a, b=b, f=f
ForwardDiff.jacobian(vcat(vec(a), vec(b))) do x
map(f, reshape(x[1:length(a)], size(a)), reshape(x[(length(a) + 1):end], size(b)))
end
end
test_approx(out_a, jac[:, 1:length(a)]; nans=true)
test_approx(out_b, jac[:, (length(a) + 1):end]; nans=true)
# forward
a2, b2 = a .- offset, b .- offset
ReverseDiff.value!(at, a2)
Expand Down Expand Up @@ -163,7 +167,7 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
# reverse
out = similar(c, (length(c), length(a)))
ReverseDiff.seeded_reverse_pass!(out, ct, at, tp)
test_approx(out, ForwardDiff.jacobian(x -> g(x, b), a))
test_approx(out, ForwardDiff.jacobian(x -> g(x, b), a); nans=true)

# forward
a2 = a .- offset
Expand All @@ -184,7 +188,7 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
# reverse
out = similar(c, (length(c), length(b)))
ReverseDiff.seeded_reverse_pass!(out, ct, bt, tp)
test_approx(out, ForwardDiff.jacobian(x -> g(a, x), b))
test_approx(out, ForwardDiff.jacobian(x -> g(a, x), b); nans=true)

# forward
b2 = b .- offset
Expand All @@ -207,8 +211,13 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
out_b = similar(c, (length(c), length(b)))
ReverseDiff.seeded_reverse_pass!(out_a, ct, at, tp)
ReverseDiff.seeded_reverse_pass!(out_b, ct, bt, tp)
test_approx(out_a, ForwardDiff.jacobian(x -> g(x, b), a))
test_approx(out_b, ForwardDiff.jacobian(x -> g(a, x), b))
jac = let a=a, b=b, g=g
ForwardDiff.jacobian(vcat(vec(a), vec(b))) do x
g(reshape(x[1:length(a)], size(a)), reshape(x[(length(a) + 1):end], size(b)))
end
end
test_approx(out_a, jac[:, 1:length(a)]; nans=true)
test_approx(out_b, jac[:, (length(a) + 1):end]; nans=true)

# forward
a2, b2 = a .- offset, b .- offset
Expand Down Expand Up @@ -243,7 +252,7 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
# reverse
out = similar(y)
ReverseDiff.seeded_reverse_pass!(out, yt, nt, tp)
test_approx(out, ForwardDiff.derivative(z -> g(z, x), n))
test_approx(out, ForwardDiff.derivative(z -> g(z, x), n); nans=true)

# forward
n2 = n + offset
Expand All @@ -264,7 +273,7 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
# reverse
out = similar(y, (length(y), length(x)))
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
test_approx(out, ForwardDiff.jacobian(z -> g(n, z), x))
test_approx(out, ForwardDiff.jacobian(z -> g(n, z), x); nans=true)

# forward
x2 = x .- offset
Expand All @@ -287,8 +296,11 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
out_x = similar(y, (length(y), length(x)))
ReverseDiff.seeded_reverse_pass!(out_n, yt, nt, tp)
ReverseDiff.seeded_reverse_pass!(out_x, yt, xt, tp)
test_approx(out_n, ForwardDiff.derivative(z -> g(z, x), n))
test_approx(out_x, ForwardDiff.jacobian(z -> g(n, z), x))
jac = let x=x, g=g
ForwardDiff.jacobian(z -> g(z[1], reshape(z[2:end], size(x))), vcat(n, vec(x)))
end
test_approx(out_n, reshape(jac[:, 1], size(y)); nans=true)
test_approx(out_x, jac[:, 2:end]; nans=true)

# forward
n2, x2 = n + offset , x .- offset
Expand Down Expand Up @@ -323,7 +335,7 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
# reverse
out = similar(y)
ReverseDiff.seeded_reverse_pass!(out, yt, nt, tp)
test_approx(out, ForwardDiff.derivative(z -> g(x, z), n))
test_approx(out, ForwardDiff.derivative(z -> g(x, z), n); nans=true)

# forward
n2 = n + offset
Expand All @@ -344,7 +356,7 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
# reverse
out = similar(y, (length(y), length(x)))
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
test_approx(out, ForwardDiff.jacobian(z -> g(z, n), x))
test_approx(out, ForwardDiff.jacobian(z -> g(z, n), x); nans=true)

# forward
x2 = x .- offset
Expand All @@ -367,8 +379,11 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
out_x = similar(y, (length(y), length(x)))
ReverseDiff.seeded_reverse_pass!(out_n, yt, nt, tp)
ReverseDiff.seeded_reverse_pass!(out_x, yt, xt, tp)
test_approx(out_n, ForwardDiff.derivative(z -> g(x, z), n))
test_approx(out_x, ForwardDiff.jacobian(z -> g(z, n), x))
jac = let x=x, g=g
ForwardDiff.jacobian(z -> g(reshape(z[1:(end - 1)], size(x)), z[end]), vcat(vec(x), n))
end
test_approx(out_x, jac[:, 1:(end - 1)]; nans=true)
test_approx(out_n, reshape(jac[:, end], size(y)); nans=true)

# forward
x2, n2 = x .- offset, n + offset
Expand All @@ -393,7 +408,7 @@ for (M, fsym, arity) in DiffRules.diffrules(; filter_modules=nothing)
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), fsym))
error("$M.$fsym is not available")
end
fsym === :rem2pi && continue
(M, fsym) in ReverseDiff.SKIPPED_DIFFRULES && continue
if arity == 1
f = eval(:($M.$fsym))
test_println("forward-mode unary scalar functions", f)
Expand Down
Loading

2 comments on commit f06b776

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/70383

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.14.4 -m "<description of version>" f06b776333dfd3b5259a27bdfb789d8f017c1ee1
git push origin v1.14.4

Please sign in to comment.