{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase   #-}
-- |
-- Module      : Data.Array.Accelerate.Array.Remote.Nursery
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Array.Remote.Nursery (

  Nursery(..), NRS, new, lookup, insert, cleanup, size

) where

-- friends
import Data.Array.Accelerate.Error
import qualified Data.Array.Accelerate.Debug                    as Debug

-- libraries
import Control.Concurrent.MVar
import Data.Int
import Data.Sequence                                            ( Seq )
import Data.Word
import System.Mem.Weak                                          ( Weak )
import Prelude                                                  hiding ( lookup )
import qualified Data.HashTable.IO                              as HT
import qualified Data.Sequence                                  as Seq
import qualified Data.Traversable                               as Seq


-- The nursery is a place to store remote memory arrays that are no longer
-- needed. Often it is quicker to reuse an existing array, rather than call out
-- to the external API to allocate fresh memory.
--
-- The nursery is wrapped in an MVar so that several threads may safely access
-- it concurrently.
--
type HashTable key val  = HT.CuckooHashTable key val
type NRS ptr            = MVar ( HashTable Int (Seq (ptr Word8)) )  -- #bytes -> available memory
data Nursery ptr        = Nursery {-# UNPACK #-} !(NRS ptr)
                                  {-# UNPACK #-} !(Weak (NRS ptr))


-- | Create a fresh nursery.
--
-- When the nursery is garbage collected, the provided function will be run on
-- each value to free the retained memory.
--
{-# INLINEABLE new #-}
new :: (ptr Word8 -> IO ()) -> IO (Nursery ptr)
new :: (ptr Word8 -> IO ()) -> IO (Nursery ptr)
new ptr Word8 -> IO ()
delete = do
  String -> IO ()
message String
"initialise nursery"
  HashTable RealWorld Int (Seq (ptr Word8))
nrs    <- IO (HashTable RealWorld Int (Seq (ptr Word8)))
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HT.new
  MVar (HashTable RealWorld Int (Seq (ptr Word8)))
ref    <- HashTable RealWorld Int (Seq (ptr Word8))
-> IO (MVar (HashTable RealWorld Int (Seq (ptr Word8))))
forall a. a -> IO (MVar a)
newMVar HashTable RealWorld Int (Seq (ptr Word8))
nrs
  Weak (MVar (HashTable RealWorld Int (Seq (ptr Word8))))
weak   <- MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> IO ()
-> IO (Weak (MVar (HashTable RealWorld Int (Seq (ptr Word8)))))
forall a. MVar a -> IO () -> IO (Weak (MVar a))
mkWeakMVar MVar (HashTable RealWorld Int (Seq (ptr Word8)))
ref ((ptr Word8 -> IO ()) -> NRS ptr -> IO ()
forall (ptr :: * -> *). (ptr Word8 -> IO ()) -> NRS ptr -> IO ()
cleanup ptr Word8 -> IO ()
delete MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref)
  Nursery ptr -> IO (Nursery ptr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Nursery ptr -> IO (Nursery ptr))
-> Nursery ptr -> IO (Nursery ptr)
forall a b. (a -> b) -> a -> b
$! NRS ptr -> Weak (NRS ptr) -> Nursery ptr
forall (ptr :: * -> *). NRS ptr -> Weak (NRS ptr) -> Nursery ptr
Nursery MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref Weak (MVar (HashTable RealWorld Int (Seq (ptr Word8))))
Weak (NRS ptr)
weak


-- | Look for an entry with the requested size.
--
{-# INLINEABLE lookup #-}
lookup :: HasCallStack => Int -> Nursery ptr -> IO (Maybe (ptr Word8))
lookup :: Int -> Nursery ptr -> IO (Maybe (ptr Word8))
lookup !Int
key (Nursery !NRS ptr
ref !Weak (NRS ptr)
_) =
  MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8))
    -> IO (Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8))
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref ((HashTable RealWorld Int (Seq (ptr Word8))
  -> IO (Maybe (ptr Word8)))
 -> IO (Maybe (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8))
    -> IO (Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8))
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld Int (Seq (ptr Word8))
nrs ->
    IOHashTable HashTable Int (Seq (ptr Word8))
-> Int
-> (Maybe (Seq (ptr Word8))
    -> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8))
forall (h :: * -> * -> * -> *) k v a.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> (Maybe v -> IO (Maybe v, a)) -> IO a
HT.mutateIO HashTable RealWorld Int (Seq (ptr Word8))
IOHashTable HashTable Int (Seq (ptr Word8))
nrs Int
key ((Maybe (Seq (ptr Word8))
  -> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8)))
 -> IO (Maybe (ptr Word8)))
