{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.LLVM.PTX.Array.Remote (
withRemote, malloc,
) where
import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.PTX.Target
import {-# SOURCE #-} Data.Array.Accelerate.LLVM.PTX.Execute.Event
import {-# SOURCE #-} Data.Array.Accelerate.LLVM.PTX.Execute.Stream
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Unique
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Array.Remote as Remote
import qualified Data.Array.Accelerate.LLVM.PTX.Debug as Debug
import Foreign.CUDA.Driver.Error
import qualified Foreign.CUDA.Ptr as CUDA
import qualified Foreign.CUDA.Driver as CUDA
import qualified Foreign.CUDA.Driver.Stream as CUDA
import Control.Exception
import Control.Monad.State
import Text.Printf
import GHC.Base
import GHC.Int
instance Remote.Task (Maybe Event) where
completed :: Maybe Event -> IO Bool
completed Maybe Event
Nothing = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
completed (Just Event
e) = Event -> IO Bool
query Event
e
instance Remote.RemoteMemory (LLVM PTX) where
type RemotePtr (LLVM PTX) = CUDA.DevicePtr
mallocRemote :: Int -> LLVM PTX (Maybe (RemotePtr (LLVM PTX) Word8))
mallocRemote Int
n
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Maybe (DevicePtr Word8) -> LLVM PTX (Maybe (DevicePtr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (DevicePtr Word8 -> Maybe (DevicePtr Word8)
forall a. a -> Maybe a
Just DevicePtr Word8
forall a. DevicePtr a
CUDA.nullDevPtr)
| Bool
otherwise = IO (Maybe (DevicePtr Word8)) -> LLVM PTX (Maybe (DevicePtr Word8))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (DevicePtr Word8))
-> LLVM PTX (Maybe (DevicePtr Word8)))
-> IO (Maybe (DevicePtr Word8))
-> LLVM PTX (Maybe (DevicePtr Word8))
forall a b. (a -> b) -> a -> b
$ do
Either CUDAException (DevicePtr Word8)
ep <- IO (DevicePtr Word8) -> IO (Either CUDAException (DevicePtr Word8))
forall e a. Exception e => IO a -> IO (Either e a)
try (Int -> IO (DevicePtr Word8)
forall a. Storable a => Int -> IO (DevicePtr a)
CUDA.mallocArray Int
n)
case Either CUDAException (DevicePtr Word8)
ep of
Right DevicePtr Word8
p -> do IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Int64 -> IO ()
Debug.didAllocateBytesRemote (Int -> Int64
i64 Int
n))
Maybe (DevicePtr Word8) -> IO (Maybe (DevicePtr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (DevicePtr Word8 -> Maybe (DevicePtr Word8)
forall a. a -> Maybe a
Just DevicePtr Word8
p)
Left (ExitCode Status
OutOfMemory) -> do Maybe (DevicePtr Word8) -> IO (Maybe (DevicePtr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (DevicePtr Word8)
forall a. Maybe a
Nothing
Left CUDAException
e -> do String -> IO ()
message (String
"malloc failed with error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ CUDAException -> String
forall a. Show a => a -> String
show CUDAException
e)
CUDAException -> IO (Maybe (DevicePtr Word8))
forall e a. Exception e => e -> IO a
throwIO CUDAException
e
peekRemote :: SingleType e
-> Int
-> RemotePtr (LLVM PTX) (ScalarArrayDataR e)
-> MutableArrayData e
-> LLVM PTX ()
peekRemote SingleType e
t Int
n RemotePtr (LLVM PTX) (ScalarArrayDataR e)
src MutableArrayData e
ad
| SingleArrayDict e
SingleArrayDict <- SingleType e -> SingleArrayDict e
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType e
t
, SingleDict e
SingleDict <- SingleType e -> SingleDict e
forall a. SingleType a -> SingleDict a
singleDict SingleType e
t
= let bytes :: Int
bytes = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt (ScalarType e -> TypeR e
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType e -> ScalarType e
forall a. SingleType a -> ScalarType a
SingleScalarType SingleType e
t))
dst :: HostPtr e
dst = Ptr e -> HostPtr e
forall a. Ptr a -> HostPtr a
CUDA.HostPtr (UniqueArray e -> Ptr e
forall a. UniqueArray a -> Ptr a
unsafeUniqueArrayPtr MutableArrayData e
UniqueArray e
ad)
in
(Stream -> IO ()) -> LLVM PTX ()
forall a. (Stream -> IO a) -> LLVM PTX a
blocking ((Stream -> IO ()) -> LLVM PTX ())
-> (Stream -> IO ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \Stream
stream ->
Stream -> (Stream -> IO ()) -> IO ()
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Stream
stream ((Stream -> IO ()) -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Stream
st -> do
Int64 -> IO ()
Debug.didCopyBytesFromRemote (Int -> Int64
i64 Int
bytes)
String -> Int -> Maybe Stream -> IO () -> IO ()
transfer String
"peekRemote" Int
bytes (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> DevicePtr e -> HostPtr e -> Maybe Stream -> IO ()
forall a.
Storable a =>
Int -> DevicePtr a -> HostPtr a -> Maybe Stream -> IO ()
CUDA.peekArrayAsync Int
n RemotePtr (LLVM PTX) (ScalarArrayDataR e)
DevicePtr e
src HostPtr e
dst (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st)
pokeRemote :: SingleType e
-> Int
-> RemotePtr (LLVM PTX) (ScalarArrayDataR e)
-> ArrayData e
-> LLVM PTX ()
pokeRemote SingleType e
t Int
n RemotePtr (LLVM PTX) (ScalarArrayDataR e)
dst ArrayData e
ad
| SingleArrayDict e
SingleArrayDict <- SingleType e -> SingleArrayDict e
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType e
t
, SingleDict e
SingleDict <- SingleType e -> SingleDict e
forall a. SingleType a -> SingleDict a
singleDict SingleType e
t
= let bytes :: Int
bytes = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR e -> Int
forall e. TypeR e -> Int
bytesElt (ScalarType e -> TypeR e
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType e -> ScalarType e
forall a. SingleType a -> ScalarType a
SingleScalarType SingleType e
t))
src :: HostPtr e
src = Ptr e -> HostPtr e
forall a. Ptr a -> HostPtr a
CUDA.HostPtr (UniqueArray e -> Ptr e
forall a. UniqueArray a -> Ptr a
unsafeUniqueArrayPtr ArrayData e
UniqueArray e
ad)
in
(Stream -> IO ()) -> LLVM PTX ()
forall a. (Stream -> IO a) -> LLVM PTX a
blocking ((Stream -> IO ()) -> LLVM PTX ())
-> (Stream -> IO ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \Stream
stream ->
Stream -> (Stream -> IO ()) -> IO ()
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Stream
stream ((Stream -> IO ()) -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Stream
st -> do
Int64 -> IO ()
Debug.didCopyBytesToRemote (Int -> Int64
i64 Int
bytes)
String -> Int -> Maybe Stream -> IO () -> IO ()
transfer String
"pokeRemote" Int
bytes (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> HostPtr e -> DevicePtr e -> Maybe Stream -> IO ()
forall a.
Storable a =>
Int -> HostPtr a -> DevicePtr a -> Maybe Stream -> IO ()
CUDA.pokeArrayAsync Int
n HostPtr e
src RemotePtr (LLVM PTX) (ScalarArrayDataR e)
DevicePtr e
dst (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st)
castRemotePtr :: RemotePtr (LLVM PTX) a -> RemotePtr (LLVM PTX) b
castRemotePtr = RemotePtr (LLVM PTX) a -> RemotePtr (LLVM PTX) b
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr
availableRemoteMem :: LLVM PTX Int64
availableRemoteMem = IO Int64 -> LLVM PTX Int64
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int64 -> LLVM PTX Int64) -> IO Int64 -> LLVM PTX Int64
forall a b. (a -> b) -> a -> b
$ (Int64, Int64) -> Int64
forall a b. (a, b) -> a
fst ((Int64, Int64) -> Int64) -> IO (Int64, Int64) -> IO Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (Int64, Int64)
CUDA.getMemInfo
totalRemoteMem :: LLVM PTX Int64
totalRemoteMem = IO Int64 -> LLVM PTX Int64
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int64 -> LLVM PTX Int64) -> IO Int64 -> LLVM PTX Int64
forall a b. (a -> b) -> a -> b
$ (Int64, Int64) -> Int64
forall a b. (a, b) -> b
snd ((Int64, Int64) -> Int64) -> IO (Int64, Int64) -> IO Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO (Int64, Int64)
CUDA.getMemInfo
remoteAllocationSize :: LLVM PTX Int
remoteAllocationSize = Int -> LLVM PTX Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
4096
{-# INLINEABLE malloc #-}
malloc
:: SingleType e
-> ArrayData e
-> Int
-> Bool
-> LLVM PTX Bool
malloc :: SingleType e -> ArrayData e -> Int -> Bool -> LLVM PTX Bool
malloc !SingleType e
tp !ArrayData e
ad !Int
n !Bool
frozen = do
PTX{MemoryTable
KernelTable
Reservoir
Context
ptxStreamReservoir :: PTX -> Reservoir
ptxKernelTable :: PTX -> KernelTable
ptxMemoryTable :: PTX -> MemoryTable
ptxContext :: PTX -> Context
ptxStreamReservoir :: Reservoir
ptxKernelTable :: KernelTable
ptxMemoryTable :: MemoryTable
ptxContext :: Context
..} <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
-> SingleType e -> ArrayData e -> Bool -> Int -> LLVM PTX Bool
forall e (m :: * -> *) task.
(HasCallStack, RemoteMemory m, MonadIO m, Task task) =>
MemoryTable (RemotePtr m) task
-> SingleType e -> ArrayData e -> Bool -> Int -> m Bool
Remote.malloc MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
MemoryTable
ptxMemoryTable SingleType e
tp ArrayData e
ad Bool
frozen Int
n
{-# INLINEABLE withRemote #-}
withRemote
:: SingleType e
-> ArrayData e
-> (CUDA.DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX (Maybe r)
withRemote :: SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX (Maybe r)
withRemote !SingleType e
tp !ArrayData e
ad !DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r)
f = do
PTX{MemoryTable
KernelTable
Reservoir
Context
ptxStreamReservoir :: Reservoir
ptxKernelTable :: KernelTable
ptxMemoryTable :: MemoryTable
ptxContext :: Context
ptxStreamReservoir :: PTX -> Reservoir
ptxKernelTable :: PTX -> KernelTable
ptxMemoryTable :: PTX -> MemoryTable
ptxContext :: PTX -> Context
..} <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
-> SingleType e
-> ArrayData e
-> (RemotePtr (LLVM PTX) (ScalarArrayDataR e)
-> LLVM PTX (Maybe Event, r))
-> LLVM PTX (Maybe r)
forall task (m :: * -> *) a c.
(HasCallStack, Task task, RemoteMemory m, MonadIO m, Functor m) =>
MemoryTable (RemotePtr m) task
-> SingleType a
-> ArrayData a
-> (RemotePtr m (ScalarArrayDataR a) -> m (task, c))
-> m (Maybe c)
Remote.withRemote MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
MemoryTable
ptxMemoryTable SingleType e
tp ArrayData e
ad RemotePtr (LLVM PTX) (ScalarArrayDataR e)
-> LLVM PTX (Maybe Event, r)
DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r)
f
{-# INLINE blocking #-}
blocking :: (Stream -> IO a) -> LLVM PTX a
blocking :: (Stream -> IO a) -> LLVM PTX a
blocking !Stream -> IO a
fun =
(Stream -> LLVM PTX a) -> (Event -> a -> LLVM PTX a) -> LLVM PTX a
forall a b.
(Stream -> LLVM PTX a) -> (Event -> a -> LLVM PTX b) -> LLVM PTX b
streaming (IO a -> LLVM PTX a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> LLVM PTX a) -> (Stream -> IO a) -> Stream -> LLVM PTX a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream -> IO a
fun) ((Event -> a -> LLVM PTX a) -> LLVM PTX a)
-> (Event -> a -> LLVM PTX a) -> LLVM PTX a
forall a b. (a -> b) -> a -> b
$ \Event
e a
r -> do
IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ Event -> IO ()
block Event
e
a -> LLVM PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
{-# INLINE i64 #-}
i64 :: Int -> Int64
i64 :: Int -> Int64
i64 (I# Int#
i#) = Int# -> Int64
I64# Int#
i#
{-# INLINE double #-}
double :: Int -> Double
double :: Int -> Double
double (I# Int#
i#) = Double# -> Double
D# (Int# -> Double#
int2Double# Int#
i#)
{-# INLINE showBytes #-}
showBytes :: Int -> String
showBytes :: Int -> String
showBytes Int
x = Maybe Int -> Double -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> a -> String -> String
Debug.showFFloatSIBase (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0) Double
1024 (Int -> Double
double Int
x) String
"B"
{-# INLINE trace #-}
trace :: String -> IO a -> IO a
trace :: String -> IO a -> IO a
trace String
msg IO a
next = Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_gc (String
"gc: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg) IO () -> IO a -> IO a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO a
next
{-# INLINE message #-}
message :: String -> IO ()
message :: String -> IO ()
message String
s = String
s String -> IO () -> IO ()
forall a. String -> IO a -> IO a
`trace` () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE transfer #-}
transfer :: String -> Int -> Maybe CUDA.Stream -> IO () -> IO ()
transfer :: String -> Int -> Maybe Stream -> IO () -> IO ()
transfer String
name Int
bytes Maybe Stream
stream IO ()
action
= let showRate :: Int -> Double -> String
showRate Int
x Double
t = Maybe Int -> Double -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> a -> String -> String
Debug.showFFloatSIBase (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
3) Double
1024 (Int -> Double
double Int
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
t) String
"B/s"
msg :: Double -> Double -> Double -> String
msg Double
wall Double
cpu Double
gpu = String -> String -> String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"gc: %s: %s bytes @ %s, %s"
String
name
(Int -> String
showBytes Int
bytes)
(Int -> Double -> String
showRate Int
bytes Double
wall)
(Double -> Double -> Double -> String
Debug.elapsed Double
wall Double
cpu Double
gpu)
in
Flag
-> (Double -> Double -> Double -> String)
-> Maybe Stream
-> IO ()
-> IO ()
Debug.timed Flag
Debug.dump_gc Double -> Double -> Double -> String
msg Maybe Stream
stream IO ()
action