{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveGeneric #-}
module Numeric.SGD.Momentum
( Config(..)
, scaleTau
, momentum
) where
import GHC.Generics (Generic)
import Data.Default
import qualified Pipes as P
import Numeric.SGD.Type
import Numeric.SGD.ParamSet
data Config = Config
{ alpha0 :: Double
, tau :: Double
, gamma :: Double
} deriving (Show, Eq, Ord, Generic)
instance Default Config where
def = Config
{ alpha0 = 0.01
, gamma = 0.9
, tau = 1000
}
scaleTau :: Double -> Config -> Config
scaleTau coef cfg = cfg {tau = coef * tau cfg}
momentum
:: (Monad m, ParamSet p)
=> Config
-> (e -> p -> p)
-> SGD m e p
momentum Config{..} gradient net0 =
go (0 :: Integer) (zero net0) net0
where
alpha k
= (alpha0 * tau)
/ (tau + fromIntegral k)
go k moment net = do
x <- P.await
let grad = scale (alpha k) (gradient x net)
moment' = scale gamma moment `add` grad
newNet = net `sub` moment'
P.yield newNet
go (k+1) moment' newNet
scale :: ParamSet p => Double -> p -> p
scale x = pmap (*x)
{-# INLINE scale #-}