module Data.Array.Accelerate.CUDA.Array.Table (
MemoryTable, new, lookup, malloc, insert, insertRemote, reclaim
) where
import Prelude hiding ( lookup )
import Data.Maybe ( isJust )
import Data.Hashable ( Hashable(..) )
import Data.Typeable ( Typeable, gcast )
import Control.Monad ( unless )
import Control.Concurrent ( yield )
import Control.Concurrent.MVar ( MVar, newMVar, withMVar, mkWeakMVar )
import Control.Exception ( bracket_, catch, throwIO )
import Control.Applicative ( (<$>) )
import System.Mem ( performGC )
import System.Mem.Weak ( Weak, mkWeak, deRefWeak, finalize )
import System.Mem.StableName ( StableName, makeStableName, hashStableName )
import Foreign.Ptr ( ptrToIntPtr )
import Foreign.Storable ( Storable, sizeOf )
import Foreign.CUDA.Ptr ( DevicePtr )
import Foreign.CUDA.Driver.Error
import qualified Foreign.CUDA.Driver as CUDA
import qualified Data.HashTable.IO as HT
import Data.Array.Accelerate.Error ( internalError )
import Data.Array.Accelerate.Array.Data ( ArrayData )
import Data.Array.Accelerate.CUDA.Context ( Context, weakContext, deviceContext )
import Data.Array.Accelerate.CUDA.Array.Nursery ( Nursery(..), NRS )
import qualified Data.Array.Accelerate.CUDA.Array.Nursery as N
import qualified Data.Array.Accelerate.CUDA.Debug as D
type HashTable key val = HT.BasicHashTable key val
type MT = MVar ( HashTable HostArray DeviceArray )
data MemoryTable = MemoryTable !MT
!(Weak MT)
!Nursery
type ContextId = Int
data HostArray where
HostArray :: Typeable e
=> !ContextId
-> !(StableName (ArrayData e))
-> HostArray
data DeviceArray where
DeviceArray :: Typeable e
=> !(Weak (DevicePtr e))
-> DeviceArray
instance Eq HostArray where
HostArray _ a1 == HostArray _ a2
= maybe False (== a2) (gcast a1)
instance Hashable HostArray where
hashWithSalt salt (HostArray cid sn)
= salt `hashWithSalt` cid `hashWithSalt` sn
instance Show HostArray where
show (HostArray _ sn) = "Array #" ++ show (hashStableName sn)
new :: IO MemoryTable
new = do
message "initialise memory table"
tbl <- HT.new
ref <- newMVar tbl
nrs <- N.new
weak <- mkWeakMVar ref (table_finalizer tbl)
return $! MemoryTable ref weak nrs
lookup :: (Typeable a, Typeable b) => Context -> MemoryTable -> ArrayData a -> IO (Maybe (DevicePtr b))
lookup ctx (MemoryTable !ref _ _) !arr = do
sa <- makeStableArray ctx arr
mw <- withMVar ref (`HT.lookup` sa)
case mw of
Nothing -> trace ("lookup/not found: " ++ show sa) $ return Nothing
Just (DeviceArray w) -> do
mv <- deRefWeak w
case mv of
Just v | Just p <- gcast v -> trace ("lookup/found: " ++ show sa) $ return (Just p)
| otherwise -> $internalError "lookup" $ "type mismatch"
Nothing ->
makeStableArray ctx arr >>= \x -> $internalError "lookup" $ "dead weak pair: " ++ show x
malloc :: forall a b. (Typeable a, Typeable b, Storable b) => Context -> MemoryTable -> ArrayData a -> Int -> IO (DevicePtr b)
malloc !ctx mt@(MemoryTable _ _ !nursery) !ad !n = do
let
multiple x f = floor ((x + (f1)) / f :: Double)
chunk = 128
!n' = chunk * multiple (fromIntegral n) (fromIntegral chunk)
!bytes = n' * sizeOf (undefined :: b)
mp <- N.malloc bytes (deviceContext ctx) nursery
ptr <- case mp of
Just p -> trace "malloc/nursery" $ return (CUDA.castDevPtr p)
Nothing -> trace "malloc/new" $
CUDA.mallocArray n' `catch` \(e :: CUDAException) ->
case e of
ExitCode OutOfMemory -> reclaim mt >> CUDA.mallocArray n'
_ -> throwIO e
insert ctx mt ad ptr bytes
return ptr
insert :: (Typeable a, Typeable b) => Context -> MemoryTable -> ArrayData a -> DevicePtr b -> Int -> IO ()
insert !ctx (MemoryTable !ref !weak_ref (Nursery _ !weak_nrs)) !arr !ptr !bytes = do
key <- makeStableArray ctx arr
dev <- DeviceArray `fmap` mkWeak arr ptr (Just $ finalizer (weakContext ctx) weak_ref weak_nrs key ptr bytes)
message $ "insert: " ++ show key
withMVar ref $ \tbl -> HT.insert tbl key dev
insertRemote :: (Typeable a, Typeable b) => Context -> MemoryTable -> ArrayData a -> DevicePtr b -> IO ()
insertRemote !ctx (MemoryTable !ref !weak_ref _) !arr !ptr = do
key <- makeStableArray ctx arr
dev <- DeviceArray `fmap` mkWeak arr ptr (Just $ remoteFinalizer weak_ref key)
message $ "insert/remote: " ++ show key
withMVar ref $ \tbl -> HT.insert tbl key dev
reclaim :: MemoryTable -> IO ()
reclaim (MemoryTable _ weak_ref (Nursery nrs _)) = do
(free, total) <- CUDA.getMemInfo
performGC
yield
withMVar nrs N.flush
mr <- deRefWeak weak_ref
case mr of
Nothing -> return ()
Just ref -> withMVar ref $ \tbl ->
flip HT.mapM_ tbl $ \(_,DeviceArray w) -> do
alive <- isJust `fmap` deRefWeak w
unless alive $ finalize w
D.when D.dump_gc $ do
(free', _) <- CUDA.getMemInfo
message $ "reclaim: freed " ++ showBytes (fromIntegral (free free'))
++ ", " ++ showBytes (fromIntegral free')
++ " of " ++ showBytes (fromIntegral total) ++ " remaining"
finalizer :: Weak CUDA.Context -> Weak MT -> Weak NRS -> HostArray -> DevicePtr b -> Int -> IO ()
finalizer !weak_ctx !weak_ref !weak_nrs !key !ptr !bytes = do
mr <- deRefWeak weak_ref
case mr of
Nothing -> message ("finalise/dead table: " ++ show key)
Just ref -> withMVar ref (`HT.delete` key)
mc <- deRefWeak weak_ctx
case mc of
Nothing -> message ("finalise/dead context: " ++ show key)
Just ctx -> do
mn <- deRefWeak weak_nrs
case mn of
Nothing -> trace ("finalise/free: " ++ show key) $ bracket_ (CUDA.push ctx) CUDA.pop (CUDA.free ptr)
Just nrs -> trace ("finalise/nursery: " ++ show key) $ N.stash bytes ctx nrs ptr
remoteFinalizer :: Weak MT -> HostArray -> IO ()
remoteFinalizer !weak_ref !key = do
mr <- deRefWeak weak_ref
case mr of
Nothing -> message ("finalise/dead table: " ++ show key)
Just ref -> trace ("finalise: " ++ show key) $ withMVar ref (`HT.delete` key)
table_finalizer :: HashTable HostArray DeviceArray -> IO ()
table_finalizer !tbl
= trace "table finaliser"
$ HT.mapM_ (\(_,DeviceArray w) -> finalize w) tbl
makeStableArray :: Typeable a => Context -> ArrayData a -> IO HostArray
makeStableArray !ctx !arr =
let CUDA.Context !p = deviceContext ctx
!cid = fromIntegral (ptrToIntPtr p)
in
HostArray cid <$> makeStableName arr
showBytes :: Int -> String
showBytes x = D.showFFloatSIBase (Just 0) 1024 (fromIntegral x :: Double) "B"
trace :: String -> IO a -> IO a
trace msg next = D.message D.dump_gc ("gc: " ++ msg) >> next
message :: String -> IO ()
message s = s `trace` return ()