{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Array.Remote
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Array.Remote (

  withRemote, malloc,

) where

import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.PTX.Target
import {-# SOURCE #-} Data.Array.Accelerate.LLVM.PTX.Execute.Event
import {-# SOURCE #-} Data.Array.Accelerate.LLVM.PTX.Execute.Stream

import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Unique
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Array.Remote                     as Remote
import qualified Data.Array.Accelerate.LLVM.PTX.Debug                   as Debug

import Foreign.CUDA.Driver.Error
import qualified Foreign.CUDA.Ptr                                       as CUDA
import qualified Foreign.CUDA.Driver                                    as CUDA
import qualified Foreign.CUDA.Driver.Stream                             as CUDA

import Control.Exception
import Control.Monad.State
import Text.Printf

import GHC.Base
import GHC.Int


-- Events signal once a computation has completed
--
instance Remote.Task (Maybe Event) where
  completed :: Maybe Event -> IO Bool
completed Maybe Event
Nothing  = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
  completed (Just Event
e) = Event -> IO Bool
query Event
e

instance Remote.RemoteMemory (LLVM PTX) where
  type RemotePtr (LLVM PTX) = CUDA.DevicePtr
  --
  mallocRemote :: Int -> LLVM PTX (Maybe (RemotePtr (LLVM PTX) Word8))
mallocRemote Int
n
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0    = Maybe (DevicePtr Word8) -> LLVM PTX (Maybe (DevicePtr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (DevicePtr Word8 -> Maybe (DevicePtr Word8)
forall a. a -> Maybe a
Just DevicePtr Word8
forall a. DevicePtr a
CUDA.nullDevPtr)
    | Bool
otherwise = IO (Maybe (DevicePtr Word8)) -> LLVM PTX (Maybe (DevicePtr Word8))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (DevicePtr Word8))
 -> LLVM PTX (Maybe (DevicePtr Word8)))
-> IO (Maybe (DevicePtr Word8))
-> LLVM PTX (Maybe (DevicePtr Word8))
forall a b. (a -> b) -> a -> b
$ do
        Either CUDAException (DevicePtr Word8)
ep <- IO (DevicePtr Word8) -> IO (Either CUDAException (DevicePtr Word8))
forall e a. Exception e => IO a -> IO (Either e a)
try (Int -> IO (DevicePtr Word8)
forall a. Storable a => Int -> IO (DevicePtr a)
CUDA.mallocArray Int
n)
        case Either CUDAException (DevicePtr Word8)
ep of
          Right DevicePtr Word8
p                     -> do IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Int64 -> IO ()
Debug.didAllocateBytesRemote (Int -> Int64
i64 Int
n))
                                            Maybe (DevicePtr Word8) -> IO (Maybe (DevicePtr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (DevicePtr Word8 -> Maybe (DevicePtr Word8)
forall a. a -> Maybe a
Just DevicePtr Word8
p)
          Left (ExitCode Status
OutOfMemory) -> do Maybe (DevicePtr Word8) -> IO (Maybe (DevicePtr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (DevicePtr Word8)
forall a. Maybe a
Nothing
          Left CUDAException
e                      -> do String -> IO ()
message (String
"malloc failed with error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ CUDAException -> String
forall a. Show a => a -> String
show CUDAException
e)
                                            CUDAException -> IO (Maybe (DevicePtr Word8))
forall e a. Exception e => e -> IO a
throwIO CUDAException
e

  peekRemote :: SingleType e
-> Int
-> RemotePtr (LLVM PTX) (ScalarArrayDataR e)
-> MutableArrayData e
-> LLVM PTX ()
peekRemote SingleType e
t Int
n RemotePtr (LLVM PTX) (ScalarArrayDataR e)
src MutableArrayData e
ad
    | SingleArrayDict e
SingleArrayDict <- SingleType e -> SingleArrayDict e
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType e
t
    , SingleDict e
SingleDict      <- SingleType e -> SingleDict e
forall a. SingleType a -> SingleDict a
singleDict SingleType e
t
    = let bytes :: Int
bytes = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt (ScalarType e -> TypeR e
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType e -> ScalarType e
forall a. SingleType a -> ScalarType a
SingleScalarType SingleType e
t))
          dst :: HostPtr e
dst   = Ptr e -> HostPtr e
forall a. Ptr a -> HostPtr a
CUDA.HostPtr (UniqueArray e -> Ptr e
forall a. UniqueArray a -> Ptr a
unsafeUniqueArrayPtr MutableArrayData e
UniqueArray e
ad)
      in
      (Stream -> IO ()) -> LLVM PTX ()
forall a. (Stream -> IO a) -> LLVM PTX a
blocking            ((Stream -> IO ()) -> LLVM PTX ())
-> (Stream -> IO ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \Stream
stream ->
      Stream -> (Stream -> IO ()) -> IO ()
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Stream
stream ((Stream -> IO ()) -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Stream
st     -> do
        Int64 -> IO ()
Debug.didCopyBytesFromRemote (Int -> Int64
i64 Int
bytes)
        String -> Int -> Maybe Stream -> IO () -> IO ()
transfer String
"peekRemote" Int
bytes (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> DevicePtr e -> HostPtr e -> Maybe Stream -> IO ()
forall a.
Storable a =>
Int -> DevicePtr a -> HostPtr a -> Maybe Stream -> IO ()
CUDA.peekArrayAsync Int
n RemotePtr (LLVM PTX) (ScalarArrayDataR e)
DevicePtr e
src HostPtr e
dst (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st)

  pokeRemote :: SingleType e
-> Int
-> RemotePtr (LLVM PTX) (ScalarArrayDataR e)
-> ArrayData e
-> LLVM PTX ()
pokeRemote SingleType e
t Int
n RemotePtr (LLVM PTX) (ScalarArrayDataR e)
dst ArrayData e
ad
    | SingleArrayDict e
SingleArrayDict <- SingleType e -> SingleArrayDict e
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType e
t
    , SingleDict e
SingleDict      <- SingleType e -> SingleDict e
forall a. SingleType a -> SingleDict a
singleDict SingleType e
t
    = let bytes :: Int
bytes = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt (ScalarType e -> TypeR e
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType e -> ScalarType e
forall a. SingleType a -> ScalarType a
SingleScalarType SingleType e
t))
          src :: HostPtr e
src   = Ptr e -> HostPtr e
forall a. Ptr a -> HostPtr a
CUDA.HostPtr (UniqueArray e -> Ptr e
forall a. UniqueArray a -> Ptr a
unsafeUniqueArrayPtr ArrayData e
UniqueArray e
ad)
      in
      (Stream -> IO ()) -> LLVM PTX ()
forall a. (Stream -> IO a) -> LLVM PTX a
blocking            ((Stream -> IO ()) -> LLVM PTX ())
-> (Stream -> IO ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \Stream
stream ->
      Stream -> (Stream -> IO ()) -> IO ()
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Stream
stream ((Stream -> IO ()) -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Stream
st     -> do
        Int64 -> IO ()
Debug.didCopyBytesToRemote (Int -> Int64
i64 Int
bytes)
        String -> Int -> Maybe Stream -> IO () -> IO ()
transfer String
"pokeRemote" Int
bytes (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> HostPtr e -> DevicePtr e -> Maybe Stream -> IO ()
forall a.
Storable a =>
Int -> HostPtr a -> DevicePtr a -> Maybe Stream -> IO ()
CUDA.pokeArrayAsync Int
n HostPtr e
src RemotePtr (LLVM PTX) (ScalarArrayDataR e)
DevicePtr e
dst (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st)

  castRemotePtr :: RemotePtr (LLVM PTX) a -> RemotePtr (LLVM PTX) b
castRemotePtr        = RemotePtr (LLVM PTX) a -> RemotePtr (LLVM PTX) b
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr
  availableRemoteMem :: LLVM PTX Int64
availableRemoteMem   = IO Int64 -> LLVM PTX Int64
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int64 -> LLVM PTX Int64) -> IO Int64 -> LLVM PTX Int64
forall a b. (a -> b) -> a -> b
$ (Int64, Int64) -> Int64
forall a b. (a, b) -> a
fst ((Int64, Int64) -> Int64) -> IO (Int64, Int64) -> IO Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (Int64, Int64)
CUDA.getMemInfo
  totalRemoteMem :: LLVM PTX Int64
totalRemoteMem       = IO Int64 -> LLVM PTX Int64
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int64 -> LLVM PTX Int64) -> IO Int64 -> LLVM PTX Int64
forall a b. (a -> b) -> a -> b
$ (Int64, Int64) -> Int64
forall a b. (a, b) -> b
snd ((Int64, Int64) -> Int64) -> IO (Int64, Int64) -> IO Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (Int64, Int64)
CUDA.getMemInfo
  remoteAllocationSize :: LLVM PTX Int
remoteAllocationSize = Int -> LLVM PTX Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
4096



-- | Allocate an array in the remote memory space sufficient to hold the given
-- number of elements, and associated with the given host side array. Space will
-- be freed from the remote device if necessary.
--
{-# INLINEABLE malloc #-}
malloc
    :: SingleType e
    -> ArrayData e
    -> Int
    -> Bool
    -> LLVM PTX Bool
malloc :: SingleType e -> ArrayData e -> Int -> Bool -> LLVM PTX Bool
malloc !SingleType e
tp !ArrayData e
ad !Int
n !Bool
frozen = do
  PTX{MemoryTable
KernelTable
Reservoir
Context
ptxStreamReservoir :: PTX -> Reservoir
ptxKernelTable :: PTX -> KernelTable
ptxMemoryTable :: PTX -> MemoryTable
ptxContext :: PTX -> Context
ptxStreamReservoir :: Reservoir
ptxKernelTable :: KernelTable
ptxMemoryTable :: MemoryTable
ptxContext :: Context
..} <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
  MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
-> SingleType e -> ArrayData e -> Bool -> Int -> LLVM PTX Bool
forall e (m :: * -> *) task.
(HasCallStack, RemoteMemory m, MonadIO m, Task task) =>
MemoryTable (RemotePtr m) task
-> SingleType e -> ArrayData e -> Bool -> Int -> m Bool
Remote.malloc MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
MemoryTable
ptxMemoryTable SingleType e
tp ArrayData e
ad Bool
frozen Int
n


-- | Lookup up the remote array pointer for the given host-side array
--
{-# INLINEABLE withRemote #-}
withRemote
    :: SingleType e
    -> ArrayData e
    -> (CUDA.DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
    -> LLVM PTX (Maybe r)
withRemote :: SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX (Maybe r)
withRemote !SingleType e
tp !ArrayData e
ad !DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r)
f = do
  PTX{MemoryTable
KernelTable
Reservoir
Context
ptxStreamReservoir :: Reservoir
ptxKernelTable :: KernelTable
ptxMemoryTable :: MemoryTable
ptxContext :: Context
ptxStreamReservoir :: PTX -> Reservoir
ptxKernelTable :: PTX -> KernelTable
ptxMemoryTable :: PTX -> MemoryTable
ptxContext :: PTX -> Context
..} <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
  MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
-> SingleType e
-> ArrayData e
-> (RemotePtr (LLVM PTX) (ScalarArrayDataR e)
    -> LLVM PTX (Maybe Event, r))
-> LLVM PTX (Maybe r)
forall task (m :: * -> *) a c.
(HasCallStack, Task task, RemoteMemory m, MonadIO m, Functor m) =>
MemoryTable (RemotePtr m) task
-> SingleType a
-> ArrayData a
-> (RemotePtr m (ScalarArrayDataR a) -> m (task, c))
-> m (Maybe c)
Remote.withRemote MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
MemoryTable
ptxMemoryTable SingleType e
tp ArrayData e
ad RemotePtr (LLVM PTX) (ScalarArrayDataR e)
-> LLVM PTX (Maybe Event, r)
DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r)
f


-- Auxiliary
-- ---------

-- | Execute the given operation in a new stream, and wait for the operation to
-- complete before returning.
--
{-# INLINE blocking #-}
blocking :: (Stream -> IO a) -> LLVM PTX a
blocking :: (Stream -> IO a) -> LLVM PTX a
blocking !Stream -> IO a
fun =
  (Stream -> LLVM PTX a) -> (Event -> a -> LLVM PTX a) -> LLVM PTX a
forall a b.
(Stream -> LLVM PTX a) -> (Event -> a -> LLVM PTX b) -> LLVM PTX b
streaming (IO a -> LLVM PTX a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> LLVM PTX a) -> (Stream -> IO a) -> Stream -> LLVM PTX a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream -> IO a
fun) ((Event -> a -> LLVM PTX a) -> LLVM PTX a)
-> (Event -> a -> LLVM PTX a) -> LLVM PTX a
forall a b. (a -> b) -> a -> b
$ \Event
e a
r -> do
    IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ Event -> IO ()
block Event
e
    a -> LLVM PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

{-# INLINE i64 #-}
i64 :: Int -> Int64
i64 :: Int -> Int64
i64 (I# Int#
i#) = Int# -> Int64
I64# Int#
i#

{-# INLINE double #-}
double :: Int -> Double
double :: Int -> Double
double (I# Int#
i#) = Double# -> Double
D# (Int# -> Double#
int2Double# Int#
i#)


-- Debugging
-- ---------

{-# INLINE showBytes #-}
showBytes :: Int -> String
showBytes :: Int -> String
showBytes Int
x = Maybe Int -> Double -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> a -> String -> String
Debug.showFFloatSIBase (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0) Double
1024 (Int -> Double
double Int
x) String
"B"

{-# INLINE trace #-}
trace :: String -> IO a -> IO a
trace :: String -> IO a -> IO a
trace String
msg IO a
next = Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_gc (String
"gc: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg) IO () -> IO a -> IO a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO a
next

{-# INLINE message #-}
message :: String -> IO ()
message :: String -> IO ()
message String
s = String
s String -> IO () -> IO ()
forall a. String -> IO a -> IO a
`trace` () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

{-# INLINE transfer #-}
transfer :: String -> Int -> Maybe CUDA.Stream -> IO () -> IO ()
transfer :: String -> Int -> Maybe Stream -> IO () -> IO ()
transfer String
name Int
bytes Maybe Stream
stream IO ()
action
  = let showRate :: Int -> Double -> String
showRate Int
x Double
t      = Maybe Int -> Double -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> a -> String -> String
Debug.showFFloatSIBase (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
3) Double
1024 (Int -> Double
double Int
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
t) String
"B/s"
        msg :: Double -> Double -> Double -> String
msg Double
wall Double
cpu Double
gpu  = String -> String -> String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"gc: %s: %s bytes @ %s, %s"
                              String
name
                              (Int -> String
showBytes Int
bytes)
                              (Int -> Double -> String
showRate Int
bytes Double
wall)
                              (Double -> Double -> Double -> String
Debug.elapsed Double
wall Double
cpu Double
gpu)
    in
    Flag
-> (Double -> Double -> Double -> String)
-> Maybe Stream
-> IO ()
-> IO ()
Debug.timed Flag
Debug.dump_gc Double -> Double -> Double -> String
msg Maybe Stream
stream IO ()
action