{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
module Data.Array.Accelerate.LLVM.PTX.Execute.Stream (
Reservoir, new,
Stream, create, destroy, streaming,
) where
import Data.Array.Accelerate.Lifetime
import qualified Data.Array.Accelerate.Array.Remote.LRU as Remote
import Data.Array.Accelerate.LLVM.PTX.Array.Remote ( )
import Data.Array.Accelerate.LLVM.PTX.Execute.Event ( Event )
import Data.Array.Accelerate.LLVM.PTX.Target ( PTX(..) )
import Data.Array.Accelerate.LLVM.State
import qualified Data.Array.Accelerate.LLVM.PTX.Debug as Debug
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event as Event
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream.Reservoir as RSV
import Foreign.CUDA.Driver.Error
import qualified Foreign.CUDA.Driver.Stream as Stream
import Control.Exception
import Control.Monad.State
type Stream = Lifetime Stream.Stream
{-# INLINEABLE streaming #-}
streaming
:: (Stream -> LLVM PTX a)
-> (Event -> a -> LLVM PTX b)
-> LLVM PTX b
streaming :: (Stream -> LLVM PTX a) -> (Event -> a -> LLVM PTX b) -> LLVM PTX b
streaming !Stream -> LLVM PTX a
action !Event -> a -> LLVM PTX b
after = do
Stream
stream <- LLVM PTX Stream
create
a
first <- Stream -> LLVM PTX a
action Stream
stream
Event
end <- Stream -> LLVM PTX Event
Event.waypoint Stream
stream
b
final <- Event -> a -> LLVM PTX b
after Event
end a
first
IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ do
Stream -> IO ()
destroy Stream
stream
Event -> IO ()
Event.destroy Event
end
b -> LLVM PTX b
forall (m :: * -> *) a. Monad m => a -> m a
return b
final
{-# INLINEABLE create #-}
create :: LLVM PTX Stream
create :: LLVM PTX Stream
create = do
PTX{MemoryTable
KernelTable
Reservoir
Context
ptxStreamReservoir :: PTX -> Reservoir
ptxKernelTable :: PTX -> KernelTable
ptxMemoryTable :: PTX -> MemoryTable
ptxContext :: PTX -> Context
ptxStreamReservoir :: Reservoir
ptxKernelTable :: KernelTable
ptxMemoryTable :: MemoryTable
ptxContext :: Context
..} <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
Stream
s <- LLVM PTX Stream
create'
Stream
stream <- IO Stream -> LLVM PTX Stream
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Stream -> LLVM PTX Stream) -> IO Stream -> LLVM PTX Stream
forall a b. (a -> b) -> a -> b
$ Stream -> IO Stream
forall a. a -> IO (Lifetime a)
newLifetime Stream
s
IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ Stream -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Stream
stream (Reservoir -> Stream -> IO ()
RSV.insert Reservoir
ptxStreamReservoir Stream
s)
Stream -> LLVM PTX Stream
forall (m :: * -> *) a. Monad m => a -> m a
return Stream
stream
create' :: LLVM PTX Stream.Stream
create' :: LLVM PTX Stream
create' = do
PTX{MemoryTable
KernelTable
Reservoir
Context
ptxStreamReservoir :: Reservoir
ptxKernelTable :: KernelTable
ptxMemoryTable :: MemoryTable
ptxContext :: Context
ptxStreamReservoir :: PTX -> Reservoir
ptxKernelTable :: PTX -> KernelTable
ptxMemoryTable :: PTX -> MemoryTable
ptxContext :: PTX -> Context
..} <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
Maybe Stream
ms <- String -> LLVM PTX (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a.
MonadIO m =>
String -> m (Maybe a) -> m (Maybe a)
attempt String
"create/reservoir" (IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Stream) -> LLVM PTX (Maybe Stream))
-> IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall a b. (a -> b) -> a -> b
$ Reservoir -> IO (Maybe Stream)
RSV.malloc Reservoir
ptxStreamReservoir)
LLVM PTX (Maybe Stream)
-> LLVM PTX (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a.
MonadIO m =>
m (Maybe a) -> m (Maybe a) -> m (Maybe a)
`orElse`
String -> LLVM PTX (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a.
MonadIO m =>
String -> m (Maybe a) -> m (Maybe a)
attempt String
"create/new" (IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Stream) -> LLVM PTX (Maybe Stream))
-> (IO Stream -> IO (Maybe Stream))
-> IO Stream
-> LLVM PTX (Maybe Stream)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Stream -> IO (Maybe Stream)
forall a. IO a -> IO (Maybe a)
catchOOM (IO Stream -> LLVM PTX (Maybe Stream))
-> IO Stream -> LLVM PTX (Maybe Stream)
forall a b. (a -> b) -> a -> b
$ [StreamFlag] -> IO Stream
Stream.create [])
LLVM PTX (Maybe Stream)
-> LLVM PTX (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a.
MonadIO m =>
m (Maybe a) -> m (Maybe a) -> m (Maybe a)
`orElse` do
MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event) -> LLVM PTX ()
forall (m :: * -> *) task.
(HasCallStack, RemoteMemory m, MonadIO m) =>
MemoryTable (RemotePtr m) task -> m ()
Remote.reclaim MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
MemoryTable
ptxMemoryTable
IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Stream) -> LLVM PTX (Maybe Stream))
-> IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall a b. (a -> b) -> a -> b
$ do
String -> IO ()
message String
"create/new: failed (purging)"
IO Stream -> IO (Maybe Stream)
forall a. IO a -> IO (Maybe a)
catchOOM (IO Stream -> IO (Maybe Stream)) -> IO Stream -> IO (Maybe Stream)
forall a b. (a -> b) -> a -> b
$ [StreamFlag] -> IO Stream
Stream.create []
case Maybe Stream
ms of
Just Stream
s -> Stream -> LLVM PTX Stream
forall (m :: * -> *) a. Monad m => a -> m a
return Stream
s
Maybe Stream
Nothing -> IO Stream -> LLVM PTX Stream
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Stream -> LLVM PTX Stream) -> IO Stream -> LLVM PTX Stream
forall a b. (a -> b) -> a -> b
$ do
String -> IO ()
message String
"create/new: failed (non-recoverable)"
CUDAException -> IO Stream
forall e a. Exception e => e -> IO a
throwIO (Status -> CUDAException
ExitCode Status
OutOfMemory)
where
catchOOM :: IO a -> IO (Maybe a)
catchOOM :: IO a -> IO (Maybe a)
catchOOM IO a
it =
(a -> Maybe a) -> IO a -> IO (Maybe a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM a -> Maybe a
forall a. a -> Maybe a
Just IO a
it IO (Maybe a) -> (CUDAException -> IO (Maybe a)) -> IO (Maybe a)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \CUDAException
e -> case CUDAException
e of
ExitCode Status
OutOfMemory -> Maybe a -> IO (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
CUDAException
_ -> CUDAException -> IO (Maybe a)
forall e a. Exception e => e -> IO a
throwIO CUDAException
e
attempt :: MonadIO m => String -> m (Maybe a) -> m (Maybe a)
attempt :: String -> m (Maybe a) -> m (Maybe a)
attempt String
msg m (Maybe a)
ea = do
Maybe a
ma <- m (Maybe a)
ea
case Maybe a
ma of
Maybe a
Nothing -> Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
Just a
a -> do IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (String -> IO ()
message String
msg)
Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
a)
orElse :: MonadIO m => m (Maybe a) -> m (Maybe a) -> m (Maybe a)
orElse :: m (Maybe a) -> m (Maybe a) -> m (Maybe a)
orElse m (Maybe a)
ea m (Maybe a)
eb = do
Maybe a
ma <- m (Maybe a)
ea
case Maybe a
ma of
Just a
a -> Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
a)
Maybe a
Nothing -> m (Maybe a)
eb
{-# INLINEABLE destroy #-}
destroy :: Stream -> IO ()
destroy :: Stream -> IO ()
destroy = Stream -> IO ()
forall a. Lifetime a -> IO ()
finalize
{-# INLINE trace #-}
trace :: String -> IO a -> IO a
trace :: String -> IO a -> IO a
trace String
msg IO a
next = do
Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_sched (String
"stream: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg)
IO a
next
{-# INLINE message #-}
message :: String -> IO ()
message :: String -> IO ()
message String
s = String
s String -> IO () -> IO ()
forall a. String -> IO a -> IO a
`trace` () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()