module Math.Polynomial.Bernstein
( bernstein
, evalBernstein
, bernsteinFit
, evalBernsteinSeries
, deCasteljau
, splitBernsteinSeries
) where
import Math.Polynomial
import Data.List
bernstein :: [[Poly Integer]]
bernstein =
[ [ scalePoly nCv p `multPoly` q
| q <- reverse qs
| p <- ps
| nCv <- bico
]
| ps <- tail $ inits [poly BE (1 : zs) | zs <- inits (repeat 0)]
| qs <- tail $ inits (iterate (multPoly (poly LE [1,1])) one)
| bico <- ptri
]
where
ptri = [1] : [ 1 : zipWith (+) row (tail row) ++ [1] | row <- ptri]
evalBernstein :: (Integral a, Num b) => a -> a -> b -> b
evalBernstein n v t
| n < 0 || v > n = 0
| otherwise = fromInteger nCv * t^v * (1t)^(nv)
where
n' = toInteger n
v' = toInteger v
nCv = product [1..n'] `div` (product [1..v'] * product [1..n'v'])
bernsteinFit :: (Fractional b, Integral a) => a -> (b -> b) -> [b]
bernsteinFit n f = [f (fromIntegral v / fromIntegral n) | v <- [0..n]]
evalBernsteinSeries :: Num a => [a] -> a -> a
evalBernsteinSeries [] = const 0
evalBernsteinSeries cs = head . last . deCasteljau cs
deCasteljau :: Num a => [a] -> a -> [[a]]
deCasteljau cs t = takeWhile (not.null) table
where
table = cs :
[ [ b_i * (1t) + b_ip1 * t
| b_i:b_ip1:_ <- tails row
]
| row <- table
]
splitBernsteinSeries :: Num a => [a] -> a -> ([a], [a])
splitBernsteinSeries cs t = (map head betas, map last (reverse betas))
where
betas = deCasteljau cs t