module Data.AlgebraicNumber.Real
  (
  
    AReal
  
  , realRoots
  , realRootsEx
  
  , minimalPolynomial
  , deg
  , isRational
  , isAlgebraicInteger
  , height
  , rootIndex
  
  , nthRoot
  
  , approx
  , approxInterval
  
  , simpARealPoly
  , goldenRatio
  ) where
import Control.Exception (assert)
import Control.Monad
import Data.List
import Data.Ratio
import qualified Data.Set as Set
import qualified Text.PrettyPrint.HughesPJClass as PP
import Text.PrettyPrint.HughesPJClass (Doc, PrettyLevel, Pretty (..), prettyParen)
import Data.Polynomial
import qualified Data.Polynomial as P
import qualified Data.Polynomial.Factorization.Rational as FactorQ
import qualified Data.Polynomial.RootSeparation.Sturm as Sturm
import Data.Interval (Interval, EndPoint (..), (<=..<), (<..<=), (<..<), (<!), (>!))
import qualified Data.Interval as Interval
import Data.AlgebraicNumber.Root
data AReal = RealRoot (UPolynomial Rational) (Interval Rational)
  deriving Show
realRoots :: UPolynomial Rational -> [AReal]
realRoots p = Set.toAscList $ Set.fromList $ do
  (q,_) <- FactorQ.factor p
  realRoots' q
realRootsEx :: UPolynomial AReal -> [AReal]
realRootsEx p
  | and [isRational c | (c,_) <- terms p] = realRoots $ mapCoeff toRational p
  | otherwise = [a | a <- realRoots (simpARealPoly p), a `isRootOf` p]
realRoots' :: UPolynomial Rational -> [AReal]
realRoots' p = do
  guard $ deg p > 0
  i <- Sturm.separate p
  return $ realRoot' p i
realRoot :: UPolynomial Rational -> Interval Rational -> AReal
realRoot p i = 
  case [q | (q,_) <- FactorQ.factor p, deg q > 0, Sturm.numRoots q i == 1] of
    p2:_ -> realRoot' p2 i
    []   -> error "Data.AlgebraicNumber.Real.realRoot: invalid interval"
realRoot' :: UPolynomial Rational -> Interval Rational -> AReal
realRoot' p i = RealRoot (normalizePoly p) i
isZero :: AReal -> Bool
isZero a = 0 `Interval.member` (interval a) && 0 `isRootOf` minimalPolynomial a
scaleAReal :: Rational -> AReal -> AReal
scaleAReal r a = realRoot' p2 i2
  where
    p2 = rootScale r (minimalPolynomial a)
    i2 = Interval.singleton r * interval a
shiftAReal :: Rational -> AReal -> AReal
shiftAReal r a = realRoot' p2 i2
  where
    p2 = rootShift r (minimalPolynomial a)
    i2 = Interval.singleton r + interval a
instance Eq AReal where
  a == b = p1==p2 && Sturm.numRoots' c (Interval.intersection i1 i2) == 1
    where
      p1 = minimalPolynomial a
      p2 = minimalPolynomial b
      i1 = interval a
      i2 = interval b
      c  = sturmChain a
