----------------------------------------------------------- -- | -- module: MXNet.Core.NDArray -- copyright: (c) 2016 Tao He -- license: MIT -- maintainer: sighingnow@gmail.com -- -- NDArray module. -- module MXNet.Core.NDArray ( -- * Data type definitions NDArray , Context -- * Functions about NDArray , makeNDArray , getNDArrayShape -- * Default contexts , defaultContext , contextCPU , contextGPU ) where import MXNet.Core.Base -- | NDArray type alias. type NDArray = NDArrayHandle -- | Context definition. -- -- * DeviceType -- -- 1. cpu -- 2. gpu -- 3. cpu_pinned data Context = Context { deviceType :: Int , deviceId :: Int } deriving (Eq, Show) -- | Default context, use the CPU 0 as device. defaultContext :: Context defaultContext = Context { deviceType = 1 -- cpu , deviceId = 1 -- default value. } -- | Context for CPU 0. contextCPU :: Context contextCPU = Context 1 0 -- | Context for GPU 0. contextGPU :: Context contextGPU = Context 2 0 -- | Make a new NDArray with given shape. makeNDArray :: [Int] -- ^ size of every dimensions. -> IO NDArray makeNDArray shape = do let shape' = fromIntegral <$> shape nlen = fromIntegral . length $ shape (_, handle) <- mxNDArrayCreate shape' nlen (deviceType contextCPU) (deviceId contextCPU) 0 return handle -- | Get the shape of given NDArray. getNDArrayShape :: NDArray -> IO (Int, [Int]) -- ^ Dimensions and size of every dimensions. getNDArrayShape array = do (_, nlen, shape) <- mxNDArrayGetShape array return (fromIntegral nlen, fromIntegral <$> shape)