----------------------------------------------------------- -- | -- module: MXNet.Core.Base.Symbol -- copyright: (c) 2016-2017 Tao He -- license: MIT -- maintainer: sighingnow@gmail.com -- -- Symbol module. -- {-# OPTIONS_GHC -Wno-missing-methods #-} {-# OPTIONS_GHC -Wno-redundant-constraints #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeApplications #-} module MXNet.Core.Base.Symbol where import Control.Exception (assert, throw) import Control.Monad import Data.HashMap.Strict (HashMap) import qualified Data.HashMap.Strict as HM import Data.IORef import Data.Monoid import Foreign.Ptr (nullPtr) import System.IO.Unsafe import Unsafe.Coerce (unsafeCoerce) import MXNet.Core.Base.DType import MXNet.Core.Base.Internal import qualified MXNet.Core.Base.Internal.TH.Symbol as I import MXNet.Core.Base.HMap import MXNet.Core.Base.Executor hiding (getHandle) import MXNet.Core.Base.NDArray hiding (getHandle) import qualified MXNet.Core.Base.NDArray as NDArray (getHandle) -- | Type alias for variable. newtype Symbol a = Symbol { getHandle :: SymbolHandle } instance DType a => Show (Symbol a) where show sym = unsafePerformIO $ do (_, str) <- mxSymbolPrint (getHandle sym) return str -- | Make a new symbolic variable with given name. variable :: DType a => String -- ^ Name. -> IO (Symbol a) -- ^ Result variable. variable name = do (_, handle) <- mxSymbolCreateVariable name return $ Symbol handle -- | Get the name of a given variable. getName :: DType a => Symbol a -> IO String getName = mxSymbolGetName . getHandle >=> \(_, nm, _) -> return nm -- | Get specified attribute of symbol. getAttr :: DType a => Symbol a -> String -> IO (Maybe String) getAttr sym key = do (_, s, success) <- mxSymbolGetAttr (getHandle sym) key return $ if success == 0 -- 0 when success, -1 when failure happens then Just s else Nothing -- | Set specified attribute of symbol. setAttr :: DType a => Symbol a -> String -> String -> IO () setAttr sym key value = void $ mxSymbolSetAttr (getHandle sym) key value -- | Infer the shape of the given symbol, return the in, out and auxiliary shape size. infershape :: DType a => Symbol a -> [String] -> IO ([[Int]], [[Int]], [[Int]]) infershape sym args = do (_, arg, out, aux) <- mxSymbolInferShape (getHandle sym) args [0] [] return (arg, out, aux) -- | Get the autodiff of current symbol. -- This function can only be used if current symbol is a loss function. grad :: DType a => Symbol a -> [String] -> IO (Symbol a) grad sym args = do let nargs = fromIntegral (length args) (_, handle) <- mxSymbolGrad (getHandle sym) nargs args return $ Symbol handle -- | Bind with explicit argument mapping (name -- value mapping). bind :: DType a => Symbol a -> Context -> HashMap String (NDArray a) -> IO (Executor a) bind sym Context{..} args = do inputs <- genNDArrayMapping <$> listInputs sym -- req_map = {'null': 0, 'write': 1, 'add': 3} let req_types = replicate (HM.size inputs) 1 -- use default value. (_, exec) <- mxExecutorBind (getHandle sym) deviceType deviceId (fromIntegral (HM.size inputs)) -- length of input arguments. (NDArray.getHandle <$> HM.elems inputs) (replicate (HM.size inputs) (unsafeCoerce nullPtr)) req_types 0 -- length of auxiliary states. [] -- no auxiliary states. return $ Executor exec where -- | Get ndarray lists handles from input arguments. genNDArrayMapping arg_names = HM.fromList (genfn <$> arg_names) where genfn nm = case HM.lookup nm args of Just v -> (nm, v) Nothing -> throw . userError $ "getNDArrayInputs: no argument " <> nm -- | Bind without explicit argument mapping (name -- value mapping). bind' :: DType a => Symbol a -> Context -> [NDArray a] -> IO (Executor a) bind' sym Context{..} args = do inputs <- genNDArrayMapping <$> listInputs sym -- req_map = {'null': 0, 'write': 1, 'add': 3} let req_types = replicate (HM.size inputs) 1 -- use default value. (_, exec) <- mxExecutorBind (getHandle sym) deviceType deviceId (fromIntegral (HM.size inputs)) -- length of input arguments. (NDArray.getHandle <$> HM.elems inputs) (replicate (HM.size inputs) (unsafeCoerce nullPtr)) req_types 0 -- length of auxiliary states. [] -- no auxiliary states. return $ Executor exec where -- | Get ndarray lists handles from input arguments without explicit argument names. genNDArrayMapping names = assert (length args == length names) $ HM.fromList (zip names args) -- | List all input arguments. listInputs :: DType a => Symbol a -> IO [String] listInputs sym = snd <$> mxSymbolListArguments (getHandle sym) -- | List all output results. listOutputs :: DType a => Symbol a -> IO [String] listOutputs sym = snd <$> mxSymbolListOutputs (getHandle sym) -- | List all auxiliary states. listAuxiliaries :: DType a => Symbol a -> IO [String] listAuxiliaries sym = snd <$> mxSymbolListAuxiliaryStates (getHandle sym) instance DType a => Num (Symbol a) where (+) sym1 sym2 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Plus (name1 <> "+" <> name2) handle1 handle2 (-) sym1 sym2 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Minus (name1 <> "-" <> name2) handle1 handle2 (*) sym1 sym2 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Mul (name1 <> "*" <> name2) handle1 handle2 abs sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.abs ("|" <> name1 <> "|") handle1 negate sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.negative ("(-" <> name1 <> ")") handle1 signum = error "Unsupported operator: signum(Symbol)" fromInteger = error "Unsupported operator: fromInteger(Symbol)" instance DType a => Fractional (Symbol a) where (/) sym1 sym2 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Div (name1 <> "/" <> name2) handle1 handle2 fromRational = error "Unsupported operator: fromRational(Symbol)" instance DType a => Floating (Symbol a) where exp sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.exp ("exp(" <> name1 <> ")") handle1 log sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.log ("log(" <> name1 <> ")") handle1 sqrt sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.sqrt ("sqrt(" <> name1 <> ")") handle1 sin sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.sin ("sin(" <> name1 <> ")") handle1 cos sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.cos ("cos(" <> name1 <> ")") handle1 tan sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.tan ("tan(" <> name1 <> ")") handle1 sinh sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.sinh ("sinh(" <> name1 <> ")") handle1 cosh sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.cosh ("cosh(" <> name1 <> ")") handle1 tanh sym1 = Symbol . unsafePerformIO $ do let handle1 = getHandle sym1 name1 <- getName sym1 I.tanh ("tanh(" <> name1 <> ")") handle1 instance Tensor Symbol where dot sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I.dot ("dot(" <> name1 <> "," <> name2 <> ")") handle1 handle2 nil reshape sym sh = Symbol <$> do let handle = getHandle sym sh' = "(" <> (init . tail . show $ sh) <> ")" name1 <- getName sym I.reshape ("reshape(" <> name1 <> "," <> sh' <> ")") handle (add @"shape" sh' nil) transpose sym = Symbol <$> do let handle = getHandle sym name1 <- getName sym I.transpose ("transpose(" <> name1 <> ")") handle nil (+.) sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Plus (name1 <> "+" <> name2) handle1 handle2 (-.) sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Minus (name1 <> "-" <> name2) handle1 handle2 (*.) sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Mul (name1 <> "*" <> name2) handle1 handle2 (/.) sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Div (name1 <> "*" <> name2) handle1 handle2 (^.) sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Power (name1 <> "*" <> name2) handle1 handle2 (.+) sym value = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._PlusScalar (name1 <> "+" <> show value) handle (realToFrac value) {-# INLINE (.+) #-} (.-) sym value = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._MinusScalar (name1 <> "-" <> show value) handle (realToFrac value) {-# INLINE (.-) #-} (.*) sym value = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._MulScalar (name1 <> "*" <> show value) handle (realToFrac value) {-# INLINE (.*) #-} (./) sym value = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._DivScalar (name1 <> "/" <> show value) handle (realToFrac value) {-# INLINE (./) #-} (.^) sym value = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._PowerScalar (name1 <> "^" <> show value) handle (realToFrac value) {-# INLINE (.^) #-} (..-) value sym = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._RMinusScalar (show value <> "-" <> name1) handle (realToFrac value) {-# INLINE (..-) #-} (../) value sym = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._RDivScalar (show value <> "/" <> name1) handle (realToFrac value) {-# INLINE (../) #-} (..^) value sym = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._RPowerScalar (show value <> "^" <> name1) handle (realToFrac value) {-# INLINE (..^) #-} _Maximum sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Maximum ("_Maximum(" <> name1 <> "," <> name2 <> ")") handle1 handle2 {-# INLINE _Maximum #-} _Maximum' sym scalar = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._MaximumScalar ("_Maximum'(" <> name1 <> "," <> show scalar <> ")") handle (realToFrac scalar) {-# INLINE _Maximum' #-} _Minimum sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I._Minimum ("_Minimum(" <> name1 <> "," <> name2 <> ")") handle1 handle2 {-# INLINE _Minimum #-} _Minimum' sym scalar = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._MinimumScalar ("_Minimum'(" <> name1 <> "," <> show scalar <> ")") handle (realToFrac scalar) {-# INLINE _Minimum' #-} equal sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I.broadcast_equal (name1 <> "==" <> name2) handle1 handle2 {-# INLINE equal #-} equal' sym scalar = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._equal_scalar (name1 <> "==" <> show scalar) handle (realToFrac scalar) {-# INLINE equal' #-} notEqual sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I.broadcast_not_equal (name1 <> "/=" <> name2) handle1 handle2 {-# INLINE notEqual #-} notEqual' sym scalar = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._not_equal_scalar (name1 <> "/=" <> show scalar) handle (realToFrac scalar) {-# INLINE notEqual' #-} greater sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I.broadcast_greater (name1 <> ">" <> name2) handle1 handle2 {-# INLINE greater #-} greater' sym scalar = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._greater_scalar (name1 <> ">" <> show scalar) handle (realToFrac scalar) {-# INLINE greater' #-} greaterEqual sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I.broadcast_greater_equal (name1 <> ">=" <> name2) handle1 handle2 {-# INLINE greaterEqual #-} greaterEqual' sym scalar = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._greater_equal_scalar (name1 <> ">=" <> show scalar) handle (realToFrac scalar) {-# INLINE greaterEqual' #-} lesser sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I.broadcast_lesser (name1 <> "<" <> name2) handle1 handle2 {-# INLINE lesser #-} lesser' sym scalar = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._lesser_scalar (name1 <> "<" <> show scalar) handle (realToFrac scalar) {-# INLINE lesser' #-} lesserEqual sym1 sym2 = Symbol <$> do let handle1 = getHandle sym1 handle2 = getHandle sym2 name1 <- getName sym1 name2 <- getName sym2 I.broadcast_lesser_equal (name1 <> "<=" <> name2) handle1 handle2 {-# INLINE lesserEqual #-} lesserEqual' sym scalar = Symbol <$> do let handle = getHandle sym name1 <- getName sym I._lesser_equal_scalar (name1 <> "<=" <> show scalar) handle (realToFrac scalar) {-# INLINE lesserEqual' #-} -- | Provide a globally unique serial ID for each symbol. symid :: IORef Int symid = unsafePerformIO (newIORef 0) -- | Generate a globally unique name for each symbol, thread safely. naming :: String -> IO String naming prefix = ((prefix <>) . show) <$> atomicModifyIORef symid (\a -> (a+1, a)) instance Neural Symbol where fullyConnected input weight bias n = Symbol <$> do let handle1 = getHandle input handle2 = getHandle weight handle3 = getHandle bias name <- naming "FullyConnected" I.fullyconnected name handle1 handle2 handle3 n nil correlation input1 input2 = Symbol <$> do let handle1 = getHandle input1 handle2 = getHandle input2 name <- naming "Correlation" I.correlation name handle1 handle2 nil activation input act = Symbol <$> do let handle1 = getHandle input name <- naming "Activation" I.activation name handle1 act leakyReLU input act = Symbol <$> do let handle1 = getHandle input name <- naming "LeakyReLU" I.leakyrelu name handle1 (add @"act_type" act nil) softmaxActivation input = Symbol <$> do let handle1 = getHandle input name <- naming "SoftmaxActivation" I.softmaxactivation name handle1 nil dropout input p = Symbol <$> do let handle1 = getHandle input name <- naming "Dropout" I.dropout name handle1 (add @"p" p nil) batchNorm input gm bt mm mv = Symbol <$> do let handle1 = getHandle input let handle2 = getHandle gm let handle3 = getHandle bt let handle4 = getHandle mm let handle5 = getHandle mv name <- naming "BatchNorm" I.batchnorm name handle1 handle2 handle3 handle4 handle5 nil instanceNorm input gamma beta eps = Symbol <$> do let handle1 = getHandle input handle2 = getHandle gamma handle3 = getHandle beta name <- naming "InstnaceNorm" I.instancenorm name handle1 handle2 handle3 (add @"eps" eps nil) l2Normalization input eps mode = Symbol <$> do let handle1 = getHandle input name <- naming "L2Normalization" I.l2normalization name handle1 (add @"eps" eps $ add @"mode" mode nil) convolution input weight bias kernel n = Symbol <$> do let handle1 = getHandle input handle2 = getHandle weight handle3 = getHandle bias name <- naming "Convolution" I.convolution name handle1 handle2 handle3 kernel n nil lrn input alpha beta knorm nsize = Symbol <$> do let handle1 = getHandle input name <- naming "LRN" I.lrn name handle1 nsize (add @"alpha" alpha $ add @"beta" beta $ add @"knorm" knorm nil) deconvolution input weight bias kernel nfilter = Symbol <$> do let handle1 = getHandle input handle2 = getHandle weight handle3 = getHandle bias name <- naming "Deconvolution" I.deconvolution name handle1 handle2 handle3 kernel nfilter nil pooling input kernel pooltype = Symbol <$> do let handle1 = getHandle input name <- naming "Pooling" I.pooling name handle1 kernel pooltype nil softmaxOutput input label = Symbol <$> do let handle1 = getHandle input handle2 = getHandle label name <- naming "SoftmaxOutput" I.softmaxoutput name handle1 handle2 nil makeLoss input grad_scale valid_thresh normalization = Symbol <$> do let handle1 = getHandle input name <- naming "MakeLoss" I.makeloss name handle1 (add @"grad_scale" grad_scale $ add @"valid_thresh" valid_thresh $ add @"normalization" normalization nil) blockGrad input = Symbol <$> do let handle1 = getHandle input name <- naming "BlockGrad" I.blockgrad name handle1 custom input op = Symbol <$> do let handles = map getHandle input name <- naming "Custom" I.custom name handles op