{-# LANGUAGE ParallelListComp #-}
module Math.Polynomial.Bernstein
    ( bernstein
    , evalBernstein
    , bernsteinFit
    , evalBernsteinSeries
    , deCasteljau
    , splitBernsteinSeries
    ) where

import Math.Polynomial
import Data.List

-- |The Bernstein basis polynomials.  The @n@th inner list is a basis for 
-- the polynomials of order @n@ or lower.  The @n@th basis consists of @n@
-- polynomials of order @n@ which sum to @1@, and have roots of varying 
-- multiplicities at @0@ and @1@.
bernstein :: [[Poly Integer]]
bernstein = 
    [ [ scalePoly nCv p `multPoly` q
      | q <- reverse qs
      | p <- ps
      | nCv  <- bico
      ]
    | ps <- tail $ inits [poly BE (1 : zs) | zs <- inits (repeat 0)]
    | qs <- tail $ inits (iterate (multPoly (poly LE [1,-1])) one)
    | bico <- ptri
    ]
    where
        -- pascal's triangle
        ptri = [1] : [ 1 : zipWith (+) row (tail row) ++ [1] | row <- ptri]

-- |@evalBernstein n v x@ evaluates the @v@'th Bernstein polynomial of order @n@
-- at the point @x@.
evalBernstein :: (Integral a, Num b) => a -> a -> b -> b
evalBernstein n v t
    | n < 0 || v > n    = 0
    | otherwise         = fromInteger nCv * t^v * (1-t)^(n-v)
    where
        n' = toInteger n
        v' = toInteger v
        nCv = product [1..n'] `div` (product [1..v'] * product [1..n'-v'])

-- |@bernsteinFit n f@: Approximate a function @f@ as a linear combination of
-- Bernstein polynomials of order @n@.  This approximation converges slowly
-- but uniformly to @f@ on the interval [0,1].
bernsteinFit :: (Fractional b, Integral a) => a -> (b -> b) -> [b]
bernsteinFit n f = [f (fromIntegral v / fromIntegral n) | v <- [0..n]]

-- |Evaluate a polynomial given as a list of @n@ coefficients for the @n@th
-- Bernstein basis.  Roughly:
-- 
-- > evalBernsteinSeries cs = sum (zipWith scalePoly cs (bernstein !! (length cs - 1)))
evalBernsteinSeries :: Num a => [a] -> a -> a
evalBernsteinSeries [] = const 0
evalBernsteinSeries cs = head . last . deCasteljau cs

-- |de Casteljau's algorithm, returning the whole tableau.  Used both for
-- evaluating and splitting polynomials in Bernstein form.
deCasteljau :: Num a => [a] -> a -> [[a]]
deCasteljau cs t = takeWhile (not.null) table
    where
        table = cs : 
            [ [ b_i * (1-t) + b_ip1 * t
              | b_i:b_ip1:_ <- tails row
              ]
            | row <- table
            ]

-- |Given a polynomial in Bernstein form (that is, a list of coefficients
-- for a basis set from 'bernstein', such as is returned by 'bernsteinFit')
-- and a parameter value @x@, split the polynomial into two halves, mapping
-- @[0,x]@ and @[x,1]@ respectively onto @[0,1]@.
--
-- A typical use for this operation would be to split a Bezier curve 
-- (inserting a new knot at @x@).
splitBernsteinSeries :: Num a => [a] -> a -> ([a], [a])
splitBernsteinSeries cs t = (map head betas, map last (reverse betas))
    where
        betas = deCasteljau cs t