Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Matrix{BlasFloat} * or \ VecOrMatrix{Dual} #589

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -26,7 +26,7 @@ Calculus = "0.5"
CommonSubexpressions = "0.3"
DiffResults = "1.1"
DiffRules = "1.4"
DiffTests = "0.1"
DiffTests = "0.1.3"
LogExpFunctions = "0.3"
NaNMath = "1"
Preferences = "1"
Expand Down
2 changes: 2 additions & 0 deletions src/ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using DiffResults: DiffResult, MutableDiffResult
using Preferences
using Random
using LinearAlgebra
using SparseArrays
using Base: require_one_based_indexing

import Printf
import NaNMath
import SpecialFunctions
Expand Down
85 changes: 85 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,91 @@ function SpecialFunctions.gamma_inc(a::Real, d::Dual{T,<:Real}, ind::Integer) wh
return (Dual{T}(p, ∂p), Dual{T}(q, -∂p))
end

# Efficient left multiplication/division of #
# Dual array by a constant matrix #
#-------------------------------------------#
# creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
# and fpartial!(partial(y, i), partial(y, i), i) to its partials
function _map_dual_components!(fvalue!, fpartial!, y::AbstractArray{DT}, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
N = npartials(DT)
tx = similar(x, T)
ty = similar(y, T) # temporary Array{T} for fvalue!/fpartial! application
# y allows res to be accessed as Array{T}
yarr = reinterpret(reshape, T, y)
@assert size(yarr) == (N + 1, size(y)...)
ystride = size(yarr, 1)

# calculate res values
@inbounds for (j, v) in enumerate(x)
tx[j] = value(v)
end
fvalue!(ty, tx)
k = 1
@inbounds for tt in ty
yarr[k] = tt
k += ystride
end

# calculate each res partial
for i in 1:N
@inbounds for (j, v) in enumerate(x)
tx[j] = partials(v, i)
end
fpartial!(ty, tx, i)
k = i + 1
@inbounds for tt in ty
yarr[k] = tt
k += ystride
end
end

return y
end

# use ldiv!() for matrices of normal numbers to
# implement ldiv!() of dual vector by a matrix
LinearAlgebra.ldiv!(y::StridedVector{T},
m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
UpperTriangular{<:LinearAlgebra.BlasFloat},
SparseMatrixCSC{<:LinearAlgebra.BlasFloat}},
x::StridedVector{T}) where T <: Dual =
(ldiv!(reinterpret(reshape, valtype(T), y)', m, reinterpret(reshape, valtype(T), x)'); y)

Base.:\(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
UpperTriangular{<:LinearAlgebra.BlasFloat},
SparseMatrixCSC{<:LinearAlgebra.BlasFloat}},
x::StridedVector{<:Dual}) = ldiv!(similar(x), m, x)

for MT in (StridedMatrix{<:LinearAlgebra.BlasFloat},
LowerTriangular{<:LinearAlgebra.BlasFloat},
UpperTriangular{<:LinearAlgebra.BlasFloat},
SparseMatrixCSC{<:LinearAlgebra.BlasFloat},
)
@eval begin

LinearAlgebra.ldiv!(y::StridedMatrix{T}, m::$MT, x::StridedMatrix{T}) where T <: Dual =
_map_dual_components!((y, x) -> ldiv!(y, m, x), (y, x, _) -> ldiv!(y, m, x), y, x)

Base.:\(m::$MT, x::StridedMatrix{<:Dual}) = ldiv!(similar(x), m, x)

