{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE DeriveGeneric #-} -- | Provides the `adam` function which implements the Adam algorithm based on -- the paper: -- -- * https://arxiv.org/pdf/1412.6980 module Numeric.SGD.Adam ( Config(..) , adam ) where import GHC.Generics (Generic) import Prelude hiding (div) -- import Control.Monad (when) import Data.Default import qualified Pipes as P import Numeric.SGD.Type import Numeric.SGD.ParamSet -- | AdaDelta configuration data Config = Config { alpha :: Double -- ^ Step size , beta1 :: Double -- ^ 1st exponential moment decay , beta2 :: Double -- ^ 1st exponential moment decay , eps :: Double -- ^ Epsilon } deriving (Show, Eq, Ord, Generic) instance Default Config where def = Config { alpha = 0.001 , beta1 = 0.9 , beta2 = 0.999 , eps = 1.0e-8 } -- | Perform gradient descent using the Adam algorithm. -- See "Numeric.SGD.Adam" for more information. adam :: (Monad m, ParamSet p) => Config -- ^ Adam configuration -> (e -> p -> p) -- ^ Gradient on a training element -> SGD m e p adam Config{..} gradient net0 = let zr = zero net0 in go (1 :: Integer) zr zr net0 where go t m v net = do x <- P.await let g = gradient x net m' = pmap (*beta1) m `add` pmap (*(1-beta1)) g v' = pmap (*beta2) v `add` pmap (*(1-beta2)) (g `mul` g) -- bias-corrected moment estimates mb = pmap (/(1-beta1^t)) m' vb = pmap (/(1-beta2^t)) v' newNet = net `sub` ( pmap (*alpha) mb `div` (pmap (+eps) (pmap sqrt vb)) ) P.yield newNet go (t+1) m' v' newNet ------------------------------- -- Utils ------------------------------- -- -- | Scaling -- scale :: ParamSet p => Double -> p -> p -- scale x = pmap (*x) -- {-# INLINE scale #-} -- -- -- -- | Root square -- squareRoot :: ParamSet p => p -> p -- squareRoot = pmap sqrt -- {-# INLINE squareRoot #-} -- -- -- -- | Square -- square :: ParamSet p => p -> p -- square x = x `mul` x -- {-# INLINE square #-}