{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Data.Array.Accelerate.IO.Data.Vector.Generic.Mutable
where
import Data.Array.Accelerate.Array.Data as A
import Data.Array.Accelerate.Array.Sugar as A
import Data.Array.Accelerate.Array.Unique as A
import Data.Array.Accelerate.Lifetime
import qualified Data.Vector.Generic.Mutable as V
import Control.Monad.Primitive
import Data.Typeable
import Foreign.Ptr
import Foreign.Marshal.Utils
import Foreign.Storable
import Prelude hiding ( length )
import GHC.Base
import GHC.ForeignPtr
data MArray sh s e where
MArray :: (Shape sh, Elt e)
=> EltRepr sh
-> MutableArrayData (EltRepr e)
-> MArray sh s e
deriving instance Typeable MArray
type MVector = MArray DIM1
instance Elt e => V.MVector MVector e where
{-# INLINE basicLength #-}
{-# INLINE basicUnsafeSlice #-}
{-# INLINE basicOverlaps #-}
{-# INLINE basicUnsafeNew #-}
{-# INLINE basicInitialize #-}
{-# INLINE basicUnsafeRead #-}
{-# INLINE basicUnsafeWrite #-}
{-# INLINE basicUnsafeCopy #-}
basicLength (MArray ((), n) _) = n
basicUnsafeSlice j m (MArray _ mad) = MArray ((),m) (go arrayElt mad 1)
where
go :: ArrayEltR a -> MutableArrayData a -> Int -> MutableArrayData a
go ArrayEltRunit AD_Unit !_ = AD_Unit
go ArrayEltRint (AD_Int v) !s = AD_Int (slice v s)
go ArrayEltRint8 (AD_Int8 v) !s = AD_Int8 (slice v s)
go ArrayEltRint16 (AD_Int16 v) !s = AD_Int16 (slice v s)
go ArrayEltRint32 (AD_Int32 v) !s = AD_Int32 (slice v s)
go ArrayEltRint64 (AD_Int64 v) !s = AD_Int64 (slice v s)
go ArrayEltRword (AD_Word v) !s = AD_Word (slice v s)
go ArrayEltRword8 (AD_Word8 v) !s = AD_Word8 (slice v s)
go ArrayEltRword16 (AD_Word16 v) !s = AD_Word16 (slice v s)
go ArrayEltRword32 (AD_Word32 v) !s = AD_Word32 (slice v s)
go ArrayEltRword64 (AD_Word64 v) !s = AD_Word64 (slice v s)
go ArrayEltRcshort (AD_CShort v) !s = AD_CShort (slice v s)
go ArrayEltRcushort (AD_CUShort v) !s = AD_CUShort (slice v s)
go ArrayEltRcint (AD_CInt v) !s = AD_CInt (slice v s)
go ArrayEltRcuint (AD_CUInt v) !s = AD_CUInt (slice v s)
go ArrayEltRclong (AD_CLong v) !s = AD_CLong (slice v s)
go ArrayEltRculong (AD_CULong v) !s = AD_CULong (slice v s)
go ArrayEltRcllong (AD_CLLong v) !s = AD_CLLong (slice v s)
go ArrayEltRcullong (AD_CULLong v) !s = AD_CULLong (slice v s)
go ArrayEltRhalf (AD_Half v) !s = AD_Half (slice v s)
go ArrayEltRfloat (AD_Float v) !s = AD_Float (slice v s)
go ArrayEltRdouble (AD_Double v) !s = AD_Double (slice v s)
go ArrayEltRcfloat (AD_CFloat v) !s = AD_CFloat (slice v s)
go ArrayEltRcdouble (AD_CDouble v) !s = AD_CDouble (slice v s)
go ArrayEltRbool (AD_Bool v) !s = AD_Bool (slice v s)
go ArrayEltRchar (AD_Char v) !s = AD_Char (slice v s)
go ArrayEltRcchar (AD_CChar v) !s = AD_CChar (slice v s)
go ArrayEltRcschar (AD_CSChar v) !s = AD_CSChar (slice v s)
go ArrayEltRcuchar (AD_CUChar v) !s = AD_CUChar (slice v s)
go (ArrayEltRvec2 ae) (AD_V2 v) !s = AD_V2 (go ae v (s*2))
go (ArrayEltRvec3 ae) (AD_V3 v) !s = AD_V3 (go ae v (s*3))
go (ArrayEltRvec4 ae) (AD_V4 v) !s = AD_V4 (go ae v (s*4))
go (ArrayEltRvec8 ae) (AD_V8 v) !s = AD_V8 (go ae v (s*8))
go (ArrayEltRvec16 ae) (AD_V16 v) !s = AD_V16 (go ae v (s*16))
go (ArrayEltRpair ae1 ae2) (AD_Pair v1 v2) !s = AD_Pair (go ae1 v1 s) (go ae2 v2 s)
slice :: forall a. Storable a => UniqueArray a -> Int -> UniqueArray a
slice (UniqueArray uid (Lifetime lft w fp)) s =
UniqueArray uid (Lifetime lft w (plusForeignPtr fp (j * s * sizeOf (undefined::a))))
basicOverlaps (MArray ((),m) mad1) (MArray ((),n) mad2) = go arrayElt mad1 mad2 1
where
go :: ArrayEltR a -> MutableArrayData a -> MutableArrayData a -> Int -> Bool
go ArrayEltRunit AD_Unit AD_Unit !_ = False
go ArrayEltRint (AD_Int v1) (AD_Int v2) !s = overlaps v1 v2 s
go ArrayEltRint8 (AD_Int8 v1) (AD_Int8 v2) !s = overlaps v1 v2 s
go ArrayEltRint16 (AD_Int16 v1) (AD_Int16 v2) !s = overlaps v1 v2 s
go ArrayEltRint32 (AD_Int32 v1) (AD_Int32 v2) !s = overlaps v1 v2 s
go ArrayEltRint64 (AD_Int64 v1) (AD_Int64 v2) !s = overlaps v1 v2 s
go ArrayEltRword (AD_Word v1) (AD_Word v2) !s = overlaps v1 v2 s
go ArrayEltRword8 (AD_Word8 v1) (AD_Word8 v2) !s = overlaps v1 v2 s
go ArrayEltRword16 (AD_Word16 v1) (AD_Word16 v2) !s = overlaps v1 v2 s
go ArrayEltRword32 (AD_Word32 v1) (AD_Word32 v2) !s = overlaps v1 v2 s
go ArrayEltRword64 (AD_Word64 v1) (AD_Word64 v2) !s = overlaps v1 v2 s
go ArrayEltRcshort (AD_CShort v1) (AD_CShort v2) !s = overlaps v1 v2 s
go ArrayEltRcushort (AD_CUShort v1) (AD_CUShort v2) !s = overlaps v1 v2 s
go ArrayEltRcint (AD_CInt v1) (AD_CInt v2) !s = overlaps v1 v2 s
go ArrayEltRcuint (AD_CUInt v1) (AD_CUInt v2) !s = overlaps v1 v2 s
go ArrayEltRclong (AD_CLong v1) (AD_CLong v2) !s = overlaps v1 v2 s
go ArrayEltRculong (AD_CULong v1) (AD_CULong v2) !s = overlaps v1 v2 s
go ArrayEltRcllong (AD_CLLong v1) (AD_CLLong v2) !s = overlaps v1 v2 s
go ArrayEltRcullong (AD_CULLong v1) (AD_CULLong v2) !s = overlaps v1 v2 s
go ArrayEltRhalf (AD_Half v1) (AD_Half v2) !s = overlaps v1 v2 s
go ArrayEltRfloat (AD_Float v1) (AD_Float v2) !s = overlaps v1 v2 s
go ArrayEltRdouble (AD_Double v1) (AD_Double v2) !s = overlaps v1 v2 s
go ArrayEltRcfloat (AD_CFloat v1) (AD_CFloat v2) !s = overlaps v1 v2 s
go ArrayEltRcdouble (AD_CDouble v1) (AD_CDouble v2) !s = overlaps v1 v2 s
go ArrayEltRbool (AD_Bool v1) (AD_Bool v2) !s = overlaps v1 v2 s
go ArrayEltRchar (AD_Char v1) (AD_Char v2) !s = overlaps v1 v2 s
go ArrayEltRcchar (AD_CChar v1) (AD_CChar v2) !s = overlaps v1 v2 s
go ArrayEltRcschar (AD_CSChar v1) (AD_CSChar v2) !s = overlaps v1 v2 s
go ArrayEltRcuchar (AD_CUChar v1) (AD_CUChar v2) !s = overlaps v1 v2 s
go (ArrayEltRvec2 ae) (AD_V2 v1) (AD_V2 v2) !s = go ae v1 v2 (s*2)
go (ArrayEltRvec3 ae) (AD_V3 v1) (AD_V3 v3) !s = go ae v1 v3 (s*3)
go (ArrayEltRvec4 ae) (AD_V4 v1) (AD_V4 v4) !s = go ae v1 v4 (s*4)
go (ArrayEltRvec8 ae) (AD_V8 v1) (AD_V8 v8) !s = go ae v1 v8 (s*8)
go (ArrayEltRvec16 ae) (AD_V16 v1) (AD_V16 v16) !s = go ae v1 v16 (s*16)
go (ArrayEltRpair ae1 ae2) (AD_Pair v11 v12) (AD_Pair v21 v22) !s = go ae1 v11 v21 s || go ae2 v12 v22 s
overlaps :: forall a. Storable a => UniqueArray a -> UniqueArray a -> Int -> Bool
overlaps (UniqueArray _ (Lifetime _ _ (ForeignPtr addr1# c1))) (UniqueArray _ (Lifetime _ _ (ForeignPtr addr2# c2))) s =
let i = I# (addr2Int# addr1#)
j = I# (addr2Int# addr2#)
k = s * sizeOf (undefined::a)
in
same c1 c2 && (between i j (j + n*k) || between j i (i + m*k))
same :: ForeignPtrContents -> ForeignPtrContents -> Bool
same (PlainPtr mba1#) (PlainPtr mba2#) = isTrue# (sameMutableByteArray# mba1# mba2#)
same (MallocPtr mba1# _) (MallocPtr mba2# _) = isTrue# (sameMutableByteArray# mba1# mba2#)
same _ _ = False
between :: Int -> Int -> Int -> Bool
between x y z = x >= y && x < z
basicUnsafeNew n = unsafePrimToPrim $ MArray ((),n) <$> newArrayData n
basicInitialize (MArray ((),n) mad) = unsafePrimToPrim $ go (arrayElt :: ArrayEltR (EltRepr e)) (ptrsOfArrayData mad) 1
where
go :: ArrayEltR a -> ArrayPtrs a -> Int -> IO ()
go ArrayEltRunit () !_ = return ()
go ArrayEltRint p !s = initialise p s
go ArrayEltRint8 p !s = initialise p s
go ArrayEltRint16 p !s = initialise p s
go ArrayEltRint32 p !s = initialise p s
go ArrayEltRint64 p !s = initialise p s
go ArrayEltRword p !s = initialise p s
go ArrayEltRword8 p !s = initialise p s
go ArrayEltRword16 p !s = initialise p s
go ArrayEltRword32 p !s = initialise p s
go ArrayEltRword64 p !s = initialise p s
go ArrayEltRcshort p !s = initialise p s
go ArrayEltRcushort p !s = initialise p s
go ArrayEltRcint p !s = initialise p s
go ArrayEltRcuint p !s = initialise p s
go ArrayEltRclong p !s = initialise p s
go ArrayEltRculong p !s = initialise p s
go ArrayEltRcllong p !s = initialise p s
go ArrayEltRcullong p !s = initialise p s
go ArrayEltRhalf p !s = initialise p s
go ArrayEltRfloat p !s = initialise p s
go ArrayEltRdouble p !s = initialise p s
go ArrayEltRcfloat p !s = initialise p s
go ArrayEltRcdouble p !s = initialise p s
go ArrayEltRbool p !s = initialise p s
go ArrayEltRchar p !s = initialise p s
go ArrayEltRcchar p !s = initialise p s
go ArrayEltRcschar p !s = initialise p s
go ArrayEltRcuchar p !s = initialise p s
go (ArrayEltRvec2 ae) p !s = go ae p (s*2)
go (ArrayEltRvec3 ae) p !s = go ae p (s*3)
go (ArrayEltRvec4 ae) p !s = go ae p (s*4)
go (ArrayEltRvec8 ae) p !s = go ae p (s*8)
go (ArrayEltRvec16 ae) p !s = go ae p (s*16)
go (ArrayEltRpair ae1 ae2) (p1,p2) s = go ae1 p1 s >> go ae2 p2 s
initialise :: forall a. Storable a => Ptr a -> Int -> IO ()
initialise p s = fillBytes p 0 (n * s * sizeOf (undefined::a))
basicUnsafeRead (MArray _ mad) i = unsafePrimToPrim $ toElt <$> unsafeReadArrayData mad i
basicUnsafeWrite (MArray _ mad) i v = unsafePrimToPrim $ unsafeWriteArrayData mad i (fromElt v)
basicUnsafeCopy (MArray _ dst) (MArray ((),n) src) = unsafePrimToPrim $ go (arrayElt :: ArrayEltR (EltRepr e)) (ptrsOfArrayData dst) (ptrsOfArrayData src) 1
where
go :: ArrayEltR a -> ArrayPtrs a -> ArrayPtrs a -> Int -> IO ()
go ArrayEltRunit () () !_ = return ()
go ArrayEltRint u v !s = copy u v s
go ArrayEltRint8 u v !s = copy u v s
go ArrayEltRint16 u v !s = copy u v s
go ArrayEltRint32 u v !s = copy u v s
go ArrayEltRint64 u v !s = copy u v s
go ArrayEltRword u v !s = copy u v s
go ArrayEltRword8 u v !s = copy u v s
go ArrayEltRword16 u v !s = copy u v s
go ArrayEltRword32 u v !s = copy u v s
go ArrayEltRword64 u v !s = copy u v s
go ArrayEltRcshort u v !s = copy u v s
go ArrayEltRcushort u v !s = copy u v s
go ArrayEltRcint u v !s = copy u v s
go ArrayEltRcuint u v !s = copy u v s
go ArrayEltRclong u v !s = copy u v s
go ArrayEltRculong u v !s = copy u v s
go ArrayEltRcllong u v !s = copy u v s
go ArrayEltRcullong u v !s = copy u v s
go ArrayEltRhalf u v !s = copy u v s
go ArrayEltRfloat u v !s = copy u v s
go ArrayEltRdouble u v !s = copy u v s
go ArrayEltRcfloat u v !s = copy u v s
go ArrayEltRcdouble u v !s = copy u v s
go ArrayEltRbool u v !s = copy u v s
go ArrayEltRchar u v !s = copy u v s
go ArrayEltRcchar u v !s = copy u v s
go ArrayEltRcschar u v !s = copy u v s
go ArrayEltRcuchar u v !s = copy u v s
go (ArrayEltRvec2 ae) u v !s = go ae u v (s*2)
go (ArrayEltRvec3 ae) u v !s = go ae u v (s*3)
go (ArrayEltRvec4 ae) u v !s = go ae u v (s*4)
go (ArrayEltRvec8 ae) u v !s = go ae u v (s*8)
go (ArrayEltRvec16 ae) u v !s = go ae u v (s*16)
go (ArrayEltRpair ae1 ae2) (u1,u2) (v1,v2) s = go ae1 u1 v1 s >> go ae2 u2 v2 s
copy :: forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copy u v s = copyBytes u v (n * s * sizeOf (undefined::a))
#if !MIN_VERSION_base(4,10,0)
plusForeignPtr :: ForeignPtr a -> Int -> ForeignPtr b
plusForeignPtr (ForeignPtr addr# c) (I# i#) = ForeignPtr (plusAddr# addr# i#) c
#endif