{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE ScopedTypeVariables #-} -- | -- Module : Data.Array.Accelerate.CUDA.Array.Table -- Copyright : [2008..2010] Manuel M T Chakravarty, Gabriele Keller, Sean Lee -- [2009..2012] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.CUDA.Array.Table ( -- Tables for host/device memory associations MemoryTable, new, lookup, malloc, insert, reclaim ) where #if !MIN_VERSION_base(4,6,0) import Prelude hiding ( lookup, catch ) #else import Prelude hiding ( lookup ) #endif import Data.IORef ( IORef, newIORef, readIORef, mkWeakIORef ) import Data.Maybe ( isJust ) import Data.Hashable ( Hashable(..) ) import Data.Typeable ( Typeable, gcast ) import Control.Monad ( unless ) import Control.Concurrent ( yield ) import Control.Exception ( bracket_, catch, throwIO ) import Control.Applicative ( (<$>) ) import System.Mem ( performGC ) import System.Mem.Weak ( Weak, mkWeak, deRefWeak, finalize ) import System.Mem.StableName ( StableName, makeStableName, hashStableName ) import Foreign.Ptr ( ptrToIntPtr ) import Foreign.Storable ( Storable, sizeOf ) import Foreign.CUDA.Ptr ( DevicePtr ) import Foreign.CUDA.Driver.Error import qualified Foreign.CUDA.Driver as CUDA import qualified Data.HashTable.IO as HT import Data.Array.Accelerate.Array.Data ( ArrayData ) import Data.Array.Accelerate.CUDA.Context ( Context, weakContext, deviceContext ) import Data.Array.Accelerate.CUDA.Array.Nursery ( Nursery(..), NRS ) import qualified Data.Array.Accelerate.CUDA.Array.Nursery as N import qualified Data.Array.Accelerate.CUDA.Debug as D #include "accelerate.h" -- We use an MVar to the hash table, so that several threads may safely access -- it concurrently. This includes the finalisation threads that remove entries -- from the table. -- -- It is important that we can garbage collect old entries from the table when -- the key is no longer reachable in the heap. Hence the value part of each -- table entry is a (Weak val), where the stable name 'key' is the key for the -- memo table, and the 'val' is the value of this table entry. When the key -- becomes unreachable, a finaliser will fire and remove this entry from the -- hash buckets, and further attempts to dereference the weak pointer will -- return Nothing. References from 'val' to the key are ignored (see the -- semantics of weak pointers in the documentation). -- type HashTable key val = HT.BasicHashTable key val type MT = IORef ( HashTable HostArray DeviceArray ) data MemoryTable = MemoryTable {-# UNPACK #-} !MT {-# UNPACK #-} !(Weak MT) {-# UNPACK #-} !Nursery -- Arrays on the host and device -- data HostArray where HostArray :: Typeable e => {-# UNPACK #-} !Int -- unique ID relating to the parent context -> {-# UNPACK #-} !(StableName (ArrayData e)) -> HostArray data DeviceArray where DeviceArray :: Typeable e => {-# UNPACK #-} !(Weak (DevicePtr e)) -> DeviceArray instance Eq HostArray where HostArray _ a1 == HostArray _ a2 = maybe False (== a2) (gcast a1) instance Hashable HostArray where hashWithSalt salt (HostArray cid sn) = salt `hashWithSalt` cid `hashWithSalt` sn instance Show HostArray where show (HostArray _ sn) = "Array #" ++ show (hashStableName sn) -- Referencing arrays -- ------------------ -- Create a new hash table from host to device arrays. When the structure is -- collected it will finalise all entries in the table. -- new :: IO MemoryTable new = do message "initialise memory table" tbl <- HT.new ref <- newIORef tbl nrs <- N.new weak <- mkWeakIORef ref (table_finalizer tbl) return $! MemoryTable ref weak nrs -- Look for the device memory corresponding to a given host-side array. -- lookup :: (Typeable a, Typeable b) => Context -> MemoryTable -> ArrayData a -> IO (Maybe (DevicePtr b)) lookup ctx (MemoryTable !ref _ _) !arr = do sa <- makeStableArray ctx arr mw <- withIORef ref (`HT.lookup` sa) case mw of Nothing -> trace ("lookup/not found: " ++ show sa) $ return Nothing Just (DeviceArray w) -> do mv <- deRefWeak w case mv of Just v | Just p <- gcast v -> trace ("lookup/found: " ++ show sa) $ return (Just p) | otherwise -> INTERNAL_ERROR(error) "lookup" $ "type mismatch" Nothing -> makeStableArray ctx arr >>= \x -> INTERNAL_ERROR(error) "lookup" $ "dead weak pair: " ++ show x -- Allocate a new device array to be associated with the given host-side array. -- This will attempt to use an old array from the nursery, but will otherwise -- allocate fresh data. -- -- Instead of allocating the exact number of elements requested, we round up to -- a fixed chunk size; currently set at 128 elements. This means there is a -- greater chance the nursery will get a hit, and moreover that we can search -- the nursery for an exact size. TLM: I believe the CUDA API allocates in -- chunks, of size 4MB. -- malloc :: forall a b. (Typeable a, Typeable b, Storable b) => Context -> MemoryTable -> ArrayData a -> Int -> IO (DevicePtr b) malloc !ctx mt@(MemoryTable _ _ !nursery) !ad !n = do let -- next highest multiple of f from x multiple x f = floor ((x + (f-1)) / f :: Double) chunk = 128 !n' = chunk * multiple (fromIntegral n) (fromIntegral chunk) !bytes = n' * sizeOf (undefined :: b) -- mp <- N.lookup bytes (deviceContext ctx) nursery ptr <- case mp of Just p -> trace "malloc/nursery" $ return (CUDA.castDevPtr p) Nothing -> trace "malloc/new" $ CUDA.mallocArray n' `catch` \(e :: CUDAException) -> case e of ExitCode OutOfMemory -> reclaim mt >> CUDA.mallocArray n' _ -> throwIO e insert ctx mt ad ptr bytes return ptr -- Record an association between a host-side array and a new device memory area. -- The device memory will be freed when the host array is garbage collected. -- insert :: (Typeable a, Typeable b) => Context -> MemoryTable -> ArrayData a -> DevicePtr b -> Int -> IO () insert !ctx (MemoryTable !ref !weak_ref (Nursery _ !weak_nrs)) !arr !ptr !bytes = do key <- makeStableArray ctx arr dev <- DeviceArray `fmap` mkWeak arr ptr (Just $ finalizer (weakContext ctx) weak_ref weak_nrs key ptr bytes) tbl <- readIORef ref message $ "insert: " ++ show key HT.insert tbl key dev -- Removing entries -- ---------------- -- Initiate garbage collection and finalise any arrays that have been marked as -- unreachable. -- reclaim :: MemoryTable -> IO () reclaim (MemoryTable _ weak_ref (Nursery nrs _)) = do (free, total) <- CUDA.getMemInfo performGC yield withIORef nrs N.flush mr <- deRefWeak weak_ref case mr of Nothing -> return () Just ref -> withIORef ref $ \tbl -> flip HT.mapM_ tbl $ \(_,DeviceArray w) -> do alive <- isJust `fmap` deRefWeak w unless alive $ finalize w -- D.when D.dump_gc $ do (free', _) <- CUDA.getMemInfo message $ "reclaim: freed " ++ showBytes (fromIntegral (free - free')) ++ ", " ++ showBytes (fromIntegral free') ++ " of " ++ showBytes (fromIntegral total) ++ " remaining" -- Because a finaliser might run at any time, we must reinstate the context in -- which the array was allocated before attempting to release it. -- -- Note also that finaliser threads will silently terminate if an exception is -- raised. If the context, and thereby all allocated memory, was destroyed -- externally before the thread had a chance to run, all we need do is update -- the hash tables --- but we must do this first before failing to use a dead -- context. -- finalizer :: Weak CUDA.Context -> Weak MT -> Weak NRS -> HostArray -> DevicePtr b -> Int -> IO () finalizer !weak_ctx !weak_ref !weak_nrs !key !ptr !bytes = do mr <- deRefWeak weak_ref case mr of Nothing -> message ("finalise/dead table: " ++ show key) Just ref -> withIORef ref (`HT.delete` key) -- mc <- deRefWeak weak_ctx case mc of Nothing -> message ("finalise/dead context: " ++ show key) Just ctx -> do -- mn <- deRefWeak weak_nrs case mn of Nothing -> trace ("finalise/free: " ++ show key) $ bracket_ (CUDA.push ctx) CUDA.pop (CUDA.free ptr) Just nrs -> trace ("finalise/nursery: " ++ show key) $ N.insert bytes ctx nrs ptr table_finalizer :: HashTable HostArray DeviceArray -> IO () table_finalizer !tbl = trace "table finaliser" $ HT.mapM_ (\(_,DeviceArray w) -> finalize w) tbl -- Miscellaneous -- ------------- {-# INLINE makeStableArray #-} makeStableArray :: Typeable a => Context -> ArrayData a -> IO HostArray makeStableArray !ctx !arr = let CUDA.Context !p = deviceContext ctx !cid = fromIntegral (ptrToIntPtr p) in HostArray cid <$> makeStableName arr {-# INLINE withIORef #-} withIORef :: IORef a -> (a -> IO b) -> IO b withIORef ref f = readIORef ref >>= f -- Debug -- ----- {-# INLINE showBytes #-} showBytes :: Int -> String showBytes x = D.showFFloatSIBase (Just 0) 1024 (fromIntegral x :: Double) "B" {-# INLINE trace #-} trace :: String -> IO a -> IO a trace msg next = D.message D.dump_gc ("gc: " ++ msg) >> next {-# INLINE message #-} message :: String -> IO () message s = s `trace` return ()