{-# LANGUAGE BangPatterns #-}
--------------------------------------------------------------------------------
-- |
-- Module    : Foreign.CUDA.Ptr
-- Copyright : [2009..2023] Trevor L. McDonell
-- License   : BSD
--
-- Data pointers on the host and device. These can be shared freely between the
-- CUDA runtime and Driver APIs.
--
--------------------------------------------------------------------------------

module Foreign.CUDA.Ptr (

  -- * Device pointers
  DevicePtr(..),
  withDevicePtr,
  devPtrToWordPtr,
  wordPtrToDevPtr,
  nullDevPtr,
  castDevPtr,
  plusDevPtr,
  alignDevPtr,
  minusDevPtr,
  advanceDevPtr,

  -- * Host pointers
  HostPtr(..),
  withHostPtr,
  nullHostPtr,
  castHostPtr,
  plusHostPtr,
  alignHostPtr,
  minusHostPtr,
  advanceHostPtr,

) where

import Foreign.Ptr
import Foreign.Storable


--------------------------------------------------------------------------------
-- Device Pointer
--------------------------------------------------------------------------------

-- |
-- A reference to data stored on the device.
--
newtype DevicePtr a = DevicePtr { forall a. DevicePtr a -> Ptr a
useDevicePtr :: Ptr a }
  deriving (DevicePtr a -> DevicePtr a -> Bool
(DevicePtr a -> DevicePtr a -> Bool)
-> (DevicePtr a -> DevicePtr a -> Bool) -> Eq (DevicePtr a)
forall a. DevicePtr a -> DevicePtr a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. DevicePtr a -> DevicePtr a -> Bool
== :: DevicePtr a -> DevicePtr a -> Bool
$c/= :: forall a. DevicePtr a -> DevicePtr a -> Bool
/= :: DevicePtr a -> DevicePtr a -> Bool
Eq,Eq (DevicePtr a)
Eq (DevicePtr a) =>
(DevicePtr a -> DevicePtr a -> Ordering)
-> (DevicePtr a -> DevicePtr a -> Bool)
-> (DevicePtr a -> DevicePtr a -> Bool)
-> (DevicePtr a -> DevicePtr a -> Bool)
-> (DevicePtr a -> DevicePtr a -> Bool)
-> (DevicePtr a -> DevicePtr a -> DevicePtr a)
-> (DevicePtr a -> DevicePtr a -> DevicePtr a)
-> Ord (DevicePtr a)
DevicePtr a -> DevicePtr a -> Bool
DevicePtr a -> DevicePtr a -> Ordering
DevicePtr a -> DevicePtr a -> DevicePtr a
forall a. Eq (DevicePtr a)
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. DevicePtr a -> DevicePtr a -> Bool
forall a. DevicePtr a -> DevicePtr a -> Ordering
forall a. DevicePtr a -> DevicePtr a -> DevicePtr a
$ccompare :: forall a. DevicePtr a -> DevicePtr a -> Ordering
compare :: DevicePtr a -> DevicePtr a -> Ordering
$c< :: forall a. DevicePtr a -> DevicePtr a -> Bool
< :: DevicePtr a -> DevicePtr a -> Bool
$c<= :: forall a. DevicePtr a -> DevicePtr a -> Bool
<= :: DevicePtr a -> DevicePtr a -> Bool
$c> :: forall a. DevicePtr a -> DevicePtr a -> Bool
> :: DevicePtr a -> DevicePtr a -> Bool
$c>= :: forall a. DevicePtr a -> DevicePtr a -> Bool
>= :: DevicePtr a -> DevicePtr a -> Bool
$cmax :: forall a. DevicePtr a -> DevicePtr a -> DevicePtr a
max :: DevicePtr a -> DevicePtr a -> DevicePtr a
$cmin :: forall a. DevicePtr a -> DevicePtr a -> DevicePtr a
min :: DevicePtr a -> DevicePtr a -> DevicePtr a
Ord)

instance Show (DevicePtr a) where
  showsPrec :: Int -> DevicePtr a -> ShowS
showsPrec Int
n (DevicePtr Ptr a
p) = Int -> Ptr a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
n Ptr a
p

instance Storable (DevicePtr a) where
  sizeOf :: DevicePtr a -> Int
sizeOf DevicePtr a
_    = Ptr Any -> Int
forall a. Storable a => a -> Int
sizeOf    (Ptr a
forall {a}. Ptr a
forall a. HasCallStack => a
undefined :: Ptr a)
  alignment :: DevicePtr a -> Int
alignment DevicePtr a
_ = Ptr Any -> Int
forall a. Storable a => a -> Int
alignment (Ptr a
forall {a}. Ptr a
forall a. HasCallStack => a
undefined :: Ptr a)
  peek :: Ptr (DevicePtr a) -> IO (DevicePtr a)
peek Ptr (DevicePtr a)
p      = Ptr a -> DevicePtr a
forall a. Ptr a -> DevicePtr a
DevicePtr (Ptr a -> DevicePtr a) -> IO (Ptr a) -> IO (DevicePtr a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr (Ptr a) -> IO (Ptr a)
forall a. Storable a => Ptr a -> IO a
peek (Ptr (DevicePtr a) -> Ptr (Ptr a)
forall a b. Ptr a -> Ptr b
castPtr Ptr (DevicePtr a)
p)
  poke :: Ptr (DevicePtr a) -> DevicePtr a -> IO ()
poke Ptr (DevicePtr a)
p DevicePtr a
v    = Ptr (Ptr a) -> Ptr a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr (DevicePtr a) -> Ptr (Ptr a)
forall a b. Ptr a -> Ptr b
castPtr Ptr (DevicePtr a)
p) (DevicePtr a -> Ptr a
forall a. DevicePtr a -> Ptr a
useDevicePtr DevicePtr a
v)


-- |
-- Look at the contents of device memory. This takes an IO action that will be
-- applied to that pointer, the result of which is returned. It would be silly
-- to return the pointer from the action.
--
{-# INLINEABLE withDevicePtr #-}
withDevicePtr :: DevicePtr a -> (Ptr a -> IO b) -> IO b
withDevicePtr :: forall a b. DevicePtr a -> (Ptr a -> IO b) -> IO b
withDevicePtr !DevicePtr a
p !Ptr a -> IO b
f = Ptr a -> IO b
f (DevicePtr a -> Ptr a
forall a. DevicePtr a -> Ptr a
useDevicePtr DevicePtr a
p)

-- |
-- Return a unique handle associated with the given device pointer
--
{-# INLINEABLE devPtrToWordPtr #-}
devPtrToWordPtr :: DevicePtr a -> WordPtr
devPtrToWordPtr :: forall a. DevicePtr a -> WordPtr
devPtrToWordPtr = Ptr a -> WordPtr
forall a. Ptr a -> WordPtr
ptrToWordPtr (Ptr a -> WordPtr)
-> (DevicePtr a -> Ptr a) -> DevicePtr a -> WordPtr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DevicePtr a -> Ptr a
forall a. DevicePtr a -> Ptr a
useDevicePtr

-- |
-- Return a device pointer from the given handle
--
{-# INLINEABLE wordPtrToDevPtr #-}
wordPtrToDevPtr :: WordPtr -> DevicePtr a
wordPtrToDevPtr :: forall a. WordPtr -> DevicePtr a
wordPtrToDevPtr = Ptr a -> DevicePtr a
forall a. Ptr a -> DevicePtr a
DevicePtr (Ptr a -> DevicePtr a)
-> (WordPtr -> Ptr a) -> WordPtr -> DevicePtr a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordPtr -> Ptr a
forall a. WordPtr -> Ptr a
wordPtrToPtr

-- |
-- The constant 'nullDevPtr' contains the distinguished memory location that is
-- not associated with a valid memory location
--
{-# INLINEABLE nullDevPtr #-}
nullDevPtr :: DevicePtr a
nullDevPtr :: forall a. DevicePtr a
nullDevPtr =  Ptr a -> DevicePtr a
forall a. Ptr a -> DevicePtr a
DevicePtr Ptr a
forall {a}. Ptr a
nullPtr

-- |
-- Cast a device pointer from one type to another
--
{-# INLINEABLE castDevPtr #-}
castDevPtr :: DevicePtr a -> DevicePtr b
castDevPtr :: forall a b. DevicePtr a -> DevicePtr b
castDevPtr (DevicePtr !Ptr a
p) = Ptr b -> DevicePtr b
forall a. Ptr a -> DevicePtr a
DevicePtr (Ptr a -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr Ptr a
p)

-- |
-- Advance the pointer address by the given offset in bytes.
--
{-# INLINEABLE plusDevPtr #-}
plusDevPtr :: DevicePtr a -> Int -> DevicePtr a
plusDevPtr :: forall a. DevicePtr a -> Int -> DevicePtr a
plusDevPtr (DevicePtr !Ptr a
p) !Int
d = Ptr a -> DevicePtr a
forall a. Ptr a -> DevicePtr a
DevicePtr (Ptr a
p Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
d)

-- |
-- Given an alignment constraint, align the device pointer to the next highest
-- address satisfying the constraint
--
{-# INLINEABLE alignDevPtr #-}
alignDevPtr :: DevicePtr a -> Int -> DevicePtr a
alignDevPtr :: forall a. DevicePtr a -> Int -> DevicePtr a
alignDevPtr (DevicePtr !Ptr a
p) !Int
i = Ptr a -> DevicePtr a
forall a. Ptr a -> DevicePtr a
DevicePtr (Ptr a
p Ptr a -> Int -> Ptr a
forall a. Ptr a -> Int -> Ptr a
`alignPtr` Int
i)

-- |
-- Compute the difference between the second and first argument. This fulfils
-- the relation
--
-- > p2 == p1 `plusDevPtr` (p2 `minusDevPtr` p1)
--
{-# INLINEABLE minusDevPtr #-}
minusDevPtr :: DevicePtr a -> DevicePtr a -> Int
minusDevPtr :: forall a. DevicePtr a -> DevicePtr a -> Int
minusDevPtr (DevicePtr !Ptr a
a) (DevicePtr !Ptr a
b) = Ptr a
a Ptr a -> Ptr a -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr a
b

-- |
-- Advance a pointer into a device array by the given number of elements
--
{-# INLINEABLE advanceDevPtr #-}
advanceDevPtr :: Storable a => DevicePtr a -> Int -> DevicePtr a
advanceDevPtr :: forall a. Storable a => DevicePtr a -> Int -> DevicePtr a
advanceDevPtr = a -> DevicePtr a -> Int -> DevicePtr a
forall a'. Storable a' => a' -> DevicePtr a' -> Int -> DevicePtr a'
doAdvance a
forall a. HasCallStack => a
undefined
  where
    doAdvance :: Storable a' => a' -> DevicePtr a' -> Int -> DevicePtr a'
    doAdvance :: forall a'. Storable a' => a' -> DevicePtr a' -> Int -> DevicePtr a'
doAdvance a'
x !DevicePtr a'
p !Int
i = DevicePtr a'
p DevicePtr a' -> Int -> DevicePtr a'
forall a. DevicePtr a -> Int -> DevicePtr a
`plusDevPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* a' -> Int
forall a. Storable a => a -> Int
sizeOf a'
x)


--------------------------------------------------------------------------------
-- Host Pointer
--------------------------------------------------------------------------------

-- |
-- A reference to page-locked host memory.
--
-- A 'HostPtr' is just a plain 'Ptr', but the memory has been allocated by CUDA
-- into page locked memory. This means that the data can be copied to the GPU
-- via DMA (direct memory access). Note that the use of the system function
-- `mlock` is not sufficient here --- the CUDA version ensures that the
-- /physical/ address stays this same, not just the virtual address.
--
-- To copy data into a 'HostPtr' array, you may use for example 'withHostPtr'
-- together with 'Foreign.Marshal.Array.copyArray' or
-- 'Foreign.Marshal.Array.moveArray'.
--
newtype HostPtr a = HostPtr { forall a. HostPtr a -> Ptr a
useHostPtr :: Ptr a }
  deriving (HostPtr a -> HostPtr a -> Bool
(HostPtr a -> HostPtr a -> Bool)
-> (HostPtr a -> HostPtr a -> Bool) -> Eq (HostPtr a)
forall a. HostPtr a -> HostPtr a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. HostPtr a -> HostPtr a -> Bool
== :: HostPtr a -> HostPtr a -> Bool
$c/= :: forall a. HostPtr a -> HostPtr a -> Bool
/= :: HostPtr a -> HostPtr a -> Bool
Eq,Eq (HostPtr a)
Eq (HostPtr a) =>
(HostPtr a -> HostPtr a -> Ordering)
-> (HostPtr a -> HostPtr a -> Bool)
-> (HostPtr a -> HostPtr a -> Bool)
-> (HostPtr a -> HostPtr a -> Bool)
-> (HostPtr a -> HostPtr a -> Bool)
-> (HostPtr a -> HostPtr a -> HostPtr a)
-> (HostPtr a -> HostPtr a -> HostPtr a)
-> Ord (HostPtr a)
HostPtr a -> HostPtr a -> Bool
HostPtr a -> HostPtr a -> Ordering
HostPtr a -> HostPtr a -> HostPtr a
forall a. Eq (HostPtr a)
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. HostPtr a -> HostPtr a -> Bool
forall a. HostPtr a -> HostPtr a -> Ordering
forall a. HostPtr a -> HostPtr a -> HostPtr a
$ccompare :: forall a. HostPtr a -> HostPtr a -> Ordering
compare :: HostPtr a -> HostPtr a -> Ordering
$c< :: forall a. HostPtr a -> HostPtr a -> Bool
< :: HostPtr a -> HostPtr a -> Bool
$c<= :: forall a. HostPtr a -> HostPtr a -> Bool
<= :: HostPtr a -> HostPtr a -> Bool
$c> :: forall a. HostPtr a -> HostPtr a -> Bool
> :: HostPtr a -> HostPtr a -> Bool
$c>= :: forall a. HostPtr a -> HostPtr a -> Bool
>= :: HostPtr a -> HostPtr a -> Bool
$cmax :: forall a. HostPtr a -> HostPtr a -> HostPtr a
max :: HostPtr a -> HostPtr a -> HostPtr a
$cmin :: forall a. HostPtr a -> HostPtr a -> HostPtr a
min :: HostPtr a -> HostPtr a -> HostPtr a
Ord)

instance Show (HostPtr a) where
  showsPrec :: Int -> HostPtr a -> ShowS
showsPrec Int
n (HostPtr Ptr a
p) = Int -> Ptr a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
n Ptr a
p

instance Storable (HostPtr a) where
  sizeOf :: HostPtr a -> Int
sizeOf HostPtr a
_    = Ptr Any -> Int
forall a. Storable a => a -> Int
sizeOf    (Ptr a
forall {a}. Ptr a
forall a. HasCallStack => a
undefined :: Ptr a)
  alignment :: HostPtr a -> Int
alignment HostPtr a
_ = Ptr Any -> Int
forall a. Storable a => a -> Int
alignment (Ptr a
forall {a}. Ptr a
forall a. HasCallStack => a
undefined :: Ptr a)
  peek :: Ptr (HostPtr a) -> IO (HostPtr a)
peek Ptr (HostPtr a)
p      = Ptr a -> HostPtr a
forall a. Ptr a -> HostPtr a
HostPtr (Ptr a -> HostPtr a) -> IO (Ptr a) -> IO (HostPtr a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr (Ptr a) -> IO (Ptr a)
forall a. Storable a => Ptr a -> IO a
peek (Ptr (HostPtr a) -> Ptr (Ptr a)
forall a b. Ptr a -> Ptr b
castPtr Ptr (HostPtr a)
p)
  poke :: Ptr (HostPtr a) -> HostPtr a -> IO ()
poke Ptr (HostPtr a)
p HostPtr a
v    = Ptr (Ptr a) -> Ptr a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr (HostPtr a) -> Ptr (Ptr a)
forall a b. Ptr a -> Ptr b
castPtr Ptr (HostPtr a)
p) (HostPtr a -> Ptr a
forall a. HostPtr a -> Ptr a
useHostPtr HostPtr a
v)


-- |
-- Apply an IO action to the memory reference living inside the host pointer
-- object. All uses of the pointer should be inside the 'withHostPtr' bracket.
--
{-# INLINEABLE withHostPtr #-}
withHostPtr :: HostPtr a -> (Ptr a -> IO b) -> IO b
withHostPtr :: forall a b. HostPtr a -> (Ptr a -> IO b) -> IO b
withHostPtr !HostPtr a
p !Ptr a -> IO b
f = Ptr a -> IO b
f (HostPtr a -> Ptr a
forall a. HostPtr a -> Ptr a
useHostPtr HostPtr a
p)


-- |
-- The constant 'nullHostPtr' contains the distinguished memory location that is
-- not associated with a valid memory location
--
{-# INLINEABLE nullHostPtr #-}
nullHostPtr :: HostPtr a
nullHostPtr :: forall a. HostPtr a
nullHostPtr =  Ptr a -> HostPtr a
forall a. Ptr a -> HostPtr a
HostPtr Ptr a
forall {a}. Ptr a
nullPtr

-- |
-- Cast a host pointer from one type to another
--
{-# INLINEABLE castHostPtr #-}
castHostPtr :: HostPtr a -> HostPtr b
castHostPtr :: forall a b. HostPtr a -> HostPtr b
castHostPtr (HostPtr !Ptr a
p) = Ptr b -> HostPtr b
forall a. Ptr a -> HostPtr a
HostPtr (Ptr a -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr Ptr a
p)

-- |
-- Advance the pointer address by the given offset in bytes
--
{-# INLINEABLE plusHostPtr #-}
plusHostPtr :: HostPtr a -> Int -> HostPtr a
plusHostPtr :: forall a. HostPtr a -> Int -> HostPtr a
plusHostPtr (HostPtr !Ptr a
p) !Int
d = Ptr a -> HostPtr a
forall a. Ptr a -> HostPtr a
HostPtr (Ptr a
p Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
d)

-- |
-- Given an alignment constraint, align the host pointer to the next highest
-- address satisfying the constraint
--
{-# INLINEABLE alignHostPtr #-}
alignHostPtr :: HostPtr a -> Int -> HostPtr a
alignHostPtr :: forall a. HostPtr a -> Int -> HostPtr a
alignHostPtr (HostPtr !Ptr a
p) !Int
i = Ptr a -> HostPtr a
forall a. Ptr a -> HostPtr a
HostPtr (Ptr a
p Ptr a -> Int -> Ptr a
forall a. Ptr a -> Int -> Ptr a
`alignPtr` Int
i)

-- |
-- Compute the difference between the second and first argument
--
{-# INLINEABLE minusHostPtr #-}
minusHostPtr :: HostPtr a -> HostPtr a -> Int
minusHostPtr :: forall a. HostPtr a -> HostPtr a -> Int
minusHostPtr (HostPtr !Ptr a
a) (HostPtr !Ptr a
b) = Ptr a
a Ptr a -> Ptr a -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr a
b

-- |
-- Advance a pointer into a host array by a given number of elements
--
{-# INLINEABLE advanceHostPtr #-}
advanceHostPtr :: Storable a => HostPtr a -> Int -> HostPtr a
advanceHostPtr :: forall a. Storable a => HostPtr a -> Int -> HostPtr a
advanceHostPtr = a -> HostPtr a -> Int -> HostPtr a
forall a'. Storable a' => a' -> HostPtr a' -> Int -> HostPtr a'
doAdvance a
forall a. HasCallStack => a
undefined
  where
    doAdvance :: Storable a' => a' -> HostPtr a' -> Int -> HostPtr a'
    doAdvance :: forall a'. Storable a' => a' -> HostPtr a' -> Int -> HostPtr a'
doAdvance a'
x !HostPtr a'
p !Int
i = HostPtr a'
p HostPtr a' -> Int -> HostPtr a'
forall a. HostPtr a -> Int -> HostPtr a
`plusHostPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* a' -> Int
forall a. Storable a => a -> Int
sizeOf a'
x)