{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
module Online.AveragesB
( foldB,
foldB',
maB,
absmaB,
sqmaB,
stdB,
std'
) where
import Prelude
import Numeric.Backprop as B
import qualified Prelude.Backprop as PB
import Control.Foldl (Fold(..))
foldB' :: (Backprop a, Backprop b, Reifies s W, Fractional b) => (BVar s a -> BVar s b) -> BVar s b -> BVar s [a] -> BVar s b
foldB' f r xs = divide (PB.foldl' (step' f r) (T2 0 0) xs) where
step' f' r' (T2 s c) a = uncurry T2 ((r'*) $ s + f' a, (r'*) $ c + 1)
divide (T2 s c) = s / c
online :: (Reifies s W, Fractional b) => (BVar s a -> BVar s b) -> (BVar s b -> BVar s b) -> Fold (BVar s a) (BVar s b)
online f g = Fold step begin extract
where
begin = (0, 0)
step (s, c) a = (g $ s + f a, g $ c + 1)
extract (s, c) = s / c
ma' :: (Reifies s W, Fractional b) => BVar s b -> Fold (BVar s b) (BVar s b)
ma' r = online id (*r)
sqma' :: (Reifies s W, Fractional b) => BVar s b -> Fold (BVar s b) (BVar s b)
sqma' r = online (\x -> x * x) (*r)
std' :: (Reifies s W, Floating b) => BVar s b -> Fold (BVar s b) (BVar s b)
std' r = (\s ss -> sqrt (ss - s ** 2)) <$> ma' r <*> sqma' r
foldB :: (Reifies s W) => (BVar s Double -> BVar s Double) -> BVar s Double -> BVar s [Double] -> BVar s Double
foldB f r xs = divide (PB.foldl' (step' f r) (T2 0 0) xs) where
step' f' r' (T2 s c) a = uncurry T2 ((r*) $ s + f' a, (r'*) $ c + 1)
divide (T2 s c) = s / c
stdB :: Reifies s W => BVar s Double -> BVar s [Double] -> BVar s Double
stdB r xs = (\ss s -> sqrt (ss - s ** 2)) (foldB id r xs) (foldB (\x -> x * x) r xs)
maB :: Reifies s W => BVar s Double -> BVar s [Double] -> BVar s Double
maB r = foldB id r
absmaB :: Reifies s W => BVar s Double -> BVar s [Double] -> BVar s Double
absmaB r = foldB abs r
sqmaB :: Reifies s W => BVar s Double -> BVar s [Double] -> BVar s Double
sqmaB r = foldB (\x -> x * x) r