{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
module Data.Array.Accelerate.Array.Remote.Nursery (
Nursery(..), NRS, new, lookup, insert, cleanup, size
) where
import Data.Array.Accelerate.Error
import qualified Data.Array.Accelerate.Debug as Debug
import Control.Concurrent.MVar
import Data.Int
import Data.Sequence ( Seq )
import Data.Word
import System.Mem.Weak ( Weak )
import Prelude hiding ( lookup )
import qualified Data.HashTable.IO as HT
import qualified Data.Sequence as Seq
import qualified Data.Traversable as Seq
type HashTable key val = HT.CuckooHashTable key val
type NRS ptr = MVar ( HashTable Int (Seq (ptr Word8)) )
data Nursery ptr = Nursery {-# UNPACK #-} !(NRS ptr)
{-# UNPACK #-} !(Weak (NRS ptr))
{-# INLINEABLE new #-}
new :: (ptr Word8 -> IO ()) -> IO (Nursery ptr)
new :: (ptr Word8 -> IO ()) -> IO (Nursery ptr)
new ptr Word8 -> IO ()
delete = do
String -> IO ()
message String
"initialise nursery"
HashTable RealWorld Int (Seq (ptr Word8))
nrs <- IO (HashTable RealWorld Int (Seq (ptr Word8)))
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HT.new
MVar (HashTable RealWorld Int (Seq (ptr Word8)))
ref <- HashTable RealWorld Int (Seq (ptr Word8))
-> IO (MVar (HashTable RealWorld Int (Seq (ptr Word8))))
forall a. a -> IO (MVar a)
newMVar HashTable RealWorld Int (Seq (ptr Word8))
nrs
Weak (MVar (HashTable RealWorld Int (Seq (ptr Word8))))
weak <- MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> IO ()
-> IO (Weak (MVar (HashTable RealWorld Int (Seq (ptr Word8)))))
forall a. MVar a -> IO () -> IO (Weak (MVar a))
mkWeakMVar MVar (HashTable RealWorld Int (Seq (ptr Word8)))
ref ((ptr Word8 -> IO ()) -> NRS ptr -> IO ()
forall (ptr :: * -> *). (ptr Word8 -> IO ()) -> NRS ptr -> IO ()
cleanup ptr Word8 -> IO ()
delete MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref)
Nursery ptr -> IO (Nursery ptr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Nursery ptr -> IO (Nursery ptr))
-> Nursery ptr -> IO (Nursery ptr)
forall a b. (a -> b) -> a -> b
$! NRS ptr -> Weak (NRS ptr) -> Nursery ptr
forall (ptr :: * -> *). NRS ptr -> Weak (NRS ptr) -> Nursery ptr
Nursery MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref Weak (MVar (HashTable RealWorld Int (Seq (ptr Word8))))
Weak (NRS ptr)
weak
{-# INLINEABLE lookup #-}
lookup :: HasCallStack => Int -> Nursery ptr -> IO (Maybe (ptr Word8))
lookup :: Int -> Nursery ptr -> IO (Maybe (ptr Word8))
lookup !Int
key (Nursery !NRS ptr
ref !Weak (NRS ptr)
_) =
MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8))
-> IO (Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8))
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref ((HashTable RealWorld Int (Seq (ptr Word8))
-> IO (Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8))
-> IO (Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8))
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld Int (Seq (ptr Word8))
nrs ->
IOHashTable HashTable Int (Seq (ptr Word8))
-> Int
-> (Maybe (Seq (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8))
forall (h :: * -> * -> * -> *) k v a.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> (Maybe v -> IO (Maybe v, a)) -> IO a
HT.mutateIO HashTable RealWorld Int (Seq (ptr Word8))
IOHashTable HashTable Int (Seq (ptr Word8))
nrs Int
key ((Maybe (Seq (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8)))
-> (Maybe (Seq (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8))
forall a b. (a -> b) -> a -> b
$ \case
Maybe (Seq (ptr Word8))
Nothing -> (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Seq (ptr Word8))
forall a. Maybe a
Nothing, Maybe (ptr Word8)
forall a. Maybe a
Nothing)
Just Seq (ptr Word8)
r ->
case Seq (ptr Word8) -> ViewL (ptr Word8)
forall a. Seq a -> ViewL a
Seq.viewl Seq (ptr Word8)
r of
ptr Word8
v Seq.:< Seq (ptr Word8)
vs -> do
Int64 -> IO ()
Debug.decreaseCurrentBytesNursery (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
key)
if Seq (ptr Word8) -> Bool
forall a. Seq a -> Bool
Seq.null Seq (ptr Word8)
vs
then (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Seq (ptr Word8))
forall a. Maybe a
Nothing, ptr Word8 -> Maybe (ptr Word8)
forall a. a -> Maybe a
Just ptr Word8
v)
else (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (Seq (ptr Word8) -> Maybe (Seq (ptr Word8))
forall a. a -> Maybe a
Just Seq (ptr Word8)
vs, ptr Word8 -> Maybe (ptr Word8)
forall a. a -> Maybe a
Just ptr Word8
v)
ViewL (ptr Word8)
Seq.EmptyL -> String -> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
forall a. HasCallStack => String -> a
internalError String
"expected non-empty sequence"
{-# INLINEABLE insert #-}
insert :: Int -> ptr Word8 -> Nursery ptr -> IO ()
insert :: Int -> ptr Word8 -> Nursery ptr -> IO ()
insert !Int
key !ptr Word8
val (Nursery !NRS ptr
ref Weak (NRS ptr)
_) =
MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8)) -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref ((HashTable RealWorld Int (Seq (ptr Word8)) -> IO ()) -> IO ())
-> (HashTable RealWorld Int (Seq (ptr Word8)) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld Int (Seq (ptr Word8))
nrs -> do
Int64 -> IO ()
Debug.increaseCurrentBytesRemote (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
key)
IOHashTable HashTable Int (Seq (ptr Word8))
-> Int
-> (Maybe (Seq (ptr Word8)) -> (Maybe (Seq (ptr Word8)), ()))
-> IO ()
forall (h :: * -> * -> * -> *) k v a.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> (Maybe v -> (Maybe v, a)) -> IO a
HT.mutate HashTable RealWorld Int (Seq (ptr Word8))
IOHashTable HashTable Int (Seq (ptr Word8))
nrs Int
key ((Maybe (Seq (ptr Word8)) -> (Maybe (Seq (ptr Word8)), ()))
-> IO ())
-> (Maybe (Seq (ptr Word8)) -> (Maybe (Seq (ptr Word8)), ()))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \case
Maybe (Seq (ptr Word8))
Nothing -> (Seq (ptr Word8) -> Maybe (Seq (ptr Word8))
forall a. a -> Maybe a
Just (ptr Word8 -> Seq (ptr Word8)
forall a. a -> Seq a
Seq.singleton ptr Word8
val), ())
Just Seq (ptr Word8)
vs -> (Seq (ptr Word8) -> Maybe (Seq (ptr Word8))
forall a. a -> Maybe a
Just (Seq (ptr Word8)
vs Seq (ptr Word8) -> ptr Word8 -> Seq (ptr Word8)
forall a. Seq a -> a -> Seq a
Seq.|> ptr Word8
val), ())
{-# INLINEABLE cleanup #-}
cleanup :: (ptr Word8 -> IO ()) -> NRS ptr -> IO ()
cleanup :: (ptr Word8 -> IO ()) -> NRS ptr -> IO ()
cleanup ptr Word8 -> IO ()
delete !NRS ptr
ref = do
String -> IO ()
message String
"nursery cleanup"
MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8))
-> IO (HashTable RealWorld Int (Seq (ptr Word8))))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref ((HashTable RealWorld Int (Seq (ptr Word8))
-> IO (HashTable RealWorld Int (Seq (ptr Word8))))
-> IO ())
-> (HashTable RealWorld Int (Seq (ptr Word8))
-> IO (HashTable RealWorld Int (Seq (ptr Word8))))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld Int (Seq (ptr Word8))
nrs -> do
((Int, Seq (ptr Word8)) -> IO (Seq ()))
-> IOHashTable HashTable Int (Seq (ptr Word8)) -> IO ()
forall (h :: * -> * -> * -> *) k v a.
HashTable h =>
((k, v) -> IO a) -> IOHashTable h k v -> IO ()
HT.mapM_ ((ptr Word8 -> IO ()) -> Seq (ptr Word8) -> IO (Seq ())
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
Seq.mapM ptr Word8 -> IO ()
delete (Seq (ptr Word8) -> IO (Seq ()))
-> ((Int, Seq (ptr Word8)) -> Seq (ptr Word8))
-> (Int, Seq (ptr Word8))
-> IO (Seq ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Seq (ptr Word8)) -> Seq (ptr Word8)
forall a b. (a, b) -> b
snd) HashTable RealWorld Int (Seq (ptr Word8))
IOHashTable HashTable Int (Seq (ptr Word8))
nrs
Int64 -> IO ()
Debug.setCurrentBytesNursery Int64
0
HashTable RealWorld Int (Seq (ptr Word8))
nrs' <- IO (HashTable RealWorld Int (Seq (ptr Word8)))
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HT.new
HashTable RealWorld Int (Seq (ptr Word8))
-> IO (HashTable RealWorld Int (Seq (ptr Word8)))
forall (m :: * -> *) a. Monad m => a -> m a
return HashTable RealWorld Int (Seq (ptr Word8))
nrs'
{-# INLINEABLE size #-}
size :: Nursery ptr -> IO Int64
size :: Nursery ptr -> IO Int64
size (Nursery NRS ptr
ref Weak (NRS ptr)
_)
= MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8)) -> IO Int64)
-> IO Int64
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref
((HashTable RealWorld Int (Seq (ptr Word8)) -> IO Int64)
-> IO Int64)
-> (HashTable RealWorld Int (Seq (ptr Word8)) -> IO Int64)
-> IO Int64
forall a b. (a -> b) -> a -> b
$ (Int64 -> (Int, Seq (ptr Word8)) -> IO Int64)
-> Int64 -> IOHashTable HashTable Int (Seq (ptr Word8)) -> IO Int64
forall (h :: * -> * -> * -> *) a k v.
HashTable h =>
(a -> (k, v) -> IO a) -> a -> IOHashTable h k v -> IO a
HT.foldM (\Int64
s (Int
k,Seq (ptr Word8)
v) -> Int64 -> IO Int64
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64 -> IO Int64) -> Int64 -> IO Int64
forall a b. (a -> b) -> a -> b
$ Int64
s Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Seq (ptr Word8) -> Int
forall a. Seq a -> Int
Seq.length Seq (ptr Word8)
v))) Int64
0
{-# INLINE message #-}
message :: String -> IO ()
message :: String -> IO ()
message String
msg = Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_gc (String
"gc: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg)