module Data.Polynomial.RootSeparation.Sturm
( SturmChain
, sturmChain
, numRoots
, numRoots'
, separate
, separate'
, halve
, halve'
, narrow
, narrow'
, approx
, approx'
) where
import Data.Maybe
import Data.Polynomial
import qualified Data.Interval as Interval
import Data.Interval (Interval, EndPoint (..), (<..<=), (<=..<=))
type SturmChain = [UPolynomial Rational]
sturmChain :: UPolynomial Rational -> SturmChain
sturmChain p = p0 : p1 : go p0 p1
where
p0 = p
p1 = deriv p X
go p q = if r==0 then [] else r : go q r
where
r = (p `polyMod` q)
numRoots
:: UPolynomial Rational
-> Interval Rational
-> Int
numRoots p ival = numRoots' (sturmChain p) ival
numRoots'
:: SturmChain
-> Interval Rational
-> Int
numRoots' chain@(p:_) ival
| Interval.null ival2 = 0
| otherwise =
case (Interval.lowerBound ival2, Interval.upperBound ival2) of
(Finite lb, Finite ub) ->
(if lb==ub then 0 else (n lb n ub)) +
(if lb `Interval.member` ival2 && isRootOf lb p then 1 else 0) +
(if ub `Interval.notMember` ival2 && isRootOf ub p then 1 else 0)
_ -> error "numRoots'': should not happen"
where
ival2 = boundInterval p ival
n x = countSignChanges [eval (\X -> x) q | q <- chain]
countSignChanges :: [Rational] -> Int
countSignChanges rs = countChanges xs
where
xs :: [Bool]
xs = map (0<) . filter (0/=) $ rs
countChanges :: Eq a => [a] -> Int
countChanges [] = 0
countChanges (x:xs) = go x xs 0
where
go x [] r = r
go x1 (x2:xs) r
| x1==x2 = go x1 xs r
| otherwise = go x2 xs (r+1)
bounds :: UPolynomial Rational -> (Rational, Rational)
bounds p = (m, m)
where
m = if p==0
then 0
else max 1 (sum [abs (c/s) | (c,_) <- terms p] 1)
(s,_) = leadingTerm grlex p
boundInterval :: UPolynomial Rational -> Interval Rational -> Interval Rational
boundInterval p ival = Interval.intersection ival (Finite lb <=..<= Finite ub)
where
(lb,ub) = bounds p
separate :: UPolynomial Rational -> [Interval Rational]
separate p = separate' (sturmChain p)
separate' :: SturmChain -> [Interval Rational]
separate' chain@(p:_) = f (bounds p)
where
n x = countSignChanges [eval (\X -> x) q | q <- chain]
f (lb,ub) =
if lb `isRootOf` p
then Interval.singleton lb : g (lb,ub)
else g (lb,ub)
g (lb,ub) =
case n lb n ub of
0 -> []
1 -> [Finite lb <..<= Finite ub]
_ -> g (lb, mid) ++ g (mid, ub)
where
mid = (lb + ub) / 2
halve :: UPolynomial Rational -> Interval Rational -> Interval Rational
halve p ival = halve' (sturmChain p) ival
halve' :: SturmChain -> Interval Rational -> Interval Rational
halve' chain@(p:_) ival
| Interval.width ival == 0 = ival
| numRoots' chain ivalL > 0 = ivalL
| otherwise = ivalR
where
Finite lb = Interval.lowerBound ival
Finite ub = Interval.upperBound ival
mid = (lb + ub) / 2
ivalL = Interval.interval (Interval.lowerBound' ival) (Finite mid, True)
ivalR = Interval.interval (Finite mid, False) (Interval.upperBound' ival)
narrow :: UPolynomial Rational -> Interval Rational -> Rational -> Interval Rational
narrow p ival size = narrow' (sturmChain p) ival size
narrow' :: SturmChain -> Interval Rational -> Rational -> Interval Rational
narrow' chain@(p:_) ival size = go (boundInterval p ival)
where
go ival
| Interval.width ival < size = ival
| otherwise = go (halve' chain ival)
approx :: UPolynomial Rational -> Interval Rational -> Rational -> Rational
approx p = approx' (sturmChain p)
approx' :: SturmChain -> Interval Rational -> Rational -> Rational
approx' chain ival epsilon = fromJust $ Interval.pickup $ narrow' chain ival epsilon