{-# OPTIONS_GHC -Wno-redundant-constraints #-}
module MXNet.Core.Base.Executor where
import Control.Monad
import MXNet.Core.Base.Internal
import MXNet.Core.Base.DType
import MXNet.Core.Base.NDArray (NDArray(NDArray))
newtype Executor a = Executor { getHandle :: ExecutorHandle }
makeExecutor :: DType a
=> ExecutorHandle
-> IO (Executor a)
makeExecutor = return . Executor
forward :: DType a
=> Executor a
-> Bool
-> IO ()
forward exec train = void $ mxExecutorForward (getHandle exec) (if train then 1 else 0)
backward :: DType a
=> Executor a
-> IO ()
backward exec = void $ mxExecutorBackward (getHandle exec) 0 []
getOutputs :: DType a
=> Executor a
-> IO [NDArray a]
getOutputs exec = do
(_, outs) <- mxExecutorOutputs (getHandle exec)
return $ NDArray <$> outs