{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE RecordWildCards #-}
module MXNet.NN (
Parameter(..),
Config(..),
Exc(..),
Initializer,
Optimizer,
TrainM,
train,
inferShape,
initialize,
fit,
forwardOnly
) where
import MXNet.Core.Base hiding (bind, context)
import MXNet.Core.Base.Internal
import qualified MXNet.Core.Base.NDArray as A
import qualified MXNet.Core.Base.Symbol as S
import qualified MXNet.Core.Base.Executor as E
import qualified MXNet.Core.Types.Internal as MXI
import qualified Data.HashMap.Strict as M
import Data.Typeable
import qualified Control.Monad.State as ST
import Data.Maybe (isJust, fromJust)
import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Resource (MonadThrow(..))
import Control.Exception.Base (Exception)
import Control.Lens (traverseOf, _1)
data Parameter a = Parameter { _param_in :: NDArray a, _param_grad :: NDArray a }
deriving Show
type TrainM a m = ST.StateT (M.HashMap String (Parameter a), Context) m
type Initializer a = [Int] -> IO (NDArray a)
type Optimizer a = NDArray a -> NDArray a -> IO (NDArray a)
train :: (DType a, Monad m) => M.HashMap String (Parameter a) -> Context -> TrainM a m r -> m r
train param context = flip ST.evalStateT (param, context)
inferShape :: DType a => Symbol a -> M.HashMap String (NDArray a) -> IO (M.HashMap String [Int])
inferShape sym known = do
let (names, vals) = unzip $ M.toList known
shapes <- mapM ndshape vals
let arg_ind = scanl (+) 0 $ map fst shapes
arg_shp = concat $ map snd shapes
(inp_shp, _, _) <- mxSymbolInferShape (S.getHandle sym) names arg_ind arg_shp
inps <- listInputs sym
return $ M.fromList $ zip inps inp_shp
data Config a = Config {
_cfg_placeholders :: M.HashMap String [Int],
_cfg_initializers :: M.HashMap String (Initializer a),
_cfg_default_initializer :: Initializer a
}
initialize :: DType a => Symbol a -> Config a -> IO (M.HashMap String (Parameter a))
initialize sym config = do
let spec1 = M.difference (_cfg_placeholders config) (_cfg_initializers config)
spec2 = _cfg_initializers config
dinit = _cfg_default_initializer config
placeholder <- mapM zeros spec1
inp_with_shp <- inferShape sym placeholder
M.traverseWithKey (init_with_random_normal placeholder spec2 dinit) inp_with_shp
where
init_with_random_normal placeholder spec2 dinit inp shp = do
case M.lookup inp placeholder of
Just in_arg -> return $ Parameter in_arg (A.NDArray MXI.nullNDArrayHandle)
Nothing -> do
arg_in <- case M.lookup inp spec2 of
Just cinit -> cinit shp
Nothing -> dinit shp
arg_gr <- zeros shp
return $ Parameter arg_in arg_gr
bind :: DType a => Symbol a -> M.HashMap String (Parameter a) -> Context -> Bool -> IO (Executor a)
bind net args Context{..} train_ = do
names <- listInputs net
exec_handle <- checked $ mxExecutorBind (S.getHandle net) deviceType deviceId
(fromIntegral (M.size args))
(map (A.getHandle . _param_in) $ map (args M.!) names)
(if train_
then map (A.getHandle . _param_grad) $ map (args M.!) names
else replicate (M.size args) MXI.nullNDArrayHandle)
(replicate (M.size args) 1)
0 []
makeExecutor exec_handle
fit :: (DType a, MonadIO m, MonadThrow m) => Optimizer a -> Symbol a -> M.HashMap String (NDArray a) -> TrainM a m ()
fit opt net datAndLbl = do
shps <- liftIO $ inferShape net datAndLbl
modifyT . traverseOf _1 $ M.traverseWithKey $ \k p -> do
let ishp = shps M.! k
case M.lookup k datAndLbl of
Just a -> return $ p {_param_in = a}
Nothing -> do
(_, pshp1) <- liftIO $ ndshape (_param_in p)
(_, pshp2) <- liftIO $ ndshape (_param_grad p)
when (ishp /= pshp1 || ishp /= pshp2) (throwM $ MismatchedShape k)
return p
(params, context) <- ST.get
liftIO $ do
exec <- bind net params context True
checked $ mxExecutorForward (E.getHandle exec) 1
backward exec
modifyT . traverseOf _1 $ M.traverseWithKey $ \ k v -> do
if (not $ M.member k datAndLbl)
then do new_in <- liftIO $ opt (_param_in v) (_param_grad v)
return $ v {_param_in = new_in}
else return v
forwardOnly :: (DType a, MonadIO m, MonadThrow m) => Symbol a -> M.HashMap String (Maybe (NDArray a)) -> TrainM a m [NDArray a]
forwardOnly net dat = do
shps <- liftIO $ inferShape net (M.map fromJust $ M.filter isJust dat)
modifyT . traverseOf _1 $ M.traverseWithKey $ \k p -> do
let ishp = shps M.! k
case M.lookup k dat of
Just (Just a) ->
return $ p {_param_in = a}
Just Nothing -> do
dummy <- liftIO $ zeros ishp
return $ p {_param_in = dummy}
Nothing -> do
(_, pshp) <- liftIO $ ndshape (_param_in p)
when (ishp /= pshp) (throwM $ MismatchedShape k)
return p
(params, context) <- ST.get
liftIO $ do
exec <- bind net params context False
checked $ mxExecutorForward (E.getHandle exec) 0
getOutputs exec
data Exc = MismatchedShape String
deriving (Show, Typeable)
instance Exception Exc
modifyT :: Monad m => (s -> m s) -> ST.StateT s m ()
modifyT func = do
s0 <- ST.get
s1 <- ST.lift $ func s0
ST.put s1