{-# LANGUAGE TemplateHaskell #-}

module MXNet.NN.Types where



import Control.Lens (makeLenses)

import qualified Data.HashMap.Strict as M

import qualified Control.Monad.State.Strict as ST

import MXNet.Core.Base hiding (bind, context, (^.))



-- | A parameter is two 'NDArray' to back a 'Symbol'

data Parameter a = Parameter { _param_in :: NDArray a, _param_grad :: NDArray a }

    deriving Show



-- | Session is all the 'Parameters' and a 'Context'

-- type Session a = (M.HashMap String (Parameter a), Context)

data Session a = Session { _sess_param :: !(M.HashMap String (Parameter a)), _sess_context :: !Context }

makeLenses ''Session

-- | TrainM is a 'StateT' monad

type TrainM a m = ST.StateT (Session a) m



-- | Initializer is about how to create a NDArray from a given shape. 

-- 

-- Usually, it can be a wrapper of MXNet operators, such as @random_uniform@, @random_normal@, 

-- @random_gamma@, etc..

type Initializer a = Context -> [Int] -> IO (NDArray a)