{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE TypeFamilies #-} -- | -- Module : Numeric.Tools.Interpolation -- Copyright : (c) 2011 Aleksey Khudyakov -- License : BSD3 -- -- Maintainer : Aleksey Khudyakov -- Stability : experimental -- Portability : portable -- -- Function interpolation. -- -- Sine interpolation using cubic splines: -- -- >>> let tbl = cubicSpline $ tabulateFun (uniformMesh (0,10) 100) sin -- >>> tbl `at` 1.786 -- 0.9769239849844867 module Numeric.Tools.Interpolation ( -- * Type class Interpolation(..) , tabulate -- * Linear interpolation , LinearInterp , linearInterp -- * Cubic splines , CubicSpline , cubicSpline -- , module Numeric.Tools.Mesh ) where import Control.Monad.ST (runST) import Data.Data (Data,Typeable) import qualified Data.Vector.Generic as G import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Unboxed.Mutable as M import Control.Monad.Numeric import Numeric.Classes.Indexing import Numeric.Tools.Mesh ---------------------------------------------------------------- -- | Interpolation for arbitraty meshes class Interpolation a where -- | Interpolate function at some point. Function should not -- fail outside of mesh however it may and most likely will give -- nonsensical results at :: (IndexVal m ~ Double, Mesh m) => a m -> Double -> Double -- | Tabulate function tabulateFun :: (IndexVal m ~ Double, Mesh m) => m -> (Double -> Double) -> a m -- | Use table of already evaluated function and mesh. Sizes of mesh -- and table must coincide but it's not checked. Do not use this -- function use 'tabulate' instead. unsafeTabulate :: (IndexVal m ~ Double, Mesh m, G.Vector v Double) => m -> v Double -> a m -- | Get mesh. interpolationMesh :: a m -> m -- | Get table of function values interpolationTable :: a m -> U.Vector Double -- | Use table of already evaluated function and mesh. Sizes of mesh -- and table must coincide. tabulate :: (Interpolation a, IndexVal m ~ Double, Mesh m, G.Vector v Double) => m -> v Double -> a m tabulate mesh tbl | size mesh /= G.length tbl = error "Numeric.Tools.Interpolation.tabulate: size of vector and mesh do not match" | otherwise = unsafeTabulate mesh tbl {-# INLINE tabulate #-} ---------------------------------------------------------------- -- Linear interpolation ---------------------------------------------------------------- -- | Data for linear interpolation data LinearInterp a = LinearInterp { linearInterpMesh :: a , linearInterpTable :: U.Vector Double } deriving (Show,Eq,Data,Typeable) -- | Function used to fix types linearInterp :: LinearInterp a -> LinearInterp a linearInterp = id instance Mesh a => Indexable (LinearInterp a) where type IndexVal (LinearInterp a) = (IndexVal a, Double) size (LinearInterp _ vec) = size vec unsafeIndex (LinearInterp mesh vec) i = ( unsafeIndex mesh i , unsafeIndex vec i ) {-# INLINE size #-} {-# INLINE unsafeIndex #-} instance Interpolation LinearInterp where at = linearInterpolation tabulateFun mesh f = LinearInterp mesh (U.generate (size mesh) (f . unsafeIndex mesh)) unsafeTabulate mesh tbl = LinearInterp mesh (G.convert tbl) interpolationMesh = linearInterpMesh interpolationTable = linearInterpTable linearInterpolation :: (Mesh a, IndexVal a ~ Double) => LinearInterp a -> Double -> Double linearInterpolation tbl@(LinearInterp mesh _) x = a + (x - xa) / (xb - xa) * (b - a) where i = safeFindIndex mesh x (xa,a) = unsafeIndex tbl i (xb,b) = unsafeIndex tbl (i+1) ---------------------------------------------------------------- -- Cubic splines ---------------------------------------------------------------- -- | Natural cubic splines data CubicSpline a = CubicSpline { cubicSplineMesh :: a , cubicSplineTable :: U.Vector Double , cubicSplineY2 :: U.Vector Double } deriving (Eq,Show,Data,Typeable) -- | Function used to fix types cubicSpline :: CubicSpline a -> CubicSpline a cubicSpline = id instance Interpolation CubicSpline where at (CubicSpline mesh ys y2) x = y where i = safeFindIndex mesh x -- Table lookup xa = unsafeIndex mesh i xb = unsafeIndex mesh (i+1) ya = unsafeIndex ys i yb = unsafeIndex ys (i+1) da = unsafeIndex y2 i db = unsafeIndex y2 (i+1) -- h = xb - xa a = (xb - x ) / h b = (x - xa) / h y = a * ya + b * yb + ((a*a*a - a) * da + (b*b*b - b) * db) * (h * h) / 6 ------ tabulateFun mesh f = makeCubicSpline mesh (U.generate (size mesh) (f . unsafeIndex mesh)) unsafeTabulate mesh tbl = makeCubicSpline mesh (G.convert tbl) interpolationMesh = cubicSplineMesh interpolationTable = cubicSplineTable -- These are natural cubic splines makeCubicSpline :: (IndexVal a ~ Double, Mesh a) => a -> U.Vector Double -> CubicSpline a makeCubicSpline xs ys = runST $ do let n = size ys y2 <- M.new n u <- M.new n M.write y2 0 0.0 M.write u 0 0.0 -- Forward pass for 1 (n-1) $ \i -> do yVal <- M.read y2 (i-1) uVal <- M.read u (i-1) let sig = delta xs i / delta xs (i+1) p = sig * yVal + 2 u' = delta ys (i+1) / delta xs (i+1) - delta ys i / delta xs i M.write y2 i $ (sig - 1) / p M.write u i $ (6 * u' / (xs ! (i+1) - xs ! (i-1)) - sig * uVal) / p -- Backward pass M.write y2 (n-1) 0.0 forGen (n-2) (>= 0) pred $ \i -> do uVal <- M.read u i yVal <- M.read y2 i yVal1 <- M.read y2 (i+1) M.write y2 i $ yVal * yVal1 + uVal -- Done y2' <- G.unsafeFreeze y2 return (CubicSpline xs ys y2') ---------------------------------------------------------------- -- Helpers delta :: (Num (IndexVal a), Indexable a) => a -> Int -> IndexVal a delta tbl i = (tbl ! i) - (tbl ! (i - 1)) {-# INLINE delta #-} safeFindIndex :: Mesh a => a -> Double -> Int safeFindIndex mesh x = case meshFindIndex mesh x of i | i < 0 -> 0 | i > n -> n | otherwise -> i where n = size mesh - 2 {-# INLINE safeFindIndex #-}