module Math.BernsteinPoly
(BernsteinPoly(..), bernsteinSubsegment, listToBernstein, zeroPoly,
(~*), (*~), (~+), (~-), degreeElevate, bernsteinSplit, bernsteinEval,
bernsteinEvalDeriv, binCoeff, convolve, bernsteinEvalDerivs, bernsteinDeriv)
where
import Data.Vector.Unboxed as V
import Data.Vector.Unboxed.Mutable as M
import qualified Data.Vector as B
data BernsteinPoly a = BernsteinPoly {
bernsteinCoeffs :: V.Vector a}
deriving Show
data ScaledPoly a = ScaledPoly {
scaledCoeffs :: V.Vector a }
deriving Show
infixl 7 ~*, *~
infixl 6 ~+, ~-
toScaled :: (Unbox a, Num a) => BernsteinPoly a -> ScaledPoly a
toScaled (BernsteinPoly v) =
ScaledPoly $
V.zipWith (*) v $ binCoeff $ V.length v 1
fromScaled :: (Unbox a, Fractional a) => ScaledPoly a -> BernsteinPoly a
fromScaled (ScaledPoly v) =
BernsteinPoly $
V.zipWith (/) v $ binCoeff $ V.length v 1
listToBernstein :: (Unbox a, Num a) => [a] -> BernsteinPoly a
listToBernstein [] = zeroPoly
listToBernstein l = BernsteinPoly $ V.fromList l
zeroPoly :: (Num a, Unbox a) => BernsteinPoly a
zeroPoly = BernsteinPoly $ V.fromList [0]
bernsteinSubsegment :: (Unbox a, Ord a, Fractional a) =>
BernsteinPoly a -> a -> a -> BernsteinPoly a
bernsteinSubsegment b t1 t2
| t1 > t2 = bernsteinSubsegment b t2 t1
| otherwise = snd $ flip bernsteinSplit (t1/t2) $
fst $ bernsteinSplit b t2
convolve :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a
convolve x h = V.create $ do
let xN = V.length x
let hN = V.length h
let xIndices = V.enumFromN 0 xN
let hIndices = V.enumFromN 0 hN
xM <- V.unsafeThaw x
hM <- V.unsafeThaw h
yM <- M.replicate (xN + hN 1) 0
V.forM_ xIndices $ \i -> do
a <- M.unsafeRead xM i
V.forM_ hIndices $ \j -> do
b <- M.unsafeRead hM j
M.unsafeModify yM (+ a * b) (i + j)
return yM
(~*) :: (Unbox a, Fractional a) =>
BernsteinPoly a -> BernsteinPoly a -> BernsteinPoly a
(toScaled -> a) ~* (toScaled -> b) =
fromScaled $ mulScaled a b
mulScaled :: (Unbox a, Num a) => ScaledPoly a -> ScaledPoly a -> ScaledPoly a
mulScaled (ScaledPoly a) (ScaledPoly b) =
ScaledPoly $ convolve a b
binCoeff :: (Num a, Unbox a) => Int -> V.Vector a
binCoeff n = V.map fromIntegral $
V.scanl (\x m -> x * (nm+1) `quot` m)
1 (V.enumFromN 1 n)
degreeElevateScaled :: (Unbox a, Num a)
=> ScaledPoly a -> Int -> ScaledPoly a
degreeElevateScaled b@(ScaledPoly p) times
| times <= 0 = b
| otherwise = ScaledPoly $ convolve (binCoeff times) p
degreeElevate :: (Unbox a, Fractional a)
=> BernsteinPoly a -> Int -> BernsteinPoly a
degreeElevate (toScaled -> b) times =
fromScaled (degreeElevateScaled b times)
bernsteinEval :: (Unbox a, Fractional a)
=> BernsteinPoly a -> a -> a
bernsteinEval (BernsteinPoly v) _
| V.length v == 0 = 0
bernsteinEval (BernsteinPoly v) _
| V.length v == 1 = V.unsafeHead v
bernsteinEval (BernsteinPoly v) t =
go t (fromIntegral n) (V.unsafeIndex v 0 * u) 1
where u = 1t
n = fromIntegral $ V.length v 1
go !tn !bc !tmp !i
| i == n = tmp + tn*V.unsafeIndex v n
| otherwise =
go (tn*t)
(bc*fromIntegral (ni)/(fromIntegral i + 1))
((tmp + tn*bc*V.unsafeIndex v i)*u)
(i+1)
bernsteinEvalDeriv :: (Unbox t, Fractional t) => BernsteinPoly t -> t -> (t,t)
bernsteinEvalDeriv b@(BernsteinPoly v) t
| V.length v <= 1 = (V.unsafeHead v, 0)
| otherwise = (bernsteinEval b t, bernsteinEval (bernsteinDeriv b) t)
bernsteinEvalDerivs :: (Unbox t, Fractional t) => BernsteinPoly t -> t -> [t]
bernsteinEvalDerivs b@(BernsteinPoly v) t
| V.length v <= 1 = [V.unsafeHead v, 0]
| otherwise = bernsteinEval b t :
bernsteinEvalDerivs (bernsteinDeriv b) t
bernsteinDeriv :: (Unbox a, Num a) => BernsteinPoly a -> BernsteinPoly a
bernsteinDeriv (BernsteinPoly v)
| V.length v == 0 = zeroPoly
bernsteinDeriv (BernsteinPoly v) =
BernsteinPoly $
V.map (* fromIntegral (V.length v 1)) $
V.zipWith () (V.tail v) v
bernsteinSplit :: (Unbox a, Num a) =>
BernsteinPoly a -> a -> (BernsteinPoly a, BernsteinPoly a)
bernsteinSplit (BernsteinPoly v) t =
(BernsteinPoly $ convert $
B.map V.head interpVecs,
BernsteinPoly $ V.reverse $ convert $
B.map V.last $ convert interpVecs)
where
interp a b = (1t)*a + t*b
interpVecs = B.iterateN (V.length v) interpVec v
interpVec v2 = V.zipWith interp v2 (V.tail v2)
addScaled :: (Unbox a, Num a) => ScaledPoly a -> ScaledPoly a -> ScaledPoly a
addScaled ba@(ScaledPoly a) bb@(ScaledPoly b)
| la < lb = ScaledPoly $
V.zipWith (+) (scaledCoeffs $ degreeElevateScaled ba $ lbla) b
| la > lb = ScaledPoly $
V.zipWith (+) a (scaledCoeffs $ degreeElevateScaled bb $ lalb)
| otherwise = ScaledPoly $ V.zipWith (+) a b
where la = V.length a
lb = V.length b
(~+) :: (Unbox a, Fractional a) =>
BernsteinPoly a -> BernsteinPoly a -> BernsteinPoly a
(toScaled -> a) ~+ (toScaled -> b) = fromScaled $ addScaled a b
subScaled :: (Unbox a, Num a) => ScaledPoly a -> ScaledPoly a -> ScaledPoly a
subScaled ba@(ScaledPoly a) bb@(ScaledPoly b)
| la < lb = ScaledPoly $
V.zipWith () (scaledCoeffs $ degreeElevateScaled ba $ lbla) b
| la > lb = ScaledPoly $
V.zipWith () a (scaledCoeffs $ degreeElevateScaled bb $ lalb)
| otherwise = ScaledPoly $ V.zipWith () a b
where la = V.length a
lb = V.length b
(~-) :: (Unbox a, Fractional a) =>
BernsteinPoly a -> BernsteinPoly a -> BernsteinPoly a
(toScaled -> a) ~- (toScaled -> b) = fromScaled $ subScaled a b
(*~) :: (Unbox a, Num a) => a -> BernsteinPoly a -> BernsteinPoly a
a *~ (BernsteinPoly v) = BernsteinPoly (V.map (*a) v)