Skip to content

Commit

Permalink
improve latency of matrix-exp
Browse files Browse the repository at this point in the history
  • Loading branch information
thchr committed Feb 28, 2022
1 parent df49828 commit 76f9157
Showing 1 changed file with 47 additions and 46 deletions.
93 changes: 47 additions & 46 deletions src/expm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,53 +77,54 @@ function _exp(::Size, _A::StaticMatrix{<:Any,<:Any,T}) where T
A = S.(_A)
# omitted: matrix balancing, i.e., LAPACK.gebal!
nA = maximum(sum(abs.(A); dims=Val(1))) # marginally more performant than norm(A, 1)
## For sufficiently small nA, use lower order Padé-Approximations
if (nA <= 2.1)
A2 = A*A
if nA > 0.95
U = @evalpoly(A2, S(8821612800)*I, S(302702400)*I, S(2162160)*I, S(3960)*I, S(1)*I)
U = A*U
V = @evalpoly(A2, S(17643225600)*I, S(2075673600)*I, S(30270240)*I, S(110880)*I, S(90)*I)
elseif nA > 0.25
U = @evalpoly(A2, S(8648640)*I, S(277200)*I, S(1512)*I, S(1)*I)
U = A*U
V = @evalpoly(A2, S(17297280)*I, S(1995840)*I, S(25200)*I, S(56)*I)
elseif nA > 0.015
U = @evalpoly(A2, S(15120)*I, S(420)*I, S(1)*I)
U = A*U
V = @evalpoly(A2, S(30240)*I, S(3360)*I, S(30)*I)
else
U = @evalpoly(A2, S(60)*I, S(1)*I)
U = A*U
V = @evalpoly(A2, S(120)*I, S(12)*I)
end
expA = (V - U) \ (V + U)

if (nA 2.1) # for sufficiently small nA, use lower order Padé-Approximations
return _pade_exp(S, A, nA)
else
s = log2(nA/5.4) # power of 2 later reversed by squaring
if s > 0
si = ceil(Int,s)
A = A / S(2^si)
end

A2 = A*A
A4 = A2*A2
A6 = A2*A4

U = A6*(S(1)*A6 + S(16380)*A4 + S(40840800)*A2) +
(S(33522128640)*A6 + S(10559470521600)*A4 + S(1187353796428800)*A2) +
S(32382376266240000)*I
U = A*U
V = A6*(S(182)*A6 + S(960960)*A4 + S(1323241920)*A2) +
(S(670442572800)*A6 + S(129060195264000)*A4 + S(7771770303897600)*A2) +
S(64764752532480000)*I
expA = (V - U) \ (V + U)

if s > 0 # squaring to reverse dividing by power of 2
for t=1:si
expA = expA*expA
end
end
return _rescaled_exp(S, A, nA)
end
end

expA
function _pade_exp(S, A, nA)
A2 = A*A
U, V = if nA > 0.95
@evalpoly(A2, S(8821612800)*I, S(302702400)*I, S(2162160)*I, S(3960)*I, S(1)*I),
@evalpoly(A2, S(17643225600)*I, S(2075673600)*I, S(30270240)*I, S(110880)*I, S(90)*I)
elseif nA > 0.25
@evalpoly(A2, S(8648640)*I, S(277200)*I, S(1512)*I, S(1)*I),
@evalpoly(A2, S(17297280)*I, S(1995840)*I, S(25200)*I, S(56)*I)
elseif nA > 0.015
@evalpoly(A2, S(15120)*I, S(420)*I, S(1)*I),
@evalpoly(A2, S(30240)*I, S(3360)*I, S(30)*I)
else
@evalpoly(A2, S(60)*I, S(1)*I),
@evalpoly(A2, S(120)*I, S(12)*I)
end
U = A*U
return (V - U) \ (V + U)
end

function _rescaled_exp(S, A, nA)
si = ceil(Int, log2(nA/5.4)) # power of 2 later reversed by squaring
if si > 0
A /= S(2^si)
end

A2 = A*A
A4 = A2*A2
A6 = A2*A4

U = A6*(S(1)*A6 + S(16380)*A4 + S(40840800)*A2) +
(S(33522128640)*A6 + S(10559470521600)*A4 + S(1187353796428800)*A2) +
S(32382376266240000)*I
U = A*U
V = A6*(S(182)*A6 + S(960960)*A4 + S(1323241920)*A2) +
(S(670442572800)*A6 + S(129060195264000)*A4 + S(7771770303897600)*A2) +
S(64764752532480000)*I
expA = (V - U) \ (V + U)

for _ in 1:si # squaring to reverse dividing by power of 2
expA *= expA
end
return expA
end

0 comments on commit 76f9157

Please sign in to comment.