```{-# LANGUAGE CPP              #-}
{-# LANGUAGE ParallelListComp #-}

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Safe #-}
#endif

module Math.ContinuedFraction
( CF
, cf, gcf
, asCF, asGCF

, truncateCF

, equiv
, setNumerators
, setDenominators

, partitionCF
, evenCF
, oddCF

, convergents
, steed
, lentz, lentzWith
, modifiedLentz, modifiedLentzWith

, sumPartialProducts
) where

import Control.Arrow ((***))
import Data.List     (tails)

-- * The 'CF' type and basic operations

-- I think I would like to try refactoring this stuff at some point to use
-- an "Inductive" CF type, something like:
--
-- > data CF a
-- >     = CFZero               -- eval CFZero          = 0
-- >     | CFAdd    a (CF a)    -- eval (CFAdd    b x) =      b + eval x
-- >     | CFCont a a (CF a)    -- eval (CFCont a b x) = a / (b + eval x)
--
-- Or perhaps Bill Gosper's "∞-centered" representation:
--
-- > data CF a
-- >     = CFInfinity           -- eval CFInfinity     = ∞
-- >     | CFCont a a (CF a)    -- eval (CFCont p q x) = p + q / eval x
--

-- |A continued fraction.  Constructed by 'cf' or 'gcf'.
data CF a
= CF a [a]
-- ^ Not exported. See 'cf', the public constructor.
| GCF a [(a,a)]
-- ^ Not exported. See 'gcf', the public constructor.

-- |Construct a continued fraction from its first term and the
-- partial denominators in its canonical form, which is the form
-- where all the partial numerators are 1.
--
-- @cf a [b,c,d]@ corresponds to @a + (b \/ (1 + (c \/ (1 + d))))@,
-- or to @GCF a [(1,b),(1,c),(1,d)]@.
cf :: a -> [a] -> CF a
cf = CF

-- |Construct a continued fraction from its first term, its partial
-- numerators and its partial denominators.
--
-- @gcf b0 [(a1,b1), (a2,b2), (a3,b3)]@ corresponds to
-- @b0 + (a1 \/ (b1 + (a2 \/ (b2 + (a3 \/ b3)))))@
gcf :: a -> [(a,a)] -> CF a
gcf = GCF

instance Show a => Show (CF a) where
showsPrec p (CF b0 ab) = showParen (p>10)
( showString "cf "
. showsPrec 11 b0
. showChar ' '
. showsPrec 11 ab
)
showsPrec p (GCF b0 ab) = showParen (p>10)
( showString "gcf "
. showsPrec 11 b0
. showChar ' '
. showsPrec 11 ab
)

instance Functor CF where
fmap f (CF  b0 cf')  = CF  (f b0) (map f cf')
fmap f (GCF b0 gcf') = GCF (f b0) (map (f *** f) gcf')

-- |Extract the partial denominators of a 'CF', normalizing it if necessary so
-- that all the partial numerators are 1.
asCF  :: Fractional a => CF a -> (a, [a])
asCF (CF  b0 cf') = (b0, cf')
asCF (GCF b0 []) = (b0, [])
asCF (GCF b0 cf') = (b0, zipWith (*) bs cs)
where
(a:as, bs) = unzip cf'
cs = recip a : [recip (a' * c) | c <- cs | a' <- as]

-- |Extract all the partial numerators and partial denominators of a 'CF'.
asGCF :: (Num a, Eq a) => CF a -> (a,[(a,a)])
asGCF (CF  b0  cf') = (b0, [(1, b) | b <- cf'])
asGCF (GCF b0 gcf') = (b0, takeWhile ((/=0).fst) gcf')

-- |Truncate a 'CF' to the specified number of partial numerators and denominators.
truncateCF :: Int -> CF a -> CF a
truncateCF n (CF  b0 ab) = CF  b0 (take n ab)
truncateCF n (GCF b0 ab) = GCF b0 (take n ab)

-- |Apply an equivalence transformation, multiplying each partial denominator
-- with the corresponding element of the supplied list and transforming
-- subsequent partial numerators and denominators as necessary.  If the list
-- is too short, the rest of the 'CF' will be unscaled.
equiv :: (Num a, Eq a) => [a] -> CF a -> CF a
equiv cs orig
= gcf b0 (zip as' bs')
where
(b0, terms) = asGCF orig
(as,bs) = unzip terms

as' = zipWith (*) (zipWith (*) cs' (1:cs')) as
bs' = zipWith (*) cs' bs
cs' = cs ++ repeat 1

-- |Apply an equivalence transformation that sets the partial denominators
-- of a 'CF' to the specfied values.  If the input list is too short, the
-- rest of the 'CF' will be unscaled.
setDenominators :: (Fractional a, Eq a) => [a] -> CF a -> CF a
setDenominators denoms orig
= gcf b0 (zip as' bs')
where
(b0, terms) = asGCF orig
(as,bs) = unzip terms

as' = zipWith (*) as (zipWith (*) cs (1:cs))
bs' = zipWith (\$) (map const denoms ++ repeat id) bs
cs = zipWith (/) bs' bs

-- |Apply an equivalence transformation that sets the partial numerators
-- of a 'CF' to the specfied values.  If the input list is too short, the
-- rest of the 'CF' will be unscaled.
setNumerators :: (Fractional a, Eq a) => [a] -> CF a -> CF a
setNumerators numers orig
= gcf b0 (zip as' bs')
where
(b0, terms) = asGCF orig
(as,bs) = unzip terms

as' = zipWith (\$) (map const numers ++ repeat id) as
bs' = zipWith (*) bs cs
cs = zipWith (/) as' (zipWith (*) as (1:cs))

-- |Computes the even and odd parts, respectively, of a 'CF'.  These are new
-- 'CF's that have the even-indexed and odd-indexed convergents of the
-- original, respectively.
partitionCF :: (Fractional a, Eq a) => CF a -> (CF a, CF a)
partitionCF orig = case terms of
[]          -> (orig, orig)
[(a1,b1)]   ->
let final = cf (b0 + a1/b1) []
in (final, final)
_           -> (evenPart, oddPart)
where
(b0, terms) = asGCF orig
(as, bs)    = unzip terms

pairs (a:b:rest) = (a,b) : pairs rest
pairs [a] = [(a,0)]
pairs _ = []

alphas@(alpha1:alpha2:_) = zipWith (/) as (zipWith (*) bs (1:bs))

evenPart = gcf b0 (zip cs ds)
where
cs =     alpha1 : [(-aOdd) * aEven  | (aEven, aOdd)  <- pairs (tail alphas)]
ds = 1 + alpha2 : [1 + aOdd + aEven | (aOdd,  aEven) <- tail (pairs alphas)]

oddPart = gcf (b0+alpha1) (zip cs ds)
where
cs = [(-aOdd) * aEven  | (aOdd, aEven) <- pairs alphas]
ds = [1 + aOdd + aEven | (aEven, aOdd) <- pairs (tail alphas)]

-- |Computes the even part of a 'CF' (that is, a new 'CF' whose convergents are
-- the even-indexed convergents of the original).
evenCF :: (Fractional a, Eq a) => CF a -> CF a
evenCF = fst . partitionCF

-- |Computes the odd part of a 'CF' (that is, a new 'CF' whose convergents are
-- the odd-indexed convergents of the original).
oddCF :: (Fractional a, Eq a) => CF a -> CF a
oddCF = snd . partitionCF

-- * Evaluating continued fractions

-- |Evaluate the convergents of a continued fraction using the fundamental
-- recurrence formula:
--
-- @A0 = b0, B0 = 1@
--
-- @A1 = b1b0 + a1,  B1 = b1@
--
-- @A{n+1} = b{n+1}An + a{n+1}A{n-1}@
--
-- @B{n+1} = b{n+1}Bn + a{n+1}B{n-1}@
--
-- The convergents are then @Xn = An/Bn@
convergents :: (Fractional a, Eq a) => CF a -> [a]
convergents orig = drop 1 (zipWith (/) nums denoms)
where
(b0, terms) = asGCF orig
nums   = 1:b0:[b * x1 + a * x0 | x0:x1:_ <- tails nums   | (a,b) <- terms]
denoms = 0:1 :[b * x1 + a * x0 | x0:x1:_ <- tails denoms | (a,b) <- terms]

-- |Evaluate the convergents of a continued fraction using Steed's method.
-- Only valid if the denominator in the following recurrence for D_i never
-- goes to zero.  If this method blows up, try 'modifiedLentz'.
--
-- @D1 = 1/b1@
--
-- @D{i} = 1 / (b{i} + a{i} * D{i-1})@
--
-- @dx1 = a1 / b1@
--
-- @dx{i} = (b{i} * D{i} - 1) * dx{i-1}@
--
-- @x0 = b0@
--
-- @x{i} = x{i-1} + dx{i}@
--
-- The convergents are given by @scanl (+) b0 dxs@
steed :: (Fractional a, Eq a) => CF a -> [a]
steed (CF  b0 []) = [b0]
steed (GCF b0 []) = [b0]
steed (CF  0 (  a  :rest)) = map (1 /) (steed (CF  a rest))
steed (GCF 0 ((a,b):rest)) = map (a /) (steed (GCF b rest))
steed orig
= scanl (+) b0 dxs
where
(b0, (a1,b1):gcf') = asGCF orig

dxs = a1/b1 : [(b * d - 1) * dx  | (_,b) <- gcf' | d <- ds | dx <- dxs]
ds  =  1/b1 : [recip (b + a * d) | (a,b) <- gcf' | d <- ds]

-- |Evaluate the convergents of a continued fraction using Lentz's method.
-- Only valid if the denominators in the following recurrence never go to
-- zero.  If this method blows up, try 'modifiedLentz'.
--
-- @C1 = b1 + a1 / b0@
--
-- @D1 = 1/b1@
--
-- @C{n} = b{n} + a{n} / C{n-1}@
--
-- @D{n} = 1 / (b{n} + a{n} * D{n-1})@
--
-- The convergents are given by @scanl (*) b0 (zipWith (*) cs ds)@
lentz :: (Fractional a, Eq a) => CF a -> [a]
lentz = lentzWith id (*) recip

-- |Evaluate the convergents of a continued fraction using Lentz's method,
-- mapping the terms in the final product to a new group before performing
-- the final multiplications.  A useful group, for example, would be logarithms
-- under addition.  In @lentzWith f op inv@, the arguments are:
--
-- * @f@, a group homomorphism (eg, 'log') from {@a@,(*),'recip'} to the group
--   in which you want to perform the multiplications.
--
-- * @op@, the group operation (eg., (+)).
--
-- * @inv@, the group inverse (eg., 'negate').
--
-- The 'lentz' function, for example, is given by the identity homomorphism:
-- @lentz@ = @lentzWith id (*) recip@.
--
-- The original motivation for this function is to allow computation of
-- the natural log of very large numbers that would overflow with the naive
-- implementation in 'lentz'.  In this case, the arguments would be 'log', (+),
-- and 'negate', respectively.
--
-- In cases where terms of the product can be negative (i.e., the sequence of
-- convergents contains negative values), the following definitions could
-- be used instead:
--
-- > signLog x = (signum x, log (abs x))
-- > addSignLog (xS,xL) (yS,yL) = (xS*yS, xL+yL)
-- > negateSignLog (s,l) = (s, negate l)
{-# INLINE lentzWith #-}
lentzWith :: (Fractional a, Eq a) => (a -> b) -> (b -> b -> b) -> (b -> b) -> CF a -> [b]
lentzWith f op inv (CF  0 (  a  :rest)) = map inv              (lentzWith f op inv (CF  a rest))
lentzWith f op inv (GCF 0 ((a,b):rest)) = map (op (f a) . inv) (lentzWith f op inv (GCF b rest))
lentzWith f op _   c = scanl opF (f b0) (zipWith (*) cs ds)
where
opF x y = op x (f y)
(b0, cs, ds) = lentzRecurrence c

-- precondition: b0 /= 0
lentzRecurrence :: (Fractional a, Eq a) => CF a -> (a,[a],[a])
lentzRecurrence orig
| null terms    = (b0,[],[])
| otherwise = (b0, cs, ds)
where
(b0, terms) = asGCF orig

cs = [   b + a/c  | (a,b) <- terms | c <- b0 : cs]
ds = [1/(b + a*d) | (a,b) <- terms | d <- 0  : ds]

-- |Evaluate the convergents of a continued fraction using Lentz's method,
-- (see 'lentz') with the additional rule that if a denominator ever goes
-- to zero, it will be replaced by a (very small) number of your choosing,
-- typically 1e-30 or so (this modification was proposed by Thompson and
-- Barnett).
--
-- Additionally splits the resulting list of convergents into sublists,
-- starting a new list every time the \'modification\' is invoked.
modifiedLentz :: (Fractional a, Eq a) => a -> CF a -> [[a]]
modifiedLentz = modifiedLentzWith id (*) recip

-- |'modifiedLentz' with a group homomorphism (see 'lentzWith', it bears the
-- same relationship to 'lentz' as this function does to 'modifiedLentz',
-- and solves the same problems).  Alternatively, 'lentzWith' with the same
-- modification to the recurrence as 'modifiedLentz'.
{-# INLINE modifiedLentzWith #-}
modifiedLentzWith :: (Fractional a, Eq a) => (a -> b) -> (b -> b -> b) -> (b -> b) -> a -> CF a -> [[b]]
modifiedLentzWith f op inv z (CF  0 (  a  :rest)) = map (map             inv ) (modifiedLentzWith f op inv z (CF  a rest))
modifiedLentzWith f op inv z (GCF 0 ((a,b):rest)) = map (map (op (f a) . inv)) (modifiedLentzWith f op inv z (GCF b rest))
modifiedLentzWith f op _   z orig = separate (scanl opF (False, f b0) cds)
where
(b0, cs, ds) = modifiedLentzRecurrence z orig
cds = zipWith mult cs ds

mult (xa,xb) (ya,yb) = (xa || ya, xb * yb)
opF  (xa,xb) (ya,yb) = (xa || ya, op xb (f yb))

-- |Takes a list of (Bool,a) and breaks it into sublists, starting
-- a new one every time it encounters (True,_).
separate [] = []
separate ((_,x):xs) = case break fst xs of
(xs', ys) -> (x:map snd xs') : separate ys

-- precondition: b0 /= 0
modifiedLentzRecurrence :: (Fractional a, Eq a) => a -> CF a -> (a,[(Bool, a)],[(Bool, a)])
modifiedLentzRecurrence z orig
| null terms = (b0, [], [])
| otherwise  = (b0, cs, ds)
where
(b0, terms) = asGCF orig

cs = [reset (b + a/c)    id | (a,b) <- terms | c <- b0 : map snd cs]
ds = [reset (b + a*d) recip | (a,b) <- terms | d <- 0  : map snd ds]

-- The sublist breaking is computed secondarily - initially,
-- 'cs' and 'ds' are constructed with this helper function that
-- adds a marker to the list whenever a term of interest goes to 0,
-- while also resetting that term to a small nonzero amount.
-- Then later, 'separate' breaks the list every time it sees one
-- of these markers.
reset x f
| x == 0    = (True,  f z)
| otherwise = (False, f x)

-- |Euler's formula for computing @sum (scanl1 (*) xs)@.
-- Successive convergents of the resulting 'CF' are successive partial sums
-- in the series.
sumPartialProducts :: Num a => [a] -> CF a
sumPartialProducts [] = cf 0 []
sumPartialProducts (x:xs) = gcf 0 ((x, 1):[(negate x', 1 + x') | x' <- xs])```