{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Array.Remote.LRU (
MemoryTable, new, withRemote, malloc, free, insertUnmanaged, reclaim,
Task(..)
) where
import Control.Concurrent.MVar ( MVar, newMVar, takeMVar, putMVar, mkWeakMVar )
import Control.Monad ( filterM )
import Control.Monad.Catch
import Control.Monad.IO.Class ( MonadIO, liftIO )
import Data.Functor
import Data.Int ( Int64 )
import Data.Maybe ( isNothing )
import Data.Proxy
import Foreign.Storable ( sizeOf )
import System.CPUTime
import System.Mem.Weak ( Weak, deRefWeak, finalize )
import Prelude hiding ( lookup )
import qualified Data.HashTable.IO as HT
import Data.Array.Accelerate.Array.Data ( ArrayData, touchArrayData )
import Data.Array.Accelerate.Array.Remote.Class
import Data.Array.Accelerate.Array.Remote.Table ( StableArray, makeWeakArrayData )
import Data.Array.Accelerate.Error ( internalError )
import qualified Data.Array.Accelerate.Array.Remote.Table as Basic
import qualified Data.Array.Accelerate.Debug as D
data MemoryTable p task = MemoryTable {-# UNPACK #-} !(Basic.MemoryTable p)
{-# UNPACK #-} !(UseTable task)
{-# UNPACK #-} !(Weak (UseTable task))
type UT task = HT.BasicHashTable StableArray (Used task)
type UseTable task = MVar (UT task)
data Status = Clean
| Dirty
| Unmanaged
| Evicted
deriving Eq
type Timestamp = Integer
data Used task where
Used :: PrimElt e a
=> !Timestamp
-> !Status
-> {-# UNPACK #-} !Int
-> ![task]
-> {-# UNPACK #-} !Int
-> {-# UNPACK #-} !(Weak (ArrayData e))
-> Used task
class Task task where
completed :: task -> IO Bool
instance Task () where
completed () = return True
new :: (forall a. ptr a -> IO ()) -> IO (MemoryTable ptr task)
new release = do
mt <- Basic.new release
utbl <- HT.new
ref <- newMVar utbl
weak_utbl <- mkWeakMVar ref (cache_finalizer utbl)
return $! MemoryTable mt ref weak_utbl
withRemote
:: forall task m a b c. (PrimElt a b, Task task, RemoteMemory m, MonadIO m, Functor m)
=> MemoryTable (RemotePtr m) task
-> ArrayData a
-> (RemotePtr m b -> m (task, c))
-> m (Maybe c)
withRemote (MemoryTable !mt !ref _) !arr run = do
key <- Basic.makeStableArray arr
mp <- withMVar' ref $ \utbl -> do
mu <- liftIO $ HT.lookup utbl key
case mu of
Nothing -> do message ("withRemote/array has never been malloc'd: " ++ show key)
return Nothing
Just u -> do
mp <- liftIO $ do HT.insert utbl key (incCount u)
Basic.lookup mt arr
case mp of
Nothing | isEvicted u -> Just <$> copy utbl (incCount u)
Just p -> return (Just p)
_ -> do message ("lost array " ++ show key)
$internalError "withRemote" "non-evicted array has been lost"
case mp of
Just p -> Just <$> run' p
Nothing -> return Nothing
where
updateTask :: Maybe (Used task) -> task -> IO (Used task)
updateTask mu task = do
ts <- getCPUTime
case mu of
Nothing -> $internalError "withRemote" "Invariant violated"
Just (Used _ status count tasks n weak_arr) -> do
tasks' <- cleanUses tasks
return (Used ts status (count - 1) (task : tasks') n weak_arr)
copy :: UT task -> Used task -> m (RemotePtr m b)
copy utbl (Used ts _ count tasks n weak_arr) = do
message "withRemote/reuploading-evicted-array"
p <- mallocWithUsage mt utbl arr (Used ts Clean count tasks n weak_arr)
pokeRemote n p arr
return p
run' :: RemotePtr m b -> m c
run' p = do
key <- Basic.makeStableArray arr
message ("withRemote/using: " ++ show key)
(task, c) <- run p
withMVar' ref $ \utbl -> liftIO $ do
mu <- HT.lookup utbl key
u <- updateTask mu task
HT.insert utbl key u
liftIO $ touchArrayData arr
return c
malloc :: forall a e m task. (PrimElt e a, RemoteMemory m, MonadIO m, Task task)
=> MemoryTable (RemotePtr m) task
-> ArrayData e
-> Bool
-> Int
-> m Bool
malloc (MemoryTable mt ref weak_utbl) !ad !frozen !n = do
ts <- liftIO $ getCPUTime
key <- Basic.makeStableArray ad
let status = if frozen
then Clean
else Dirty
withMVar' ref $ \utbl -> do
mu <- liftIO $ HT.lookup utbl key
if isNothing mu
then do
weak_arr <- liftIO $ makeWeakArrayData ad ad (Just $ finalizer key weak_utbl)
_ <- mallocWithUsage mt utbl ad (Used ts status 0 [] n weak_arr)
return True
else
return False
mallocWithUsage
:: forall a e m task. (PrimElt e a, RemoteMemory m, MonadIO m, Task task)
=> Basic.MemoryTable (RemotePtr m)
-> UT task
-> ArrayData e
-> Used task
-> m (RemotePtr m a)
mallocWithUsage !mt utbl !ad !usage@(Used _ _ _ _ n _) = malloc'
where
malloc' = do
mp <- Basic.malloc mt ad n :: m (Maybe (RemotePtr m a))
case mp of
Nothing -> do
success <- evictLRU utbl mt
if success then malloc'
else $internalError "malloc" "Remote memory exhausted"
Just p -> liftIO $ do
key <- Basic.makeStableArray ad
HT.insert utbl key usage
return p
evictLRU :: forall m task. (RemoteMemory m, MonadIO m, Task task)
=> UT task
-> Basic.MemoryTable (RemotePtr m)
-> m Bool
evictLRU utbl mt = trace "evictLRU/evicting-eldest-array" $ do
mused <- liftIO $ HT.foldM eldest Nothing utbl
case mused of
Just (sa, Used ts status count tasks n weak_arr) -> do
mad <- liftIO $ deRefWeak weak_arr
case mad of
Nothing -> liftIO $ do
Basic.freeStable (Proxy :: Proxy m) mt sa
delete utbl sa
message "evictLRU/Accelerate GC interrupted by GHC GC"
Just arr -> do
message ("evictLRU/evicting " ++ show sa)
copyIfNecessary status n arr
liftIO $ D.didEvictBytes (remoteBytes n weak_arr)
liftIO $ Basic.freeStable (Proxy :: Proxy m) mt sa
liftIO $ HT.insert utbl sa (Used ts Evicted count tasks n weak_arr)
return True
_ -> trace "evictLRU/All arrays in use, unable to evict" $ return False
where
eldest :: (Maybe (StableArray, Used task)) -> (StableArray, Used task) -> IO (Maybe (StableArray, Used task))
eldest prev (sa, used@(Used ts status count tasks n weak_arr)) | count == 0
, evictable status = do
tasks' <- cleanUses tasks
HT.insert utbl sa (Used ts status count tasks' n weak_arr)
case tasks' of
[] | Just (_, Used ts' _ _ _ _ _) <- prev
, ts < ts' -> return (Just (sa, used))
| Nothing <- prev -> return (Just (sa, used))
_ -> return prev
eldest prev _ = return prev
remoteBytes :: forall e a. PrimElt e a => Int -> Weak (ArrayData e) -> Int64
remoteBytes n _ = fromIntegral n * fromIntegral (sizeOf (undefined::a))
evictable :: Status -> Bool
evictable Clean = True
evictable Dirty = True
evictable Unmanaged = False
evictable Evicted = False
copyIfNecessary :: PrimElt e a => Status -> Int -> ArrayData e -> m ()
copyIfNecessary Clean _ _ = return ()
copyIfNecessary Unmanaged _ _ = return ()
copyIfNecessary Evicted _ _ = $internalError "evictLRU" "Attempting to evict already evicted array"
copyIfNecessary Dirty n ad = do
mp <- liftIO $ Basic.lookup mt ad
case mp of
Nothing -> return ()
Just p -> peekRemote n p ad
free :: (RemoteMemory m, PrimElt a b)
=> proxy m
-> MemoryTable (RemotePtr m) task
-> ArrayData a
-> IO ()
free proxy (MemoryTable !mt !ref _) !arr = withMVar' ref $ \utbl -> do
key <- Basic.makeStableArray arr
delete utbl key
Basic.freeStable proxy mt key
insertUnmanaged
:: (PrimElt e a, MonadIO m)
=> MemoryTable p task
-> ArrayData e
-> p a
-> m ()
insertUnmanaged (MemoryTable mt ref weak_utbl) !arr !ptr = liftIO . withMVar' ref $ \utbl -> do
key <- Basic.makeStableArray arr
Basic.insertUnmanaged mt arr ptr
ts <- getCPUTime
weak_arr <- makeWeakArrayData arr arr (Just $ finalizer key weak_utbl)
HT.insert utbl key (Used ts Unmanaged 0 [] 0 weak_arr)
finalizer :: StableArray -> Weak (UseTable task) -> IO ()
finalizer !key !weak_utbl = do
mref <- deRefWeak weak_utbl
case mref of
Nothing -> message "finalize cache/dead table"
Just ref -> trace ("finalize cache: " ++ show key) $ withMVar' ref (`delete` key)
delete :: UT task -> StableArray -> IO ()
delete utbl key = do
mu <- HT.lookup utbl key
case mu of
Nothing -> return ()
Just _ -> HT.delete utbl key
reclaim
:: forall m task. (RemoteMemory m, MonadIO m)
=> MemoryTable (RemotePtr m) task
-> m ()
reclaim (MemoryTable !mt _ _) = Basic.reclaim mt
cache_finalizer :: UT task -> IO ()
cache_finalizer !tbl
= trace "cache finaliser"
$ HT.mapM_ (\(_,u) -> f u)
tbl
where
f :: Used task -> IO ()
f (Used _ _ _ _ _ w) = finalize w
cleanUses :: Task task => [task] -> IO [task]
cleanUses = filterM (fmap not . completed)
incCount :: Used task -> Used task
incCount (Used ts status count uses n weak_arr) = Used ts status (count + 1) uses n weak_arr
isEvicted :: Used task -> Bool
isEvicted (Used _ status _ _ _ _) = status == Evicted
withMVar' :: (MonadIO m, MonadMask m) => MVar a -> (a -> m b) -> m b
withMVar' m f = mask $ \restore -> do
a <- liftIO $ takeMVar m
b <- restore (f a) `onException` (liftIO $ putMVar m a)
liftIO $ putMVar m a
return b
{-# INLINE trace #-}
trace :: MonadIO m => String -> m a -> m a
trace msg next = message msg >> next
{-# INLINE message #-}
message :: MonadIO m => String -> m ()
message msg = liftIO $ D.traceIO D.dump_gc ("gc: " ++ msg)