module Data.Array.Accelerate.Array.Remote.Table (
MemoryTable, new, lookup, malloc, free, freeStable, insertUnmanaged, reclaim,
StableArray, makeStableArray,
makeWeakArrayData
) where
import Control.Concurrent ( yield )
import Control.Concurrent.MVar ( MVar, newMVar, withMVar, mkWeakMVar )
import Control.Concurrent.Unique ( Unique )
import Control.Monad.IO.Class ( MonadIO, liftIO )
import Data.Functor
import Data.Hashable ( hash, Hashable )
import Data.Maybe ( isJust )
import Data.Proxy
import Data.Typeable ( Typeable, gcast )
import Data.Word
import Foreign.Storable ( sizeOf )
import System.Mem ( performGC )
import System.Mem.Weak ( Weak, deRefWeak )
import Prelude hiding ( lookup, id )
import qualified Data.HashTable.IO as HT
import GHC.Exts ( Ptr(..) )
import Data.Array.Accelerate.Error ( internalError )
import Data.Array.Accelerate.Array.Unique ( UniqueArray(..) )
import Data.Array.Accelerate.Array.Data ( ArrayData, GArrayData(..),
ArrayPtrs, ArrayElt, arrayElt, ArrayEltR(..) )
import Data.Array.Accelerate.Array.Remote.Class
import Data.Array.Accelerate.Array.Remote.Nursery ( Nursery(..) )
import Data.Array.Accelerate.Lifetime
import qualified Data.Array.Accelerate.Array.Remote.Nursery as N
import qualified Data.Array.Accelerate.Debug as D
type HashTable key val = HT.BasicHashTable key val
type MT p = MVar ( HashTable StableArray (RemoteArray p) )
data MemoryTable p = MemoryTable !(MT p)
!(Weak (MT p))
!(Nursery p)
(p Word8 -> IO ())
data RemoteArray p where
RemoteArray :: Typeable e
=> !(Weak ())
-> !(p e)
-> !Int
-> RemoteArray p
newtype StableArray = StableArray Unique
deriving (Eq, Hashable)
instance Show StableArray where
show (StableArray u) = show (hash u)
new :: (forall a. ptr a -> IO ()) -> IO (MemoryTable ptr)
new release = do
message "initialise memory table"
tbl <- HT.new
ref <- newMVar tbl
nrs <- N.new release
weak <- mkWeakMVar ref (return ())
return $! MemoryTable ref weak nrs release
lookup
:: (PrimElt a b)
=> MemoryTable p
-> ArrayData a
-> IO (Maybe (p b))
lookup (MemoryTable !ref _ _ _) !arr = do
sa <- makeStableArray arr
mw <- withMVar ref (`HT.lookup` sa)
case mw of
Nothing -> trace ("lookup/not found: " ++ show sa) $ return Nothing
Just (RemoteArray w p _) -> do
mv <- deRefWeak w
case mv of
Just _ | Just p' <- gcast p -> trace ("lookup/found: " ++ show sa) $ return (Just p')
| otherwise -> $internalError "lookup" $ "type mismatch"
Nothing ->
makeStableArray arr >>= \x -> $internalError "lookup" $ "dead weak pair: " ++ show x
malloc :: forall a b m. (PrimElt a b, RemoteMemory m, MonadIO m)
=> MemoryTable (RemotePtr m)
-> ArrayData a
-> Int
-> m (Maybe (RemotePtr m b))
malloc mt@(MemoryTable _ _ !nursery _) !ad !n = do
chunk <- remoteAllocationSize
let
multiple x f = (x + (f1)) `div` f
bytes = chunk * multiple (n * sizeOf (undefined::b)) chunk
message ("malloc: " ++ showBytes bytes)
mp <-
fmap (castRemotePtr (Proxy :: Proxy m))
<$> attempt "malloc/nursery" (liftIO $ N.lookup bytes nursery)
`orElse`
attempt "malloc/new" (mallocRemote bytes)
`orElse` do message "malloc/remote-malloc-failed (cleaning)"
clean mt
liftIO $ N.lookup bytes nursery
`orElse` do message "malloc/remote-malloc-failed (purging)"
purge mt
mallocRemote bytes
`orElse` do message "malloc/remote-malloc-failed (non-recoverable)"
return Nothing
case mp of
Nothing -> return Nothing
Just p' -> do
insert mt ad p' bytes
return (Just p')
where
orElse :: m (Maybe x) -> m (Maybe x) -> m (Maybe x)
orElse ra rb = do
ma <- ra
case ma of
Nothing -> rb
Just a -> return (Just a)
attempt :: String -> m (Maybe x) -> m (Maybe x)
attempt msg next = do
ma <- next
case ma of
Nothing -> return Nothing
Just a -> trace msg (return (Just a))
free :: (RemoteMemory m, PrimElt a b)
=> proxy m
-> MemoryTable (RemotePtr m)
-> ArrayData a
-> IO ()
free proxy mt !arr = do
sa <- makeStableArray arr
freeStable proxy mt sa
freeStable
:: RemoteMemory m
=> proxy m
-> MemoryTable (RemotePtr m)
-> StableArray
-> IO ()
freeStable proxy (MemoryTable !ref _ !nrs _) !sa =
withMVar ref $ \mt -> do
mw <- mt `HT.lookup` sa
case mw of
Nothing -> message ("free/already-removed: " ++ show sa)
Just (RemoteArray _ !p !bytes) -> do
message ("free/evict: " ++ show sa ++ " of " ++ showBytes bytes)
N.insert bytes (castRemotePtr proxy p) nrs
D.decreaseCurrentBytesRemote (fromIntegral bytes)
mt `HT.delete` sa
insert :: forall m a b. (PrimElt a b, RemoteMemory m, MonadIO m)
=> MemoryTable (RemotePtr m)
-> ArrayData a
-> RemotePtr m b
-> Int
-> m ()
insert mt@(MemoryTable !ref _ _ _) !arr !ptr !bytes = do
key <- makeStableArray arr
weak <- liftIO $ makeWeakArrayData arr () (Just $ freeStable (Proxy :: Proxy m) mt key)
message $ "insert: " ++ show key
liftIO $ D.increaseCurrentBytesRemote (fromIntegral bytes)
liftIO $ withMVar ref $ \tbl -> HT.insert tbl key (RemoteArray weak ptr bytes)
insertUnmanaged
:: (PrimElt a b, MonadIO m)
=> MemoryTable p
-> ArrayData a
-> p b
-> m ()
insertUnmanaged (MemoryTable !ref !weak_ref _ _) !arr !ptr = do
key <- makeStableArray arr
weak <- liftIO $ makeWeakArrayData arr () (Just $ remoteFinalizer weak_ref key)
message $ "insertUnmanaged: " ++ show key
liftIO $ withMVar ref $ \tbl -> HT.insert tbl key (RemoteArray weak ptr 0)
clean :: forall m. (RemoteMemory m, MonadIO m) => MemoryTable (RemotePtr m) -> m ()
clean mt@(MemoryTable _ weak_ref nrs _) = management "clean" nrs . liftIO $ do
D.didRemoteGC
performGC
yield
mr <- deRefWeak weak_ref
case mr of
Nothing -> return ()
Just ref -> do
rs <- withMVar ref $ HT.foldM removable []
mapM_ (freeStable (Proxy :: Proxy m) mt) rs
where
removable rs (sa, RemoteArray w _ _) = do
alive <- isJust <$> deRefWeak w
if alive
then return rs
else return (sa:rs)
purge :: (RemoteMemory m, MonadIO m) => MemoryTable (RemotePtr m) -> m ()
purge (MemoryTable _ _ nursery@(Nursery nrs _) release)
= management "purge" nursery
$ liftIO (N.cleanup release nrs)
reclaim :: forall m. (RemoteMemory m, MonadIO m) => MemoryTable (RemotePtr m) -> m ()
reclaim mt = clean mt >> purge mt
remoteFinalizer :: Weak (MT p) -> StableArray -> IO ()
remoteFinalizer !weak_ref !key = do
mr <- deRefWeak weak_ref
case mr of
Nothing -> message ("finalise/dead table: " ++ show key)
Just ref -> trace ("finalise: " ++ show key) $ withMVar ref (`HT.delete` key)
makeStableArray
:: (MonadIO m, Typeable a, Typeable e, ArrayPtrs a ~ Ptr e, ArrayElt a)
=> ArrayData a
-> m StableArray
makeStableArray !ad = return $! StableArray (id arrayElt ad)
where
id :: ArrayEltR e -> ArrayData e -> Unique
id ArrayEltRint (AD_Int ua) = uniqueArrayId ua
id ArrayEltRint8 (AD_Int8 ua) = uniqueArrayId ua
id ArrayEltRint16 (AD_Int16 ua) = uniqueArrayId ua
id ArrayEltRint32 (AD_Int32 ua) = uniqueArrayId ua
id ArrayEltRint64 (AD_Int64 ua) = uniqueArrayId ua
id ArrayEltRword (AD_Word ua) = uniqueArrayId ua
id ArrayEltRword8 (AD_Word8 ua) = uniqueArrayId ua
id ArrayEltRword16 (AD_Word16 ua) = uniqueArrayId ua
id ArrayEltRword32 (AD_Word32 ua) = uniqueArrayId ua
id ArrayEltRword64 (AD_Word64 ua) = uniqueArrayId ua
id ArrayEltRcshort (AD_CShort ua) = uniqueArrayId ua
id ArrayEltRcushort (AD_CUShort ua) = uniqueArrayId ua
id ArrayEltRcint (AD_CInt ua) = uniqueArrayId ua
id ArrayEltRcuint (AD_CUInt ua) = uniqueArrayId ua
id ArrayEltRclong (AD_CLong ua) = uniqueArrayId ua
id ArrayEltRculong (AD_CULong ua) = uniqueArrayId ua
id ArrayEltRcllong (AD_CLLong ua) = uniqueArrayId ua
id ArrayEltRcullong (AD_CULLong ua) = uniqueArrayId ua
id ArrayEltRfloat (AD_Float ua) = uniqueArrayId ua
id ArrayEltRdouble (AD_Double ua) = uniqueArrayId ua
id ArrayEltRcfloat (AD_CFloat ua) = uniqueArrayId ua
id ArrayEltRcdouble (AD_CDouble ua) = uniqueArrayId ua
id ArrayEltRbool (AD_Bool ua) = uniqueArrayId ua
id ArrayEltRchar (AD_Char ua) = uniqueArrayId ua
id ArrayEltRcchar (AD_CChar ua) = uniqueArrayId ua
id ArrayEltRcschar (AD_CSChar ua) = uniqueArrayId ua
id ArrayEltRcuchar (AD_CUChar ua) = uniqueArrayId ua
id _ _ = error "I do have a cause, though. It is obscenity. I'm for it."
makeWeakArrayData
:: forall a e c. (ArrayElt e, ArrayPtrs e ~ Ptr a)
=> ArrayData e
-> c
-> Maybe (IO ())
-> IO (Weak c)
makeWeakArrayData !ad !c !mf = mw arrayElt ad
where
mw :: ArrayEltR e -> ArrayData e -> IO (Weak c)
mw ArrayEltRint (AD_Int ua) = mkWeak' ua
mw ArrayEltRint8 (AD_Int8 ua) = mkWeak' ua
mw ArrayEltRint16 (AD_Int16 ua) = mkWeak' ua
mw ArrayEltRint32 (AD_Int32 ua) = mkWeak' ua
mw ArrayEltRint64 (AD_Int64 ua) = mkWeak' ua
mw ArrayEltRword (AD_Word ua) = mkWeak' ua
mw ArrayEltRword8 (AD_Word8 ua) = mkWeak' ua
mw ArrayEltRword16 (AD_Word16 ua) = mkWeak' ua
mw ArrayEltRword32 (AD_Word32 ua) = mkWeak' ua
mw ArrayEltRword64 (AD_Word64 ua) = mkWeak' ua
mw ArrayEltRcshort (AD_CShort ua) = mkWeak' ua
mw ArrayEltRcushort (AD_CUShort ua) = mkWeak' ua
mw ArrayEltRcint (AD_CInt ua) = mkWeak' ua
mw ArrayEltRcuint (AD_CUInt ua) = mkWeak' ua
mw ArrayEltRclong (AD_CLong ua) = mkWeak' ua
mw ArrayEltRculong (AD_CULong ua) = mkWeak' ua
mw ArrayEltRcllong (AD_CLLong ua) = mkWeak' ua
mw ArrayEltRcullong (AD_CULLong ua) = mkWeak' ua
mw ArrayEltRfloat (AD_Float ua) = mkWeak' ua
mw ArrayEltRdouble (AD_Double ua) = mkWeak' ua
mw ArrayEltRcfloat (AD_CFloat ua) = mkWeak' ua
mw ArrayEltRcdouble (AD_CDouble ua) = mkWeak' ua
mw ArrayEltRbool (AD_Bool ua) = mkWeak' ua
mw ArrayEltRchar (AD_Char ua) = mkWeak' ua
mw ArrayEltRcchar (AD_CChar ua) = mkWeak' ua
mw ArrayEltRcschar (AD_CSChar ua) = mkWeak' ua
mw ArrayEltRcuchar (AD_CUChar ua) = mkWeak' ua
#if __GLASGOW_HASKELL__ < 800
mw _ _ = error "Base eight is just like base ten really - if you're missing two fingers."
#endif
mkWeak' :: UniqueArray a -> IO (Weak c)
mkWeak' !ua = do
let !uad = uniqueArrayData ua
case mf of
Nothing -> return ()
Just f -> addFinalizer uad f
mkWeak uad c
showBytes :: Integral n => n -> String
showBytes x = D.showFFloatSIBase (Just 0) 1024 (fromIntegral x :: Double) "B"
trace :: MonadIO m => String -> m a -> m a
trace msg next = message msg >> next
message :: MonadIO m => String -> m ()
message msg = liftIO $ D.traceIO D.dump_gc ("gc: " ++ msg)
management :: (RemoteMemory m, MonadIO m) => String -> Nursery p -> m a -> m a
management msg nrs next = do
before <- availableRemoteMem
before_nrs <- liftIO $ N.size nrs
total <- totalRemoteMem
r <- next
D.when D.dump_gc $ do
after <- availableRemoteMem
after_nrs <- liftIO $ N.size nrs
message $ msg ++ " (freed: " ++ showBytes (after before)
++ ", stashed: " ++ showBytes (before_nrs after_nrs)
++ ", remaining: " ++ showBytes after
++ " of " ++ showBytes total ++ ")"
return r