{-# LANGUAGE BangPatterns    #-}
{-# LANGUAGE MagicHash       #-}
{-# LANGUAGE RecordWildCards #-}
module Data.Array.Accelerate.LLVM.PTX.Execute.Stream (
  Reservoir, new,
  Stream, create, destroy, streaming,
) where
import Data.Array.Accelerate.Lifetime
import qualified Data.Array.Accelerate.Array.Remote.LRU             as Remote
import Data.Array.Accelerate.LLVM.PTX.Array.Remote                  ( )
import Data.Array.Accelerate.LLVM.PTX.Execute.Event                 ( Event )
import Data.Array.Accelerate.LLVM.PTX.Target                        ( PTX(..) )
import Data.Array.Accelerate.LLVM.State
import qualified Data.Array.Accelerate.LLVM.PTX.Debug               as Debug
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event       as Event
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream.Reservoir      as RSV
import Foreign.CUDA.Driver.Error
import qualified Foreign.CUDA.Driver.Stream                         as Stream
import Control.Exception
import Control.Monad.State
type Stream = Lifetime Stream.Stream
{-# INLINEABLE streaming #-}
streaming
    :: (Stream -> LLVM PTX a)
    -> (Event -> a -> LLVM PTX b)
    -> LLVM PTX b
streaming !action !after = do
  PTX{..} <- gets llvmTarget
  stream  <- create
  first   <- action stream
  end     <- Event.waypoint stream
  final   <- after end first
  liftIO $ do
    destroy stream
    Event.destroy end
  return final
{-# INLINEABLE create #-}
create :: LLVM PTX Stream
create = do
  PTX{..} <- gets llvmTarget
  s       <- create'
  stream  <- liftIO $ newLifetime s
  liftIO $ addFinalizer stream (RSV.insert ptxStreamReservoir s)
  return stream
create' :: LLVM PTX Stream.Stream
create' = do
  PTX{..} <- gets llvmTarget
  ms      <- attempt "create/reservoir" (liftIO $ RSV.malloc ptxStreamReservoir)
             `orElse`
             attempt "create/new"       (liftIO . catchOOM $ Stream.create [])
             `orElse` do
               Remote.reclaim ptxMemoryTable
               liftIO $ do
                 message "create/new: failed (purging)"
                 catchOOM $ Stream.create []
  case ms of
    Just s  -> return s
    Nothing -> liftIO $ do
      message "create/new: failed (non-recoverable)"
      throwIO (ExitCode OutOfMemory)
  where
    catchOOM :: IO a -> IO (Maybe a)
    catchOOM it =
      liftM Just it `catch` \e -> case e of
                                    ExitCode OutOfMemory -> return Nothing
                                    _                    -> throwIO e
    attempt :: MonadIO m => String -> m (Maybe a) -> m (Maybe a)
    attempt msg ea = do
      ma <- ea
      case ma of
        Nothing -> return Nothing
        Just a  -> do liftIO (message msg)
                      return (Just a)
    orElse :: MonadIO m => m (Maybe a) -> m (Maybe a) -> m (Maybe a)
    orElse ea eb = do
      ma <- ea
      case ma of
        Just a  -> return (Just a)
        Nothing -> eb
{-# INLINEABLE destroy #-}
destroy :: Stream -> IO ()
destroy = finalize
{-# INLINE trace #-}
trace :: String -> IO a -> IO a
trace msg next = do
  Debug.traceIO Debug.dump_sched ("stream: " ++ msg)
  next
{-# INLINE message #-}
message :: String -> IO ()
message s = s `trace` return ()