From 884422b4700a40c0c087da1670c3c19d9324574e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 17:01:37 -0400 Subject: [PATCH] refactor: move SparseArrays into an extension --- Project.toml | 9 +++--- ext/FiniteDiffSparseArraysExt.jl | 55 ++++++++++++++++++++++++++++++++ src/FiniteDiff.jl | 2 +- src/iteration_utils.jl | 26 --------------- src/jacobians.jl | 46 +++++++------------------- 5 files changed, 72 insertions(+), 66 deletions(-) create mode 100644 ext/FiniteDiffSparseArraysExt.jl diff --git a/Project.toml b/Project.toml index 37300ef..7947fad 100644 --- a/Project.toml +++ b/Project.toml @@ -1,22 +1,22 @@ name = "FiniteDiff" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.24.0" +version = "2.25.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] FiniteDiffBandedMatricesExt = "BandedMatrices" FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" +FiniteDiffSparseArraysExt = "SparseArrays" FiniteDiffStaticArraysExt = "StaticArrays" [compat] @@ -32,8 +32,9 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets", "StaticArrays"] +test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets", "SparseArrays", "StaticArrays"] diff --git a/ext/FiniteDiffSparseArraysExt.jl b/ext/FiniteDiffSparseArraysExt.jl new file mode 100644 index 0000000..99913a7 --- /dev/null +++ b/ext/FiniteDiffSparseArraysExt.jl @@ -0,0 +1,55 @@ +module FiniteDiffSparseArraysExt + +using SparseArrays +using FiniteDiff + +# jacobians.jl +function FiniteDiff._make_Ji(::SparseMatrixCSC, rows_index, cols_index, dx, colorvec, color_i, nrows, ncols) + pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i] + rows_index_c = rows_index[pick_inds] + cols_index_c = cols_index[pick_inds] + Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c], nrows, ncols) + Ji +end + +function FiniteDiff._make_Ji(::SparseMatrixCSC, xtype, dx, color_i, nrows, ncols) + Ji = sparse(1:nrows, fill(color_i, nrows), dx, nrows, ncols) + Ji +end + +@inline function FiniteDiff._colorediteration!(J, sparsity::SparseMatrixCSC, rows_index, cols_index, vfx, colorvec, color_i, ncols) + @inbounds for col_index in 1:ncols + if colorvec[col_index] == color_i + @inbounds for row_index in view(sparsity.rowval, sparsity.colptr[col_index]:sparsity.colptr[col_index+1]-1) + J[row_index, col_index] = vfx[row_index] + end + end + end +end + +@inline FiniteDiff.fill_matrix!(J::AbstractSparseMatrix, v) = fill!(nonzeros(J), v) + +@inline function FiniteDiff.fast_jacobian_setindex!(J::AbstractSparseMatrix, rows_index, cols_index, _color, color_i, vfx) + @. FiniteDiff.void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,), rows_index), rows_index) +end + +# iteration_utils.jl +## fast version for the case where J and sparsity have the same sparsity pattern +@inline function FiniteDiff._colorediteration!(Jsparsity::SparseMatrixCSC, vfx, colorvec, color_i, ncols) + @inbounds for col_index in 1:ncols + if colorvec[col_index] == color_i + @inbounds for spidx in nzrange(Jsparsity, col_index) + row_index = Jsparsity.rowval[spidx] + Jsparsity.nzval[spidx] = vfx[row_index] + end + end + end +end + +FiniteDiff._use_findstructralnz(::SparseMatrixCSC) = false + +FiniteDiff._use_sparseCSC_common_sparsity(J::SparseMatrixCSC, sparsity::SparseMatrixCSC) = + ((J.colptr == sparsity.colptr) && (J.rowval == sparsity.rowval)) + + +end diff --git a/src/FiniteDiff.jl b/src/FiniteDiff.jl index d873881..bc05843 100644 --- a/src/FiniteDiff.jl +++ b/src/FiniteDiff.jl @@ -5,7 +5,7 @@ Fast non-allocating calculations of gradients, Jacobians, and Hessians with spar """ module FiniteDiff -using LinearAlgebra, SparseArrays, ArrayInterface +using LinearAlgebra, ArrayInterface import Base: resize! diff --git a/src/iteration_utils.jl b/src/iteration_utils.jl index 4d21050..55ab7e3 100644 --- a/src/iteration_utils.jl +++ b/src/iteration_utils.jl @@ -6,34 +6,8 @@ end end -@inline function _colorediteration!(J,sparsity::SparseMatrixCSC,rows_index,cols_index,vfx,colorvec,color_i,ncols) - @inbounds for col_index in 1:ncols - if colorvec[col_index] == color_i - @inbounds for row_index in view(sparsity.rowval,sparsity.colptr[col_index]:sparsity.colptr[col_index+1]-1) - J[row_index,col_index]=vfx[row_index] - end - end - end -end - -# fast version for the case where J and sparsity have the same sparsity pattern -@inline function _colorediteration!(Jsparsity::SparseMatrixCSC,vfx,colorvec,color_i,ncols) - @inbounds for col_index in 1:ncols - if colorvec[col_index] == color_i - @inbounds for spidx in nzrange(Jsparsity, col_index) - row_index = Jsparsity.rowval[spidx] - Jsparsity.nzval[spidx]=vfx[row_index] - end - end - end -end - #override default setting of using findstructralnz _use_findstructralnz(sparsity) = ArrayInterface.has_sparsestruct(sparsity) -_use_findstructralnz(::SparseMatrixCSC) = false # test if J, sparsity are both SparseMatrixCSC and have the same sparsity pattern of stored values _use_sparseCSC_common_sparsity(J, sparsity) = false -_use_sparseCSC_common_sparsity(J::SparseMatrixCSC, sparsity::SparseMatrixCSC) = - ((J.colptr == sparsity.colptr) && (J.rowval == sparsity.rowval)) - diff --git a/src/jacobians.jl b/src/jacobians.jl index 95085c4..4e59559 100644 --- a/src/jacobians.jl +++ b/src/jacobians.jl @@ -125,14 +125,6 @@ function JacobianCache( JacobianCache{typeof(_x1),typeof(_x2),typeof(_fx),typeof(fx1),typeof(colorvec),typeof(sparsity),fdtype,returntype}(_x1,_x2,_fx,fx1,colorvec,sparsity) end -function _make_Ji(::SparseMatrixCSC, rows_index,cols_index,dx,colorvec,color_i,nrows,ncols) - pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i] - rows_index_c = rows_index[pick_inds] - cols_index_c = cols_index[pick_inds] - Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols) - Ji -end - function _make_Ji(::AbstractArray, rows_index,cols_index,dx,colorvec,color_i,nrows,ncols) pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i] rows_index_c = rows_index[pick_inds] @@ -145,12 +137,6 @@ function _make_Ji(::AbstractArray, rows_index,cols_index,dx,colorvec,color_i,nro Ji end -function _make_Ji(::SparseMatrixCSC, xtype, dx, color_i, nrows, ncols) - Ji = sparse(1:nrows,fill(color_i,nrows),dx,nrows,ncols) - Ji -end - - function _make_Ji(::AbstractArray, xtype, dx, color_i, nrows, ncols) Ji = mapreduce(i -> i==color_i ? dx : zero(dx), hcat, 1:ncols) size(Ji) != (nrows, ncols) ? reshape(Ji, (nrows, ncols)) : Ji #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1) @@ -445,11 +431,7 @@ function finite_difference_jacobian!( end if sparsity !== nothing - if J isa AbstractSparseMatrix - fill!(nonzeros(J),false) - else - fill!(J,false) - end + fill_matrix!(J, false) end # fast path if J and sparsity are both AbstractSparseMatrix and have the same sparsity pattern @@ -497,11 +479,7 @@ function finite_difference_jacobian!( J[rows_index, cols_index] .+= (colorvec[cols_index] .== color_i) .* vfx1[rows_index] += means requires a zero'd out start =# - if J isa AbstractSparseMatrix - @. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index) - else - @. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index, cols_index) - end + fast_jacobian_setindex!(J, rows_index, cols_index, _color, color_i, vfx1) end # Now return x1 back to its original value @. x1 = x1 - epsilon * (_color == color_i) @@ -535,11 +513,7 @@ function finite_difference_jacobian!( _colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n) end else - if J isa AbstractSparseMatrix - @. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index) - else - @. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index, cols_index) - end + fast_jacobian_setindex!(J, rows_index, cols_index, _color, color_i, vfx1) end @. x1 = x1 - epsilon * (_color == color_i) @. x = x + epsilon * (_color == color_i) @@ -565,11 +539,7 @@ function finite_difference_jacobian!( _colorediteration!(J,sparsity,rows_index,cols_index,vfx,colorvec,color_i,n) end else - if J isa AbstractSparseMatrix - @. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,),rows_index), rows_index) - else - @. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,), rows_index), rows_index, cols_index) - end + fast_jacobian_setindex!(J, rows_index, cols_index, _color, color_i, vfx) end @. x1 = x1 - im * epsilon * (_color == color_i) end @@ -583,7 +553,13 @@ end function resize!(cache::JacobianCache, i::Int) resize!(cache.x1, i) resize!(cache.fx, i) - cache.fx1 != nothing && resize!(cache.fx1, i) + cache.fx1 !== nothing && resize!(cache.fx1, i) cache.colorvec = 1:i nothing end + +@inline fill_matrix!(J, v) = fill!(J, v) + +@inline function fast_jacobian_setindex!(J, rows_index, cols_index, _color, color_i, vfx) + @. void_setindex!((J,), getindex((J,), rows_index, cols_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,), rows_index), rows_index, cols_index) +end \ No newline at end of file