{-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DoAndIfThenElse #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Array.Remote.Cache -- Copyright : [2015..2017] Manuel M T Chakravarty, Gabriele Keller, Robert Clifton-Everest -- [2016..2017] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Robert Clifton-Everest -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- This module extends the memory tables provided by -- 'Data.Array.Accelerate.Array.Remote.Table' with an LRU caching policy that -- evicts old arrays from the remote memory space once it runs out of memory. -- Consequently, use of this module requires the backend client to keep track of -- which remote arrays are currently being used, so that they will not be -- evicted. See: 'withRemote' for more details on this requirement. -- module Data.Array.Accelerate.Array.Remote.LRU ( -- Tables for host/device memory associations MemoryTable, new, withRemote, malloc, free, insertUnmanaged, reclaim, -- Asynchronous tasks Task(..) ) where import Control.Concurrent.MVar ( MVar, newMVar, withMVar, takeMVar, putMVar, mkWeakMVar ) import Control.Monad ( filterM ) import Control.Monad.Catch import Control.Monad.IO.Class ( MonadIO, liftIO ) import Data.Functor import Data.Int ( Int64 ) import Data.Maybe ( isNothing ) import Data.Proxy import Foreign.Storable ( sizeOf ) import System.CPUTime import System.Mem.Weak ( Weak, deRefWeak, finalize ) import Prelude hiding ( lookup ) import qualified Data.HashTable.IO as HT import Data.Array.Accelerate.Array.Data ( ArrayData, touchArrayData ) import Data.Array.Accelerate.Array.Remote.Class import Data.Array.Accelerate.Array.Remote.Table ( StableArray, makeWeakArrayData ) import Data.Array.Accelerate.Error ( internalError ) import qualified Data.Array.Accelerate.Array.Remote.Table as Basic import qualified Data.Array.Accelerate.Debug as D -- We build cached memory tables on top of a basic memory table. -- -- A key invariant is that the arrays in the MemoryTable are a subset of the -- arrays in the UseTable. The UseTable reflects all arrays that have ever been -- in the cache. -- data MemoryTable p task = MemoryTable {-# UNPACK #-} !(Basic.MemoryTable p) {-# UNPACK #-} !(UseTable task) {-# UNPACK #-} !(Weak (UseTable task)) type UT task = HT.CuckooHashTable StableArray (Used task) type UseTable task = MVar (UT task) data Status = Clean -- Array in remote memory matches array in host memory. | Dirty -- Array in remote memory has been modified. | Unmanaged -- Array in remote memory was injected by FFI, so we -- cannot remove it under any circumstance. | Evicted -- Array has been evicted from remote memory deriving Eq type Timestamp = Integer data Used task where Used :: PrimElt e a => !Timestamp -> !Status -> {-# UNPACK #-} !Int -- Use count -> ![task] -- Asynchronous tasks using the array -> {-# UNPACK #-} !Int -- Array size -> {-# UNPACK #-} !(Weak (ArrayData e)) -> Used task -- |A Task represents a process executing asynchronously that can be polled for -- its status. This is necessary for backends that work asynchronously (i.e. -- the CUDA backend). If a backend is synchronous, the () instance can be used. -- class Task task where -- |Returns true when the task has finished. completed :: task -> IO Bool instance Task () where completed () = return True -- |Create a new memory cache from host to remote arrays. -- -- The function supplied should be the `free` for the remote pointers being -- stored. This function will be called by the GC, which typically runs on a -- different thread. Unlike the `free` in `RemoteMemory`, this function cannot -- depend on any state. -- new :: (forall a. ptr a -> IO ()) -> IO (MemoryTable ptr task) new release = do mt <- Basic.new release utbl <- HT.new ref <- newMVar utbl weak_utbl <- mkWeakMVar ref (cache_finalizer utbl) return $! MemoryTable mt ref weak_utbl -- |Perform some action that requires the remote pointer corresponding to -- the given array. Returns `Nothing` if the array have NEVER been in the -- cache. If the array was previously in the cache, but was evicted due to its -- age, then the array will be copied back from host memory. -- -- The continuation passed as the third argument needs to obey some precise -- properties. As with all bracketed functions, the supplied remote pointer must -- not leak out of the function, as it is only guaranteed to be valid within it. -- If it is required that it does leak (e.g. the backend uses concurrency to -- interleave execution of different parts of the program), then `completed` on -- the returned task should not return true until it is guaranteed there are no -- more accesses of the remote pointer. -- withRemote :: forall task m a b c. (PrimElt a b, Task task, RemoteMemory m, MonadIO m, Functor m) => MemoryTable (RemotePtr m) task -> ArrayData a -> (RemotePtr m b -> m (task, c)) -> m (Maybe c) withRemote (MemoryTable !mt !ref _) !arr run = do key <- Basic.makeStableArray arr mp <- withMVar' ref $ \utbl -> do mu <- liftIO . HT.mutate utbl key $ \case Nothing -> (Nothing, Nothing) Just u -> (Just (incCount u), Just u) -- case mu of Nothing -> do message ("withRemote/array has never been malloc'd: " ++ show key) return Nothing -- The array was never in the table Just u -> do mp <- liftIO $ Basic.lookup mt arr ptr <- case mp of Just p -> return p Nothing | isEvicted u -> copyBack utbl (incCount u) | otherwise -> do message ("lost array " ++ show key) $internalError "withRemote" "non-evicted array has been lost" return (Just ptr) -- case mp of Nothing -> return Nothing Just ptr -> Just <$> go key ptr where updateTask :: Used task -> task -> IO (Used task) updateTask (Used _ status count tasks n weak_arr) task = do ts <- getCPUTime tasks' <- cleanUses tasks return (Used ts status (count - 1) (task : tasks') n weak_arr) copyBack :: UT task -> Used task -> m (RemotePtr m b) copyBack utbl (Used ts _ count tasks n weak_arr) = do message "withRemote/reuploading-evicted-array" p <- mallocWithUsage mt utbl arr (Used ts Clean count tasks n weak_arr) pokeRemote n p arr return p -- We can't combine the use of `withMVar ref` above with the one here -- because the `permute` operation from the PTX backend requires nested -- calls to `withRemote` in order to copy the defaults array. -- go :: StableArray -> RemotePtr m b -> m c go key ptr = do message ("withRemote/using: " ++ show key) (task, c) <- run ptr liftIO . withMVar ref $ \utbl -> do HT.mutateIO utbl key $ \case Nothing -> $internalError "withRemote" "invariant violated" Just u -> do u' <- updateTask u task return (Just u', ()) -- touchArrayData arr return c -- | Allocate a new device array to be associated with the given host-side array. -- This has similar behaviour to malloc in Data.Array.Accelerate.Array.Memory.Table -- but also will copy remote arrays back to main memory in order to make space. -- -- The third argument indicates that the array should be considered frozen. That -- is to say that the array contents will never change. In the event that the -- array has to be evicted from the remote memory, the copy already residing in -- host memory should be considered valid. -- -- If this function is called on an array that is already contained within the -- cache, this is a no-op. -- -- On return, 'True' indicates that we allocated some remote memory, and 'False' -- indicates that we did not need to. -- malloc :: forall a e m task. (PrimElt e a, RemoteMemory m, MonadIO m, Task task) => MemoryTable (RemotePtr m) task -> ArrayData e -> Bool -- ^ True if host array is frozen. -> Int -> m Bool -- ^ Was the array allocated successfully? malloc (MemoryTable mt ref weak_utbl) !ad !frozen !n = do ts <- liftIO $ getCPUTime key <- Basic.makeStableArray ad -- let status = if frozen then Clean else Dirty -- withMVar' ref $ \utbl -> do mu <- liftIO $ HT.lookup utbl key if isNothing mu then do weak_arr <- liftIO $ makeWeakArrayData ad ad (Just $ finalizer key weak_utbl) _ <- mallocWithUsage mt utbl ad (Used ts status 0 [] n weak_arr) return True else return False mallocWithUsage :: forall a e m task. (PrimElt e a, RemoteMemory m, MonadIO m, Task task) => Basic.MemoryTable (RemotePtr m) -> UT task -> ArrayData e -> Used task -> m (RemotePtr m a) mallocWithUsage !mt !utbl !ad !usage@(Used _ _ _ _ n _) = malloc' where malloc' = do mp <- Basic.malloc mt ad n :: m (Maybe (RemotePtr m a)) case mp of Nothing -> do success <- evictLRU utbl mt if success then malloc' else $internalError "malloc" "Remote memory exhausted" Just p -> liftIO $ do key <- Basic.makeStableArray ad HT.insert utbl key usage return p evictLRU :: forall m task. (RemoteMemory m, MonadIO m, Task task) => UT task -> Basic.MemoryTable (RemotePtr m) -> m Bool evictLRU !utbl !mt = trace "evictLRU/evicting-eldest-array" $ do mused <- liftIO $ HT.foldM eldest Nothing utbl case mused of Just (sa, Used ts status count tasks n weak_arr) -> do mad <- liftIO $ deRefWeak weak_arr case mad of Nothing -> liftIO $ do -- This can only happen if our eviction process was interrupted by -- garbage collection. In which case, even though we didn't actually -- evict anything, we should return true, as we know some remote -- memory is now free. -- -- Small caveat: Due to finalisers being delayed, it's a good idea -- to free the array here. Basic.freeStable (Proxy :: Proxy m) mt sa delete utbl sa message "evictLRU/Accelerate GC interrupted by GHC GC" Just arr -> do message ("evictLRU/evicting " ++ show sa) copyIfNecessary status n arr liftIO $ D.didEvictBytes (remoteBytes n weak_arr) liftIO $ Basic.freeStable (Proxy :: Proxy m) mt sa liftIO $ HT.insert utbl sa (Used ts Evicted count tasks n weak_arr) return True _ -> trace "evictLRU/All arrays in use, unable to evict" $ return False where -- Find the eldest, not currently in use, array. eldest :: (Maybe (StableArray, Used task)) -> (StableArray, Used task) -> IO (Maybe (StableArray, Used task)) eldest prev (sa, used@(Used ts status count tasks n weak_arr)) | count == 0 , evictable status = do tasks' <- cleanUses tasks HT.insert utbl sa (Used ts status count tasks' n weak_arr) case tasks' of [] | Just (_, Used ts' _ _ _ _ _) <- prev , ts < ts' -> return (Just (sa, used)) | Nothing <- prev -> return (Just (sa, used)) _ -> return prev eldest prev _ = return prev remoteBytes :: forall e a. PrimElt e a => Int -> Weak (ArrayData e) -> Int64 remoteBytes n _ = fromIntegral n * fromIntegral (sizeOf (undefined::a)) evictable :: Status -> Bool evictable Clean = True evictable Dirty = True evictable Unmanaged = False evictable Evicted = False copyIfNecessary :: PrimElt e a => Status -> Int -> ArrayData e -> m () copyIfNecessary Clean _ _ = return () copyIfNecessary Unmanaged _ _ = return () copyIfNecessary Evicted _ _ = $internalError "evictLRU" "Attempting to evict already evicted array" copyIfNecessary Dirty n ad = do mp <- liftIO $ Basic.lookup mt ad case mp of Nothing -> return () -- RCE: I think this branch is actually impossible. Just p -> peekRemote n p ad -- | Deallocate the device array associated with the given host-side array. -- Typically this should only be called in very specific circumstances. This -- operation is not thread-safe. -- free :: (RemoteMemory m, PrimElt a b) => proxy m -> MemoryTable (RemotePtr m) task -> ArrayData a -> IO () free proxy (MemoryTable !mt !ref _) !arr = withMVar' ref $ \utbl -> do key <- Basic.makeStableArray arr delete utbl key Basic.freeStable proxy mt key -- |Record an association between a host-side array and a remote memory area -- that was not allocated by accelerate. The remote memory will NOT be re-used -- once the host-side array is garbage collected. -- -- This typically only has use for backends that provide an FFI. -- insertUnmanaged :: (PrimElt e a, MonadIO m) => MemoryTable p task -> ArrayData e -> p a -> m () insertUnmanaged (MemoryTable mt ref weak_utbl) !arr !ptr = liftIO . withMVar ref $ \utbl -> do key <- Basic.makeStableArray arr () <- Basic.insertUnmanaged mt arr ptr ts <- getCPUTime weak_arr <- makeWeakArrayData arr arr (Just $ finalizer key weak_utbl) HT.insert utbl key (Used ts Unmanaged 0 [] 0 weak_arr) -- Removing entries -- ---------------- finalizer :: StableArray -> Weak (UseTable task) -> IO () finalizer !key !weak_utbl = do mref <- deRefWeak weak_utbl case mref of Nothing -> message "finalize cache/dead table" Just ref -> trace ("finalize cache: " ++ show key) $ withMVar' ref (`delete` key) delete :: UT task -> StableArray -> IO () delete = HT.delete -- |Initiate garbage collection and `free` any remote arrays that no longer -- have matching host-side equivalents. -- reclaim :: forall m task. (RemoteMemory m, MonadIO m) => MemoryTable (RemotePtr m) task -> m () reclaim (MemoryTable !mt _ _) = Basic.reclaim mt cache_finalizer :: UT task -> IO () cache_finalizer !tbl = trace "cache finaliser" $ HT.mapM_ (\(_,u) -> f u) tbl where f :: Used task -> IO () f (Used _ _ _ _ _ w) = finalize w -- Miscellaneous -- ------------- cleanUses :: Task task => [task] -> IO [task] cleanUses = filterM (fmap not . completed) incCount :: Used task -> Used task incCount (Used ts status count uses n weak_arr) = Used ts status (count + 1) uses n weak_arr isEvicted :: Used task -> Bool isEvicted (Used _ status _ _ _ _) = status == Evicted {-# INLINE withMVar' #-} withMVar' :: (MonadIO m, MonadMask m) => MVar a -> (a -> m b) -> m b withMVar' m f = mask $ \restore -> do a <- takeMVar' m b <- restore (f a) `onException` putMVar' m a putMVar' m a return b {-# INLINE putMVar' #-} putMVar' :: (MonadIO m, MonadMask m) => MVar a -> a -> m () putMVar' m a = liftIO (putMVar m a) {-# INLINE takeMVar' #-} takeMVar' :: (MonadIO m, MonadMask m) => MVar a -> m a takeMVar' m = liftIO (takeMVar m) -- Debug -- ----- {-# INLINE trace #-} trace :: MonadIO m => String -> m a -> m a trace msg next = message msg >> next {-# INLINE message #-} message :: MonadIO m => String -> m () message msg = liftIO $ D.traceIO D.dump_gc ("gc: " ++ msg)