{-# LANGUAGE RecordWildCards #-}
module Numeric.SGD
( SgdArgs (..)
, sgdArgsDefault
, Dataset
, Para
, sgd
, sgdM
, module Numeric.SGD.Grad
) where
import Control.Monad (forM_)
import Control.Monad.ST (ST, runST)
import qualified System.Random as R
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import qualified Control.Monad.Primitive as Prim
import Numeric.SGD.Grad
data SgdArgs = SgdArgs
{
batchSize :: Int
, regVar :: Double
, iterNum :: Double
, gain0 :: Double
, tau :: Double }
sgdArgsDefault :: SgdArgs
sgdArgsDefault = SgdArgs
{ batchSize = 30
, regVar = 10
, iterNum = 10
, gain0 = 1
, tau = 5 }
type Dataset x = V.Vector x
type Para = U.Vector Double
type MVect m = UM.MVector (Prim.PrimState m) Double
sgd :: SgdArgs
-> (Para -> x -> Grad)
-> Dataset x
-> Para
-> Para
sgd sgdArgs mkGrad dataset x0 =
let dummy _ _ = return ()
in runST $ sgdM sgdArgs dummy mkGrad dataset x0
{-# SPECIALIZE sgdM :: SgdArgs
-> (Para -> Int -> IO ())
-> (Para -> x -> Grad)
-> Dataset x -> Para -> IO Para #-}
{-# SPECIALIZE sgdM :: SgdArgs
-> (Para -> Int -> ST s ())
-> (Para -> x -> Grad)
-> Dataset x -> Para -> ST s Para #-}
sgdM
:: (Prim.PrimMonad m)
=> SgdArgs
-> (Para -> Int -> m ())
-> (Para -> x -> Grad)
-> Dataset x
-> Para
-> m Para
sgdM SgdArgs{..} notify mkGrad dataset x0 = do
u <- UM.new (U.length x0)
doIt u 0 (R.mkStdGen 0) =<< U.thaw x0
where
gain k = (gain0 * tau) / (tau + done k)
done k
= fromIntegral (k * batchSize)
/ fromIntegral (V.length dataset)
doIt u k stdGen x
| done k > iterNum = do
frozen <- U.unsafeFreeze x
notify frozen k
return frozen
| otherwise = do
let (batch, stdGen') = sample stdGen batchSize dataset
frozen <- U.unsafeFreeze x
notify frozen k
let grad = parUnions (map (mkGrad frozen) batch)
addUp grad u
scale (gain k) u
x' <- U.unsafeThaw frozen
apply u x'
doIt u (k+1) stdGen' x'
{-# SPECIALIZE addUp :: Grad -> MVect IO -> IO () #-}
{-# SPECIALIZE addUp :: Grad -> MVect (ST s) -> ST s () #-}
addUp :: Prim.PrimMonad m => Grad -> MVect m -> m ()
addUp grad v = do
UM.set v 0
forM_ (toList grad) $ \(i, x) -> do
y <- UM.unsafeRead v i
UM.unsafeWrite v i (x + y)
{-# SPECIALIZE scale :: Double -> MVect IO -> IO () #-}
{-# SPECIALIZE scale :: Double -> MVect (ST s) -> ST s () #-}
scale :: Prim.PrimMonad m => Double -> MVect m -> m ()
scale c v = do
forM_ [0 .. UM.length v - 1] $ \i -> do
y <- UM.unsafeRead v i
UM.unsafeWrite v i (c * y)
{-# SPECIALIZE apply :: MVect IO -> MVect IO -> IO () #-}
{-# SPECIALIZE apply :: MVect (ST s) -> MVect (ST s) -> ST s () #-}
apply :: Prim.PrimMonad m => MVect m -> MVect m -> m ()
apply w v = do
forM_ [0 .. UM.length v - 1] $ \i -> do
x <- UM.unsafeRead v i
y <- UM.unsafeRead w i
UM.unsafeWrite v i (x + y)
sample :: R.RandomGen g => g -> Int -> Dataset x -> ([x], g)
sample g 0 _ = ([], g)
sample g n dataset =
let (xs, g') = sample g (n-1) dataset
(i, g'') = R.next g'
x = dataset V.! (i `mod` V.length dataset)
in (x:xs, g'')