{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveGeneric #-}
module Numeric.SGD.AdaDelta
( Config(..)
, adaDelta
) where
import GHC.Generics (Generic)
import Prelude hiding (div)
import Data.Default
import qualified Pipes as P
import Numeric.SGD.Type
import Numeric.SGD.ParamSet
data Config = Config
{ decay :: Double
, eps :: Double
} deriving (Show, Eq, Ord, Generic)
instance Default Config where
def = Config
{ decay = 0.9
, eps = 1.0e-6
}
adaDelta
:: (Monad m, ParamSet p)
=> Config
-> (e -> p -> p)
-> SGD m e p
adaDelta Config{..} gradient net0 =
let zr = zero net0
in go (0 :: Integer) zr zr zr net0
where
go k expSqGradPrev expSqDeltaPrev deltaPrev net = do
x <- P.await
let grad = gradient x net
expSqGrad = scale decay expSqGradPrev
`add` scale (1-decay) (square grad)
rmsGrad = squareRoot (pmap (+eps) expSqGrad)
expSqDelta = scale decay expSqDeltaPrev
`add` scale (1-decay) (square deltaPrev)
rmsDelta = squareRoot (pmap (+eps) expSqDelta)
delta = (rmsDelta `mul` grad) `div` rmsGrad
newNet = net `sub` delta
P.yield newNet
go (k+1) expSqGrad expSqDelta delta newNet
scale :: ParamSet p => Double -> p -> p
scale x = pmap (*x)
{-# INLINE scale #-}
squareRoot :: ParamSet p => p -> p
squareRoot = pmap sqrt
{-# INLINE squareRoot #-}
square :: ParamSet p => p -> p
square x = x `mul` x
{-# INLINE square #-}