{-# LANGUAGE BangPatterns    #-}
{-# LANGUAGE MagicHash       #-}
{-# LANGUAGE RecordWildCards #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Execute.Stream
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Execute.Stream (

  Reservoir, new,
  Stream, create, destroy, streaming,

) where

-- accelerate
import Data.Array.Accelerate.Lifetime
import qualified Data.Array.Accelerate.Array.Remote.LRU             as Remote

import Data.Array.Accelerate.LLVM.PTX.Array.Remote                  ( )
import Data.Array.Accelerate.LLVM.PTX.Execute.Event                 ( Event )
import Data.Array.Accelerate.LLVM.PTX.Target                        ( PTX(..) )
import Data.Array.Accelerate.LLVM.State
import qualified Data.Array.Accelerate.LLVM.PTX.Debug               as Debug
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event       as Event
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream.Reservoir      as RSV

-- cuda
import Foreign.CUDA.Driver.Error
import qualified Foreign.CUDA.Driver.Stream                         as Stream

-- standard library
import Control.Exception
import Control.Monad.State


-- | A 'Stream' represents an independent sequence of computations executed on
-- the GPU. Operations in different streams may be executed concurrently with
-- each other, but operations in the same stream can never overlap.
-- 'Data.Array.Accelerate.LLVM.PTX.Execute.Event.Event's can be used for
-- efficient cross-stream synchronisation.
--
type Stream = Lifetime Stream.Stream


-- Executing operations in streams
-- -------------------------------

-- | Execute an operation in a unique execution stream. The (asynchronous)
-- result is passed to a second operation together with an event that will be
-- signalled once the operation is complete. The stream and event are released
-- after the second operation completes.
--
{-# INLINEABLE streaming #-}
streaming
    :: (Stream -> LLVM PTX a)
    -> (Event -> a -> LLVM PTX b)
    -> LLVM PTX b
streaming :: (Stream -> LLVM PTX a) -> (Event -> a -> LLVM PTX b) -> LLVM PTX b
streaming !Stream -> LLVM PTX a
action !Event -> a -> LLVM PTX b
after = do
  Stream
stream  <- LLVM PTX Stream
create
  a
first   <- Stream -> LLVM PTX a
action Stream
stream
  Event
end     <- Stream -> LLVM PTX Event
Event.waypoint Stream
stream
  b
final   <- Event -> a -> LLVM PTX b
after Event
end a
first
  IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ do
    Stream -> IO ()
destroy Stream
stream
    Event -> IO ()
Event.destroy Event
end
  b -> LLVM PTX b
forall (m :: * -> *) a. Monad m => a -> m a
return b
final


-- Primitive operations
-- --------------------

{--
-- | Delete all execution streams from the reservoir
--
{-# INLINEABLE flush #-}
flush :: Context -> Reservoir -> IO ()
flush !Context{..} !ref = do
  mc <- deRefWeak weakContext
  case mc of
    Nothing     -> message "delete reservoir/dead context"
    Just ctx    -> do
      message "flush reservoir"
      old <- swapMVar ref Seq.empty
      bracket_ (CUDA.push ctx) CUDA.pop $ Seq.mapM_ Stream.destroy old
--}


-- | Create a CUDA execution stream. If an inactive stream is available for use,
-- use that, otherwise generate a fresh stream.
--
-- Note: [Finalising execution streams]
--
-- We don't actually ensure that the stream has executed all of its operations
-- to completion before attempting to return it to the reservoir for reuse.
-- Doing so increases overhead of the LLVM RTS due to 'forkIO', and consumes CPU
-- time as 'Stream.block' busy-waits for the stream to complete. It is quicker
-- to optimistically return the streams to the end of the reservoir immediately,
-- and just check whether the stream is done before reusing it.
--
-- > void . forkIO $ do
-- >   Stream.block stream
-- >   modifyMVar_ ref $ \rsv -> return (rsv Seq.|> stream)
--
{-# INLINEABLE create #-}
create :: LLVM PTX Stream
create :: LLVM PTX Stream
create = do
  PTX{MemoryTable
KernelTable
Reservoir
Context
ptxStreamReservoir :: PTX -> Reservoir
ptxKernelTable :: PTX -> KernelTable
ptxMemoryTable :: PTX -> MemoryTable
ptxContext :: PTX -> Context
ptxStreamReservoir :: Reservoir
ptxKernelTable :: KernelTable
ptxMemoryTable :: MemoryTable
ptxContext :: Context
..} <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
  Stream
s       <- LLVM PTX Stream
create'
  Stream
stream  <- IO Stream -> LLVM PTX Stream
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Stream -> LLVM PTX Stream) -> IO Stream -> LLVM PTX Stream
forall a b. (a -> b) -> a -> b
$ Stream -> IO Stream
forall a. a -> IO (Lifetime a)
newLifetime Stream
s
  IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ Stream -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Stream
stream (Reservoir -> Stream -> IO ()
RSV.insert Reservoir
ptxStreamReservoir Stream
s)
  Stream -> LLVM PTX Stream
forall (m :: * -> *) a. Monad m => a -> m a
return Stream
stream

create' :: LLVM PTX Stream.Stream
create' :: LLVM PTX Stream
create' = do
  PTX{MemoryTable
KernelTable
Reservoir
Context
ptxStreamReservoir :: Reservoir
ptxKernelTable :: KernelTable
ptxMemoryTable :: MemoryTable
ptxContext :: Context
ptxStreamReservoir :: PTX -> Reservoir
ptxKernelTable :: PTX -> KernelTable
ptxMemoryTable :: PTX -> MemoryTable
ptxContext :: PTX -> Context
..} <- (PTX -> PTX) -> LLVM PTX PTX
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> PTX
forall t. t -> t
llvmTarget
  Maybe Stream
ms      <- String -> LLVM PTX (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a.
MonadIO m =>
String -> m (Maybe a) -> m (Maybe a)
attempt String
"create/reservoir" (IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Stream) -> LLVM PTX (Maybe Stream))
-> IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall a b. (a -> b) -> a -> b
$ Reservoir -> IO (Maybe Stream)
RSV.malloc Reservoir
ptxStreamReservoir)
             LLVM PTX (Maybe Stream)
-> LLVM PTX (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a.
MonadIO m =>
m (Maybe a) -> m (Maybe a) -> m (Maybe a)
`orElse`
             String -> LLVM PTX (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a.
MonadIO m =>
String -> m (Maybe a) -> m (Maybe a)
attempt String
"create/new"       (IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Stream) -> LLVM PTX (Maybe Stream))
-> (IO Stream -> IO (Maybe Stream))
-> IO Stream
-> LLVM PTX (Maybe Stream)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Stream -> IO (Maybe Stream)
forall a. IO a -> IO (Maybe a)
catchOOM (IO Stream -> LLVM PTX (Maybe Stream))
-> IO Stream -> LLVM PTX (Maybe Stream)
forall a b. (a -> b) -> a -> b
$ [StreamFlag] -> IO Stream
Stream.create [])
             LLVM PTX (Maybe Stream)
-> LLVM PTX (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a.
MonadIO m =>
m (Maybe a) -> m (Maybe a) -> m (Maybe a)
`orElse` do
               MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event) -> LLVM PTX ()
forall (m :: * -> *) task.
(HasCallStack, RemoteMemory m, MonadIO m) =>
MemoryTable (RemotePtr m) task -> m ()
Remote.reclaim MemoryTable (RemotePtr (LLVM PTX)) (Maybe Event)
MemoryTable
ptxMemoryTable
               IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Stream) -> LLVM PTX (Maybe Stream))
-> IO (Maybe Stream) -> LLVM PTX (Maybe Stream)
forall a b. (a -> b) -> a -> b
$ do
                 String -> IO ()
message String
"create/new: failed (purging)"
                 IO Stream -> IO (Maybe Stream)
forall a. IO a -> IO (Maybe a)
catchOOM (IO Stream -> IO (Maybe Stream)) -> IO Stream -> IO (Maybe Stream)
forall a b. (a -> b) -> a -> b
$ [StreamFlag] -> IO Stream
Stream.create []
  case Maybe Stream
ms of
    Just Stream
s  -> Stream -> LLVM PTX Stream
forall (m :: * -> *) a. Monad m => a -> m a
return Stream
s
    Maybe Stream
Nothing -> IO Stream -> LLVM PTX Stream
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Stream -> LLVM PTX Stream) -> IO Stream -> LLVM PTX Stream
forall a b. (a -> b) -> a -> b
$ do
      String -> IO ()
message String
"create/new: failed (non-recoverable)"
      CUDAException -> IO Stream
forall e a. Exception e => e -> IO a
throwIO (Status -> CUDAException
ExitCode Status
OutOfMemory)

  where
    catchOOM :: IO a -> IO (Maybe a)
    catchOOM :: IO a -> IO (Maybe a)
catchOOM IO a
it =
      (a -> Maybe a) -> IO a -> IO (Maybe a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM a -> Maybe a
forall a. a -> Maybe a
Just IO a
it IO (Maybe a) -> (CUDAException -> IO (Maybe a)) -> IO (Maybe a)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \CUDAException
e -> case CUDAException
e of
                                    ExitCode Status
OutOfMemory -> Maybe a -> IO (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
                                    CUDAException
_                    -> CUDAException -> IO (Maybe a)
forall e a. Exception e => e -> IO a
throwIO CUDAException
e

    attempt :: MonadIO m => String -> m (Maybe a) -> m (Maybe a)
    attempt :: String -> m (Maybe a) -> m (Maybe a)
attempt String
msg m (Maybe a)
ea = do
      Maybe a
ma <- m (Maybe a)
ea
      case Maybe a
ma of
        Maybe a
Nothing -> Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
        Just a
a  -> do IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (String -> IO ()
message String
msg)
                      Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
a)

    orElse :: MonadIO m => m (Maybe a) -> m (Maybe a) -> m (Maybe a)
    orElse :: m (Maybe a) -> m (Maybe a) -> m (Maybe a)
orElse m (Maybe a)
ea m (Maybe a)
eb = do
      Maybe a
ma <- m (Maybe a)
ea
      case Maybe a
ma of
        Just a
a  -> Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
a)
        Maybe a
Nothing -> m (Maybe a)
eb


-- | Merge a stream back into the reservoir. This must only be done once all
-- pending operations in the stream have completed.
--
{-# INLINEABLE destroy #-}
destroy :: Stream -> IO ()
destroy :: Stream -> IO ()
destroy = Stream -> IO ()
forall a. Lifetime a -> IO ()
finalize


-- Debug
-- -----

{-# INLINE trace #-}
trace :: String -> IO a -> IO a
trace :: String -> IO a -> IO a
trace String
msg IO a
next = do
  Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_sched (String
"stream: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg)
  IO a
next

{-# INLINE message #-}
message :: String -> IO ()
message :: String -> IO ()
message String
s = String
s String -> IO () -> IO ()
forall a. String -> IO a -> IO a
`trace` () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()