Skip to content

Commit

Permalink
Make the refactor work for basic examples in 1D and 2D
Browse files Browse the repository at this point in the history
  • Loading branch information
Kolaru committed Jan 17, 2024
1 parent 8087feb commit 203e2c1
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 74 deletions.
6 changes: 3 additions & 3 deletions src/IntervalRootFinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ const D = derivative

const where_bisect = 0.49609375 # 127//256

const Region = Union{Interval, SVector{N, <:Interval} where N}

IntervalBox(x::Interval, N::Integer) = SVector{N}(fill(x, N))
IntervalBox(xx::Vararg{Interval, N}) where N = SVector{N}(xx...)


include("region.jl")
include("root_object.jl")
include("roots.jl")

include("complex.jl")
include("contractors.jl")
include("roots.jl")
include("newton1d.jl")
include("quadratic.jl")
include("linear_eq.jl")
Expand Down
52 changes: 16 additions & 36 deletions src/contractors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,17 @@ function contract(::Type{Krawczyk}, f, derivative, X::AbstractVector, where_mid)
return m - Y*f(mm) + (I - Y*J) * (X.v - m)
end

function contract(root_problem::RootProblem{C}, X::region) where C
CX = contract(C, root_problem.f, root_problem.derivative, region(X), root_problem.where_bisect)
return Region(CX)
function contract(root_problem::RootProblem{C}, X) where C
return contract(C, root_problem.f, root_problem.derivative, X, root_problem.where_bisect)
end

"""
safe_isempty(X)
Similar to `isempty` function for `IntervalBox`, but also works for `SVector`
of `Interval`.
"""
safe_isempty(X) = any(isempty_interval.(X))

function image_contains_zero(f, R::Root)
X = interval(R)
X = root_region(R)
R.status == :empty && return Root(X, :empty)

imX = f(X)

if !(all(in_interval.(0, imX))) || safe_isempty(imX)
if !(all(in_interval.(0, imX))) || isempty_region(imX)
return Root(X, :empty)
end

Expand All @@ -71,17 +62,16 @@ function contract_root(root_problem::RootProblem{C}, R::Root) where C
C == Bisection && return R2
R2.status == :empty && return R2

X = interval(R)
X = root_region(R)
contracted_X = contract(root_problem, X)

# Only happens if X is partially out of the domain of f
safe_isempty(contracted_X) && return Root(X, :unknown) # force bisection
isempty_region(contracted_X) && return Root(X, :unknown) # force bisection

NX = intersect_interval(bareinterval(contracted_X), bareinterval(X))
NX = IntervalArithmetic._unsafe_interval(NX, min(decoration(contracted_X), decoration(X)), isguaranteed(contracted_X))
NX = intersect_region(contracted_X, X)

!isbounded(X) && return Root(NX, :unknown) # force bisection
safe_isempty(NX) && return Root(X, :empty)
!isbounded_region(X) && return Root(NX, :unknown) # force bisection
isempty_region(NX) && return Root(X, :empty)

if R.status == :unique || NX X # isstrictsubset_interval, we know there's a unique root inside
return Root(NX, :unique)
Expand All @@ -91,30 +81,20 @@ function contract_root(root_problem::RootProblem{C}, R::Root) where C
end

"""
refine(op, X::Root, tol)
refine(root_problem::RootProblem, X::Root)
Wrap the refine method to leave unchanged intervals that are not guaranteed to
contain an unique solution.
Refine a root.
"""
function refine(root_problem::RootProblem, R::Root)
root_status(R) != :unique && return R
return Root(refine_root(root_problem, region(R)))
end

"""
refine(C, X::Region, tol)

Refine a interval known to contain a solution.
X = root_region(R)

This function assumes that it is already known that `X` contains a unique root.
"""
function refine_root(root_problem::RootProblem, X::Region)
while diam(X) > root_problem.abstol
NX = intersect_interval(bareinterval(C(X)), bareinterval(X))
NX = IntervalArithmetic._unsafe_interval(NX, min(decoration(C(X)), decoration(X)), isguaranteed(C(X)))
isequal_interval(NX, X) && break # reached limit of precision
while diam_region(X) > root_problem.abstol
NX = intersect_region(contract(root_problem, X), X)
isequal_region(NX, X) && break # reached limit of precision
X = NX
end

