```{-# LANGUAGE ParallelListComp #-}
module Math.ContinuedFraction
( CF
, cf, gcf
, asCF, asGCF

, truncateCF

, equiv
, setNumerators
, setDenominators

, partitionCF
, evenCF
, oddCF

, convergents
, steed
, lentz
, modifiedLentz

, sumPartialProducts
) where

import Control.Arrow ((***))
import Data.List (tails, mapAccumL)

-- * The 'CF' type and basic operations

-- I think I would like to try refactoring this stuff at some point to use
-- an "Inductive" CF type, something like:
--
-- > data CF a
-- >     = CFZero               -- eval CFZero          = 0
-- >     | CFAdd    a (CF a)    -- eval (CFAdd    b x) =      b + eval x
-- >     | CFCont a a (CF a)    -- eval (CFCont a b x) = a / (b + eval x)

-- |A continued fraction.  Constructed by 'cf' or 'gcf'.
data CF a
= CF a [a]
-- ^ Not exported. See 'cf', the public constructor.
| GCF a [(a,a)]
-- ^ Not exported. See 'gcf', the public constructor.

-- |Construct a continued fraction from its first term and the
-- partial denominators in its canonical form, which is the form
-- where all the partial numerators are 1.
--
-- @cf a [b,c,d]@ corresponds to @a + (b \/ (1 + (c \/ (1 + d))))@,
-- or to @GCF a [(1,b),(1,c),(1,d)]@.
cf :: a -> [a] -> CF a
cf = CF

-- |Construct a continued fraction from its first term, its partial
-- numerators and its partial denominators.
--
-- @gcf b0 [(a1,b1), (a2,b2), (a3,b3)]@ corresponds to
-- @b0 + (a1 \/ (b1 + (a2 \/ (b2 + (a3 \/ b3)))))@
gcf :: a -> [(a,a)] -> CF a
gcf = GCF

instance Show a => Show (CF a) where
showsPrec p (CF b0 ab) = showParen (p>10)
( showString "cf "
. showsPrec 11 b0
. showChar ' '
. showsPrec 11 ab
)
showsPrec p (GCF b0 ab) = showParen (p>10)
( showString "gcf "
. showsPrec 11 b0
. showChar ' '
. showsPrec 11 ab
)

instance Functor CF where
fmap f (CF  b0 cf)  = CF  (f b0) (map f cf)
fmap f (GCF b0 gcf) = GCF (f b0) (map (f *** f) gcf)

-- |Extract the partial denominators of a 'CF', normalizing it if necessary so
-- that all the partial numerators are 1.
asCF  :: Fractional a => CF a -> (a, [a])
asCF (CF  b0 cf) = (b0, cf)
asCF (GCF b0 []) = (b0, [])
asCF (GCF b0 cf) = (b0, zipWith (*) bs cs)
where
(a:as, bs) = unzip cf
cs = recip a : [recip (a*c) | c <- cs | a <- as]

-- |Extract all the partial numerators and partial denominators of a 'CF'.
asGCF :: Num a => CF a -> (a,[(a,a)])
asGCF (CF  b0  cf) = (b0, [(1, b) | b <- cf])
asGCF (GCF b0 gcf) = (b0, takeWhile ((/=0).fst) gcf)

-- |Truncate a 'CF' to the specified number of partial numerators and denominators.
truncateCF :: Int -> CF a -> CF a
truncateCF n (CF  b0 ab) = CF  b0 (take n ab)
truncateCF n (GCF b0 ab) = GCF b0 (take n ab)

-- |Apply an equivalence transformation, multiplying each partial denominator
-- with the corresponding element of the supplied list and transforming
-- subsequent partial numerators and denominators as necessary.  If the list
-- is too short, the rest of the 'CF' will be unscaled.
equiv :: Num a => [a] -> CF a -> CF a
equiv cs orig
= gcf b0 (zip as' bs')
where
(b0, terms) = asGCF orig
(as,bs) = unzip terms

as' = zipWith (*) (zipWith (*) cs' (1:cs')) as
bs' = zipWith (*) cs' bs
cs' = cs ++ repeat 1

-- |Apply an equivalence transformation that sets the partial denominators
-- of a 'CF' to the specfied values.  If the input list is too short, the
-- rest of the 'CF' will be unscaled.
setDenominators :: Fractional a => [a] -> CF a -> CF a
setDenominators denoms orig
= gcf b0 (zip as' bs')
where
(b0, terms) = asGCF orig
(as,bs) = unzip terms

as' = zipWith (*) as (zipWith (*) cs (1:cs))
bs' = zipWith (\$) (map const denoms ++ repeat id) bs
cs = zipWith (/) bs' bs

-- |Apply an equivalence transformation that sets the partial numerators
-- of a 'CF' to the specfied values.  If the input list is too short, the
-- rest of the 'CF' will be unscaled.
setNumerators :: Fractional a => [a] -> CF a -> CF a
setNumerators numers orig
= gcf b0 (zip as' bs')
where
(b0, terms) = asGCF orig
(as,bs) = unzip terms

as' = zipWith (\$) (map const numers ++ repeat id) as
bs' = zipWith (*) bs cs
cs = zipWith (/) as' (zipWith (*) as (1:cs))

-- |Computes the even and odd parts, respectively, of a 'CF'.  These are new
-- 'CF's that have the even-indexed and odd-indexed convergents of the
-- original, respectively.
partitionCF :: Fractional a => CF a -> (CF a, CF a)
partitionCF orig = case terms of
[]          -> (orig, orig)
[(a1,b1)]   ->
let final = cf (b0 + a1/b1) []
in (final, final)
_           -> (evenPart, oddPart)
where
(b0, terms) = asGCF orig
(as, bs)    = unzip terms

pairs (a:b:rest) = (a,b) : pairs rest
pairs [a] = [(a,0)]
pairs _ = []

alphas@(alpha1:alpha2:_) = zipWith (/) as (zipWith (*) bs (1:bs))

evenPart = gcf b0 (zip cs ds)
where
cs =     alpha1 : [(-aOdd) * aEven  | (aEven, aOdd)  <- pairs (tail alphas)]
ds = 1 + alpha2 : [1 + aOdd + aEven | (aOdd,  aEven) <- tail (pairs alphas)]

oddPart = gcf (b0+alpha1) (zip cs ds)
where
cs = [(-aOdd) * aEven  | (aOdd, aEven) <- pairs alphas]
ds = [1 + aOdd + aEven | (aEven, aOdd) <- pairs (tail alphas)]

-- |Computes the even part of a 'CF' (that is, a new 'CF' whose convergents are
-- the even-indexed convergents of the original).
evenCF :: Fractional a => CF a -> CF a
evenCF = fst . partitionCF

-- |Computes the odd part of a 'CF' (that is, a new 'CF' whose convergents are
-- the odd-indexed convergents of the original).
oddCF :: Fractional a => CF a -> CF a
oddCF = snd . partitionCF

-- * Evaluating continued fractions

-- |Evaluate the convergents of a continued fraction using the fundamental
-- recurrence formula:
--
-- A0 = b0, B0 = 1
--
-- A1 = b1b0 + a1,  B1 = b1
--
-- A{n+1} = b{n+1}An + a{n+1}A{n-1}
--
-- B{n+1} = b{n+1}Bn + a{n+1}B{n-1}
--
-- The convergents are then Xn = An/Bn
convergents :: Fractional a => CF a -> [a]
convergents orig = drop 1 (zipWith (/) nums denoms)
where
(b0, terms) = asGCF orig
nums   = 1:b0:[b * x1 + a * x0 | x0:x1:_ <- tails nums   | (a,b) <- terms]
denoms = 0:1 :[b * x1 + a * x0 | x0:x1:_ <- tails denoms | (a,b) <- terms]

-- |Evaluate the convergents of a continued fraction using Steed's method.
-- Only valid if the denominator in the following recurrence for D_i never
-- goes to zero.  If this method blows up, try 'modifiedLentz'.
--
-- D1 = 1/b1
--
-- D{i} = 1 / (b{i} + a{i} * D{i-1})
--
-- dx1 = a1 / b1
--
-- dx{i} = (b{i} * D{i} - 1) * dx{i-1}
--
-- x0 = b0
--
-- x{i} = x{i-1} + dx{i}
--
-- The convergents are given by @scanl (+) b0 dxs@
steed :: Fractional a => CF a -> [a]
steed (CF  b0 []) = [b0]
steed (GCF b0 []) = [b0]
steed (CF  0 (  a  :rest)) = map (1 /) (steed (CF  a rest))
steed (GCF 0 ((a,b):rest)) = map (a /) (steed (GCF b rest))
steed orig
= scanl (+) b0 dxs
where
(b0, (a1,b1):gcf) = asGCF orig

dxs = a1/b1 : [(b * d - 1) * dx  | (a,b) <- gcf | d <- ds | dx <- dxs]
ds  =  1/b1 : [recip (b + a * d) | (a,b) <- gcf | d <- ds]

-- |Evaluate the convergents of a continued fraction using Lentz's method.
-- Only valid if the denominators in the following recurrence never go to
-- zero.  If this method blows up, try 'modifiedLentz'.
--
-- C1 = b1 + a1 / b0
--
-- D1 = 1/b1
--
-- C{n} = b{n} + a{n} / C{n-1}
--
-- D{n} = 1 / (b{n} + a{n} * D{n-1})
--
-- The convergents are given by @scanl (*) b0 (zipWith (*) cs ds)@
lentz :: Fractional a => CF a -> [a]
lentz (CF  b0 []) = [b0]
lentz (GCF b0 []) = [b0]
lentz (CF  0 (  a  :rest)) = map (1 /) (lentz (CF  a rest))
lentz (GCF 0 ((a,b):rest)) = map (a /) (lentz (GCF b rest))
lentz orig
= scanl (*) b0 (zipWith (*) cs ds)
where
(b0, gcf) = asGCF orig

cs = [   b + a/c  | (a,b) <- gcf | c <- b0 : cs]
ds = [1/(b + a*d) | (a,b) <- gcf | d <- 0  : ds]

-- |Evaluate the convergents of a continued fraction using Lentz's method,
-- (see 'lentz') with the additional rule that if a denominator ever goes
-- to zero, it will be replaced by a (very small) number of your choosing,
-- typically 1e-30 or so (this modification was proposed by Thompson and
-- Barnett).
--
-- Additionally splits the resulting list of convergents into sublists,
-- starting a new list every time the \'modification\' is invoked.
modifiedLentz :: Fractional a => a -> CF a -> [[a]]
modifiedLentz z (CF  b0 []) = [[b0]]
modifiedLentz z (GCF b0 []) = [[b0]]
modifiedLentz z (CF  0 (  a  :rest)) = map (map (1 /)) (modifiedLentz z (CF  a rest))
modifiedLentz z (GCF 0 ((a,b):rest)) = map (map (a /)) (modifiedLentz z (GCF b rest))
modifiedLentz z orig
= snd (mapAccumL multSublist b0 (separate cds))
where
(b0, gcf) = asGCF orig
multSublist b0 cds = let xs = scanl (*) b0 cds in (last xs, xs)

cds = zipWith (\(xa,xb) (ya,yb) -> (xa || ya, xb * yb)) cs ds
cs = [reset (b + a/c)    id | (a,b) <- gcf | c <- b0 : map snd cs]
ds = [reset (b + a*d) recip | (a,b) <- gcf | d <- 0  : map snd ds]

-- The sublist breaking is computed secondarily - initially,
-- 'cs' and 'ds' are constructed with this helper function that
-- adds a marker to the list whenever a term of interest goes to 0,
-- while also resetting that term to a small nonzero amount.
-- Then later, 'separate' breaks the list every time it sees one
-- of these markers.
reset x f
| x == 0    = (True,  f z)
| otherwise = (False, f x)

-- |Takes a list of (Bool,a) and breaks it into sublists, starting
-- a new one every time it encounters (True,_).
separate :: [(Bool,a)] -> [[a]]
separate [] = []
separate xs = case break fst xs of
([], x:xs)  -> case separate xs of
[]          -> [[snd x]]
(xs:rest)   -> (snd x:xs):rest
(xs, ys)            -> map snd xs : separate ys

-- |Euler's formula for computing @sum (map product (tail (inits xs)))@.
-- Successive convergents of the resulting 'CF' are successive partial sums
-- in the series.
sumPartialProducts :: Num a => [a] -> CF a
sumPartialProducts [] = cf 0 []
sumPartialProducts (x:xs) = gcf 0 ((x,1):[(negate x, 1 + x) | x <- xs])
```