{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.LLVM.PTX.Array.Data (
module Data.Array.Accelerate.LLVM.Array.Data,
module Data.Array.Accelerate.LLVM.PTX.Array.Data,
) where
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Array.Unique ( UniqueArray(..) )
import Data.Array.Accelerate.Lifetime ( Lifetime(..) )
import qualified Data.Array.Accelerate.Array.Representation as R
import Data.Array.Accelerate.LLVM.Array.Data
import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.PTX.State
import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.LLVM.PTX.Execute.Async
import qualified Data.Array.Accelerate.LLVM.PTX.Array.Prim as Prim
import Control.Applicative
import Control.Monad.State ( liftIO, gets )
import Data.Typeable
import Foreign.Ptr
import Foreign.Storable
import System.IO.Unsafe
import Prelude
instance Remote PTX where
{-# INLINEABLE allocateRemote #-}
allocateRemote !sh = do
let !n = size sh
arr <- liftIO $ allocateArray sh
runArray arr (\m ad -> Prim.mallocArray (n*m) ad >> return ad)
{-# INLINEABLE useRemoteR #-}
useRemoteR !n !mst !ad = do
case mst of
Nothing -> Prim.useArray n ad
Just st -> Prim.useArrayAsync st n ad
{-# INLINEABLE copyToRemoteR #-}
copyToRemoteR !from !n !mst !ad = do
case mst of
Nothing -> Prim.pokeArrayR from n ad
Just st -> Prim.pokeArrayAsyncR st from n ad
{-# INLINEABLE copyToHostR #-}
copyToHostR !from !n !mst !ad = do
case mst of
Nothing -> Prim.peekArrayR from n ad
Just st -> Prim.peekArrayAsyncR st from n ad
{-# INLINEABLE copyToPeerR #-}
copyToPeerR !from !n !dst !mst !ad = do
case mst of
Nothing -> Prim.copyArrayPeerR (ptxContext dst) (ptxMemoryTable dst) from n ad
Just st -> Prim.copyArrayPeerAsyncR (ptxContext dst) (ptxMemoryTable dst) st from n ad
{-# INLINEABLE indexRemote #-}
indexRemote arr i =
runIndexArray Prim.indexArray arr i
copyToHostLazy
:: Arrays arrs
=> arrs
-> LLVM PTX arrs
copyToHostLazy arrs = do
ptx <- gets llvmTarget
liftIO $ runArrays arrs $ \(Array sh adata) ->
let
peekR :: (ArrayElt e, ArrayPtrs e ~ Ptr a, Storable a, Typeable a, Typeable e)
=> ArrayData e
-> UniqueArray a
-> Int
-> IO (UniqueArray a)
peekR ad (UniqueArray uid (Lifetime ref weak fp)) n = do
fp' <- unsafeInterleaveIO $
evalPTX ptx $ do
s <- fork
copyToHostR 0 n (Just s) ad
e <- checkpoint s
block e
join s
return fp
return $ UniqueArray uid (Lifetime ref weak fp')
runR :: ArrayEltR e -> ArrayData e -> Int -> IO (ArrayData e)
runR ArrayEltRunit AD_Unit _ = return AD_Unit
runR (ArrayEltRpair aeR2 aeR1) (AD_Pair ad2 ad1) n = AD_Pair <$> runR aeR2 ad2 n <*> runR aeR1 ad1 n
runR (ArrayEltRvec2 aeR) (AD_V2 ad) n = AD_V2 <$> runR aeR ad (n*2)
runR (ArrayEltRvec3 aeR) (AD_V3 ad) n = AD_V3 <$> runR aeR ad (n*3)
runR (ArrayEltRvec4 aeR) (AD_V4 ad) n = AD_V4 <$> runR aeR ad (n*4)
runR (ArrayEltRvec8 aeR) (AD_V8 ad) n = AD_V8 <$> runR aeR ad (n*8)
runR (ArrayEltRvec16 aeR) (AD_V16 ad) n = AD_V16 <$> runR aeR ad (n*16)
runR ArrayEltRint ad@(AD_Int ua) n = AD_Int <$> peekR ad ua n
runR ArrayEltRint8 ad@(AD_Int8 ua) n = AD_Int8 <$> peekR ad ua n
runR ArrayEltRint16 ad@(AD_Int16 ua) n = AD_Int16 <$> peekR ad ua n
runR ArrayEltRint32 ad@(AD_Int32 ua) n = AD_Int32 <$> peekR ad ua n
runR ArrayEltRint64 ad@(AD_Int64 ua) n = AD_Int64 <$> peekR ad ua n
runR ArrayEltRword ad@(AD_Word ua) n = AD_Word <$> peekR ad ua n
runR ArrayEltRword8 ad@(AD_Word8 ua) n = AD_Word8 <$> peekR ad ua n
runR ArrayEltRword16 ad@(AD_Word16 ua) n = AD_Word16 <$> peekR ad ua n
runR ArrayEltRword32 ad@(AD_Word32 ua) n = AD_Word32 <$> peekR ad ua n
runR ArrayEltRword64 ad@(AD_Word64 ua) n = AD_Word64 <$> peekR ad ua n
runR ArrayEltRcshort ad@(AD_CShort ua) n = AD_CShort <$> peekR ad ua n
runR ArrayEltRcushort ad@(AD_CUShort ua) n = AD_CUShort <$> peekR ad ua n
runR ArrayEltRcint ad@(AD_CInt ua) n = AD_CInt <$> peekR ad ua n
runR ArrayEltRcuint ad@(AD_CUInt ua) n = AD_CUInt <$> peekR ad ua n
runR ArrayEltRclong ad@(AD_CLong ua) n = AD_CLong <$> peekR ad ua n
runR ArrayEltRculong ad@(AD_CULong ua) n = AD_CULong <$> peekR ad ua n
runR ArrayEltRcllong ad@(AD_CLLong ua) n = AD_CLLong <$> peekR ad ua n
runR ArrayEltRcullong ad@(AD_CULLong ua) n = AD_CULLong <$> peekR ad ua n
runR ArrayEltRhalf ad@(AD_Half ua) n = AD_Half <$> peekR ad ua n
runR ArrayEltRfloat ad@(AD_Float ua) n = AD_Float <$> peekR ad ua n
runR ArrayEltRdouble ad@(AD_Double ua) n = AD_Double <$> peekR ad ua n
runR ArrayEltRcfloat ad@(AD_CFloat ua) n = AD_CFloat <$> peekR ad ua n
runR ArrayEltRcdouble ad@(AD_CDouble ua) n = AD_CDouble <$> peekR ad ua n
runR ArrayEltRbool ad@(AD_Bool ua) n = AD_Bool <$> peekR ad ua n
runR ArrayEltRchar ad@(AD_Char ua) n = AD_Char <$> peekR ad ua n
runR ArrayEltRcchar ad@(AD_CChar ua) n = AD_CChar <$> peekR ad ua n
runR ArrayEltRcschar ad@(AD_CSChar ua) n = AD_CSChar <$> peekR ad ua n
runR ArrayEltRcuchar ad@(AD_CUChar ua) n = AD_CUChar <$> peekR ad ua n
in
Array sh <$> runR arrayElt adata (R.size sh)
cloneArrayAsync
:: (Shape sh, Elt e)
=> Stream
-> Array sh e
-> LLVM PTX (Array sh e)
cloneArrayAsync stream arr@(Array _ src) = do
out@(Array _ dst) <- allocateRemote sh
copyR arrayElt src dst (size sh)
return out
where
sh = shape arr
copyR :: ArrayEltR e -> ArrayData e -> ArrayData e -> Int -> LLVM PTX ()
copyR ArrayEltRunit _ _ _ = return ()
copyR (ArrayEltRpair aeR1 aeR2) ad1 ad2 n = copyR aeR1 (fstArrayData ad1) (fstArrayData ad2) n >>
copyR aeR2 (sndArrayData ad1) (sndArrayData ad2) n
copyR (ArrayEltRvec2 aeR) (AD_V2 ad1) (AD_V2 ad2) n = copyR aeR ad1 ad2 (n*2)
copyR (ArrayEltRvec3 aeR) (AD_V3 ad1) (AD_V3 ad2) n = copyR aeR ad1 ad2 (n*3)
copyR (ArrayEltRvec4 aeR) (AD_V4 ad1) (AD_V4 ad2) n = copyR aeR ad1 ad2 (n*4)
copyR (ArrayEltRvec8 aeR) (AD_V8 ad1) (AD_V8 ad2) n = copyR aeR ad1 ad2 (n*8)
copyR (ArrayEltRvec16 aeR) (AD_V16 ad1) (AD_V16 ad2) n = copyR aeR ad1 ad2 (n*16)
copyR ArrayEltRint ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRint8 ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRint16 ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRint32 ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRint64 ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRword ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRword8 ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRword16 ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRword32 ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRword64 ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRhalf ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRfloat ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRdouble ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRbool ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRchar ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcshort ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcushort ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcint ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcuint ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRclong ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRculong ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcllong ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcullong ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcfloat ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcdouble ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcchar ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcschar ad1 ad2 n = copyPrim ad1 ad2 n
copyR ArrayEltRcuchar ad1 ad2 n = copyPrim ad1 ad2 n
copyPrim
:: (ArrayElt e, ArrayPtrs e ~ Ptr a, Typeable e, Storable a, Typeable a)
=> ArrayData e
-> ArrayData e
-> Int
-> LLVM PTX ()
copyPrim !a1 !a2 !m = Prim.copyArrayAsync stream m a1 a2