return X
return Root(X, :unique)
end
4 changes: 2 additions & 2 deletions src/linear_eq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ function gauss_elimination_interval!(x::AbstractArray, A::AbstractMatrix, b::Abs
p .= 0

for i in 1:(n-1)
if 0 A[i, i] # diagonal matrix is not invertible
if in_interval(0, A[i, i]) # diagonal matrix is not invertible
p .= entireinterval(b[1])
return p .∩ x # return x?
return intersect_region(p, x) # return x?
end

for j in (i+1):n
Expand Down
52 changes: 33 additions & 19 deletions src/region.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
struct Region{T, N}
intervals::T
function intersect_region(X::Interval, Y::Interval)
intersection = intersect_interval(bareinterval(X), bareinterval(Y))
dec = min(decoration(X), decoration(Y))
guarantee = isguaranteed(X) && isguaranteed(Y)
return IntervalArithmetic._unsafe_interval(intersection, dec, guarantee)
end

Region(X::T) where {T <: Interval} = Region{T, 1}(X)
intersect_region(X::AbstractVector, Y::AbstractVector) = intersect_region.(X, Y)

function Region(Xs::AbstractVector{T}) where {T <: Interval}
N = length(Xs)
N == 1 && return Region{T, 1}(only(Xs))
return Region{T, N}(Xs)
end
isempty_region(X::Interval) = isempty_interval(X)
isempty_region(X::AbstractVector) = any(isempty_region.(X))

function Base.intersect(X::Region{<:Any, 1}, Y::Region{<:Any, 1})
x = only(X.intervals)
y = only(Y.intervals)
intersection = intersect_interval(bareinterval(x), bareinterval(x))
dec = min(decoration(x), decoration(y))
guarantee = isguaranteed(x) && isguaranteed(y)
decorated = IntervalArithmetic._unsafe_interval(intersection, dec, guarantee)
return Region(decorated)
end
isequal_region(X::Interval, Y::Interval) = isequal_interval(X, Y)
isequal_region(X::AbstractVector, Y::AbstractVector) = all(isequal_region.(X, Y))

isbounded_region(X::Interval) = isbounded(X)
isbounded_region(X::AbstractVector) = all(isbounded.(X))

isnai_region(X::Interval) = isnai(X)
isnai_region(X::AbstractVector) = any(isnai.(X))

function Base.intersect(X::Region{<:Any, N}, Y::Region{<:Any, N}) where N
return Region(intersect.(X.intervals, Y.intervals))
diam_region(X::Interval) = diam(X)
diam_region(X::AbstractVector) = maximum(diam.(X))

function bisect(X::Interval, α)
m = mid(X, α)
return (interval(inf(X), m), interval(m, sup(X)))
end

function bisect(X::AbstractVector, α)
X1 = copy(X)
X2 = copy(X)

i = argmax(diam.(X))
x1, x2 = bisect(X[i], α)
X1[i] = x1
X2[i] = x2
return X1, X2
end
10 changes: 5 additions & 5 deletions src/root_object.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ root, however such `Root`s are discarded by default and thus never returned by
the `roots` function.
# Fields
- `interval`: a region (either `Interval` or `SVector` of interval
- `region`: a region (either `Interval` or `SVector` of interval
representing an interval box) searched for roots.
- `status`: the status of the region, valid values are `:empty`, `unknown`
and `:unique`.
"""
struct Root{T, N}
region::IntervalRegion{T, N}
struct Root{T}
region::T
status::Symbol
end

Expand Down Expand Up @@ -46,5 +46,5 @@ show(io::IO, rt::Root) = print(io, "Root($(rt.region), :$(rt.status))")
(a::Interval, b::Root) = a b.region
(a::Root, b::Root) = a.region b.region

diam(r::Root) = diam(interval(r))
isnai(r::Root) = isnai(interval(r))
IntervalArithmetic.diam(r::Root) = diam_region(root_region(r))
IntervalArithmetic.isnai(r::Root) = isnai_region(root_region(r))
20 changes: 11 additions & 9 deletions src/roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ struct RootProblem{C, F, G, R, S, T}
where_bisect::T
end

RootProblem(f, region ; kwargs...) = RootProblem(f, Root(region, :unkown) ; kwargs...)

function RootProblem(
f, region ;
f, root::Root ;
contractor = Newton,
derivative = nothing,
search_order = BreadthFirst,
Expand All @@ -25,7 +27,7 @@ function RootProblem(
max_iteration = 100_000,
where_bisect = 0.49609375) # 127//256

N = length(region)
N = length(root_region(root))
if isnothing(derivative)
if N == 1
derivative = x -> ForwardDiff.derivative(f, x)
Expand All @@ -38,7 +40,7 @@ function RootProblem(
contractor,
f,
derivative,
Region(region),
root,
search_order,
abstol,
reltol,
Expand All @@ -47,8 +49,8 @@ function RootProblem(
)
end

function bisect(r::Root)
Y1, Y2 = bisect(interval(r))
function bisect(r::Root, α)
Y1, Y2 = bisect(root_region(r), α)
return Root(Y1, :unknown), Root(Y2, :unknown)
end

Expand Down Expand Up @@ -80,7 +82,7 @@ function roots(f, region ; kwargs...)
search = BranchAndPruneSearch(
root_problem.search_order,
X -> process(root_problem, X),
bisect,
X -> bisect(X, root_problem.where_bisect),
root_problem.region
)
result = bpsearch(search)
Expand All @@ -90,7 +92,7 @@ end
# TODO Reinstaste support for that
# Acting on complex `Interval`
function _roots(f, Xc::Complex{Interval{T}}, contractor::Type{C},
search_order::Type{S}, tol::Float64) where {T, C <: AbstractContractor, S <: SearchOrder}
search_order::Type{S}, tol::Float64) where {T, C, S <: SearchOrder}

g = realify(f)
Y = IntervalBox(reim(Xc)...)
Expand All @@ -99,7 +101,7 @@ function _roots(f, Xc::Complex{Interval{T}}, contractor::Type{C},
return [Root(Complex(root.interval...), root.status) for root in rts]
end

function _roots(f, Xc::Complex{Interval{T}}, contractor::NewtonLike,
function _roots(f, Xc::Complex{Interval{T}}, contractor,
search_order::Type{S}, tol::Float64) where {T, S <: SearchOrder}

g = realify(f)
Expand All @@ -110,7 +112,7 @@ function _roots(f, Xc::Complex{Interval{T}}, contractor::NewtonLike,
return [Root(Complex(root.interval...), root.status) for root in rts]
end

function _roots(f, deriv, Xc::Complex{Interval{T}}, contractor::NewtonLike,
function _roots(f, deriv, Xc::Complex{Interval{T}}, contractor,
search_order::Type{S}, tol::Float64) where {T, S <: SearchOrder}

g = realify(f)
Expand Down

0 comments on commit 203e2c1

Please sign in to comment.