module MXNet.Core.NDArray (
NDArray
, Context
, makeNDArray
, getNDArrayShape
, defaultContext
, contextCPU
, contextGPU
) where
import MXNet.Core.Base
type NDArray = NDArrayHandle
data Context = Context { deviceType :: Int
, deviceId :: Int
} deriving (Eq, Show)
defaultContext :: Context
defaultContext = Context { deviceType = 1
, deviceId = 1
}
contextCPU :: Context
contextCPU = Context 1 0
contextGPU :: Context
contextGPU = Context 2 0
makeNDArray :: [Int]
-> IO NDArray
makeNDArray shape = do
let shape' = fromIntegral <$> shape
nlen = fromIntegral . length $ shape
(_, handle) <- mxNDArrayCreate shape' nlen (deviceType contextCPU) (deviceId contextCPU) 0
return handle
getNDArrayShape :: NDArray
-> IO (Int, [Int])
getNDArrayShape array = do
(_, nlen, shape) <- mxNDArrayGetShape array
return (fromIntegral nlen, fromIntegral <$> shape)