module Data.Array.Accelerate.CUDA.Execute.Stream (
Stream, Reservoir, new, streaming,
) where
import Data.Array.Accelerate.CUDA.Array.Nursery ( )
import Data.Array.Accelerate.CUDA.Context ( Context(..) )
import Data.Array.Accelerate.CUDA.FullList ( FullList(..) )
import Data.Array.Accelerate.CUDA.Execute.Event ( Event )
import qualified Data.Array.Accelerate.CUDA.Execute.Event as Event
import qualified Data.Array.Accelerate.CUDA.FullList as FL
#ifdef ACCELERATE_DEBUG
import qualified Data.Array.Accelerate.CUDA.Debug as D
#endif
import Control.Monad.Trans ( MonadIO, liftIO )
import Control.Exception ( bracket_ )
import Control.Concurrent.MVar ( MVar, newMVar, withMVar, mkWeakMVar )
import System.Mem.Weak ( Weak, deRefWeak )
import Foreign.CUDA.Driver.Stream ( Stream(..) )
import qualified Foreign.CUDA.Driver as CUDA
import qualified Foreign.CUDA.Driver.Stream as Stream
import qualified Data.HashTable.IO as HT
type HashTable key val = HT.BasicHashTable key val
type RSV = MVar ( HashTable CUDA.Context (FullList () Stream) )
data Reservoir = Reservoir !RSV
!(Weak RSV)
streaming :: MonadIO m => Context -> Reservoir -> (Stream -> m a) -> (Event -> a -> m b) -> m b
streaming !ctx !rsv@(Reservoir !_ !weak_rsv) !action !after = do
stream <- liftIO $ create ctx rsv
first <- action stream
end <- liftIO $ Event.waypoint stream
final <- after end first
liftIO $! destroy (weakContext ctx) weak_rsv stream
liftIO $! Event.destroy end
return final
new :: IO Reservoir
new = do
tbl <- HT.new
ref <- newMVar tbl
weak <- mkWeakMVar ref (flush tbl)
return $! Reservoir ref weak
create :: Context -> Reservoir -> IO Stream
create !ctx (Reservoir !ref !_) = withMVar ref $ \tbl -> do
let key = deviceContext ctx
ms <- HT.lookup tbl key
case ms of
Nothing -> do
stream <- Stream.create []
message ("new " ++ showStream stream)
return stream
Just (FL () stream rest) -> do
case rest of
FL.Nil -> HT.delete tbl key
FL.Cons () s ss -> HT.insert tbl key (FL () s ss)
return stream
destroy :: Weak CUDA.Context -> Weak RSV -> Stream -> IO ()
destroy !weak_ctx !weak_rsv !stream = do
mc <- deRefWeak weak_ctx
case mc of
Nothing -> message ("finalise/dead context " ++ showStream stream)
Just ctx -> do
mr <- deRefWeak weak_rsv
case mr of
Nothing -> trace ("destroy/free " ++ showStream stream) $ Stream.destroy stream
Just ref -> trace ("destroy/save " ++ showStream stream) $ withMVar ref $ \tbl -> do
ms <- HT.lookup tbl ctx
case ms of
Nothing -> HT.insert tbl ctx (FL.singleton () stream)
Just ss -> HT.insert tbl ctx (FL.cons () stream ss)
flush :: HashTable CUDA.Context (FullList () Stream) -> IO ()
flush !tbl =
let clean (!ctx,!ss) = do
bracket_ (CUDA.push ctx) CUDA.pop (FL.mapM_ (const Stream.destroy) ss)
HT.delete tbl ctx
in
message "flush reservoir" >> HT.mapM_ clean tbl
trace :: String -> IO a -> IO a
trace _msg next = do
#ifdef ACCELERATE_DEBUG
D.when D.verbose $ D.message D.dump_exec ("stream: " ++ _msg)
#endif
next
message :: String -> IO ()
message s = s `trace` return ()
showStream :: Stream -> String
showStream (Stream s) = show s