LinearAlgebra.mul!(y::StridedVector{T}, m::$MT, x::StridedVector{T}) where T <: Dual =
(mul!(reinterpret(reshape, valtype(T), y), reinterpret(reshape, valtype(T), x), m'); y)

LinearAlgebra.mul!(y::StridedVector{T}, m::$MT, x::StridedVector{T},
α::Union{LinearAlgebra.BlasFloat, Integer},
β::Union{LinearAlgebra.BlasFloat, Integer}) where T <: Dual =
(mul!(reinterpret(reshape, valtype(T), y), reinterpret(reshape, valtype(T), x), m', α, β); y)

Base.:*(m::$MT, x::StridedVector{<:Dual}) = mul!(similar(x, (size(m, 1),)), m, x)

LinearAlgebra.mul!(y::StridedMatrix{T}, m::$MT, x::StridedMatrix{T}) where T <: Dual =
_map_dual_components!((y, x) -> mul!(y, m, x), (y, x, _) -> mul!(y, m, x), y, x)

Base.:*(m::$MT, x::StridedMatrix{<:Dual}) = mul!(similar(x, (size(m, 1), size(x, 2))), m, x)

end
end

###################
# Pretty Printing #
###################
Expand Down
13 changes: 7 additions & 6 deletions test/DerivativeTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ Random.seed!(1)

const x = 1

for f in DiffTests.NUMBER_TO_NUMBER_FUNCS
println(" ...testing $f")
@testset "Derivative test vs Calculus.jl" begin

@testset "$f(x::Number)::Number" for f in DiffTests.NUMBER_TO_NUMBER_FUNCS
v = f(x)
d = ForwardDiff.derivative(f, x)
@test isapprox(d, Calculus.derivative(f, x), atol=FINITEDIFF_ERROR)
Expand All @@ -29,8 +30,7 @@ for f in DiffTests.NUMBER_TO_NUMBER_FUNCS
@test isapprox(DiffResults.derivative(out), d)
end

for f in DiffTests.NUMBER_TO_ARRAY_FUNCS
println(" ...testing $f")
@testset "$f(x::Number)::Array" for f in DiffTests.NUMBER_TO_ARRAY_FUNCS
v = f(x)
d = ForwardDiff.derivative(f, x)

Expand All @@ -47,8 +47,7 @@ for f in DiffTests.NUMBER_TO_ARRAY_FUNCS
@test isapprox(DiffResults.derivative(out), d)
end

for f! in DiffTests.INPLACE_NUMBER_TO_ARRAY_FUNCS
println(" ...testing $f!")
@testset "$f!(y::Vector, x::Number)" for f! in DiffTests.INPLACE_NUMBER_TO_ARRAY_FUNCS
m, n = 3, 2
y = fill(0.0, m, n)
f = x -> (tmp = similar(y, promote_type(eltype(y), typeof(x)), m, n); f!(tmp, x); tmp)
Expand Down Expand Up @@ -89,6 +88,8 @@ for f! in DiffTests.INPLACE_NUMBER_TO_ARRAY_FUNCS
@test isapprox(DiffResults.derivative(out), d)
end

end

@testset "exponential function at base zero" begin
@test (x -> ForwardDiff.derivative(y -> x^y, -0.5))(0.0) === -Inf
@test (x -> ForwardDiff.derivative(y -> x^y, 0.0))(0.0) === -Inf
Expand Down
4 changes: 1 addition & 3 deletions test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ ForwardDiff.:≺(::Type{TestTag}, ::Type{OuterTestTag}) = true
ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false

@testset "Dual{Z,$V,$N} and Dual{Z,Dual{Z,$V,$M},$N}" for N in (0,3), M in (0,4), V in (Int, Float32)
println(" ...testing Dual{TestTag(),$V,$N} and Dual{TestTag(),Dual{TestTag(),$V,$M},$N}")

PARTIALS = Partials{N,V}(ntuple(n -> intrand(V), N))
PRIMAL = intrand(V)
Expand Down Expand Up @@ -466,13 +465,12 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
@test abs(NESTED_FDNUM) === NESTED_FDNUM

if V != Int
@testset "$f" for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
@testset "auto-testing $(M).$(f) with $arity arguments" for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if f in (:/, :rem2pi)
continue # Skip these rules
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
continue # Skip rules for methods not defined in the current scope
end
println(" ...auto-testing $(M).$(f) with $arity arguments")
if arity == 1
deriv = DiffRules.diffrule(M, f, :x)
modifier = if in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth))
Expand Down
19 changes: 11 additions & 8 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ include(joinpath(dirname(@__FILE__), "utils.jl"))
# hardcoded test #
##################

@testset "hardcoded tests" begin

f = DiffTests.rosenbrock_1
x = [0.1, 0.2, 0.3]
v = f(x)
g = [-9.4, 15.6, 52.0]

@testset "Rosenbrock, chunk size = $c and tag = $(repr(tag))" for c in (1, 2, 3), tag in (nothing, Tag(f, eltype(x)))
println(" ...running hardcoded test with chunk size = $c and tag = $(repr(tag))")
cfg = ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{c}(), tag)

@test eltype(cfg) == Dual{typeof(tag), eltype(x), c}
Expand Down Expand Up @@ -51,17 +52,19 @@ cfgx = ForwardDiff.GradientConfig(sin, x)
@test_throws ForwardDiff.InvalidTagException ForwardDiff.gradient(f, x, cfgx)
@test ForwardDiff.gradient(f, x, cfgx, Val{false}()) == ForwardDiff.gradient(f,x)

end

########################
# test vs. Calculus.jl #
########################
@testset "Comparison vs Calculus.jl" begin

@testset "$f" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
@testset "$f(x::Vector)::Number" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
v = f(X)
g = ForwardDiff.gradient(f, X)
@test isapprox(g, Calculus.gradient(f, X), atol=FINITEDIFF_ERROR)
for c in CHUNK_SIZES, tag in (nothing, Tag(f, eltype(x)))
println(" ...testing $f with chunk size = $c and tag = $(repr(tag))")

@testset "chunk size = $c and tag = $(repr(tag))" for c in CHUNK_SIZES, tag in (nothing, Tag(f, eltype(X)))
cfg = ForwardDiff.GradientConfig(f, X, ForwardDiff.Chunk{c}(), tag)

out = ForwardDiff.gradient(f, X, cfg)
Expand All @@ -78,15 +81,15 @@ cfgx = ForwardDiff.GradientConfig(sin, x)
end
end

end

##########################################
# test specialized StaticArray codepaths #
##########################################

println(" ...testing specialized StaticArray codepaths")

@testset "$T" for T in (StaticArrays.SArray, StaticArrays.MArray)
x = rand(3, 3)
x = rand(3, 3)

@testset "Specialized $T codepaths" for T in (StaticArrays.SArray, StaticArrays.MArray)
sx = T{Tuple{3,3}}(x)

cfg = ForwardDiff.GradientConfig(nothing, x)
Expand Down
21 changes: 13 additions & 8 deletions test/HessianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ h = [-66.0 -40.0 0.0;
-40.0 130.0 -80.0;
0.0 -80.0 200.0]

for c in HESSIAN_CHUNK_SIZES, tag in (nothing, Tag((f,ForwardDiff.gradient), eltype(x)))
println(" ...running hardcoded test with chunk size = $c and tag = $(repr(tag))")
@testset "hardcoded" begin

@testset "chunk size = $c and tag = $(repr(tag))" for c in HESSIAN_CHUNK_SIZES, tag in (nothing, Tag((f,ForwardDiff.gradient), eltype(x)))
cfg = ForwardDiff.HessianConfig(f, x, ForwardDiff.Chunk{c}(), tag)
resultcfg = ForwardDiff.HessianConfig(f, DiffResults.HessianResult(x), x, ForwardDiff.Chunk{c}(), tag)

Expand Down Expand Up @@ -54,6 +55,8 @@ for c in HESSIAN_CHUNK_SIZES, tag in (nothing, Tag((f,ForwardDiff.gradient), elt
@test isapprox(DiffResults.hessian(out), h)
end

end

cfgx = ForwardDiff.HessianConfig(sin, x)
@test_throws ForwardDiff.InvalidTagException ForwardDiff.hessian(f, x, cfgx)
@test ForwardDiff.hessian(f, x, cfgx, Val{false}()) == ForwardDiff.hessian(f,x)
Expand All @@ -63,14 +66,16 @@ cfgx = ForwardDiff.HessianConfig(sin, x)
# test vs. Calculus.jl #
########################

for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
@testset "Comparison vs Calculus.jl" begin

@testset "$f(x::Vector)::Number" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
v = f(X)
g = ForwardDiff.gradient(f, X)
h = ForwardDiff.hessian(f, X)
# finite difference approximation error is really bad for Hessians...
@test isapprox(h, Calculus.hessian(f, X), atol=0.02)
for c in HESSIAN_CHUNK_SIZES, tag in (nothing, Tag((f,ForwardDiff.gradient), eltype(x)))
println(" ...testing $f with chunk size = $c and tag = $(repr(tag))")

@testset "chunk size = $c and tag = $(repr(tag))" for c in HESSIAN_CHUNK_SIZES, tag in (nothing, Tag((f,ForwardDiff.gradient), eltype(x)))
cfg = ForwardDiff.HessianConfig(f, X, ForwardDiff.Chunk{c}(), tag)
resultcfg = ForwardDiff.HessianConfig(f, DiffResults.HessianResult(X), X, ForwardDiff.Chunk{c}(), tag)

Expand All @@ -89,14 +94,14 @@ for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
end
end

end

##########################################
# test specialized StaticArray codepaths #
##########################################

println(" ...testing specialized StaticArray codepaths")

x = rand(3, 3)
for T in (StaticArrays.SArray, StaticArrays.MArray)
@testset "Specialized $T codepaths" for T in (StaticArrays.SArray, StaticArrays.MArray)
sx = T{Tuple{3,3}}(x)

cfg = ForwardDiff.HessianConfig(nothing, x)
Expand Down
Loading