module Data.Array.Accelerate.CUDA.Foreign.Export (
accelerateCreate, accelerateDestroy, freeOutput, freeProgram,
exportAfun, buildExported,
InputArray, OutputArray, ShapeBuffer, DevicePtrBuffer,
) where
import Prelude as P
import Data.Functor
import Control.Applicative
import Foreign.StablePtr
import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable ( Storable(..) )
import Foreign.Marshal.Array ( peekArray, pokeArray, mallocArray )
import Foreign.Marshal.Alloc ( free )
import Control.Monad.State ( liftIO )
import qualified Foreign.CUDA.Driver as CUDA
import Language.Haskell.TH hiding ( ppr )
import Data.Array.Accelerate.Smart ( Acc )
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.CUDA ( run1In )
import Data.Array.Accelerate.CUDA.Array.Sugar hiding ( shape, size )
import Data.Array.Accelerate.CUDA.Array.Data hiding ( pokeArray, peekArray, mallocArray )
import Data.Array.Accelerate.CUDA.State
import Data.Array.Accelerate.CUDA.Context
type AccHandle = StablePtr Context
type ShapeBuffer = Ptr CInt
type DevicePtrBuffer = Ptr WordPtr
type InputArray = (ShapeBuffer, DevicePtrBuffer)
type OutputArray = (ShapeBuffer, DevicePtrBuffer, StablePtr EArray)
type Device = Int32
type ForeignContext = Ptr ()
data Afun where
Afun :: (Arrays a, Arrays b)
=> (a -> b)
-> a
-> b
-> Afun
data EArray where
EArray :: (Shape sh, Elt e) => Array sh e -> EArray
foreign export ccall accelerateCreate :: Device -> ForeignContext -> IO AccHandle
foreign export ccall accelerateDestroy :: AccHandle -> IO ()
foreign export ccall runProgram :: AccHandle -> StablePtr Afun -> Ptr InputArray -> Ptr OutputArray -> IO ()
foreign export ccall freeOutput :: Ptr OutputArray -> IO ()
foreign export ccall freeProgram :: StablePtr a -> IO ()
instance Storable InputArray where
sizeOf (sh, ptrs) = sizeOf sh + sizeOf ptrs
alignment _ = 0
peek ptr = do
let p_sh = castPtr ptr :: Ptr ShapeBuffer
sh <- peek p_sh
let p_ptrs = (castPtr p_sh :: Ptr DevicePtrBuffer) `plusPtr` sizeOf sh
ptrs <- peek p_ptrs
return (sh, ptrs)
poke ptr (sh, ptrs) = do
let p_sh = castPtr ptr :: Ptr ShapeBuffer
p_ptrs = (castPtr p_sh :: Ptr DevicePtrBuffer) `plusPtr` sizeOf sh
poke p_sh sh
poke p_ptrs ptrs
instance Storable OutputArray where
sizeOf (sh, ptrs, sa) = sizeOf sh + sizeOf ptrs + sizeOf sa
alignment _ = 0
peek ptr = do
let p_sh = castPtr ptr :: Ptr ShapeBuffer
sh <- peek p_sh
let p_ptrs = (castPtr p_sh :: Ptr DevicePtrBuffer) `plusPtr` sizeOf sh
ptrs <- peek p_ptrs
let p_sa = (castPtr p_ptrs :: Ptr (StablePtr a)) `plusPtr` sizeOf ptrs
sa <- peek p_sa
return (sh, ptrs, sa)
poke ptr (sh, ptrs, sa) = do
let p_sh = castPtr ptr :: Ptr ShapeBuffer
p_ptrs = (castPtr p_sh :: Ptr DevicePtrBuffer) `plusPtr` sizeOf sh
p_sa = (castPtr p_ptrs :: Ptr (StablePtr a)) `plusPtr` sizeOf ptrs
poke p_sh sh
poke p_ptrs ptrs
poke p_sa sa
accelerateCreate :: Device -> ForeignContext -> IO AccHandle
accelerateCreate dev ctx = fromDeviceContext (CUDA.Device $ CInt dev) (CUDA.Context ctx) >>= newStablePtr
accelerateDestroy :: AccHandle -> IO ()
accelerateDestroy = freeStablePtr
freeOutput :: Ptr OutputArray -> IO ()
freeOutput o = do
(sh, dptrs, sa) <- peek o
free sh
free dptrs
freeStablePtr sa
freeProgram :: StablePtr a -> IO ()
freeProgram = freeStablePtr
runProgram :: AccHandle -> StablePtr Afun -> Ptr InputArray -> Ptr OutputArray -> IO ()
runProgram hndl fun input output = do
ctx <- deRefStablePtr hndl
af <- deRefStablePtr fun
run ctx af
where
run :: Context -> Afun -> IO ()
run ctx (Afun f (_ :: a) (_ :: b)) = do
_ <- evalCUDA ctx $ do
(a, _) <- marshalIn (arrays (undefined :: a)) input
let !b = f (toArr a)
marshalOut (arrays (undefined :: b)) (fromArr b) output
return ()
marshalIn :: ArraysR a -> Ptr InputArray -> CIO (a, Ptr InputArray)
marshalIn ArraysRunit ptr = return ((), ptr)
marshalIn ArraysRarray ptr = do
(sh, ptrs) <- liftIO (peek ptr)
a <- arrayFromForeignData ptrs sh
let ptr' = plusPtr ptr (sizeOf (sh, ptrs))
return (a, ptr')
marshalIn (ArraysRpair aR1 aR2) ptr = do
(x, ptr') <- marshalIn aR1 ptr
(y, ptr'') <- marshalIn aR2 ptr'
return ((x,y), ptr'')
marshalOut :: ArraysR b -> b -> Ptr OutputArray -> CIO (Ptr OutputArray)
marshalOut ArraysRunit () ptr = return ptr
marshalOut ArraysRarray a ptr = do
oarr <- mkOutput a
liftIO $ poke ptr oarr
return (plusPtr ptr (sizeOf oarr))
where
mkOutput :: forall sh e. Shape sh => Array sh e -> CIO OutputArray
mkOutput (Array sh adata) = do
let sh' = shapeToList (toElt sh :: sh)
shbuf <- liftIO $ mallocArray (P.length sh')
liftIO $ pokeArray shbuf (map fromIntegral sh')
dptrs <- devicePtrsToWordPtrs adata <$> devicePtrsOfArrayData adata
pbuf <- liftIO $ mallocArray (P.length dptrs)
liftIO $ pokeArray pbuf dptrs
sa <- liftIO $ newStablePtr (EArray a)
return (shbuf, pbuf, sa)
marshalOut (ArraysRpair aR1 aR2) (x,y) ptr = do
ptr' <- marshalOut aR1 x ptr
marshalOut aR2 y ptr'
exportAfun :: Name -> String -> Q [Dec]
exportAfun fname ename = do
(VarI n ty _ _) <- reify fname
genCompileFun n ename ty
genCompileFun :: Name -> String -> Type -> Q [Dec]
genCompileFun fname ename (AppT (AppT ArrowT (AppT _ _)) (AppT _ _))
= sequence [sig, dec, expt]
where
initName = mkName ename
body = [| \hndl -> buildExported hndl $(varE fname) |]
dec = FunD initName . (:[]) <$> cls
cls = Clause [] <$> (NormalB <$> body) <*> return []
sig = SigD initName <$> ety
expt = ForeignD <$> (ExportF cCall (nameBase initName) initName <$> ety)
ety = [t| AccHandle -> IO (StablePtr Afun) |]
genCompileFun _ _ _
= error "Invalid accelerate function"
buildExported :: forall a b. (Arrays a, Arrays b) => AccHandle -> (Acc a -> Acc b) -> IO (StablePtr Afun)
buildExported hndl f = ef
where
ef :: IO (StablePtr Afun)
ef = do
ctx <- deRefStablePtr hndl
newStablePtr (Afun (run1In ctx f) (undefined :: a) (undefined :: b))
arrayFromForeignData :: forall sh e. (Shape sh, Elt e) => DevicePtrBuffer -> ShapeBuffer -> CIO (Array sh e)
arrayFromForeignData ptrs shape = do
let d = dim (ignore :: sh)
let sz = eltSize (eltType (undefined :: e))
lst <- liftIO (peekArray d shape)
let sh = listToShape (map fromIntegral lst) :: sh
plst <- liftIO $ peekArray sz ptrs
let ptrs' = devicePtrsFromList (arrayElt :: ArrayEltR (EltRepr e)) plst
useDevicePtrs (fromElt sh) ptrs'
eltSize :: TupleType e -> Int
eltSize UnitTuple = 0
eltSize (SingleTuple _ ) = 1
eltSize (PairTuple a b) = eltSize a + eltSize b