{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}
module MXNet.NN (
    Parameter(..),
    Config(..),
    Session(..),
    Exc(..),
    Initializer,
    TrainM,
    train,
    inferShape,
    initialize,
    fit, fitAndEval,
    forwardOnly,
    getContext,
    sess_param,
    sess_context,
    module MXNet.NN.Optimizer
) where

import MXNet.Core.Base hiding (bind, context, (^.))
import MXNet.Core.Base.Internal
import qualified MXNet.Core.Base.NDArray as A
import qualified MXNet.Core.Base.Symbol as S
import qualified MXNet.Core.Base.Executor as E
import qualified MXNet.Core.Types.Internal as MXI
import qualified MXNet.Core.Base.Internal.TH.NDArray as MXI
import qualified Data.HashMap.Strict as M
import Data.Typeable
import qualified Control.Monad.State.Strict as ST
import Data.Maybe (isJust, fromJust, maybe)
import Control.Monad (when, zipWithM_)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Resource (MonadThrow(..))
import Control.Exception.Base (Exception)
import Control.Lens (traverseOf, use, (^.))

import MXNet.NN.Types
import MXNet.NN.Optimizer
import MXNet.NN.EvalMetric

-- | Execute the 'TrainM' monad
train :: (DType a, Monad m) => Session a -> TrainM a m r -> m r
train = flip ST.evalStateT

-- | infer the shapes of all the symbols in a symbolic neural network
inferShape :: DType a => Symbol a -> M.HashMap String (NDArray a) -> IO (M.HashMap String [Int])
inferShape sym known = do
    let (names, vals) = unzip $ M.toList known
    shapes <- mapM ndshape vals
    let arg_ind = scanl (+) 0 $ map fst shapes
        arg_shp = concat $ map snd shapes
    (inp_shp, _, _) <- mxSymbolInferShape (S.getHandle sym) names arg_ind arg_shp
    inps <- listInputs sym
    return $ M.fromList $ zip inps inp_shp

-- | For every symbol in the neural network, it can be placeholder or a variable.
-- therefore, a Config is to specify the shape of the placeholder and the 
-- method to initialize the variables.
-- 
-- Note that it is not right to specify a symbol as both placeholder and 
-- initializer, although it is tolerated and such a symbol is considered
-- as a variable.
-- 
-- Note that any symbol not specified will be initialized with the 
-- _cfg_default_initializer.
data Config a = Config {
    _cfg_placeholders :: M.HashMap String [Int],
    _cfg_initializers :: M.HashMap String (Initializer a),
    _cfg_default_initializer :: Initializer a,
    _cfg_context :: Context
}

-- | initialize all parameters
initialize :: DType a => Symbol a -> Config a -> IO (Session a)
initialize sym config = do
    let spec1 = M.difference (_cfg_placeholders config) (_cfg_initializers config)
        spec2 = _cfg_initializers config
        dinit = _cfg_default_initializer config
        cxt   = _cfg_context config
    placeholder  <- mapM (\shp -> makeEmptyNDArray shp cxt False) spec1
    inp_with_shp <- inferShape sym placeholder
    args <- M.traverseWithKey (init_with_random_normal placeholder spec2 dinit) inp_with_shp
    return $ Session args cxt
  where
    init_with_random_normal placeholder spec2 dinit inp shp = do
        case M.lookup inp placeholder of
            Just in_arg -> do
                nullarg <- MXI.nullNDArrayHandle
                return $ Parameter in_arg (A.NDArray nullarg)
            Nothing -> do
                arg_in <- case M.lookup inp spec2 of
                    Just cinit -> cinit (_cfg_context config) shp
                    Nothing    -> dinit (_cfg_context config) shp
                arg_gr <- makeEmptyNDArray shp (_cfg_context config) False
                return $ Parameter arg_in arg_gr