instance Ord AReal where
  compare a b
    | i1 >! i2  = GT
    | i1 <! i2  = LT
    | a == b    = EQ
    | otherwise = go i1 i2
    where
      i1 = interval a
      i2 = interval b
      c1 = sturmChain a
      c2 = sturmChain b
      go i1 i2
        | i1 >! i2 = GT
        | i1 <! i2 = LT
        | otherwise =
            if Interval.width i1 > Interval.width i2
            then go (Sturm.halve' c1 i1) i2
            else go i1 (Sturm.halve' c2 i2)
instance Num AReal where
  a + b
    | isRational a = shiftAReal (toRational a) b
    | isRational b = shiftAReal (toRational b) a
    | otherwise    = realRoot p3 i3
    where
      p3 = rootAdd (minimalPolynomial a) (minimalPolynomial b)
      c1 = sturmChain a
      c2 = sturmChain b
      c3 = Sturm.sturmChain p3
      i3 = go (interval a) (interval b) (Sturm.separate' c3)
      go i1 i2 is3 =
        case [i5 | i3 <- is3, let i5 = Interval.intersection i3 i4, Sturm.numRoots' c3 i5 > 0] of
          []   -> error "AReal.+: should not happen"
          [i5] -> i5
          is5  -> go (Sturm.halve' c1 i1) (Sturm.halve' c2 i2) [Sturm.halve' c3 i5 | i5 <- is5]
        where
          i4 = i1 + i2
  a * b
    | isRational a = scaleAReal (toRational a) b
    | isRational b = scaleAReal (toRational b) a
    | otherwise    = realRoot p3 i3
    where
      p3 = rootMul (minimalPolynomial a) (minimalPolynomial b)
      c1 = sturmChain a
      c2 = sturmChain b
      c3 = Sturm.sturmChain p3
      i3 = go (interval a) (interval b) (Sturm.separate' c3)
      go i1 i2 is3 =
        case [i5 | i3 <- is3, let i5 = Interval.intersection i3 i4, Sturm.numRoots' c3 i5 > 0] of
          []   -> error "AReal.*: should not happen"
          [i5] -> i5
          is5  -> go (Sturm.halve' c1 i1)(Sturm.halve' c2 i2)[Sturm.halve' c3 i5 | i5 <- is5]
        where
          i4 = i1 * i2
  negate a = scaleAReal (1) a
  abs a =
    case compare 0 a of
      EQ -> fromInteger 0
      LT -> a
      GT -> negate a
  signum a = fromInteger $
    case compare 0 a of
      EQ -> 0
      LT -> 1
      GT -> 1
  fromInteger = fromRational . toRational
instance Fractional AReal where
  fromRational r = realRoot' (x  constant r) (Interval.singleton r)
    where
      x = var X
  recip a
    | isZero a  = error "AReal.recip: zero division"
    | otherwise = realRoot' p2 i2
      where
        p2 = rootRecip (minimalPolynomial a)
        i2 = recip (interval a)
instance Real AReal where
  toRational x
    | isRational x =
        let p = minimalPolynomial x
            a = P.coeff (P.var X) p
            b = P.coeff P.mmOne p
        in  b / a
    | otherwise  = error "toRational: proper algebraic number"
instance RealFrac AReal where
  properFraction = properFraction'
  truncate       = truncate'
  round          = round'
  ceiling        = ceiling'
  floor          = floor'
approx
  :: AReal    
  -> Rational 
  -> Rational
approx a epsilon =
  if isRational a
    then toRational a
    else Sturm.approx' (sturmChain a) (interval a) epsilon
approxInterval
  :: AReal    
  -> Rational 
  -> Interval Rational
approxInterval a epsilon =
  if isRational a
    then Interval.singleton (toRational a)
    else Sturm.narrow' (sturmChain a) (interval a) epsilon
properFraction' :: Integral b => AReal -> (b, AReal)
properFraction' x =
  case compare x 0 of
    EQ -> (0, 0)
    GT -> (fromInteger floor_x, x  fromInteger floor_x)
    LT -> (fromInteger ceiling_x, x  fromInteger ceiling_x)
  where
    floor_x   = floor' x
    ceiling_x = ceiling' x
truncate' :: Integral b => AReal -> b
truncate' = fst . properFraction'
round' :: Integral b => AReal -> b
round' x = 
  case signum (abs r  0.5) of
    1 -> n
    0  -> if even n then n else m
    1  -> m
    _  -> error "round default defn: Bad value"
  where
    (n,r) = properFraction' x
    m = if r < 0 then n  1 else n + 1
ceiling' :: Integral b => AReal -> b
ceiling' a =
  if Sturm.numRoots' chain (Interval.intersection i2 i3) >= 1
    then fromInteger ceiling_lb
    else fromInteger ceiling_ub
  where
    chain = sturmChain a
    i2 = Sturm.narrow' chain (interval a) (1/2)
    (Finite lb, inLB) = Interval.lowerBound' i2
    (Finite ub, inUB) = Interval.upperBound' i2
    ceiling_lb = ceiling lb
    ceiling_ub = ceiling ub
    i3 = NegInf <..<= Finite (fromInteger ceiling_lb)
floor' :: Integral b => AReal -> b
floor' a =
  if Sturm.numRoots' chain (Interval.intersection i2 i3) >= 1
    then fromInteger floor_ub
    else fromInteger floor_lb
  where
    chain = sturmChain a
    i2 = Sturm.narrow' chain (interval a) (1/2)
    (Finite lb, inLB) = Interval.lowerBound' i2
    (Finite ub, inUB) = Interval.upperBound' i2
    floor_lb = floor lb
    floor_ub = floor ub
    i3 = Finite (fromInteger floor_ub) <=..< PosInf
nthRoot :: Integer -> AReal -> AReal
nthRoot n a
  | n <= 0 = error "Data.AlgebraicNumver.Root.nthRoot"
  | even n =
      if a < 0
      then error "Data.AlgebraicNumver.Root.nthRoot: no real roots"
      else assert (length bs == 2) (maximum bs) 
  | otherwise =
      assert (length bs == 1) (head bs) 
  where
    bs = nthRoots n a
nthRoots :: Integer -> AReal -> [AReal]
nthRoots n _ | n <= 0 = []
nthRoots n a | even n && a < 0 = []
nthRoots n a = filter check (realRoots p2)
  where
    p1 = minimalPolynomial a
    p2 = rootNthRoot n p1
    c1 = sturmChain a
    ok0 = interval a
    ng0 = map interval $ delete a $ realRoots p1
    check :: AReal -> Bool
    check b = loop ok0 ng0 (interval b)
      where
        c2 = sturmChain b
        loop ok ng i
          | Sturm.numRoots' c1 ok' == 0 = False
          | null ng'  = True
          | otherwise =
              loop (Sturm.halve' c1 ok')
                   (map (\i3 -> Sturm.halve' c1 i3) ng')
                   (Sturm.halve' c2 i)
          where
            i2  = i ^ n
            ok' = Interval.intersection i2 ok
            ng' = filter (\i3 -> Sturm.numRoots' c1 i3 /= 0) $
                    map (Interval.intersection i2) ng
minimalPolynomial :: AReal -> UPolynomial Rational
minimalPolynomial (RealRoot p _) = p
sturmChain :: AReal -> Sturm.SturmChain
sturmChain a = Sturm.sturmChain (minimalPolynomial a)
interval :: AReal -> Interval Rational
interval (RealRoot _ i) = i
instance Degree AReal where
  deg a = deg $ minimalPolynomial a
isRational :: AReal -> Bool
isRational x = deg x == 1
isAlgebraicInteger :: AReal -> Bool
isAlgebraicInteger x = cn * fromIntegral d == 1
  where
    p = minimalPolynomial x
    d = foldl' lcm 1 [denominator c | (c,_) <- terms p]
    (cn,_) = leadingTerm grlex p
height :: AReal -> Integer
height x = maximum [ assert (denominator c' == 1) (abs (numerator c'))
                   | (c,_) <- terms p, let c' = c * fromInteger d ]
  where
    p = minimalPolynomial x
    d = foldl' lcm 1 [denominator c | (c,_) <- terms p]
rootIndex :: AReal -> Int
rootIndex a = idx
  where
    as = realRoots' (minimalPolynomial a)
    Just idx = elemIndex a as
instance Pretty AReal where
  pPrintPrec lv prec r =
    prettyParen (prec > appPrec) $
      PP.hsep [PP.text "RealRoot", pPrintPrec lv (appPrec+1) p, PP.int (rootIndex r)]
    where
      p = minimalPolynomial r
      appPrec = 10
instance PrettyCoeff AReal where
  pPrintCoeff = pPrintPrec
  isNegativeCoeff = (0>)
simpARealPoly :: UPolynomial AReal -> UPolynomial Rational
simpARealPoly p = rootSimpPoly minimalPolynomial p
goldenRatio :: AReal
goldenRatio = (1 + root5) / 2
  where
    [_, root5] = sort $ realRoots' ((var X)^2  5)