-> (Maybe (Seq (ptr Word8))
    -> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8)))
-> IO (Maybe (ptr Word8))
forall a b. (a -> b) -> a -> b
$ \case
      Maybe (Seq (ptr Word8))
Nothing -> (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Seq (ptr Word8))
forall a. Maybe a
Nothing, Maybe (ptr Word8)
forall a. Maybe a
Nothing)
      Just Seq (ptr Word8)
r  ->
        case Seq (ptr Word8) -> ViewL (ptr Word8)
forall a. Seq a -> ViewL a
Seq.viewl Seq (ptr Word8)
r of
          ptr Word8
v Seq.:< Seq (ptr Word8)
vs -> do
            Int64 -> IO ()
Debug.decreaseCurrentBytesNursery (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
key)
            if Seq (ptr Word8) -> Bool
forall a. Seq a -> Bool
Seq.null Seq (ptr Word8)
vs
              then (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Seq (ptr Word8))
forall a. Maybe a
Nothing, ptr Word8 -> Maybe (ptr Word8)
forall a. a -> Maybe a
Just ptr Word8
v)   -- delete this entry from the map
              else (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
-> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (Seq (ptr Word8) -> Maybe (Seq (ptr Word8))
forall a. a -> Maybe a
Just Seq (ptr Word8)
vs, ptr Word8 -> Maybe (ptr Word8)
forall a. a -> Maybe a
Just ptr Word8
v)   -- re-insert the tail
          --
          ViewL (ptr Word8)
Seq.EmptyL  -> String -> IO (Maybe (Seq (ptr Word8)), Maybe (ptr Word8))
forall a. HasCallStack => String -> a
internalError String
"expected non-empty sequence"


-- | Add an entry to the nursery
--
{-# INLINEABLE insert #-}
insert :: Int -> ptr Word8 -> Nursery ptr -> IO ()
insert :: Int -> ptr Word8 -> Nursery ptr -> IO ()
insert !Int
key !ptr Word8
val (Nursery !NRS ptr
ref Weak (NRS ptr)
_) =
  MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8)) -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref ((HashTable RealWorld Int (Seq (ptr Word8)) -> IO ()) -> IO ())
-> (HashTable RealWorld Int (Seq (ptr Word8)) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld Int (Seq (ptr Word8))
nrs -> do
    Int64 -> IO ()
Debug.increaseCurrentBytesRemote (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
key)
    IOHashTable HashTable Int (Seq (ptr Word8))
-> Int
-> (Maybe (Seq (ptr Word8)) -> (Maybe (Seq (ptr Word8)), ()))
-> IO ()
forall (h :: * -> * -> * -> *) k v a.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> (Maybe v -> (Maybe v, a)) -> IO a
HT.mutate HashTable RealWorld Int (Seq (ptr Word8))
IOHashTable HashTable Int (Seq (ptr Word8))
nrs Int
key ((Maybe (Seq (ptr Word8)) -> (Maybe (Seq (ptr Word8)), ()))
 -> IO ())
-> (Maybe (Seq (ptr Word8)) -> (Maybe (Seq (ptr Word8)), ()))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \case
      Maybe (Seq (ptr Word8))
Nothing -> (Seq (ptr Word8) -> Maybe (Seq (ptr Word8))
forall a. a -> Maybe a
Just (ptr Word8 -> Seq (ptr Word8)
forall a. a -> Seq a
Seq.singleton ptr Word8
val), ())
      Just Seq (ptr Word8)
vs -> (Seq (ptr Word8) -> Maybe (Seq (ptr Word8))
forall a. a -> Maybe a
Just (Seq (ptr Word8)
vs Seq (ptr Word8) -> ptr Word8 -> Seq (ptr Word8)
forall a. Seq a -> a -> Seq a
Seq.|> ptr Word8
val),     ())


-- | Delete all entries from the nursery
--
{-# INLINEABLE cleanup #-}
cleanup :: (ptr Word8 -> IO ()) -> NRS ptr -> IO ()
cleanup :: (ptr Word8 -> IO ()) -> NRS ptr -> IO ()
cleanup ptr Word8 -> IO ()
delete !NRS ptr
ref = do
  String -> IO ()
message String
"nursery cleanup"
  MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8))
    -> IO (HashTable RealWorld Int (Seq (ptr Word8))))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref ((HashTable RealWorld Int (Seq (ptr Word8))
  -> IO (HashTable RealWorld Int (Seq (ptr Word8))))
 -> IO ())
