Skip to content
Open
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[weakdeps]
BaseType = "7fbed51b-1ef5-4d67-9085-a4a9b26f478c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand All @@ -18,6 +19,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[extensions]
MeasurementsBaseTypeExt = "BaseType"
MeasurementsForwardDiffExt = "ForwardDiff"
MeasurementsJunoExt = "Juno"
MeasurementsMakieExt = "Makie"
MeasurementsRecipesBaseExt = "RecipesBase"
Expand All @@ -28,6 +30,7 @@ MeasurementsUnitfulExt = "Unitful"
Aqua = "0.8"
BaseType = "0.2"
Calculus = "0.4.1, 0.5"
ForwardDiff = "0.10.36, 1"
Juno = "0.8"
LinearAlgebra = "<0.0.1, 1"
Makie = "0.21, 0.22"
Expand All @@ -43,6 +46,7 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BaseType = "7fbed51b-1ef5-4d67-9085-a4a9b26f478c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand All @@ -53,4 +57,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Aqua", "Makie", "BaseType", "QuadGK", "RecipesBase", "SpecialFunctions", "Statistics", "Test", "Unitful"]
test = ["Aqua", "Makie", "BaseType", "QuadGK", "RecipesBase", "SpecialFunctions", "Statistics", "Test", "Unitful", "ForwardDiff"]
123 changes: 123 additions & 0 deletions ext/MeasurementsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
module MeasurementsForwardDiffExt

using ForwardDiff: Dual, DiffRules, NaNMath, LogExpFunctions, SpecialFunctions,≺
using Measurements: Measurement
import Base: +,-,/,*,promote_rule
using ForwardDiff: AMBIGUOUS_TYPES, partials, values, Partials, value
using ForwardDiff: ForwardDiff

#patch this until is fixed in ForwardDiff

@generated function ForwardDiff.construct_seeds(::Type{Partials{N,V}}) where {N,V<:Measurement}
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...)

Check warning on line 12 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L11-L12

Added lines #L11 - L12 were not covered by tests
end

#needs redefinition here, because generated functions don't allow extra definitions
@generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i}
ex = Expr(:tuple, [ifelse(i === j, :(oneunit(V)), :(zero(V))) for j in 1:N]...)
return :(Partials($(ex)))

Check warning on line 18 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L16-L18

Added lines #L16 - L18 were not covered by tests
end

function promote_rule(::Type{Measurement{V}}, ::Type{Dual{T, V, N}}) where {T,V,N}
Dual{Measurement{T}, V, N}

Check warning on line 22 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
end

function promote_rule(::Type{Measurement{V1}}, ::Type{Dual{T, V2, N}}) where {V1<:AbstractFloat, T, V2, N}
Vx = promote_rule(Measurement{V1},V2)
return Dual{T , Vx, N}

Check warning on line 27 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L25-L27

Added lines #L25 - L27 were not covered by tests
end

function overload_ambiguous_binary(M,f)
Mf = :($M.$f)
return quote
@inline function $Mf(x::Dual{Tx}, y::Measurement) where {Tx}
∂y = Dual{Tx}(y)
$Mf(x,∂y)
end

@inline function $Mf(x::Measurement,y::Dual{Ty}) where {Ty}
∂x = Dual{Ty}(x)
$Mf(∂x,y)
end
end
end

macro define_ternary_dual_op2(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body)
FD = ForwardDiff
R = Measurement
defs = quote
Comment thread
longemen3000 marked this conversation as resolved.
@inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body
@inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body
@inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz ≺ Tx ? $x_body : $z_body
@inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body

Check warning on line 54 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L49-L54

Added lines #L49 - L54 were not covered by tests
end
for Q in AMBIGUOUS_TYPES
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body
@inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body

Check warning on line 60 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L60

Added line #L60 was not covered by tests
end
append!(defs.args, expr.args)
end
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body
@inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body

Check warning on line 67 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L65-L67

Added lines #L65 - L67 were not covered by tests
end
append!(defs.args, expr.args)
return esc(defs)
end

#use DiffRules.jl rules

for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if (M, f) in ((:Base, :^), (:NaNMath, :pow))
continue # Skip methods which we define elsewhere.
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
continue # Skip rules for methods not defined in the current scope

Check warning on line 79 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L79

Added line #L79 was not covered by tests
end
if arity == 2
eval(overload_ambiguous_binary(M,f))
else

Check warning on line 83 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L83

Added line #L83 was not covered by tests
# error("ForwardDiff currently only knows how to autogenerate Dual definitions for unary and binary functions.")
# However, the presence of N-ary rules need not cause any problems here, they can simply be ignored.
end
end

#ternary overloads
@define_ternary_dual_op2(
Base.hypot,
ForwardDiff.calc_hypot(x, y, z, Txyz),
ForwardDiff.calc_hypot(x, y, z, Txy),
ForwardDiff.calc_hypot(x, y, z, Txz),
ForwardDiff.calc_hypot(x, y, z, Tyz),
ForwardDiff.calc_hypot(x, y, z, Tx),
ForwardDiff.calc_hypot(x, y, z, Ty),
ForwardDiff.calc_hypot(x, y, z, Tz),
)

@define_ternary_dual_op2(
Base.fma,
ForwardDiff.calc_fma_xyz(x, y, z), # xyz_body
ForwardDiff.calc_fma_xy(x, y, z), # xy_body
ForwardDiff.calc_fma_xz(x, y, z), # xz_body
Base.fma(y, x, z), # yz_body
Dual{Tx}(Base.fma(value(x), y, z), partials(x) * y), # x_body
Base.fma(y, x, z), # y_body
Dual{Tz}(Base.fma(x, y, value(z)), partials(z)) # z_body
)

@define_ternary_dual_op2(
Base.muladd,
ForwardDiff.calc_muladd_xyz(x, y, z), # xyz_body
ForwardDiff.calc_muladd_xy(x, y, z), # xy_body
ForwardDiff.calc_muladd_xz(x, y, z), # xz_body
Base.muladd(y, x, z), # yz_body
Dual{Tx}(Base.muladd(value(x), y, z), partials(x) * y), # x_body
Base.muladd(y, x, z), # y_body
Dual{Tz}(Base.muladd(x, y, value(z)), partials(z)) # z_body
)

end #module
19 changes: 19 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Measurements, SpecialFunctions, QuadGK, Calculus, BaseType, Makie
using Test, LinearAlgebra, Statistics, Unitful, Printf, Aqua
using ForwardDiff

Aqua.test_all(Measurements)

Expand Down Expand Up @@ -1067,3 +1068,21 @@ end
@test base_numeric_type(typeof(x)) == T
end
end


fd_f1(x,y) = x+y
fd_f2(x,y) = x-y
fd_f3(x,y) = x*y
fd_f4(x,y) = x/y
fd_f5(x,y) = muladd(x,y,1)

@testset "ForwardDiff" begin
x1 = 1.0 ± 0.1
y1 = 2.0 ± 0.001
for op in (:fd_f1,:fd_f2,:fd_f3,:fd_f4,:fd_f5)
@eval begin
@test ForwardDiff.derivative(x->$(op)(x,$y1),$x1) isa Measurement
@test ForwardDiff.derivative(y->$(op)($x1,y),$y1) isa Measurement
end
Comment thread
longemen3000 marked this conversation as resolved.
Outdated
end
end
Loading