Skip to content

Commit

Permalink
Fix and test promote_rule definitions (#207)
Browse files Browse the repository at this point in the history
* Fix and test `promote_rule` definitions

* Update Project.toml
  • Loading branch information
devmotion authored Oct 3, 2022
1 parent ac44511 commit 8ac1f7d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReverseDiff"
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
version = "1.14.2"
version = "1.14.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/ReverseDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using ChainRulesCore
# Not all operations will be valid over all of these types, but that's okay; such cases
# will simply error when they hit the original operation in the overloaded definition.
const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix)
const REAL_TYPES = (:Bool, :Integer, :(Irrational{:e}), :(Irrational{}), :Rational, :BigFloat, :BigInt, :AbstractFloat, :Real, :Dual)
const REAL_TYPES = (:Bool, :Integer, :(Irrational{:}), :(Irrational{}), :Rational, :BigFloat, :BigInt, :AbstractFloat, :Real, :Dual)

const SKIPPED_UNARY_SCALAR_FUNCS = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
const SKIPPED_BINARY_SCALAR_FUNCS = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
Expand Down
16 changes: 14 additions & 2 deletions src/tracked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,22 @@ Base.convert(::Type{T}, t::T) where {T<:TrackedReal} = t
Base.convert(::Type{T}, t::T) where {T<:TrackedArray} = t

for R in REAL_TYPES
@eval Base.promote_rule(::Type{$R}, ::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{promote_type($R,V),D,O}
R === :Dual && continue # ForwardDiff.Dual is handled below
@eval begin
if isconcretetype($R) # issue ForwardDiff#322
Base.promote_rule(::Type{TrackedReal{V,D,O}}, ::Type{$R}) where {V,D,O} = TrackedReal{promote_type(V,$R),D,O}
Base.promote_rule(::Type{$R}, ::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{promote_type($R,V),D,O}
else
Base.promote_rule(::Type{TrackedReal{V,D,O}}, ::Type{R}) where {V,D,O,R<:$R} = TrackedReal{promote_type(V,R),D,O}
Base.promote_rule(::Type{R}, ::Type{TrackedReal{V,D,O}}) where {R<:$R,V,D,O,} = TrackedReal{promote_type(R,V),D,O}
end
end
end

Base.promote_rule(::Type{R}, ::Type{TrackedReal{V,D,O}}) where {R<:Real,V,D,O} = TrackedReal{promote_type(R,V),D,O}
# Avoid method ambiguities for ForwardDiff.Dual
Base.promote_rule(::Type{TrackedReal{V1,D,O}}, ::Type{Dual{T,V2,N}}) where {V1,D,O,T,V2,N} = TrackedReal{promote_type(V1,Dual{T,V2,N}),D,O}
Base.promote_rule(::Type{Dual{T,V1,N}}, ::Type{TrackedReal{V2,D,O}}) where {T,V1,N,V2,D,O} = TrackedReal{promote_type(Dual{T,V1,N},V2),D,O}

Base.promote_rule(::Type{TrackedReal{V1,D1,O1}}, ::Type{TrackedReal{V2,D2,O2}}) where {V1,V2,D1,D2,O1,O2} = TrackedReal{promote_type(V1,V2),promote_type(D1,D2),Nothing}

###########################
Expand Down
13 changes: 13 additions & 0 deletions test/TrackedTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module TrackedTests
using ReverseDiff, Test
using ReverseDiff: TrackedReal, TrackedArray

import ForwardDiff

include(joinpath(dirname(@__FILE__), "utils.jl"))

samefields(a, b) = a === b
Expand Down Expand Up @@ -601,8 +603,19 @@ empty!(tp)
@test convert(typeof(ta), ta) === ta
@test convert(typeof(ta1), ta1) === ta1

@test promote_type(T, Bool) === T
@test promote_type(T, Int32) === T
@test promote_type(T, Int64) === T
@test promote_type(T, Integer) === TrackedReal{BigInt,Float64,A}
@test promote_type(T, typeof(ℯ)) === TrackedReal{BigFloat,Float64,A}
@test promote_type(T, typeof(π)) === TrackedReal{BigFloat,Float64,A}
@test promote_type(T, Rational{Int}) === TrackedReal{Rational{BigInt},Float64,A}
@test promote_type(T, BigFloat) === TrackedReal{BigFloat,Float64,A}
@test promote_type(T, BigInt) === T
@test promote_type(T, Float64) === TrackedReal{BigFloat,Float64,A}
@test promote_type(T, AbstractFloat) === TrackedReal{BigFloat,Float64,A}
@test promote_type(T, Real) === TrackedReal{Real,Float64,A}
@test promote_type(T, ForwardDiff.Dual{:tag,Float64,1}) === TrackedReal{ForwardDiff.Dual{:tag,BigFloat,1},Float64,A}
@test promote_type(T, TrackedReal{BigFloat,BigFloat,Nothing}) === TrackedReal{BigFloat,BigFloat,Nothing}
@test promote_type(T, T) === T

Expand Down

2 comments on commit 8ac1f7d

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/69416

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.14.3 -m "<description of version>" 8ac1f7dae3bb3c20956ecb8f14a3240b4752f318
git push origin v1.14.3

Please sign in to comment.