{-# LANGUAGE RecordWildCards #-}
module Numeric.SGD.Momentum
( SgdArgs (..)
, sgdArgsDefault
, Para
, sgd
, module Numeric.SGD.Grad
, module Numeric.SGD.Dataset
) where
import Control.Monad (forM_, when)
import qualified System.Random as R
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import qualified Control.Monad.Primitive as Prim
import Numeric.SGD.Grad
import Numeric.SGD.Dataset
data SgdArgs = SgdArgs
{
batchSize :: Int
, regVar :: Double
, iterNum :: Double
, gain0 :: Double
, tau :: Double }
sgdArgsDefault :: SgdArgs
sgdArgsDefault = SgdArgs
{ batchSize = 50
, regVar = 10
, iterNum = 10
, gain0 = 0.25
, tau = 5 }
gamma :: Double
gamma = 0.9
type Para = U.Vector Double
type MVect = UM.MVector (Prim.PrimState IO) Double
sgd
:: SgdArgs
-> (Para -> Int -> IO ())
-> (Para -> x -> Grad)
-> Dataset x
-> Para
-> IO Para
sgd SgdArgs{..} notify mkGrad dataset x0 = do
putStrLn $ "Running momentum!"
momentum <- UM.new (U.length x0)
u <- UM.new (U.length x0)
doIt momentum u 0 (R.mkStdGen 0) =<< U.thaw x0
where
gain k = (gain0 * tau) / (tau + done k)
done :: Int -> Double
done k
= fromIntegral (k * batchSize)
/ fromIntegral (size dataset)
doneTotal :: Int -> Int
doneTotal = floor . done
regularizationParam = regCoef
where
regCoef = iVar ** coef
iVar = 1.0 / regVar
coef = fromIntegral (size dataset)
/ fromIntegral batchSize
doIt momentum u k stdGen x
| done k > iterNum = do
frozen <- U.unsafeFreeze x
notify frozen k
return frozen
| otherwise = do
(batch, stdGen') <- sample stdGen batchSize dataset
frozen <- U.unsafeFreeze x
notify frozen k
let grad = parUnions (map (mkGrad frozen) batch)
addUp grad u
applyRegularization regularizationParam x u
scale (gain k) u
updateMomentum gamma momentum u
x' <- U.unsafeThaw frozen
momentum `addTo` x'
doIt momentum u (k+1) stdGen' x'
applyRegularization
:: Double
-> MVect
-> MVect
-> IO ()
applyRegularization regParam params grad = do
forM_ [0 .. UM.length grad - 1] $ \i -> do
x <- UM.unsafeRead grad i
y <- UM.unsafeRead params i
UM.unsafeWrite grad i $ x - regParam * y
updateMomentum
:: Double
-> MVect
-> MVect
-> IO ()
updateMomentum gammaCoef momentum grad = do
forM_ [0 .. UM.length momentum - 1] $ \i -> do
x <- UM.unsafeRead momentum i
y <- UM.unsafeRead grad i
UM.unsafeWrite momentum i (gammaCoef * x + y)
addUp :: Grad -> MVect -> IO ()
addUp grad v = do
UM.set v 0
forM_ (toList grad) $ \(i, x) -> do
y <- UM.unsafeRead v i
UM.unsafeWrite v i (x + y)
scale :: Double -> MVect -> IO ()
scale c v = do
forM_ [0 .. UM.length v - 1] $ \i -> do
y <- UM.unsafeRead v i
UM.unsafeWrite v i (c * y)
addTo :: MVect -> MVect -> IO ()
addTo w v = do
forM_ [0 .. UM.length v - 1] $ \i -> do
x <- UM.unsafeRead v i
y <- UM.unsafeRead w i
UM.unsafeWrite v i (x + y)