-- |
-- Module:     Control.Wire.Prefab.Analyze
-- Copyright:  (c) 2011 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
--
-- Various signal analysis tools

module Control.Wire.Prefab.Analyze
    ( -- * Statistics
      -- ** Average
      avg,
      avgAll,
      avgFps,
      -- ** Peak
      highPeak,
      lowPeak,
      peakBy,

      -- * Monitoring
      collect,
      diff,
      firstSeen,
      lastSeen
    )
    where

import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Vector.Unboxed as Vu
import qualified Data.Vector.Unboxed.Mutable as Vum
import Control.Arrow
import Control.Monad.Fix
import Control.Monad.ST
import Control.Wire.Classes
import Control.Wire.Prefab.Clock
import Control.Wire.Types
import Data.Map (Map)
import Data.Monoid
import Data.Set (Set)


-- | Calculate the average of the signal over the given number of last
-- samples.  If you need an average over all samples ever produced,
-- consider using 'avgAll' instead.
--
-- * Complexity: O(n) space, O(1) time wrt number of samples.
--
-- * Depends: current instant.

avg :: forall e v (>~). (Arrow (>~), Fractional v, Vu.Unbox v) => Int -> Wire e (>~) v v
avg n = mkPure $ \x -> (Right x, avg' (Vu.replicate n (x/d)) x 0)
    where
    avg' :: Vu.Vector v -> v -> Int -> Wire e (>~) v v
    avg' samples' s' cur' =
        mkPure $ \((/d) -> x) ->
            let cur = let cur = succ cur' in if cur >= n then 0 else cur
                x' = samples' Vu.! cur
                samples =
                    x' `seq` runST $ do
                        s <- Vu.unsafeThaw samples'
                        Vum.write s cur x
                        Vu.unsafeFreeze s
                s = s' - x' + x
            in cur `seq` s' `seq` (Right s, avg' samples s cur)

    d :: v
    d = realToFrac n


-- | Calculate the average of the signal over all samples.
--
-- Please note that somewhat surprisingly this wire runs in constant
-- space and is generally faster than 'avg', but most applications will
-- benefit from averages over only the last few samples.
--
-- * Depends: current instant.

avgAll :: forall e v (>~). (Arrow (>~), Fractional v) => Wire e (>~) v v
avgAll = mkPure $ \x -> (Right x, avgAll' 1 x)
    where
    avgAll' :: v -> v -> Wire e (>~) v v
    avgAll' n' a' =
        mkPure $ \x ->
            let n = n' + 1
                a = a' - a'/n + x/n
            in a' `seq` (Right a, avgAll' n a)


-- | Calculate the average number of frames per virtual second for the
-- last given number of frames.
--
-- Please note that this wire uses the clock from the 'ArrowClock'
-- instance for the underlying arrow.  If this clock doesn't represent
-- real time, then the output of this wire won't either.

avgFps ::
    (ArrowChoice (>~), ArrowClock (>~), Fractional t, Time (>~) ~ t, Vu.Unbox t)
    => Int
    -> Wire e (>~) a t
avgFps n = recip ^<< avg n <<< dtime


-- | Collects all distinct inputs ever received.
--
-- * Complexity: O(n) space, O(log n) time wrt collected inputs so far.
--
-- * Depends: current instant.

collect :: forall b e (>~). Ord b => Wire e (>~) b (Set b)
collect = collect' S.empty
    where
    collect' :: Set b -> Wire e (>~) b (Set b)
    collect' ins' =
        mkPure $ \x ->
            let ins = S.insert x ins'
            in (Right ins, collect' ins)


-- | Outputs the last input value on every change of the input signal.
-- Acts like the identity wire at the first instant.
--
-- * Depends: current instant.
--
-- * Inhibits: on no change after the first instant.

diff :: forall b e (>~). (Eq b, Monoid e) => Wire e (>~) b b
diff =
    mkPure $ \x -> (Right x, diff' x)

    where
    diff' :: b -> Wire e (>~) b b
    diff' x' =
        mkPure $ \x ->
            if x' == x
              then (Left mempty, diff' x')
              else (Right x', diff' x)


-- | Reports the first time the given input was seen.
--
-- * Complexity: O(n) space, O(log n) time wrt collected inputs so far.
--
-- * Depends: Current instant.

firstSeen ::
    forall a e t (>~). (ArrowChoice (>~), ArrowClock (>~), Monoid e, Ord a, Time (>~) ~ t)
    => Wire e (>~) a t
firstSeen = firstSeen' M.empty
    where
    firstSeen' :: Map a t -> Wire e (>~) a t
    firstSeen' xs' =
        fix $ \again ->
        mkGen $ proc x' -> do
            case M.lookup x' xs' of
              Just t  -> returnA -< (Right t, again)
              Nothing -> do
                  t <- arrTime -< ()
                  returnA -< (Right t, firstSeen' (M.insert x' t xs'))


-- | Outputs the high peak of the input signal.
--
-- * Depends: Current instant.

highPeak :: Ord b => Wire e (>~) b b
highPeak = peakBy compare


-- | Reports the last time the given input was seen.  Inhibits when
-- seeing a signal for the first time.
--
-- * Complexity: O(n) space, O(log n) time wrt collected inputs so far.
--
-- * Depends: Current instant.
--
-- * Inhibits: On first sight of a signal.

lastSeen ::
    forall a e t (>~). (ArrowClock (>~), Monoid e, Ord a, Time (>~) ~ t)
    => Wire e (>~) a t
lastSeen = lastSeen' M.empty
    where
    lastSeen' :: Map a t -> Wire e (>~) a t
    lastSeen' xs' =
        mkGen $ proc x' -> do
            t <- arrTime -< ()
            let xs = M.insert x' t xs'
            returnA -< (maybe (Left mempty) Right $ M.lookup x' xs',
                        lastSeen' xs)


-- | Outputs the low peak of the input signal.
--
-- * Depends: Current instant.

lowPeak :: Ord b => Wire e (>~) b b
lowPeak = peakBy (flip compare)


-- | Outputs the high peak of the input signal with respect to the given
-- comparison function.
--
-- * Depends: Current instant.

peakBy :: forall b e (>~). (b -> b -> Ordering) -> Wire e (>~) b b
peakBy comp = mkPure (Right &&& peakBy')
    where
    peakBy' :: b -> Wire e (>~) b b
    peakBy' x'' =
        mkPure $ \x' ->
            Right &&& peakBy' $ if comp x' x'' == GT then x' else x''