{-# OPTIONS_GHC -Wno-missing-methods #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module MXNet.Core.Base.NDArray where
import Control.Monad
import Data.Int
import Data.Monoid
import Data.Vector.Storable (Vector)
import qualified Data.Vector.Storable as V
import Foreign.Marshal.Alloc (alloca)
import Foreign.Marshal.Array (peekArray)
import Foreign.Ptr
import GHC.Exts (IsList(..))
import Text.PrettyPrint.Annotated.HughesPJClass (Pretty(..), prettyShow)
import System.IO.Unsafe (unsafePerformIO)
import MXNet.Core.Base.DType
import MXNet.Core.Base.Internal
import qualified MXNet.Core.Base.Internal.TH.NDArray as I
import MXNet.Core.Base.HMap
newtype NDArray a = NDArray { getHandle :: NDArrayHandle }
waitAll :: IO ()
waitAll = void mxNDArrayWaitAll
makeEmptyNDArray :: forall a. DType a
=> [Int]
-> Context
-> Bool
-> IO (NDArray a)
makeEmptyNDArray sh ctx delayed = do
let sh' = fromIntegral <$> sh
nlen = fromIntegral . length $ sh
dtype = typeid (undefined :: a)
(_, handle) <- mxNDArrayCreateEx sh' nlen (deviceType ctx) (deviceId ctx) (if delayed then 1 else 0) dtype
return $ NDArray handle
makeNDArray :: DType a
=> [Int]
-> Context
-> Vector a
-> IO (NDArray a)
makeNDArray sh ctx ds = do
let sh' = fromIntegral <$> sh
nlen = fromIntegral . length $ sh
(_, handle) <- mxNDArrayCreate sh' nlen (deviceType ctx) (deviceId ctx) 0
V.unsafeWith ds $ \p -> do
let len = fromIntegral (V.length ds)
void $ mxNDArraySyncCopyFromCPU handle (castPtr p) len
return $ NDArray handle
ndshape :: DType a
=> NDArray a
-> IO (Int, [Int])
ndshape arr = do
(_, nlen, sh) <- mxNDArrayGetShape (getHandle arr)
return (fromIntegral nlen, fromIntegral <$> sh)
ndsize :: DType a
=> NDArray a
-> IO Int
ndsize arr = (product . snd) <$> ndshape arr
context :: DType a => NDArray a -> IO Context
context arr = do
(_, device'type, device'id) <- mxNDArrayGetContext (getHandle arr)
return $ Context device'type device'id
copy :: DType a => NDArray a -> IO (NDArray a)
copy arr = NDArray <$> I._copy (getHandle arr)
items :: DType a => NDArray a -> IO (Vector a)
items arr = do
nlen <- ndsize arr
alloca $ \p -> do
_ <- mxNDArraySyncCopyToCPU (getHandle arr) p (fromIntegral nlen)
fromList <$> peekArray nlen (castPtr p :: Ptr a)
slice :: DType a
=> NDArray a
-> Int
-> Int
-> NDArray a
slice arr start end = NDArray . unsafePerformIO $ do
let handle = getHandle arr
(_, handle') <- mxNDArraySlice handle (fromIntegral start) (fromIntegral end)
return handle'
at :: DType a
=> NDArray a
-> Int
-> NDArray a
at arr idx = NDArray . unsafePerformIO $ do
let handle = getHandle arr
(_, handle') <- mxNDArrayAt handle (fromIntegral idx)
return handle'
waitToRead :: DType a => NDArray a -> IO ()
waitToRead arr = void $ mxNDArrayWaitToRead (getHandle arr)
onehotEncode :: DType a
=> NDArray a
-> NDArray a
-> IO (NDArray a)
onehotEncode indices out = do
let handle1 = getHandle indices
handle2 = getHandle out
NDArray <$> I._onehot_encode' handle1 handle2 [handle2]
zeros :: DType a
=> [Int]
-> IO (NDArray a)
zeros sh = full sh 0
ones :: DType a
=> [Int]
-> IO (NDArray a)
ones sh = full sh 1
full :: DType a
=> [Int]
-> a
-> IO (NDArray a)
full sh value = makeNDArray sh contextCPU $ V.replicate (product sh) value
array :: DType a
=> [Int]
-> Vector a
-> IO (NDArray a)
array sh = makeNDArray sh contextCPU
instance {-# OVERLAPPABLE #-} (DType a, Floating a) => Eq (NDArray a) where
(==) arr1 arr2 = unsafePerformIO $ do
(_, sh1) <- ndshape arr1
(_, sh2) <- ndshape arr2
if sh1 == sh2
then do
r <- (abs (arr1 - arr2) `lesser`) =<< full sh1 0.0001
V.all (== fromIntegral (1 :: Int)) <$> items r
else return False
instance (DType a, a ~ Int8) => Eq (NDArray Int8) where
(==) arr1 arr2 = unsafePerformIO $ do
let handle1 = getHandle arr1
handle2 = getHandle arr2
let cmp = V.all (== fromIntegral (1 :: Int)) :: Vector a -> Bool
(cmp <$>) . items . NDArray =<< I.broadcast_equal handle1 handle2
instance (DType a, a ~ Int32) => Eq (NDArray Int32) where
(==) arr1 arr2 = unsafePerformIO $ do
let handle1 = getHandle arr1
handle2 = getHandle arr2
let cmp = V.all (== fromIntegral (1 :: Int)) :: Vector a -> Bool
(cmp <$>) . items . NDArray =<< I.broadcast_equal handle1 handle2
data PrettyWrapper = forall a. Pretty a => MkPretty { runPretty :: a }
instance Pretty PrettyWrapper where
pPrint (MkPretty inner) = pPrint inner
instance (DType a, Pretty a) => Show (NDArray a) where
show arr = unsafePerformIO $ do
(_, dims) <- ndshape arr
values <- items arr
let info = show dims
body = prettyShow . splitItems values dims $ 0
return ("NDArray " <> info <> "\n" <> body)
where
splitItems :: Vector a -> [Int] -> Int -> PrettyWrapper
splitItems _ [] _ = error "Impossible: never match an empty list."
splitItems values [x] s = MkPretty . toList $ V.unsafeSlice s x values
splitItems values (d:ds) s = MkPretty $ (\x -> splitItems values ds (s + (product ds) * x)) <$> ([0 .. (d - 1)] :: [Int])
instance DType a => Num (NDArray a) where
(+) arr1 arr2 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_add handle1 handle2
(-) arr1 arr2 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_sub handle1 handle2
(*) arr1 arr2 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_mul handle1 handle2
abs arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.abs handle1
negate arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.negative handle1
signum = error "Unsupported operator: signum(NDArray)"
fromInteger = error "Unsupported operator: fromInteger(NDArray)"
instance DType a => Fractional (NDArray a) where
(/) arr1 arr2 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_div handle1 handle2
fromRational = error "Unsupported operator: fromRational(NDArray)"
instance DType a => Floating (NDArray a) where
exp arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.exp handle1
log arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.log handle1
sqrt arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.sqrt handle1
sin arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.sin handle1
cos arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.cos handle1
tan arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.tan handle1
sinh arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.sinh handle1
cosh arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.cosh handle1
tanh arr1 = NDArray . unsafePerformIO $ do
let handle1 = getHandle arr1
I.tanh handle1
instance Tensor NDArray where
dot arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.dot handle1 handle2 nil
reshape arr sh = NDArray <$> do
let handle = getHandle arr
(_, handle') <- mxNDArrayReshape handle (length sh) sh
return handle'
transpose arr = NDArray <$> do
let handle = getHandle arr
I.transpose handle nil
(+.) arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_add handle1 handle2
(-.) arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_sub handle1 handle2
(*.) arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_mul handle1 handle2
(/.) arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_div handle1 handle2
(^.) arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_power handle1 handle2
(.+) arr value = NDArray <$> do
let handle = getHandle arr
I._plus_scalar handle (realToFrac value)
{-# INLINE (.+) #-}
(.-) arr value = NDArray <$> do
let handle = getHandle arr
I._minus_scalar handle (realToFrac value)
{-# INLINE (.-) #-}
(.*) arr value = NDArray <$> do
let handle = getHandle arr
I._mul_scalar handle (realToFrac value)
{-# INLINE (.*) #-}
(./) arr value = NDArray <$> do
let handle = getHandle arr
I._div_scalar handle (realToFrac value)
{-# INLINE (./) #-}
(.^) arr value = NDArray <$> do
let handle = getHandle arr
I._power_scalar handle (realToFrac value)
{-# INLINE (.^) #-}
(..-) value arr = NDArray <$> do
let handle = getHandle arr
I._rminus_scalar handle (realToFrac value)
{-# INLINE (..-) #-}
(../) value arr = NDArray <$> do
let handle = getHandle arr
I._rdiv_scalar handle (realToFrac value)
{-# INLINE (../) #-}
(..^) value arr = NDArray <$> do
let handle = getHandle arr
I._rpower_scalar handle (realToFrac value)
{-# INLINE (..^) #-}
(.+=) arr value = do
let handle = getHandle arr
I._plus_scalar' handle (realToFrac value) [handle]
(.-=) arr value = do
let handle = getHandle arr
I._minus_scalar' handle (realToFrac value) [handle]
(.*=) arr value = do
let handle = getHandle arr
I._mul_scalar' handle (realToFrac value) [handle]
(./=) arr value = do
let handle = getHandle arr
I._div_scalar' handle (realToFrac value) [handle]
(.^=) arr value = do
let handle = getHandle arr
I._power_scalar' handle (realToFrac value) [handle]
_Maximum arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_maximum handle1 handle2
{-# INLINE _Maximum #-}
_Maximum' arr scalar = NDArray <$> do
let handle = getHandle arr
I._maximum_scalar handle (realToFrac scalar)
{-# INLINE _Maximum' #-}
_Minimum arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_minimum handle1 handle2
{-# INLINE _Minimum #-}
_Minimum' arr scalar = NDArray <$> do
let handle = getHandle arr
I._minimum_scalar handle (realToFrac scalar)
{-# INLINE _Minimum' #-}
equal arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_equal handle1 handle2
{-# INLINE equal #-}
equal' arr scalar = NDArray <$> do
let handle = getHandle arr
I._equal_scalar handle (realToFrac scalar)
{-# INLINE equal' #-}
notEqual arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_not_equal handle1 handle2
{-# INLINE notEqual #-}
notEqual' arr scalar = NDArray <$> do
let handle = getHandle arr
I._not_equal_scalar handle (realToFrac scalar)
{-# INLINE notEqual' #-}
greater arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_greater handle1 handle2
{-# INLINE greater #-}
greater' arr scalar = NDArray <$> do
let handle = getHandle arr
I._greater_scalar handle (realToFrac scalar)
{-# INLINE greater' #-}
greaterEqual arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_greater_equal handle1 handle2
{-# INLINE greaterEqual #-}
greaterEqual' arr scalar = NDArray <$> do
let handle = getHandle arr
I._greater_equal_scalar handle (realToFrac scalar)
{-# INLINE greaterEqual' #-}
lesser arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_lesser handle1 handle2
{-# INLINE lesser #-}
lesser' arr scalar = NDArray <$> do
let handle = getHandle arr
I._lesser_scalar handle (realToFrac scalar)
{-# INLINE lesser' #-}
lesserEqual arr1 arr2 = NDArray <$> do
let handle1 = getHandle arr1
handle2 = getHandle arr2
I.broadcast_lesser_equal handle1 handle2
{-# INLINE lesserEqual #-}
lesserEqual' arr scalar = NDArray <$> do
let handle = getHandle arr
I._lesser_equal_scalar handle (realToFrac scalar)
{-# INLINE lesserEqual' #-}
instance Neural NDArray where
fullyConnected input weight bias n = NDArray <$> do
let handle1 = getHandle input
handle2 = getHandle weight
handle3 = getHandle bias
I.fullyconnected handle1 handle2 handle3 n nil
correlation input1 input2 = NDArray <$> do
let handle1 = getHandle input1
handle2 = getHandle input2
I.correlation handle1 handle2 nil
activation input act = NDArray <$> do
let handle1 = getHandle input
I.activation handle1 act
leakyReLU input act = NDArray <$> do
let handle1 = getHandle input
I.leakyrelu handle1 (add @"act_type" act nil)
softmaxActivation input = NDArray <$> do
let handle1 = getHandle input
I.softmaxactivation handle1 nil
dropout input p = NDArray <$> do
let handle1 = getHandle input
I.dropout handle1 (add @"p" p nil)
batchNorm input gm bt mm mv = NDArray <$> do
let handle1 = getHandle input
let handle2 = getHandle gm
let handle3 = getHandle bt
let handle4 = getHandle mm
let handle5 = getHandle mv
I.batchnorm handle1 handle2 handle3 handle4 handle5 nil
instanceNorm input gamma beta eps = NDArray <$> do
let handle1 = getHandle input
handle2 = getHandle gamma
handle3 = getHandle beta
I.instancenorm handle1 handle2 handle3 (add @"eps" eps nil)
l2Normalization input eps mode = NDArray <$> do
let handle1 = getHandle input
I.l2normalization handle1 (add @"eps" eps $ add @"mode" mode nil)
convolution input weight bias kernel n = NDArray <$> do
let handle1 = getHandle input
handle2 = getHandle weight
handle3 = getHandle bias
I.convolution handle1 handle2 handle3 kernel n nil
lrn input alpha beta knorm nsize = NDArray <$> do
let handle1 = getHandle input
I.lrn handle1 nsize (add @"alpha" alpha $ add @"beta" beta $ add @"knorm" knorm nil)
deconvolution input weight bias kernel nfilter = NDArray <$> do
let handle1 = getHandle input
handle2 = getHandle weight
handle3 = getHandle bias
I.deconvolution handle1 handle2 handle3 kernel nfilter nil
pooling input kernel pooltype = NDArray <$> do
let handle1 = getHandle input
I.pooling handle1 kernel pooltype nil
softmaxOutput input label = NDArray <$> do
let handle1 = getHandle input
handle2 = getHandle label
I.softmaxoutput handle1 handle2 nil
makeLoss input grad_scale valid_thresh normalization = NDArray <$> do
let handle1 = getHandle input
I.makeloss handle1 (add @"grad_scale" grad_scale $ add @"valid_thresh" valid_thresh $ add @"normalization" normalization nil)
blockGrad input = NDArray <$> do
let handle1 = getHandle input
I.blockgrad handle1
custom input op = NDArray <$> do
let handles = map getHandle input
I.custom handles op