{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveGeneric #-}


-- | Provides the `adaDelta` function which implements the AdaDelta algorithm
-- as described in the following paper:
--
--     * https://arxiv.org/pdf/1212.5701.pdf


module Numeric.SGD.AdaDelta
  ( Config(..)
  , adaDelta
  ) where


import           GHC.Generics (Generic)

import           Prelude hiding (div)
-- import           Control.Monad (when)

import           Data.Default

import qualified Pipes as P

import           Numeric.SGD.Type
import           Numeric.SGD.ParamSet
-- import           Numeric.SGD.Args


-- | AdaDelta configuration
data Config = Config
  { decay :: Double
    -- ^ Exponential decay parameter
  , eps   :: Double
    -- ^ Epsilon value
  } deriving (Show, Eq, Ord, Generic)

instance Default Config where
  def = Config
    { decay = 0.9
    , eps = 1.0e-6
    }


-- | Perform gradient descent using the AdaDelta algorithm.  
-- See "Numeric.SGD.AdaDelta" for more information.
adaDelta
  :: (Monad m, ParamSet p)
  => Config
    -- ^ AdaDelta configuration
  -> (e -> p -> p)
    -- ^ Gradient on a training element
  -> SGD m e p
adaDelta Config{..} gradient net0 =

  let zr = zero net0
   in go (0 :: Integer) zr zr zr net0

  where

    go k expSqGradPrev expSqDeltaPrev deltaPrev net = do
      x <- P.await
      let grad = gradient x net
          expSqGrad = scale decay expSqGradPrev
                `add` scale (1-decay) (square grad)
          rmsGrad = squareRoot (pmap (+eps) expSqGrad)
          expSqDelta = scale decay expSqDeltaPrev
                 `add` scale (1-decay) (square deltaPrev)
          rmsDelta = squareRoot (pmap (+eps) expSqDelta)
          delta = (rmsDelta `mul` grad) `div` rmsGrad
          newNet = net `sub` delta
      P.yield newNet
      go (k+1) expSqGrad expSqDelta delta newNet


-------------------------------
-- Utils
-------------------------------


-- | Scaling
scale :: ParamSet p => Double -> p -> p
scale x = pmap (*x)
{-# INLINE scale #-}


-- | Root square
squareRoot :: ParamSet p => p -> p
squareRoot = pmap sqrt
{-# INLINE squareRoot #-}


-- | Square
square :: ParamSet p => p -> p
square x = x `mul` x
{-# INLINE square #-}