-----------------------------------------------------------
-- |
-- module:                      MXNet.Core.Base.Executor
-- copyright:                   (c) 2016 Tao He
-- license:                     MIT
-- maintainer:                  sighingnow@gmail.com
--
-- Symbol module.
--
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module MXNet.Core.Base.Executor where

import           Control.Monad
import           MXNet.Core.Base.Internal
import           MXNet.Core.Base.DType
import           MXNet.Core.Base.NDArray (NDArray(NDArray))

-- | Type alias for variable.
newtype Executor a = Executor { getHandle :: ExecutorHandle }

-- | Make an executor using the given handler.
makeExecutor :: DType a
             => ExecutorHandle
             -> IO (Executor a)
makeExecutor = return . Executor

-- | Executor forward method.
forward :: DType a
        => Executor a   -- ^ The executor handle.
        -> Bool         -- ^ Whether this forward is for evaluation purpose.
        -> IO ()
forward exec train = void $ mxExecutorForward (getHandle exec) (if train then 1 else 0)

-- | Executor backward method.
backward :: DType a
         => Executor a  -- ^ The executor handle.
         -> IO ()
backward exec = void $ mxExecutorBackward (getHandle exec) 0 []

getOutputs :: DType a
          => Executor a
          -> IO [NDArray a]
getOutputs exec = do
    (_, outs) <- mxExecutorOutputs (getHandle exec)
    return $ NDArray <$> outs