module Numeric.Tools.Interpolation (
Interpolation(..)
, tabulate
, tabulateFun
, LinearInterp
, linearInterp
, CubicSpline
, cubicSpline
, module Numeric.Tools.Mesh
, defaultInterpSize
, defaultInterpIndex
) where
import Control.Monad.ST (runST)
import Data.Data (Data,Typeable)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
import Control.Monad.Numeric
import Numeric.Classes.Indexing
import Numeric.Tools.Mesh
class ( IndexVal (interp mesh) ~ (Double,Double), Indexable (interp mesh)
, IndexVal mesh ~ Double, Mesh mesh
) => Interpolation interp mesh where
at :: interp mesh -> Double -> Double
unsafeTabulate :: (G.Vector v Double) => mesh -> v Double -> interp mesh
interpolationMesh :: interp mesh -> mesh
interpolationTable :: interp mesh -> U.Vector Double
tabulateFun :: (Interpolation i m) => m -> (Double -> Double) -> i m
tabulateFun mesh f = unsafeTabulate mesh $ U.generate (size mesh) (f . unsafeIndex mesh)
tabulate :: (Interpolation i m, G.Vector v Double) => m -> v Double -> i m
tabulate mesh tbl
| size mesh /= G.length tbl = error "Numeric.Tools.Interpolation.tabulate: size of vector and mesh do not match"
| otherwise = unsafeTabulate mesh tbl
data LinearInterp mesh = LinearInterp
{ linearInterpMesh :: mesh
, linearInterpTable :: U.Vector Double
}
deriving (Show,Eq,Data,Typeable)
linearInterp :: LinearInterp mesh -> LinearInterp mesh
linearInterp = id
instance (Mesh mesh, IndexVal mesh ~ Double) => Indexable (LinearInterp mesh) where
type IndexVal (LinearInterp mesh) = (IndexVal mesh, Double)
size = defaultInterpSize
unsafeIndex = defaultInterpIndex
instance (Mesh mesh, IndexVal mesh ~ Double) => Interpolation LinearInterp mesh where
at = linearInterpolation
unsafeTabulate mesh tbl = LinearInterp mesh (G.convert tbl)
interpolationMesh = linearInterpMesh
interpolationTable = linearInterpTable
linearInterpolation :: (Mesh a, IndexVal a ~ Double) => LinearInterp a -> Double -> Double
linearInterpolation tbl@(LinearInterp mesh _) x = a + (x xa) / (xb xa) * (b a)
where
i = safeFindIndex mesh x
(xa,a) = unsafeIndex tbl i
(xb,b) = unsafeIndex tbl (i+1)
data CubicSpline a = CubicSpline { cubicSplineMesh :: a
, cubicSplineTable :: U.Vector Double
, _cubicSplineY2 :: U.Vector Double
}
deriving (Eq,Show,Data,Typeable)
cubicSpline :: CubicSpline a -> CubicSpline a
cubicSpline = id
instance (Mesh mesh, IndexVal mesh ~ Double) => Indexable (CubicSpline mesh) where
type IndexVal (CubicSpline mesh) = (IndexVal mesh, Double)
size = defaultInterpSize
unsafeIndex = defaultInterpIndex
instance (Mesh mesh, IndexVal mesh ~ Double) => Interpolation CubicSpline mesh where
at (CubicSpline mesh ys y2) x = y
where
i = safeFindIndex mesh x
xa = unsafeIndex mesh i
xb = unsafeIndex mesh (i+1)
ya = unsafeIndex ys i
yb = unsafeIndex ys (i+1)
da = unsafeIndex y2 i
db = unsafeIndex y2 (i+1)
h = xb xa
a = (xb x ) / h
b = (x xa) / h
y = a * ya + b * yb
+ ((a*a*a a) * da + (b*b*b b) * db) * (h * h) / 6
unsafeTabulate mesh tbl = makeCubicSpline mesh (G.convert tbl)
interpolationMesh = cubicSplineMesh
interpolationTable = cubicSplineTable
makeCubicSpline :: (IndexVal a ~ Double, Mesh a) => a -> U.Vector Double -> CubicSpline a
makeCubicSpline xs ys = runST $ do
let n = size ys
y2 <- M.new n
u <- M.new n
M.write y2 0 0.0
M.write u 0 0.0
for 1 (n1) $ \i -> do
yVal <- M.read y2 (i1)
uVal <- M.read u (i1)
let sig = delta xs i / delta xs (i+1)
p = sig * yVal + 2
u' = delta ys (i+1) / delta xs (i+1) delta ys i / delta xs i
M.write y2 i $ (sig 1) / p
M.write u i $ (6 * u' / (xs ! (i+1) xs ! (i1)) sig * uVal) / p
M.write y2 (n1) 0.0
forGen (n2) (>= 0) pred $ \i -> do
uVal <- M.read u i
yVal <- M.read y2 i
yVal1 <- M.read y2 (i+1)
M.write y2 i $ yVal * yVal1 + uVal
y2' <- G.unsafeFreeze y2
return (CubicSpline xs ys y2')
delta :: (Num (IndexVal a), Indexable a) => a -> Int -> IndexVal a
delta tbl i = (tbl ! i) (tbl ! (i 1))
safeFindIndex :: Mesh a => a -> Double -> Int
safeFindIndex mesh x =
case meshFindIndex mesh x of
i | i < 0 -> 0
| i > n -> n
| otherwise -> i
where
n = size mesh 2
defaultInterpSize :: Interpolation i m => i m -> Int
defaultInterpSize = U.length . interpolationTable
defaultInterpIndex :: Interpolation i m => i m -> Int -> (Double, Double)
defaultInterpIndex tbl i = ( unsafeIndex (interpolationMesh tbl) i
, unsafeIndex (interpolationTable tbl) i
)