{-# LANGUAGE RecordWildCards #-}
module Numeric.SGD.Sparse.Momentum
( SgdArgs (..)
, sgdArgsDefault
, Para
, sgd
, module Numeric.SGD.Sparse.Grad
, module Numeric.SGD.DataSet
) where
import Control.Monad (forM_)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import qualified Control.Monad.Primitive as Prim
import Numeric.SGD.Sparse.Grad
import Numeric.SGD.DataSet
data SgdArgs = SgdArgs
{
batchSize :: Int
, regVar :: Double
, iterNum :: Double
, gain0 :: Double
, tau :: Double
}
sgdArgsDefault :: SgdArgs
sgdArgsDefault = SgdArgs
{ batchSize = 50
, regVar = 10
, iterNum = 10
, gain0 = 0.25
, tau = 5
}
gamma :: Double
gamma = 0.9
type Para = U.Vector Double
type MVect = UM.MVector (Prim.PrimState IO) Double
sgd
:: SgdArgs
-> (Para -> Int -> IO ())
-> (Para -> x -> Grad)
-> DataSet x
-> Para
-> IO Para
sgd SgdArgs{..} notify mkGrad dataset x0 = do
putStrLn $ "Running momentum!"
momentum <- UM.new (U.length x0)
u <- UM.new (U.length x0)
doIt momentum u 0 =<< U.thaw x0
where
gain k = (gain0 * tau) / (tau + done k)
done :: Int -> Double
done k
= fromIntegral (k * batchSize)
/ fromIntegral (size dataset)
regularizationParam = regCoef
where
regCoef = iVar ** coef
iVar = 1.0 / regVar
coef = fromIntegral (size dataset)
/ fromIntegral batchSize
doIt momentum u k x
| done k > iterNum = do
frozen <- U.unsafeFreeze x
notify frozen k
return frozen
| otherwise = do
batch <- randomSample batchSize dataset
frozen <- U.unsafeFreeze x
notify frozen k
let grad = parUnions (map (mkGrad frozen) batch)
addUp grad u
applyRegularization regularizationParam x u
scale (gain k) u
updateMomentum gamma momentum u
x' <- U.unsafeThaw frozen
momentum `addTo` x'
doIt momentum u (k+1) x'
applyRegularization
:: Double
-> MVect
-> MVect
-> IO ()
applyRegularization regParam params grad = do
forM_ [0 .. UM.length grad - 1] $ \i -> do
x <- UM.unsafeRead grad i
y <- UM.unsafeRead params i
UM.unsafeWrite grad i $ x - regParam * y
updateMomentum
:: Double
-> MVect
-> MVect
-> IO ()
updateMomentum gammaCoef momentum grad = do
forM_ [0 .. UM.length momentum - 1] $ \i -> do
x <- UM.unsafeRead momentum i
y <- UM.unsafeRead grad i
UM.unsafeWrite momentum i (gammaCoef * x + y)
addUp :: Grad -> MVect -> IO ()
addUp grad v = do
UM.set v 0
forM_ (toList grad) $ \(i, x) -> do
y <- UM.unsafeRead v i
UM.unsafeWrite v i (x + y)
scale :: Double -> MVect -> IO ()
scale c v = do
forM_ [0 .. UM.length v - 1] $ \i -> do
y <- UM.unsafeRead v i
UM.unsafeWrite v i (c * y)
addTo :: MVect -> MVect -> IO ()
addTo w v = do
forM_ [0 .. UM.length v - 1] $ \i -> do
x <- UM.unsafeRead v i
y <- UM.unsafeRead w i
UM.unsafeWrite v i (x + y)