{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveGeneric #-}
module Numeric.SGD.Adam
( Config(..)
, adam
) where
import GHC.Generics (Generic)
import Prelude hiding (div)
import Data.Default
import qualified Pipes as P
import Numeric.SGD.Type
import Numeric.SGD.ParamSet
data Config = Config
{ alpha :: Double
, beta1 :: Double
, beta2 :: Double
, eps :: Double
} deriving (Show, Eq, Ord, Generic)
instance Default Config where
def = Config
{ alpha = 0.001
, beta1 = 0.9
, beta2 = 0.999
, eps = 1.0e-8
}
adam
:: (Monad m, ParamSet p)
=> Config
-> (e -> p -> p)
-> SGD m e p
adam Config{..} gradient net0 =
let zr = zero net0
in go (1 :: Integer) zr zr net0
where
go t m v net = do
x <- P.await
let g = gradient x net
m' = pmap (*beta1) m `add` pmap (*(1-beta1)) g
v' = pmap (*beta2) v `add` pmap (*(1-beta2)) (g `mul` g)
mb = pmap (/(1-beta1^t)) m'
vb = pmap (/(1-beta2^t)) v'
newNet = net `sub`
( pmap (*alpha) mb `div`
(pmap (+eps) (pmap sqrt vb))
)
P.yield newNet
go (t+1) m' v' newNet