{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE DeriveGeneric #-} -- | Pipe-based interface module Numeric.SGD.Pipe where -- ( Config(..) -- , Method(..) -- , sgd -- , result -- , every -- , pipeSeq -- , pipeRan -- ) where -- import GHC.Generics (Generic) -- import Numeric.Natural (Natural) -- -- import Control.Monad (when, forM_) -- -- import qualified System.Random as R -- -- import qualified Data.IORef as IO -- -- import qualified Pipes as P -- import qualified Pipes.Prelude as P -- import Pipes ((>->)) -- -- import Numeric.SGD.DataSet -- import Numeric.SGD.ParamSet -- import qualified Numeric.SGD.AdaDelta as Ada -- import qualified Numeric.SGD.Momentum as Mom -- -- -- ------------------------------- -- -- Data -- ------------------------------- -- -- -- -- | Top-level SGD configuration -- data Config = Config -- { iterNum :: Natural -- -- ^ Number of iteration over the entire training dataset -- , batchRandom :: Bool -- -- ^ Should the mini-batch be selected at random? If not, the subsequent -- -- training elements will be picked sequentially. Random selection gives -- -- no guarantee of seeing each training sample in every epoch. Use `False` -- -- if unsure. -- , method :: Method -- -- ^ Selected SGD method -- , reportPeriod :: Double -- -- ^ How often the quality should be reported (with `1` meaning once per -- -- pass over the training data) -- } deriving (Show, Eq, Ord, Generic) -- -- -- -- | Different SGD methods, together with the corresponding configurations -- data Method -- = AdaDelta {adaDeltaCfg :: Ada.Config} -- | Momentum {momentumCfg :: Mom.Config} -- deriving (Show, Eq, Ord, Generic) -- -- -- ------------------------------- -- -- SGD -- ------------------------------- -- -- -- -- | Pipe-based SGD. -- sgd -- :: (ParamSet p) -- => Config -- -> DataSet e -- -> (e -> p -> p) -- -- ^ Network gradient on a sample element -- -> (e -> p -> Double) -- -- ^ Value of the objective function on a sample element -- -> p -- -- ^ Initial parameter values -- -> IO p -- sgd Config{..} dataSet grad0 quality0 net0 = do -- let sgdPipe = -- case method of -- Momentum cfg -> Mom.momentum cfg -- -- (cfg {Mom.tau = iterScale (Mom.tau cfg)}) -- grad0 net0 -- AdaDelta cfg -> Ada.adaDelta cfg grad0 net0 -- report net0 -- result net0 $ pipeSeq dataSet -- >-> sgdPipe -- >-> P.take realIterNum -- >-> every realReportPeriod report -- where -- -- Iteration scaling -- iterScale x = fromIntegral (size dataSet) * x -- -- Number of iterations and reporting period -- realIterNum = ceiling $ iterScale (fromIntegral iterNum :: Double) -- realReportPeriod = ceiling $ iterScale reportPeriod -- -- Network quality over the entire training dataset -- report net = do -- putStr . show =<< quality net -- putStrLn $ " (norm_2 = " ++ show (norm_2 net) ++ ")" -- quality net = do -- res <- IO.newIORef 0.0 -- forM_ [0 .. size dataSet - 1] $ \ix -> do -- x <- elemAt dataSet ix -- IO.modifyIORef' res (+ quality0 x net) -- IO.readIORef res -- -- -- ------------------------------- -- -- Dataset producers -- ------------------------------- -- -- -- -- | Pipe the dataset sequentially in a loop. -- pipeSeq :: DataSet e -> P.Producer e IO () -- pipeSeq dataSet = do -- go (0 :: Int) -- where -- go k -- | k >= size dataSet = go 0 -- | otherwise = do -- x <- P.lift $ elemAt dataSet k -- P.yield x -- go (k+1) -- -- -- -- | Pipe the dataset randomly in a loop. -- pipeRan :: DataSet e -> P.Producer e IO () -- pipeRan dataSet = do -- x <- P.lift $ do -- ix <- R.randomRIO (0, size dataSet - 1) -- elemAt dataSet ix -- P.yield x -- pipeRan dataSet -- -- -- ------------------------------- -- -- Utils -- ------------------------------- -- -- -- -- | Extract the result of the SGD calculation (the last parameter -- -- set flowing downstream). -- result -- :: (Monad m) -- => p -- -- ^ Default value (in case the stream is empty) -- -> P.Producer p m () -- -- ^ Stream of parameter sets -- -> m p -- result pDef = fmap (maybe pDef id) . P.last -- -- -- -- | Report every `k`-th parameter set flowing downstream. -- every :: (Monad m) => Int -> (p -> m ()) -> P.Pipe p p m x -- every k report = do -- go (1 `mod` k) -- where -- go i = do -- paramSet <- P.await -- when (i == 0) $ do -- P.lift $ report paramSet -- P.yield paramSet -- go $ (i+1) `mod` k