-- | bind the symbolic network with actual parameters
bind :: (DType a, MonadIO m, MonadThrow m) => Symbol a -> M.HashMap String (Maybe (NDArray a)) -> Bool -> TrainM a m (Executor a)
bind net dat train_ = do
    Context{..} <- use sess_context

    shps <- liftIO $ inferShape net (M.map fromJust $ M.filter isJust dat)
    modifyT . traverseOf sess_param $ M.traverseWithKey $ \k p -> do
        let ishp = shps M.! k
        case M.lookup k dat of
            Just a  -> liftIO $ update_param (maybe (Right ishp) Left a) p
            Nothing -> do
                (_, pshp1) <- liftIO $ ndshape (_param_in p)
                when (ishp /= pshp1 ) (throwM $ MismatchedShape k)
                when train_ $ do
                    (_, pshp2) <- liftIO $ ndshape (_param_grad p)
                    when (ishp /= pshp2) (throwM $ MismatchedShape k)
                return p

    args <- use sess_param
    exec_handle <- liftIO $ do
        names <- listInputs net
        nullarg <- MXI.nullNDArrayHandle
        -- the parameters to bind should be arranged in the same order as the names
        let arg_num = fromIntegral (M.size args)
            arg_in  = map (A.getHandle . _param_in) $ map (args M.!) names
            arg_gr  = if train_ 
                        then map (A.getHandle . _param_grad) $ map (args M.!) names
                        else replicate (M.size args) nullarg
            arg_gr_req = replicate (M.size args) 1

        checked $ mxExecutorBind (S.getHandle net) deviceType deviceId
                                            arg_num arg_in arg_gr arg_gr_req 
                                            0 []
    return $ E.Executor exec_handle
  where
    update_param :: DType a => Either (NDArray a) [Int] -> Parameter a -> IO (Parameter a)
    update_param (Left a) p = do
        src_cxt <- A.context a
        src_shp <- snd <$> A.ndshape a
        dst_cxt <- A.context (_param_in p)
        dst_shp <- snd <$> A.ndshape (_param_in p)
        case (src_cxt == dst_cxt, src_shp == dst_shp) of
            (True , True) -> return $ p {_param_in = a}
            (False, True) -> do
                MXI._copyto' (A.getHandle a) [A.getHandle (_param_in p)] :: IO ()
                return p
            _ -> do
                a_copy <- makeEmptyNDArray src_shp dst_cxt False
                MXI._copyto' (A.getHandle a) [A.getHandle a_copy] :: IO ()
                return $! p {_param_in = a_copy}    
    update_param (Right src_shp) p = do
        dst_cxt <- A.context (_param_in p)
        dst_shp <- snd <$> A.ndshape (_param_in p)
        if src_shp == dst_shp 
            then return p
            else do
                dummy <- makeEmptyNDArray src_shp dst_cxt False
                return $! p {_param_in = dummy}

-- | single step train. Must provide all the placeholders.
fit :: (DType a, MonadIO m, MonadThrow m, Optimizer opt, OptArgsCst opt g) 
    => opt a g -> Symbol a -> M.HashMap String (NDArray a) -> TrainM a m ()
fit opt net datAndLbl = do
    exec <- bind net (M.map Just datAndLbl) True
    liftIO $ do 
        checked $ mxExecutorForward (E.getHandle exec) 1
        checked $ mxExecutorBackward (E.getHandle exec) 0 []
        -- forward/backward are asynchronised operation in mxnet, in a
        -- sense that only opcodes are pushed onto an internal execution 
        -- stack, and there is a executor running in a separate thread.
        -- It is possible that an OOM of CPU memory occurs, if 'fit' are 
        -- called so fast that too many opcodes and data on the stack, 
        -- as described in issue #1
        checked $ mxNDArrayWaitAll
    updateParameters opt datAndLbl

-- | single step train. Must provide all the placeholders.
--   After fitting, it also update the evaluation metric.
fitAndEval :: (DType a, MonadIO m, MonadThrow m, Optimizer opt, OptArgsCst opt g, EvalMetricMethod mth)
           => opt a g -> Symbol a -> M.HashMap String (NDArray a) -> Metric a mth -> TrainM a m ()
fitAndEval opt net datAndLbl metric = do
     exec  <- bind net (M.map Just datAndLbl) True
     preds <- liftIO $ do 
         checked $ mxExecutorForward (E.getHandle exec) 1
         checked $ mxExecutorBackward (E.getHandle exec) 0 []
         checked $ mxNDArrayWaitAll
         getOutputs exec
     updateParameters opt datAndLbl
     let labels = map (datAndLbl M.!) (metric ^. metric_labelname)
     liftIO $ zipWithM_ (evaluate metric) preds labels

updateParameters :: (MonadIO m, Optimizer opt, OptArgsCst opt args) 
                 => opt dtype args -> M.HashMap String any -> TrainM dtype m ()
updateParameters opt blacklist = do
    modifyT . traverseOf sess_param  $ M.traverseWithKey $ \ k v -> do
        if (not $ M.member k blacklist)
            then do new_in <- liftIO $ optimize opt k (_param_in v) (_param_grad v) 
                    -- must evaluate the new parameter to WHNF
                    -- otherwise, the old _param_in is retained.
                    -- if context is GPU, then OOM will soon 
                    -- occur, as described in issue #2
                    return $! v {_param_in = new_in}
            else return v

-- | forward only. Must provide all the placeholders, setting the data to @Just xx@, and set label to @Nothing@.
-- 
-- Note that the batch size here can be different from that in the training phase.
forwardOnly :: (DType a, MonadIO m, MonadThrow m) => Symbol a -> M.HashMap String (Maybe (NDArray a)) -> TrainM a m [NDArray a]
forwardOnly net dat = do
    exec <- bind net dat False
    liftIO $ do
        checked $ mxExecutorForward (E.getHandle exec) 0
        -- for the same reason in 'fit'.
        checked $ mxNDArrayWaitAll
        getOutputs exec

getContext :: Monad m => TrainM a m Context
getContext = use sess_context

-- | Possible exception in 'TrainM'
data Exc = MismatchedShape String
    deriving (Show, Typeable)
instance Exception Exc

-- | modify the state within the inner monad
-- 
-- thanks to lens, we can modify the first field of the state with following 
-- combinator:
-- 
-- modifyT . traverseOf _1
--  :: (Field1 s s a b, Monad m) => (a -> m b) -> StateT s m ()
modifyT :: Monad m => (s -> m s) -> ST.StateT s m ()
modifyT func = do
    s0 <- ST.get
    s1 <- ST.lift $ func s0
    ST.put s1