Skip to content

Commit

Permalink
Allow custom struct args to grad_from_chainrules macro (#232)
Browse files Browse the repository at this point in the history
* allow custom struct args to grad_from_chainrules

* fix test

* bump version

* fix test

* Update test/ChainRulesTests.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/macros.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/macros.jl

Co-authored-by: David Widmann <[email protected]>

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
mohamed82008 and devmotion authored Jul 19, 2023
1 parent 65cd309 commit 105290f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReverseDiff"
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
version = "1.14.6"
version = "1.15.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
14 changes: 12 additions & 2 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,19 @@ macro grad_from_chainrules(fcall)
Meta.isexpr(fcall, :call) && length(fcall.args) >= 2 ||
error("`@grad_from_chainrules` has to be applied to a function signature")
f = esc(fcall.args[1])
xs = fcall.args[2:end]
xs = map(fcall.args[2:end]) do x
if Meta.isexpr(x, :(::))
if length(x.args) == 1 # ::T without var name
return :($(gensym())::$(esc(x.args[1])))
else # x::T
@assert length(x.args) == 2
return :($(x.args[1])::$(esc(x.args[2])))
end
else
return x
end
end
args_l, args_r, args_track, args_fixed, arg_types, kwargs = _make_fwd_args(f, xs)

return quote
$f($(args_l...)) = ReverseDiff.track($(args_r...))
function ReverseDiff.track($(args_track...))
Expand Down
48 changes: 40 additions & 8 deletions test/ChainRulesTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ using DiffResults
using ReverseDiff
using Test

struct MyStruct end
f(::MyStruct, x) = sum(4x .+ 1)
f(x, y::MyStruct) = sum(4x .+ 1)
f(x) = sum(4x .+ 1)

function ChainRulesCore.rrule(::typeof(f), x)
Expand All @@ -20,21 +23,37 @@ function ChainRulesCore.rrule(::typeof(f), x)
rather than 4 when we compute the derivative of `f`, it means
the importing mechanism works.
=#
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
return NoTangent(), fill(3 * d, size(x))
end
return r, back
end
function ChainRulesCore.rrule(::typeof(f), ::MyStruct, x)
r = f(MyStruct(), x)
function back(d)
return NoTangent(), NoTangent(), fill(3 * d, size(x))
end
return r, back
end
function ChainRulesCore.rrule(::typeof(f), x, ::MyStruct)
r = f(x, MyStruct())
function back(d)
return NoTangent(), fill(3 * d, size(x)), NoTangent()
end
return r, back
end

ReverseDiff.@grad_from_chainrules f(x::ReverseDiff.TrackedArray)

# test arg type hygiene
ReverseDiff.@grad_from_chainrules f(::MyStruct, x::ReverseDiff.TrackedArray)
ReverseDiff.@grad_from_chainrules f(x::ReverseDiff.TrackedArray, y::MyStruct)

g(x, y) = sum(4x .+ 4y)

function ChainRulesCore.rrule(::typeof(g), x, y)
r = g(x, y)
function back(d)
# same as above, use 3 and 5 as the derivatives
return ChainRulesCore.NoTangent(), fill(3 * d, size(x)), fill(5 * d, size(x))
return NoTangent(), fill(3 * d, size(x)), fill(5 * d, size(x))
end
return r, back
end
Expand Down Expand Up @@ -93,6 +112,19 @@ ReverseDiff.@grad_from_chainrules g(x::ReverseDiff.TrackedArray, y::ReverseDiff.

end

@testset "custom struct input" begin
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, MyStruct(), input);
_, _, d = back(1)
@test output == f(MyStruct(), input)
@test d == fill(3, size(input))

output, back = ChainRulesCore.rrule(f, input, MyStruct());
_, d, _ = back(1)
@test output == f(input, MyStruct())
@test d == fill(3, size(input))
end

### Tape test
@testset "Tape test: Ensure ordinary call is not tracked" begin
tp = ReverseDiff.InstructionTape()
Expand All @@ -112,7 +144,7 @@ f_vararg(x, args...) = sum(4x .+ sum(args))
function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
r = f_vararg(x, args...)
function back(d)
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
return NoTangent(), fill(3 * d, size(x))
end
return r, back
end
Expand All @@ -136,7 +168,7 @@ f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))
function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
r = f_kw(x, args...; k=k, kwargs...)
function back(d)
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
return NoTangent(), fill(3 * d, size(x))
end
return r, back
end
Expand Down Expand Up @@ -175,20 +207,20 @@ end
### Isolated Scope
module IsolatedModuleForTestingScoping
using ChainRulesCore
using ReverseDiff: @grad_from_chainrules
using ReverseDiff: ReverseDiff, @grad_from_chainrules

f(x) = sum(4x .+ 1)

function ChainRulesCore.rrule(::typeof(f), x)
r = f(x)
function back(d)
# return a distinguishable but improper grad
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
return NoTangent(), fill(3 * d, size(x))
end
return r, back
end

@grad_from_chainrules f(x::TrackedArray)
@grad_from_chainrules f(x::ReverseDiff.TrackedArray)

module SubModule
using Test
Expand Down

2 comments on commit 105290f

@mohamed82008
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/87806

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.15.0 -m "<description of version>" 105290ffff86a5eccc55ca64feee0a1c52fffad0
git push origin v1.15.0

Please sign in to comment.