Skip to content

Commit

Permalink
Implement vcat(::AbstractBandedMatrix...) (#448)
Browse files Browse the repository at this point in the history
* implement bandwidths for OneElement

* make improvements

* fix sparse(::SparseMatrixCSC)

* fix bandwidths for SparseMatrixCSC, add for SparseVector

* add bandwidths(::Zeros) behaviour for empty sparse structures

* add unit tests

* overload vcat(::AbstractBandedMatrix...)

* style

* include tests in runtests.jl

* fix issue involving LazyBandedMatrices

* fixed mistake

* make improvements

* add vcat between BandedMatrices and OneElements

* fix issue involving calculation of bandwidths. Add unit tests for OneElement

* fix issue involving bandwidths larger than dimensions

* restore vcat

* v1.7.4

---------

Co-authored-by: Sheehan Olver <[email protected]>
  • Loading branch information
max-vassili3v and dlfivefifty authored Sep 4, 2024
1 parent 57a70a5 commit 4ff8ff0
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/BandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import ArrayLayouts: AbstractTridiagonalLayout, BidiagonalLayout, BlasMatLdivVec
symmetricuplo, transposelayout, triangulardata, triangularlayout, zero!,
QRPackedQLayout, AdjQRPackedQLayout

import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector
import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector, ZerosMatrix, ZerosVector

const libblas = LinearAlgebra.BLAS.libblas
const liblapack = LinearAlgebra.BLAS.liblapack
Expand Down
54 changes: 54 additions & 0 deletions src/generic/AbstractBandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,57 @@ function sum(A::AbstractBandedMatrix; dims=:)
throw(ArgumentError("dimension must be ≥ 1, got $dims"))
end
end

###
# vcat
###

function LinearAlgebra.vcat(x::AbstractBandedMatrix...)
#avoid unnecessary steps for singleton
if length(x) == 1
return x[1]
end

#instantiate the returned banded matrix with zeros and required bandwidths/dimensions
m = size(x[1], 2)
l,u = -m, typemin(Int64)
n = 0
isempty = true

#Check for dimension error and calculate bandwidths
for A in x
if size(A, 2) != m
sizes = Tuple(size(b, 2) for b in x)
throw(DimensionMismatch("number of columns of each matrix must match (got $sizes)"))
end

l_A, u_A = bandwidths(A)
if l_A + u_A >= 0
isempty = false
u = max(u, min(m - 1, u_A) - n)
l = max(l, min(size(A, 1) - 1, l_A) + n)
end

n += size(A, 1)
end

type = promote_type(eltype.(x)...)
if isempty
return BandedMatrix{type}(undef, (n, m), bandwidths(Zeros(1)))
end
ret = BandedMatrix(Zeros{type}(n, m), (l, u))

#Populate the banded matrix
row_offset = 0
for A in x
n_A = size(A, 1)

for i = 1:n_A, j = rowrange(A, i)
ret[row_offset + i, j] = A[i, j]
end

row_offset += n_A
end

ret
end
3 changes: 3 additions & 0 deletions src/interfaceimpl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ function getindex(D::Bidiagonal{T,V}, b::Band) where {T,V}
D.uplo == 'U' && b.i == 1 && return copy(D.ev)
convert(V, Zeros{T}(size(D,1)-abs(b.i)))
end


Base.vcat(x::Union{OneElement, ZerosMatrix, AdjOrTrans{<:Any,<:ZerosVector}, AbstractBandedMatrix}...) = vcat(BandedMatrix.(x)...)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ include("test_tribanded.jl")
include("test_interface.jl")
include("test_miscs.jl")
include("test_sum.jl")
include("test_cat.jl")
37 changes: 37 additions & 0 deletions test/test_cat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module TestCat

using BandedMatrices, LinearAlgebra, Test, Random, FillArrays, SparseArrays

@testset "vcat" begin
@testset "banded matrices" begin
a = BandedMatrix(0 => 1:2)
@test vcat(a) == a

b = BandedMatrix(0 => 1:3,-1 => 1:2, -2 => 1:1)
@test_throws DimensionMismatch vcat(a,b)

c = BandedMatrix(0 => [1.0, 2.0, 3.0], 1 => [1.0, 2.0], 2 => [1.0])
@test eltype(vcat(b, c)) == Float64
@test vcat(b, c) == vcat(Matrix(b), Matrix(c))

for i in ((1,2), (-3,4), (0,-1))
a = BandedMatrix(ones(Float64, rand(1:10), 5), i)
b = BandedMatrix(ones(Int64, rand(1:10), 5), i)
c = BandedMatrix(ones(Int32, rand(1:10), 5), i)
d = vcat(a, b, c)
sd = vcat(sparse(a), sparse(b), sparse(c))
@test eltype(d) == Float64
@test d == sd
@test bandwidths(d) == bandwidths(sd)
end
end

@testset "one element" begin
n = rand(3:20)
x,y = OneElement(1, (1,1), (1,n)), OneElement(1, (1,n), (1,n))
b = BandedMatrix((0 => ones(n-2), 1 => -2ones(n - 2), 2 => ones(n - 2)), (n-2, n))
@test vcat(x,b,y) == Tridiagonal([ones(n - 2); 0], [1 ; -2ones(n - 2); 1], [0; ones(n - 2)])
end
end

end

0 comments on commit 4ff8ff0

Please sign in to comment.