{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}

{-|
Module      : Numeric.Neural.Normalization
Description : normalizing data
Copyright   : (c) Lars Brünjes, 2016
License     : MIT
Maintainer  : brunjlar@gmail.com
Stability   : experimental
Portability : portable

This modules provides utilities for data normalization.
-}

module Numeric.Neural.Normalization
    ( encode1ofN
    , decode1ofN
    , encodeEquiDist
    , decodeEquiDist
    , crossEntropyError
    , white
    , whiten
    ) where

import Control.Arrow
import Data.Proxy
import GHC.TypeLits
import GHC.TypeLits.Witnesses
import Data.MyPrelude
import Data.Utils.Analytic
import Data.Utils.Statistics
import Data.Utils.Traversable
import Data.Utils.Vector
import Numeric.Neural.Model

-- | Provides "1 of @n@" encoding for enumerable types.
--
-- >>> :set -XDataKinds
-- >>> encode1ofN LT :: Vector 3 Int
-- [1,0,0]
--
-- >>> encode1ofN EQ :: Vector 3 Int
-- [0,1,0]
--
-- >>> encode1ofN GT :: Vector 3 Int
-- [0,0,1]
--
encode1ofN :: (Enum a, Num b, KnownNat n) => a -> Vector n b
encode1ofN x = generate $ \i -> if i == fromEnum x then 1 else 0

-- | Provides "1 of @n@" decoding for enumerable types.
--
-- >>> decode1ofN [0.9, 0.3, 0.1 :: Double] :: Ordering
-- LT
--
-- >>> decode1ofN [0.7, 0.8, 0.6 :: Double] :: Ordering
-- EQ
--
-- >>> decode1ofN [0.2, 0.3, 0.8 :: Double] :: Ordering
-- GT
--
decode1ofN :: (Enum a, Num b, Ord b, Foldable f) => f b -> a
decode1ofN = toEnum . fst . maximumBy (compare `on` snd) . zip [0..] . toList

polyhedron :: Floating a => Int -> [[a]]
polyhedron = fst . p

  where

    p 2 = ([[-1], [1]], 2)
    p n = let (xs, d) = p (n - 1)
              y       = sqrt (d * d - 1)
              v       = y : replicate (n - 2) 0
              xs'     = v : ((0 :) <$> xs)
              shift   = y / fromIntegral n
              shifted = (\(z : zs) -> (z - shift : zs)) <$> xs'
              scale   = 1 / (y - shift)
              scaled  = ((scale *) <$>) <$> shifted
          in  (scaled, d * scale)

polyhedron' :: forall a n. (Floating a, KnownNat n) => Proxy n -> [[a]]
polyhedron' p = withNatOp (%+) p (Proxy :: Proxy 1) $
    polyhedron (fromIntegral $ natVal (Proxy :: Proxy (n + 1)))

-- | Provides equidistant encoding for enumerable types.
--
-- >>> :set -XDataKinds
-- >>> encodeEquiDist LT :: Vector 2 Float
-- [1.0,0.0]
--
-- >>> encodeEquiDist EQ :: Vector 2 Float
-- [-0.5,-0.86602545]
--
-- >>> encodeEquiDist GT :: Vector 2 Float
-- [-0.5,0.86602545]
--
encodeEquiDist :: forall a b n. (Enum a, Floating b, KnownNat n) => a -> Vector n b
encodeEquiDist x = let ys = polyhedron' (Proxy :: Proxy n)
                       y  = ys !! fromEnum x
                   in  fromJust (fromList y)

-- | Provides equidistant decoding for enumerable types.
--
-- >>> :set -XDataKinds
-- >>> let u = fromJust (fromList [0.9, 0.2]) :: Vector 2 Double
-- >>> decodeEquiDist u :: Ordering
-- LT
--
-- >>> :set -XDataKinds
-- >>> let v = fromJust (fromList [-0.4, -0.5]) :: Vector 2 Double
-- >>> decodeEquiDist v :: Ordering
-- EQ
--
-- >>> :set -XDataKinds
-- >>> let w = fromJust (fromList [0.1, 0.8]) :: Vector 2 Double
-- >>> decodeEquiDist w :: Ordering
-- GT
--
decodeEquiDist :: forall a b n. (Enum a, Ord b, Floating b, KnownNat n) => Vector n b -> a
decodeEquiDist y = let xs  = polyhedron' (Proxy :: Proxy n)
                       xs' = (fromJust . fromList) <$> xs
                       ds  = [(j, sqDiff x y) | (j, x) <- zip [0..] xs']
                       i   = fst $ minimumBy (compare `on` snd) ds
                   in  toEnum i

-- | Computes the cross entropy error (assuming "1 of n" encoding).
--
-- >>> crossEntropyError LT (cons 0.8 (cons 0.1 (cons 0.1 nil))) :: Float
-- 0.22314353
--
-- >>> crossEntropyError EQ (cons 0.8 (cons 0.1 (cons 0.1 nil))) :: Float
-- 2.3025851 
--
crossEntropyError :: (Enum a, Floating b, KnownNat n) => a -> Vector n b -> b
crossEntropyError a ys = negate $ log $ encode1ofN a <%> ys

-- | Function 'white' takes a batch of values (of a specific shape)
--   and computes a normalization function which whitens values of that shape,
--   so that each component has zero mean and unit variance.
--
-- >>> :set -XDataKinds
-- >>> let xss = [cons 1 (cons 1 nil), cons 1 (cons 2 nil), cons 1 (cons 3 nil)] :: [Vector 2 Float]
-- >>> let f   = white xss
-- >>> f <$> xss
-- [[0.0,-1.224745],[0.0,0.0],[0.0,1.224745]]
white :: (Applicative f, Traversable t, Eq a, Floating a) => t (f a) -> f a -> f a
white xss = ((w <$> sequenceA xss) <*>) where

    w xs = case toList xs of
        []  -> id
        xs' -> let (_, m, v) = countMeanVar xs'
                   s         = if v == 0 then 1 else 1 / sqrt v
               in  \x -> (x - m) * s

-- | Modifies a 'Model' by whitening the input before feeding it into the embedded component.
--
whiten :: (Applicative f, Traversable t)
          => Model f g a b c             -- ^ original model 
          -> t b                         -- ^ batch of input data
          -> Model f g a b c             
whiten (Model c e i o) xss = Model c' e i o where

    c' = white xss' ^>> c

    xss' = (fmap fromDouble . i) <$> xss