module Data.Array.Accelerate.CUDA.Array.Nursery (
Nursery(..), NRS, new, malloc, stash, flush,
) where
import Data.Array.Accelerate.CUDA.FullList ( FullList(..) )
import qualified Data.Array.Accelerate.CUDA.FullList as FL
import qualified Data.Array.Accelerate.CUDA.Debug as D
import Prelude
import Data.Hashable
import Control.Exception ( bracket_ )
import Control.Concurrent.MVar ( MVar, newMVar, withMVar, mkWeakMVar )
import System.Mem.Weak ( Weak )
import Foreign.Ptr ( ptrToIntPtr )
import Foreign.CUDA.Ptr ( DevicePtr )
import qualified Foreign.CUDA.Driver as CUDA
import qualified Data.HashTable.IO as HT
type HashTable key val = HT.BasicHashTable key val
type NRS = MVar ( HashTable (CUDA.Context, Int) (FullList () (DevicePtr ())) )
data Nursery = Nursery !NRS
!(Weak NRS)
instance Hashable CUDA.Context where
hashWithSalt salt (CUDA.Context ctx)
= salt `hashWithSalt` (fromIntegral (ptrToIntPtr ctx) :: Int)
new :: IO Nursery
new = do
tbl <- HT.new
ref <- newMVar tbl
weak <- mkWeakMVar ref (flush tbl)
return $! Nursery ref weak
malloc :: Int -> CUDA.Context -> Nursery -> IO (Maybe (DevicePtr ()))
malloc !n !ctx (Nursery !ref _) = withMVar ref $ \tbl -> do
let !key = (ctx,n)
mp <- HT.lookup tbl key
case mp of
Nothing -> return Nothing
Just (FL () ptr rest) ->
case rest of
FL.Nil -> HT.delete tbl key >> return (Just ptr)
FL.Cons () v xs -> HT.insert tbl key (FL () v xs) >> return (Just ptr)
stash :: Int -> CUDA.Context -> NRS -> DevicePtr a -> IO ()
stash !n !ctx !ref (CUDA.castDevPtr -> !ptr) = withMVar ref $ \tbl -> do
let !key = (ctx, n)
mp <- HT.lookup tbl key
case mp of
Nothing -> HT.insert tbl key (FL.singleton () ptr)
Just xs -> HT.insert tbl key (FL.cons () ptr xs)
flush :: HashTable (CUDA.Context,Int) (FullList () (CUDA.DevicePtr ())) -> IO ()
flush !tbl =
let clean (!key@(ctx,_),!val) = do
bracket_ (CUDA.push ctx) CUDA.pop (FL.mapM_ (const CUDA.free) val)
HT.delete tbl key
in
message "flush nursery" >> HT.mapM_ clean tbl
trace :: String -> IO a -> IO a
trace msg next = D.message D.dump_gc ("gc: " ++ msg) >> next
message :: String -> IO ()
message s = s `trace` return ()