{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Numeric.EMD (
emd
, emdTrace
, emd'
, EMD(..)
, EMDOpts(..), defaultEO, SiftCondition(..), defaultSC, SplineEnd(..)
, sift, SiftResult(..)
, envelopes
) where
import Control.Monad.IO.Class
import Data.Finite
import Data.Functor.Identity
import GHC.TypeNats
import Numeric.EMD.Internal.Extrema
import Numeric.EMD.Internal.Spline
import Text.Printf
import qualified Data.Map as M
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Sized as SVG
data EMDOpts a = EO { eoSiftCondition :: SiftCondition a
, eoSplineEnd :: SplineEnd
, eoClampEnvelope :: Bool
}
deriving (Show, Eq, Ord)
defaultEO :: Fractional a => EMDOpts a
defaultEO = EO { eoSiftCondition = defaultSC
, eoSplineEnd = SENotAKnot
, eoClampEnvelope = True
}
data SiftCondition a = SCStdDev a
| SCTimes Int
| SCOr (SiftCondition a) (SiftCondition a)
| SCAnd (SiftCondition a) (SiftCondition a)
deriving (Show, Eq, Ord)
defaultSC :: Fractional a => SiftCondition a
defaultSC = SCStdDev 0.3
testCondition
:: (VG.Vector v a, Fractional a, Ord a)
=> SiftCondition a
-> Int
-> SVG.Vector v n a
-> SVG.Vector v n a
-> Bool
testCondition = \case
SCStdDev t -> \_ v v' ->
let sd = SVG.sum $ SVG.zipWith (\x x' -> (x-x')^(2::Int) / (x^(2::Int) + eps)) v v'
in sd <= t
SCTimes l -> \i _ _ -> i >= l
SCOr f g -> \i v v' -> testCondition f i v v' || testCondition g i v v'
SCAnd f g -> \i v v' -> testCondition f i v v' && testCondition g i v v'
where
eps = 0.0000001
data EMD v n a = EMD { emdIMFs :: ![SVG.Vector v n a]
, emdResidual :: !(SVG.Vector v n a)
}
deriving Show
emd :: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> EMDOpts a
-> SVG.Vector v (n + 1) a
-> EMD v (n + 1) a
emd eo = runIdentity . emd' (const (pure ())) eo
emdTrace
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a, MonadIO m)
=> EMDOpts a
-> SVG.Vector v (n + 1) a
-> m (EMD v (n + 1) a)
emdTrace = emd' $ \case
SRResidual _ -> liftIO $ putStrLn "Residual found."
SRIMF _ i -> liftIO $ printf "IMF found (%d iterations)\n" i
emd'
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a, Applicative m)
=> (SiftResult v (n + 1) a -> m r)
-> EMDOpts a
-> SVG.Vector v (n + 1) a
-> m (EMD v (n + 1) a)
emd' cb eo = go id
where
go !imfs !v = cb res *> case res of
SRResidual r -> pure $ EMD (imfs []) r
SRIMF v' _ -> go (imfs . (v':)) (v - v')
where
res = sift eo v
data SiftResult v n a = SRResidual !(SVG.Vector v n a)
| SRIMF !(SVG.Vector v n a) !Int
sift
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> EMDOpts a
-> SVG.Vector v (n + 1) a
-> SiftResult v (n + 1) a
sift EO{..} = go 1
where
go !i !v = case sift' eoSplineEnd eoClampEnvelope v of
Nothing -> SRResidual v
Just !v'
| testCondition eoSiftCondition i v v' -> SRIMF v' i
| otherwise -> go (i + 1) v'
sift'
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> SplineEnd
-> Bool
-> SVG.Vector v (n + 1) a
-> Maybe (SVG.Vector v (n + 1) a)
sift' se cl v = go <$> envelopes se cl v
where
go (mins, maxs) = SVG.zipWith3 (\x mi ma -> x - (mi + ma)/2) v mins maxs
envelopes
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> SplineEnd
-> Bool
-> SVG.Vector v (n + 1) a
-> Maybe (SVG.Vector v (n + 1) a, SVG.Vector v (n + 1) a)
envelopes se cl xs = (,) <$> splineAgainst se mins'
<*> splineAgainst se maxs'
where
minMax = M.fromList [(minBound, SVG.head xs), (maxBound, SVG.last xs)]
(mins,maxs) = extrema xs
(mins', maxs')
| cl = (mins `M.union` minMax, maxs `M.union` minMax)
| otherwise = (mins, maxs)
splineAgainst
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> SplineEnd
-> M.Map (Finite n) a
-> Maybe (SVG.Vector v n a)
splineAgainst se = fmap go . makeSpline se . M.mapKeysMonotonic fromIntegral
where
go spline = SVG.generate (sampleSpline spline . fromIntegral)