module Foreign.CUDA.Runtime.Marshal (
AllocFlag(..),
mallocHostArray, freeHost,
mallocArray, allocaArray, free,
AttachFlag(..),
mallocManagedArray,
peekArray, peekArrayAsync, peekListArray,
pokeArray, pokeArrayAsync, pokeListArray,
copyArray, copyArrayAsync,
newListArray, newListArrayLen,
withListArray, withListArrayLen,
memset
) where
import Foreign.CUDA.Ptr
import Foreign.CUDA.Runtime.Error
import Foreign.CUDA.Runtime.Stream
import Foreign.CUDA.Internal.C2HS
import Data.Int
import Control.Exception
import Foreign.C
import Foreign.Ptr
import Foreign.Storable
import qualified Foreign.Marshal as F
data AllocFlag = DeviceMapped
| Portable
| WriteCombined
deriving (Eq,Show)
instance Enum AllocFlag where
fromEnum DeviceMapped = 2
fromEnum Portable = 1
fromEnum WriteCombined = 4
toEnum 2 = DeviceMapped
toEnum 1 = Portable
toEnum 4 = WriteCombined
toEnum unmatched = error ("AllocFlag.toEnum: Cannot match " ++ show unmatched)
mallocHostArray :: Storable a => [AllocFlag] -> Int -> IO (HostPtr a)
mallocHostArray !flags = doMalloc undefined
where
doMalloc :: Storable a' => a' -> Int -> IO (HostPtr a')
doMalloc x !n = resultIfOk =<< cudaHostAlloc (fromIntegral n * fromIntegral (sizeOf x)) flags
cudaHostAlloc :: (Int64) -> ([AllocFlag]) -> IO ((Status), (HostPtr a))
cudaHostAlloc a2 a3 =
alloca' $ \a1' ->
let {a2' = cIntConv a2} in
let {a3' = combineBitMasks a3} in
cudaHostAlloc'_ a1' a2' a3' >>= \res ->
let {res' = cToEnum res} in
hptr a1'>>= \a1'' ->
return (res', a1'')
where
alloca' !f = F.alloca $ \ !p -> poke p nullPtr >> f (castPtr p)
hptr !p = (HostPtr . castPtr) `fmap` peek p
freeHost :: HostPtr a -> IO ()
freeHost !p = nothingIfOk =<< cudaFreeHost p
cudaFreeHost :: (HostPtr a) -> IO ((Status))
cudaFreeHost a1 =
let {a1' = hptr a1} in
cudaFreeHost'_ a1' >>= \res ->
let {res' = cToEnum res} in
return (res')
where hptr = castPtr . useHostPtr
mallocArray :: Storable a => Int -> IO (DevicePtr a)
mallocArray = doMalloc undefined
where
doMalloc :: Storable a' => a' -> Int -> IO (DevicePtr a')
doMalloc x !n = resultIfOk =<< cudaMalloc (fromIntegral n * fromIntegral (sizeOf x))
cudaMalloc :: (Int64) -> IO ((Status), (DevicePtr a))
cudaMalloc a2 =
alloca' $ \a1' ->
let {a2' = cIntConv a2} in
cudaMalloc'_ a1' a2' >>= \res ->
let {res' = cToEnum res} in
dptr a1'>>= \a1'' ->
return (res', a1'')
where
alloca' !f = F.alloca $ \ !p -> poke p nullPtr >> f (castPtr p)
dptr !p = (castDevPtr . DevicePtr) `fmap` peek p
allocaArray :: Storable a => Int -> (DevicePtr a -> IO b) -> IO b
allocaArray n = bracket (mallocArray n) free
free :: DevicePtr a -> IO ()
free !p = nothingIfOk =<< cudaFree p
cudaFree :: (DevicePtr a) -> IO ((Status))
cudaFree a1 =
let {a1' = dptr a1} in
cudaFree'_ a1' >>= \res ->
let {res' = cToEnum res} in
return (res')
where
dptr = useDevicePtr . castDevPtr
data AttachFlag = Global
| Host
| Single
deriving (Eq,Show)
instance Enum AttachFlag where
fromEnum Global = 1
fromEnum Host = 2
fromEnum Single = 4
toEnum 1 = Global
toEnum 2 = Host
toEnum 4 = Single
toEnum unmatched = error ("AttachFlag.toEnum: Cannot match " ++ show unmatched)
mallocManagedArray :: Storable a => [AttachFlag] -> Int -> IO (DevicePtr a)
mallocManagedArray !flags = doMalloc undefined
where
doMalloc :: Storable a' => a' -> Int -> IO (DevicePtr a')
doMalloc x !n = resultIfOk =<< cudaMallocManaged (fromIntegral n * fromIntegral (sizeOf x)) flags
cudaMallocManaged :: (Int64) -> ([AttachFlag]) -> IO ((Status), (DevicePtr a))
cudaMallocManaged a2 a3 =
alloca' $ \a1' ->
let {a2' = cIntConv a2} in
let {a3' = combineBitMasks a3} in
cudaMallocManaged'_ a1' a2' a3' >>= \res ->
let {res' = cToEnum res} in
dptr a1'>>= \a1'' ->
return (res', a1'')
where
alloca' !f = F.alloca $ \ !p -> poke p nullPtr >> f (castPtr p)
dptr !p = (castDevPtr . DevicePtr) `fmap` peek p
peekArray :: Storable a => Int -> DevicePtr a -> Ptr a -> IO ()
peekArray !n !dptr !hptr = memcpy hptr (useDevicePtr dptr) n DeviceToHost
peekArrayAsync :: Storable a => Int -> DevicePtr a -> HostPtr a -> Maybe Stream -> IO ()
peekArrayAsync !n !dptr !hptr !mst =
memcpyAsync (useHostPtr hptr) (useDevicePtr dptr) n DeviceToHost mst
peekListArray :: Storable a => Int -> DevicePtr a -> IO [a]
peekListArray !n !dptr =
F.allocaArray n $ \p -> do
peekArray n dptr p
F.peekArray n p
pokeArray :: Storable a => Int -> Ptr a -> DevicePtr a -> IO ()
pokeArray !n !hptr !dptr = memcpy (useDevicePtr dptr) hptr n HostToDevice
pokeArrayAsync :: Storable a => Int -> HostPtr a -> DevicePtr a -> Maybe Stream -> IO ()
pokeArrayAsync !n !hptr !dptr !mst =
memcpyAsync (useDevicePtr dptr) (useHostPtr hptr) n HostToDevice mst
pokeListArray :: Storable a => [a] -> DevicePtr a -> IO ()
pokeListArray !xs !dptr = F.withArrayLen xs $ \len p -> pokeArray len p dptr
copyArray :: Storable a => Int -> DevicePtr a -> DevicePtr a -> IO ()
copyArray !n !src !dst = memcpy (useDevicePtr dst) (useDevicePtr src) n DeviceToDevice
copyArrayAsync :: Storable a => Int -> DevicePtr a -> DevicePtr a -> Maybe Stream -> IO ()
copyArrayAsync !n !src !dst !mst =
memcpyAsync (useDevicePtr dst) (useDevicePtr src) n DeviceToDevice mst
data CopyDirection = HostToHost
| HostToDevice
| DeviceToHost
| DeviceToDevice
| Default
deriving (Eq,Show)
instance Enum CopyDirection where
fromEnum HostToHost = 0
fromEnum HostToDevice = 1
fromEnum DeviceToHost = 2
fromEnum DeviceToDevice = 3
fromEnum Default = 4
toEnum 0 = HostToHost
toEnum 1 = HostToDevice
toEnum 2 = DeviceToHost
toEnum 3 = DeviceToDevice
toEnum 4 = Default
toEnum unmatched = error ("CopyDirection.toEnum: Cannot match " ++ show unmatched)
memcpy :: Storable a
=> Ptr a
-> Ptr a
-> Int
-> CopyDirection
-> IO ()
memcpy !dst !src !n !dir = doMemcpy undefined dst
where
doMemcpy :: Storable a' => a' -> Ptr a' -> IO ()
doMemcpy x _ =
nothingIfOk =<< cudaMemcpy dst src (fromIntegral n * fromIntegral (sizeOf x)) dir
cudaMemcpy :: (Ptr a) -> (Ptr a) -> (Int64) -> (CopyDirection) -> IO ((Status))
cudaMemcpy a1 a2 a3 a4 =
let {a1' = castPtr a1} in
let {a2' = castPtr a2} in
let {a3' = cIntConv a3} in
let {a4' = cFromEnum a4} in
cudaMemcpy'_ a1' a2' a3' a4' >>= \res ->
let {res' = cToEnum res} in
return (res')
memcpyAsync :: Storable a
=> Ptr a
-> Ptr a
-> Int
-> CopyDirection
-> Maybe Stream
-> IO ()
memcpyAsync !dst !src !n !kind !mst = doMemcpy undefined dst
where
doMemcpy :: Storable a' => a' -> Ptr a' -> IO ()
doMemcpy x _ =
let bytes = fromIntegral n * fromIntegral (sizeOf x) in
nothingIfOk =<< cudaMemcpyAsync dst src bytes kind (maybe defaultStream id mst)
cudaMemcpyAsync :: (Ptr a) -> (Ptr a) -> (Int64) -> (CopyDirection) -> (Stream) -> IO ((Status))
cudaMemcpyAsync a1 a2 a3 a4 a5 =
let {a1' = castPtr a1} in
let {a2' = castPtr a2} in
let {a3' = cIntConv a3} in
let {a4' = cFromEnum a4} in
let {a5' = useStream a5} in
cudaMemcpyAsync'_ a1' a2' a3' a4' a5' >>= \res ->
let {res' = cToEnum res} in
return (res')
newListArrayLen :: Storable a => [a] -> IO (DevicePtr a, Int)
newListArrayLen !xs =
F.withArrayLen xs $ \len p ->
bracketOnError (mallocArray len) free $ \d_xs -> do
pokeArray len p d_xs
return (d_xs, len)
newListArray :: Storable a => [a] -> IO (DevicePtr a)
newListArray !xs = fst `fmap` newListArrayLen xs
withListArray :: Storable a => [a] -> (DevicePtr a -> IO b) -> IO b
withListArray !xs = withListArrayLen xs . const
withListArrayLen :: Storable a => [a] -> (Int -> DevicePtr a -> IO b) -> IO b
withListArrayLen !xs !f =
bracket (newListArrayLen xs) (free . fst) (uncurry . flip $ f)
memset :: DevicePtr a
-> Int64
-> Int8
-> IO ()
memset !dptr !bytes !symbol = nothingIfOk =<< cudaMemset dptr symbol bytes
cudaMemset :: (DevicePtr a) -> (Int8) -> (Int64) -> IO ((Status))
cudaMemset a1 a2 a3 =
let {a1' = dptr a1} in
let {a2' = cIntConv a2} in
let {a3' = cIntConv a3} in
cudaMemset'_ a1' a2' a3' >>= \res ->
let {res' = cToEnum res} in
return (res')
where
dptr = useDevicePtr . castDevPtr
foreign import ccall unsafe "Foreign/CUDA/Runtime/Marshal.chs.h cudaHostAlloc"
cudaHostAlloc'_ :: ((Ptr (Ptr ())) -> (CULong -> (CUInt -> (IO CInt))))
foreign import ccall unsafe "Foreign/CUDA/Runtime/Marshal.chs.h cudaFreeHost"
cudaFreeHost'_ :: ((Ptr ()) -> (IO CInt))
foreign import ccall unsafe "Foreign/CUDA/Runtime/Marshal.chs.h cudaMalloc"
cudaMalloc'_ :: ((Ptr (Ptr ())) -> (CULong -> (IO CInt)))
foreign import ccall unsafe "Foreign/CUDA/Runtime/Marshal.chs.h cudaFree"
cudaFree'_ :: ((Ptr ()) -> (IO CInt))
foreign import ccall unsafe "Foreign/CUDA/Runtime/Marshal.chs.h cudaMallocManaged"
cudaMallocManaged'_ :: ((Ptr (Ptr ())) -> (CULong -> (CUInt -> (IO CInt))))
foreign import ccall unsafe "Foreign/CUDA/Runtime/Marshal.chs.h cudaMemcpy"
cudaMemcpy'_ :: ((Ptr ()) -> ((Ptr ()) -> (CULong -> (CInt -> (IO CInt)))))
foreign import ccall unsafe "Foreign/CUDA/Runtime/Marshal.chs.h cudaMemcpyAsync"
cudaMemcpyAsync'_ :: ((Ptr ()) -> ((Ptr ()) -> (CULong -> (CInt -> ((Ptr ()) -> (IO CInt))))))
foreign import ccall unsafe "Foreign/CUDA/Runtime/Marshal.chs.h cudaMemset"
cudaMemset'_ :: ((Ptr ()) -> (CInt -> (CULong -> (IO CInt))))