-> (HashTable RealWorld Int (Seq (ptr Word8))
    -> IO (HashTable RealWorld Int (Seq (ptr Word8))))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld Int (Seq (ptr Word8))
nrs -> do
    ((Int, Seq (ptr Word8)) -> IO (Seq ()))
-> IOHashTable HashTable Int (Seq (ptr Word8)) -> IO ()
forall (h :: * -> * -> * -> *) k v a.
HashTable h =>
((k, v) -> IO a) -> IOHashTable h k v -> IO ()
HT.mapM_ ((ptr Word8 -> IO ()) -> Seq (ptr Word8) -> IO (Seq ())
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
Seq.mapM ptr Word8 -> IO ()
delete (Seq (ptr Word8) -> IO (Seq ()))
-> ((Int, Seq (ptr Word8)) -> Seq (ptr Word8))
-> (Int, Seq (ptr Word8))
-> IO (Seq ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Seq (ptr Word8)) -> Seq (ptr Word8)
forall a b. (a, b) -> b
snd) HashTable RealWorld Int (Seq (ptr Word8))
IOHashTable HashTable Int (Seq (ptr Word8))
nrs
    Int64 -> IO ()
Debug.setCurrentBytesNursery Int64
0
    HashTable RealWorld Int (Seq (ptr Word8))
nrs'   <- IO (HashTable RealWorld Int (Seq (ptr Word8)))
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HT.new
    HashTable RealWorld Int (Seq (ptr Word8))
-> IO (HashTable RealWorld Int (Seq (ptr Word8)))
forall (m :: * -> *) a. Monad m => a -> m a
return HashTable RealWorld Int (Seq (ptr Word8))
nrs'


-- | The total number of bytes retained by the nursery
--
{-# INLINEABLE size #-}
size :: Nursery ptr -> IO Int64
size :: Nursery ptr -> IO Int64
size (Nursery NRS ptr
ref Weak (NRS ptr)
_)
  = MVar (HashTable RealWorld Int (Seq (ptr Word8)))
-> (HashTable RealWorld Int (Seq (ptr Word8)) -> IO Int64)
-> IO Int64
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (HashTable RealWorld Int (Seq (ptr Word8)))
NRS ptr
ref
  ((HashTable RealWorld Int (Seq (ptr Word8)) -> IO Int64)
 -> IO Int64)
-> (HashTable RealWorld Int (Seq (ptr Word8)) -> IO Int64)
-> IO Int64
forall a b. (a -> b) -> a -> b
$ (Int64 -> (Int, Seq (ptr Word8)) -> IO Int64)
-> Int64 -> IOHashTable HashTable Int (Seq (ptr Word8)) -> IO Int64
forall (h :: * -> * -> * -> *) a k v.
HashTable h =>
(a -> (k, v) -> IO a) -> a -> IOHashTable h k v -> IO a
HT.foldM (\Int64
s (Int
k,Seq (ptr Word8)
v) -> Int64 -> IO Int64
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64 -> IO Int64) -> Int64 -> IO Int64
forall a b. (a -> b) -> a -> b
$ Int64
s Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Seq (ptr Word8) -> Int
forall a. Seq a -> Int
Seq.length Seq (ptr Word8)
v))) Int64
0


-- Debug
-- -----

{-# INLINE message #-}
message :: String -> IO ()
message :: String -> IO ()
message String
msg = Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_gc (String
"gc: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg)