{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Numeric.EMD (
emd
, emdTrace
, emd'
, EMD(..)
, EMDOpts(..), defaultEO, BoundaryHandler(..), SiftCondition(..), defaultSC, SplineEnd(..)
, sift, SiftResult(..)
, envelopes
) where
import Control.Monad
import Control.Monad.IO.Class
import Data.Default.Class
import Data.Finite
import Data.Functor.Identity
import GHC.Generics (Generic)
import GHC.TypeNats
import Numeric.EMD.Internal.Extrema
import Numeric.EMD.Internal.Spline
import Text.Printf
import qualified Data.Binary as Bi
import qualified Data.Map as M
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Sized as SVG
data EMDOpts a = EO { eoSiftCondition :: SiftCondition a
, eoSplineEnd :: SplineEnd a
, eoBoundaryHandler :: Maybe BoundaryHandler
}
deriving (Show, Eq, Ord, Generic)
data BoundaryHandler
= BHClamp
| BHSymmetric
deriving (Show, Eq, Ord, Generic)
instance Bi.Binary BoundaryHandler
instance Bi.Binary a => Bi.Binary (EMDOpts a)
defaultEO :: Fractional a => EMDOpts a
defaultEO = EO { eoSiftCondition = defaultSC
, eoSplineEnd = SENatural
, eoBoundaryHandler = Just BHSymmetric
}
instance Fractional a => Default (EMDOpts a) where
def = defaultEO
data SiftCondition a
= SCStdDev !a
| SCTimes !Int
| SCOr (SiftCondition a) (SiftCondition a)
| SCAnd (SiftCondition a) (SiftCondition a)
deriving (Show, Eq, Ord, Generic)
instance Bi.Binary a => Bi.Binary (SiftCondition a)
instance Fractional a => Default (SiftCondition a) where
def = defaultSC
defaultSC :: Fractional a => SiftCondition a
defaultSC = SCStdDev 0.3 `SCOr` SCTimes 50
testCondition
:: (VG.Vector v a, Fractional a, Ord a)
=> SiftCondition a
-> Int
-> SVG.Vector v n a
-> SVG.Vector v n a
-> Bool
testCondition tc i v v' = go tc
where
sd = SVG.sum $ SVG.zipWith (\x x' -> (x-x')^(2::Int) / (x^(2::Int) + eps)) v v'
go = \case
SCStdDev t -> sd <= t
SCTimes l -> i >= l
SCOr f g -> go f || go g
SCAnd f g -> go f && go g
eps = 0.0000001
data EMD v n a = EMD { emdIMFs :: ![SVG.Vector v n a]
, emdResidual :: !(SVG.Vector v n a)
}
deriving (Show, Generic, Eq, Ord)
instance (VG.Vector v a, KnownNat n, Bi.Binary (v a)) => Bi.Binary (EMD v n a) where
put EMD{..} = Bi.put (SVG.fromSized <$> emdIMFs)
*> Bi.put (SVG.fromSized emdResidual)
get = do
Just emdIMFs <- traverse SVG.toSized <$> Bi.get
Just emdResidual <- SVG.toSized <$> Bi.get
pure EMD{..}
emd :: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> EMDOpts a
-> SVG.Vector v (n + 1) a
-> EMD v (n + 1) a
emd eo = runIdentity . emd' (const (pure ())) eo
emdTrace
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a, MonadIO m)
=> EMDOpts a
-> SVG.Vector v (n + 1) a
-> m (EMD v (n + 1) a)
emdTrace = emd' $ \case
SRResidual _ -> liftIO $ putStrLn "Residual found."
SRIMF _ i -> liftIO $ printf "IMF found (%d sifts)\n" i
emd'
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a, Applicative m)
=> (SiftResult v (n + 1) a -> m r)
-> EMDOpts a
-> SVG.Vector v (n + 1) a
-> m (EMD v (n + 1) a)
emd' cb eo = go id
where
go !imfs !v = cb res *> case res of
SRResidual r -> pure $ EMD (imfs []) r
SRIMF v' _ -> go (imfs . (v':)) (v - v')
where
res = sift eo v
data SiftResult v n a = SRResidual !(SVG.Vector v n a)
| SRIMF !(SVG.Vector v n a) !Int
sift
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> EMDOpts a
-> SVG.Vector v (n + 1) a
-> SiftResult v (n + 1) a
sift EO{..} = go 1
where
go !i !v = case sift' eoSplineEnd eoBoundaryHandler v of
Nothing -> SRResidual v
Just !v'
| testCondition eoSiftCondition i v v' -> SRIMF v' i
| otherwise -> go (i + 1) v'
sift'
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> SplineEnd a
-> Maybe BoundaryHandler
-> SVG.Vector v (n + 1) a
-> Maybe (SVG.Vector v (n + 1) a)
sift' se bh v = go <$> envelopes se bh v
where
go (mins, maxs) = SVG.zipWith3 (\x mi ma -> x - (mi + ma)/2) v mins maxs
envelopes
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> SplineEnd a
-> Maybe BoundaryHandler
-> SVG.Vector v (n + 1) a
-> Maybe (SVG.Vector v (n + 1) a, SVG.Vector v (n + 1) a)
envelopes se bh xs = do
when (bh == Just BHClamp) $ do
guard (M.size mins > 1)
guard (M.size maxs > 1)
(,) <$> splineAgainst se emin mins
<*> splineAgainst se emax maxs
where
(mins,maxs) = extrema xs
(emin,emax) = case bh of
Nothing -> mempty
Just bh' -> extendExtrema xs bh' (mins,maxs)
extendExtrema
:: forall v n a. (VG.Vector v a, KnownNat n)
=> SVG.Vector v (n + 1) a
-> BoundaryHandler
-> (M.Map (Finite (n + 1)) a, M.Map (Finite (n + 1)) a)
-> (M.Map Int a, M.Map Int a)
extendExtrema xs = \case
BHClamp -> const (firstLast, firstLast)
BHSymmetric -> \(mins, maxs) ->
let addFirst = case (flippedMin, flippedMax) of
(Nothing , Nothing ) -> mempty
(Just (_,mn) , Nothing ) -> (mn , firstPoint)
(Nothing , Just (_,mx) ) -> (firstPoint, mx )
(Just (mni,mn), Just (mxi,mx))
| mni < mxi -> (mn , firstPoint)
| otherwise -> (firstPoint, mx )
where
flippedMin = flip fmap (M.lookupMin mins) $ \(minIx, minVal) ->
(minIx, M.singleton (negate (fromIntegral minIx)) minVal)
flippedMax = flip fmap (M.lookupMin maxs) $ \(maxIx, maxVal) ->
(maxIx, M.singleton (negate (fromIntegral maxIx)) maxVal)
addLast = case (flippedMin, flippedMax) of
(Nothing , Nothing ) -> mempty
(Just (_,mn) , Nothing ) -> (mn , lastPoint )
(Nothing , Just (_,mx) ) -> (lastPoint , mx )
(Just (mni,mn), Just (mxi,mx))
| mni > mxi -> (mn , lastPoint )
| otherwise -> (lastPoint , mx )
where
flippedMin = flip fmap (M.lookupMax mins) $ \(minIx, minVal) ->
(minIx, M.singleton (extendSym (fromIntegral minIx)) minVal)
flippedMax = flip fmap (M.lookupMax maxs) $ \(maxIx, maxVal) ->
(maxIx, M.singleton (extendSym (fromIntegral maxIx)) maxVal)
in addFirst `mappend` addLast
where
lastIx = fromIntegral $ maxBound @(Finite n)
firstPoint = M.singleton 0 (SVG.head xs)
lastPoint = M.singleton lastIx (SVG.last xs)
firstLast = firstPoint `mappend` lastPoint
extendSym i = 2 * lastIx - i
splineAgainst
:: (VG.Vector v a, KnownNat n, Fractional a, Ord a)
=> SplineEnd a
-> M.Map Int a
-> M.Map (Finite n) a
-> Maybe (SVG.Vector v n a)
splineAgainst se ext = fmap go
. makeSpline se
. mappend (M.mapKeysMonotonic fromIntegral ext)
. M.mapKeysMonotonic fromIntegral
where
go spline = SVG.generate (sampleSpline spline . fromIntegral)