{-# LANGUAGE TemplateHaskell #-}
module MXNet.NN.Types where
import Control.Lens (makeLenses)
import qualified Data.HashMap.Strict as M
import qualified Control.Monad.State.Strict as ST
import MXNet.Core.Base hiding (bind, context, (^.))
data Parameter a = Parameter { _param_in :: NDArray a, _param_grad :: NDArray a }
deriving Show
data Session a = Session { _sess_param :: !(M.HashMap String (Parameter a)), _sess_context :: !Context }
makeLenses ''Session
type TrainM a m = ST.StateT (Session a) m
type Initializer a = Context -> [Int] -> IO (NDArray a)