{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE DoAndIfThenElse     #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Array.Remote.LRU
-- Copyright   : [2015..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- This module extends the memory tables provided by
-- 'Data.Array.Accelerate.Array.Remote.Table' with an LRU caching policy that
-- evicts old arrays from the remote memory space once it runs out of memory.
-- Consequently, use of this module requires the backend client to keep track of
-- which remote arrays are currently being used, so that they will not be
-- evicted. See: 'withRemote' for more details on this requirement.
--
module Data.Array.Accelerate.Array.Remote.LRU (

  -- Tables for host/device memory associations
  MemoryTable, new, withRemote, malloc, free, insertUnmanaged, reclaim,

  -- Asynchronous tasks
  Task(..)

) where

import Data.Array.Accelerate.Analysis.Match                     ( matchSingleType, (:~:)(..) )
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Remote.Class
import Data.Array.Accelerate.Array.Remote.Table                 ( StableArray, makeWeakArrayData )
import Data.Array.Accelerate.Array.Unique                       ( touchUniqueArray )
import Data.Array.Accelerate.Error                              ( internalError )
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Array.Remote.Table       as Basic
import qualified Data.Array.Accelerate.Debug                    as D

import Control.Concurrent.MVar                                  ( MVar, newMVar, withMVar, takeMVar, putMVar, mkWeakMVar )
import Control.Monad                                            ( filterM )
import Control.Monad.Catch
import Control.Monad.IO.Class                                   ( MonadIO, liftIO )
import Data.Functor
#if __GLASGOW_HASKELL__ < 808
import Data.Int                                                 ( Int64 )
#endif
import Data.Maybe                                               ( isNothing )
import System.CPUTime
import System.Mem.Weak                                          ( Weak, deRefWeak, finalize )
import Prelude                                                  hiding ( lookup )
import qualified Data.HashTable.IO                              as HT

import GHC.Stack


-- We build cached memory tables on top of a basic memory table.
--
-- A key invariant is that the arrays in the MemoryTable are a subset of the
-- arrays in the UseTable. The UseTable reflects all arrays that have ever been
-- in the cache.
--
data MemoryTable p task = MemoryTable {-# UNPACK #-} !(Basic.MemoryTable p)
                                      {-# UNPACK #-} !(UseTable task)
                                      {-# UNPACK #-} !(Weak (UseTable task))

type UT task            = HT.CuckooHashTable StableArray (Used task)
type UseTable task      = MVar (UT task)

data Status = Clean     -- Array in remote memory matches array in host memory.
            | Dirty     -- Array in remote memory has been modified.
            | Unmanaged -- Array in remote memory was injected by FFI, so we
                        -- cannot remove it under any circumstance.
            | Evicted   -- Array has been evicted from remote memory
            deriving Status -> Status -> Bool
(Status -> Status -> Bool)
-> (Status -> Status -> Bool) -> Eq Status
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Status -> Status -> Bool
$c/= :: Status -> Status -> Bool
== :: Status -> Status -> Bool
$c== :: Status -> Status -> Bool
Eq

type Timestamp = Integer

data Used task where
  Used :: ArrayData e ~ ScalarArrayData e
       => !Timestamp
       -> !Status
       -> {-# UNPACK #-} !Int                   -- Use count
       -> ![task]                               -- Asynchronous tasks using the array
       -> {-# UNPACK #-} !Int                   -- Number of elements
       -> !(SingleType e)
       -> {-# UNPACK #-} !(Weak (ScalarArrayData e))
       -> Used task

-- | A Task represents a process executing asynchronously that can be polled for
-- its status. This is necessary for backends that work asynchronously (i.e.
-- the CUDA backend). If a backend is synchronous, the () instance can be used.
--
class Task task where
  -- |Returns true when the task has finished.
  completed :: task -> IO Bool

instance Task () where
  completed :: () -> IO Bool
completed () = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

-- | Create a new memory cache from host to remote arrays.
--
-- The function supplied should be the `free` for the remote pointers being
-- stored. This function will be called by the GC, which typically runs on a
-- different thread. Unlike the `free` in `RemoteMemory`, this function cannot
-- depend on any state.
--
new :: (forall a. ptr a -> IO ()) -> IO (MemoryTable ptr task)
new :: (forall a. ptr a -> IO ()) -> IO (MemoryTable ptr task)
new forall a. ptr a -> IO ()
release = do
  MemoryTable ptr
mt        <- (forall a. ptr a -> IO ()) -> IO (MemoryTable ptr)
forall (ptr :: * -> *).
(forall a. ptr a -> IO ()) -> IO (MemoryTable ptr)
Basic.new forall a. ptr a -> IO ()
release
  HashTable RealWorld StableArray (Used task)
utbl      <- IO (HashTable RealWorld StableArray (Used task))
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HT.new
  MVar (HashTable RealWorld StableArray (Used task))
ref       <- HashTable RealWorld StableArray (Used task)
-> IO (MVar (HashTable RealWorld StableArray (Used task)))
forall a. a -> IO (MVar a)
newMVar HashTable RealWorld StableArray (Used task)
utbl
  Weak (MVar (HashTable RealWorld StableArray (Used task)))
weak_utbl <- MVar (HashTable RealWorld StableArray (Used task))
-> IO ()
-> IO (Weak (MVar (HashTable RealWorld StableArray (Used task))))
forall a. MVar a -> IO () -> IO (Weak (MVar a))
mkWeakMVar MVar (HashTable RealWorld StableArray (Used task))
ref (UT task -> IO ()
forall task. UT task -> IO ()
cache_finalizer HashTable RealWorld StableArray (Used task)
UT task
utbl)
  MemoryTable ptr task -> IO (MemoryTable ptr task)
forall (m :: * -> *) a. Monad m => a -> m a
return    (MemoryTable ptr task -> IO (MemoryTable ptr task))
-> MemoryTable ptr task -> IO (MemoryTable ptr task)
forall a b. (a -> b) -> a -> b
$! MemoryTable ptr
-> UseTable task -> Weak (UseTable task) -> MemoryTable ptr task
forall (p :: * -> *) task.
MemoryTable p
-> UseTable task -> Weak (UseTable task) -> MemoryTable p task
MemoryTable MemoryTable ptr
mt MVar (HashTable RealWorld StableArray (Used task))
UseTable task
ref Weak (MVar (HashTable RealWorld StableArray (Used task)))
Weak (UseTable task)
weak_utbl

-- | Perform some action that requires the remote pointer corresponding to
-- the given array. Returns `Nothing` if the array have NEVER been in the
-- cache. If the array was previously in the cache, but was evicted due to its
-- age, then the array will be copied back from host memory.
--
-- The continuation passed as the third argument needs to obey some precise
-- properties. As with all bracketed functions, the supplied remote pointer must
-- not leak out of the function, as it is only guaranteed to be valid within it.
-- If it is required that it does leak (e.g. the backend uses concurrency to
-- interleave execution of different parts of the program), then `completed` on
-- the returned task should not return true until it is guaranteed there are no
-- more accesses of the remote pointer.
--
withRemote
    :: forall task m a c. (HasCallStack, Task task, RemoteMemory m, MonadIO m, Functor m)
    => MemoryTable (RemotePtr m) task
    -> SingleType a
    -> ArrayData a
    -> (RemotePtr m (ScalarArrayDataR a) -> m (task, c))
    -> m (Maybe c)
withRemote :: MemoryTable (RemotePtr m) task
-> SingleType a
-> ArrayData a
-> (RemotePtr m (ScalarArrayDataR a) -> m (task, c))
-> m (Maybe c)
withRemote (MemoryTable !MemoryTable (RemotePtr m)
mt !UseTable task
ref Weak (UseTable task)
_) !SingleType a
tp !ArrayData a
arr RemotePtr m (ScalarArrayDataR a) -> m (task, c)
run | SingleArrayDict a
SingleArrayDict <- SingleType a -> SingleArrayDict a
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType a
tp = do
  StableArray
key <- SingleType a -> ArrayData a -> m StableArray
forall (m :: * -> *) a.
MonadIO m =>
SingleType a -> ArrayData a -> m StableArray
Basic.makeStableArray SingleType a
tp ArrayData a
arr
  Maybe (RemotePtr m a)
mp  <- MVar (HashTable RealWorld StableArray (Used task))
-> (HashTable RealWorld StableArray (Used task)
    -> m (Maybe (RemotePtr m a)))
-> m (Maybe (RemotePtr m a))
forall (m :: * -> *) a b.
(MonadIO m, MonadMask m) =>
MVar a -> (a -> m b) -> m b
withMVar' MVar (HashTable RealWorld StableArray (Used task))
UseTable task
ref ((HashTable RealWorld StableArray (Used task)
  -> m (Maybe (RemotePtr m a)))
 -> m (Maybe (RemotePtr m a)))
-> (HashTable RealWorld StableArray (Used task)
    -> m (Maybe (RemotePtr m a)))
-> m (Maybe (RemotePtr m a))
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld StableArray (Used task)
utbl -> do
    Maybe (Used task)
mu  <- IO (Maybe (Used task)) -> m (Maybe (Used task))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (Used task)) -> m (Maybe (Used task)))
-> ((Maybe (Used task) -> (Maybe (Used task), Maybe (Used task)))
    -> IO (Maybe (Used task)))
-> (Maybe (Used task) -> (Maybe (Used task), Maybe (Used task)))
-> m (Maybe (Used task))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOHashTable HashTable StableArray (Used task)
-> StableArray
-> (Maybe (Used task) -> (Maybe (Used task), Maybe (Used task)))
-> IO (Maybe (Used task))
forall (h :: * -> * -> * -> *) k v a.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> (Maybe v -> (Maybe v, a)) -> IO a
HT.mutate HashTable RealWorld StableArray (Used task)
IOHashTable HashTable StableArray (Used task)
utbl StableArray
key ((Maybe (Used task) -> (Maybe (Used task), Maybe (Used task)))
 -> m (Maybe (Used task)))
-> (Maybe (Used task) -> (Maybe (Used task), Maybe (Used task)))
-> m (Maybe (Used task))
forall a b. (a -> b) -> a -> b
$ \case
      Maybe (Used task)
Nothing -> (Maybe (Used task)
forall a. Maybe a
Nothing,           Maybe (Used task)
forall a. Maybe a
Nothing)
      Just Used task
u  -> (Used task -> Maybe (Used task)
forall a. a -> Maybe a
Just (Used task -> Used task
forall task. Used task -> Used task
incCount Used task
u), Used task -> Maybe (Used task)
forall a. a -> Maybe a
Just Used task
u)
    --
    case Maybe (Used task)
mu of
      Maybe (Used task)
Nothing -> do
        String -> m ()
forall (m :: * -> *). MonadIO m => String -> m ()
message (String
"withRemote/array has never been malloc'd: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ StableArray -> String
forall a. Show a => a -> String
show StableArray
key)
        Maybe (RemotePtr m a) -> m (Maybe (RemotePtr m a))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (RemotePtr m a)
forall a. Maybe a
Nothing -- The array was never in the table

      Just Used task
u  -> do
        Maybe (RemotePtr m a)
mp  <- IO (Maybe (RemotePtr m a)) -> m (Maybe (RemotePtr m a))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (RemotePtr m a)) -> m (Maybe (RemotePtr m a)))
-> IO (Maybe (RemotePtr m a)) -> m (Maybe (RemotePtr m a))
forall a b. (a -> b) -> a -> b
$ MemoryTable (RemotePtr m)
-> SingleType a
-> ArrayData a
-> IO (Maybe (RemotePtr m (ScalarArrayDataR a)))
forall (m :: * -> *) a.
(HasCallStack, RemoteMemory m) =>
MemoryTable (RemotePtr m)
-> SingleType a
-> ArrayData a
-> IO (Maybe (RemotePtr m (ScalarArrayDataR a)))
Basic.lookup @m MemoryTable (RemotePtr m)
mt SingleType a
tp ArrayData a
arr
        RemotePtr m a
ptr <- case Maybe (RemotePtr m a)
mp of
                Just RemotePtr m a
p          -> RemotePtr m a -> m (RemotePtr m a)
forall (m :: * -> *) a. Monad m => a -> m a
return RemotePtr m a
p
                Maybe (RemotePtr m a)
Nothing
                  | Used task -> Bool
forall task. Used task -> Bool
isEvicted Used task
u -> HasCallStack =>
IOHashTable HashTable StableArray (Used task)
-> Used task -> m (RemotePtr m (ScalarArrayDataR a))
IOHashTable HashTable StableArray (Used task)
-> Used task -> m (RemotePtr m (ScalarArrayDataR a))
copyBack HashTable RealWorld StableArray (Used task)
IOHashTable HashTable StableArray (Used task)
utbl (Used task -> Used task
forall task. Used task -> Used task
incCount Used task
u)
                  | Bool
otherwise   -> do String -> m ()
forall (m :: * -> *). MonadIO m => String -> m ()
message (String
"lost array " String -> String -> String
forall a. [a] -> [a] -> [a]
++ StableArray -> String
forall a. Show a => a -> String
show StableArray
key)
                                      String -> m (RemotePtr m a)
forall a. HasCallStack => String -> a
internalError String
"non-evicted array has been lost"
        Maybe (RemotePtr m a) -> m (Maybe (RemotePtr m a))
forall (m :: * -> *) a. Monad m => a -> m a
return (RemotePtr m a -> Maybe (RemotePtr m a)
forall a. a -> Maybe a
Just RemotePtr m a
ptr)
  --
  case Maybe (RemotePtr m a)
mp of
    Maybe (RemotePtr m a)
Nothing  -> Maybe c -> m (Maybe c)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe c
forall a. Maybe a
Nothing
    Just RemotePtr m a
ptr -> c -> Maybe c
forall a. a -> Maybe a
Just (c -> Maybe c) -> m c -> m (Maybe c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HasCallStack, ArrayData a ~ ScalarArrayData a) =>
StableArray -> RemotePtr m (ScalarArrayDataR a) -> m c
StableArray -> RemotePtr m (ScalarArrayDataR a) -> m c
go StableArray
key RemotePtr m a
RemotePtr m (ScalarArrayDataR a)
ptr
  where
    updateTask :: Used task -> task -> IO (Used task)
    updateTask :: Used task -> task -> IO (Used task)
updateTask (Used Timestamp
_ Status
status Int
count [task]
tasks Int
n SingleType e
tp' Weak (ScalarArrayData e)
weak_arr) task
task = do
      Timestamp
ts      <- IO Timestamp
getCPUTime
      [task]
tasks'  <- [task] -> IO [task]
forall task. Task task => [task] -> IO [task]
cleanUses [task]
tasks
      Used task -> IO (Used task)
forall (m :: * -> *) a. Monad m => a -> m a
return (Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
forall e task.
(ArrayData e ~ ScalarArrayData e) =>
Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
Used Timestamp
ts Status
status (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (task
task task -> [task] -> [task]
forall a. a -> [a] -> [a]
: [task]
tasks') Int
n SingleType e
tp' Weak (ScalarArrayData e)
weak_arr)

    copyBack :: HasCallStack => UT task -> Used task -> m (RemotePtr m (ScalarArrayDataR a))
    copyBack :: IOHashTable HashTable StableArray (Used task)
-> Used task -> m (RemotePtr m (ScalarArrayDataR a))
copyBack IOHashTable HashTable StableArray (Used task)
utbl (Used Timestamp
ts Status
_ Int
count [task]
tasks Int
n SingleType e
tp' Weak (ScalarArrayData e)
weak_arr)
      | Just a :~: e
Refl <- SingleType a -> SingleType e -> Maybe (a :~: e)
forall s t. SingleType s -> SingleType t -> Maybe (s :~: t)
matchSingleType SingleType a
tp SingleType e
tp' = do
        String -> m ()
forall (m :: * -> *). MonadIO m => String -> m ()
message String
"withRemote/reuploading-evicted-array"
        RemotePtr m (ScalarArrayDataR e)
p <- MemoryTable (RemotePtr m)
-> IOHashTable HashTable StableArray (Used task)
-> SingleType a
-> ArrayData a
-> Used task
-> m (RemotePtr m (ScalarArrayDataR a))
forall e (m :: * -> *) task.
(HasCallStack, RemoteMemory m, MonadIO m, Task task,
 ArrayData e ~ ScalarArrayData e) =>
MemoryTable (RemotePtr m)
-> UT task
-> SingleType e
-> ArrayData e
-> Used task
-> m (RemotePtr m (ScalarArrayDataR e))
mallocWithUsage MemoryTable (RemotePtr m)
mt IOHashTable HashTable StableArray (Used task)
utbl SingleType a
tp ArrayData a
arr (Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType a
-> Weak (ScalarArrayData a)
-> Used task
forall e task.
(ArrayData e ~ ScalarArrayData e) =>
Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
Used Timestamp
ts Status
Clean Int
count [task]
tasks Int
n SingleType a
tp Weak (ScalarArrayData a)
Weak (ScalarArrayData e)
weak_arr)
        SingleType a
-> Int -> RemotePtr m (ScalarArrayDataR a) -> ArrayData a -> m ()
forall (m :: * -> *) e.
RemoteMemory m =>
SingleType e
-> Int -> RemotePtr m (ScalarArrayDataR e) -> ArrayData e -> m ()
pokeRemote SingleType a
tp Int
n RemotePtr m (ScalarArrayDataR a)
RemotePtr m (ScalarArrayDataR e)
p ArrayData a
arr
        RemotePtr m (ScalarArrayDataR e)
-> m (RemotePtr m (ScalarArrayDataR e))
forall (m :: * -> *) a. Monad m => a -> m a
return RemotePtr m (ScalarArrayDataR e)
p
      | Bool
otherwise = String -> m (RemotePtr m (ScalarArrayDataR a))
forall a. HasCallStack => String -> a
internalError String
"Type mismatch"

    -- We can't combine the use of `withMVar ref` above with the one here
    -- because the `permute` operation from the PTX backend requires nested
    -- calls to `withRemote` in order to copy the defaults array.
    --
    go :: (HasCallStack, ArrayData a ~ ScalarArrayData a)
       => StableArray
       -> RemotePtr m (ScalarArrayDataR a)
       -> m c
    go :: StableArray -> RemotePtr m (ScalarArrayDataR a) -> m c
go StableArray
key RemotePtr m (ScalarArrayDataR a)
ptr = do
      String -> m ()
forall (m :: * -> *). MonadIO m => String -> m ()
message (String
"withRemote/using: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ StableArray -> String
forall a. Show a => a -> String
show StableArray
key)
      (task
task, c
c) <- RemotePtr m (ScalarArrayDataR a) -> m (task, c)
run RemotePtr m (ScalarArrayDataR a)
ptr
      IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> ((HashTable RealWorld StableArray (Used task) -> IO ())
    -> IO ())
-> (HashTable RealWorld StableArray (Used task) -> IO ())
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar (HashTable RealWorld StableArray (Used task))
-> (HashTable RealWorld StableArray (Used task) -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (HashTable RealWorld StableArray (Used task))
UseTable task
ref  ((HashTable RealWorld StableArray (Used task) -> IO ()) -> m ())
-> (HashTable RealWorld StableArray (Used task) -> IO ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld StableArray (Used task)
utbl -> do
        IOHashTable HashTable StableArray (Used task)
-> StableArray
-> (Maybe (Used task) -> IO (Maybe (Used task), ()))
-> IO ()
forall (h :: * -> * -> * -> *) k v a.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> (Maybe v -> IO (Maybe v, a)) -> IO a
HT.mutateIO HashTable RealWorld StableArray (Used task)
IOHashTable HashTable StableArray (Used task)
utbl StableArray
key ((Maybe (Used task) -> IO (Maybe (Used task), ())) -> IO ())
-> (Maybe (Used task) -> IO (Maybe (Used task), ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \case
          Maybe (Used task)
Nothing -> String -> IO (Maybe (Used task), ())
forall a. HasCallStack => String -> a
internalError String
"invariant violated"
          Just Used task
u  -> do
            Used task
u' <- Used task -> task -> IO (Used task)
updateTask Used task
u task
task
            (Maybe (Used task), ()) -> IO (Maybe (Used task), ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Used task -> Maybe (Used task)
forall a. a -> Maybe a
Just Used task
u', ())
        --
        ScalarArrayData a -> IO ()
forall a. UniqueArray a -> IO ()
touchUniqueArray ScalarArrayData a
ArrayData a
arr
      c -> m c
forall (m :: * -> *) a. Monad m => a -> m a
return c
c


-- | Allocate a new device array to be associated with the given host-side array.
-- This has similar behaviour to malloc in Data.Array.Accelerate.Array.Memory.Table
-- but also will copy remote arrays back to main memory in order to make space.
--
-- The third argument indicates that the array should be considered frozen. That
-- is to say that the array contents will never change. In the event that the
-- array has to be evicted from the remote memory, the copy already residing in
-- host memory should be considered valid.
--
-- If this function is called on an array that is already contained within the
-- cache, this is a no-op.
--
-- On return, 'True' indicates that we allocated some remote memory, and 'False'
-- indicates that we did not need to.
--
malloc :: forall e m task. (HasCallStack, RemoteMemory m, MonadIO m, Task task)
       => MemoryTable (RemotePtr m) task
       -> SingleType e
       -> ArrayData e
       -> Bool            -- ^ True if host array is frozen.
       -> Int             -- ^ Number of elements
       -> m Bool          -- ^ Was the array allocated successfully?
malloc :: MemoryTable (RemotePtr m) task
-> SingleType e -> ArrayData e -> Bool -> Int -> m Bool
malloc (MemoryTable MemoryTable (RemotePtr m)
mt UseTable task
ref Weak (UseTable task)
weak_utbl) !SingleType e
tp !ArrayData e
ad !Bool
frozen !Int
n | SingleArrayDict e
SingleArrayDict <- SingleType e -> SingleArrayDict e
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType e
tp = do -- Required for ArrayData e ~ ScalarArrayData e
  Timestamp
ts  <- IO Timestamp -> m Timestamp
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Timestamp -> m Timestamp) -> IO Timestamp -> m Timestamp
forall a b. (a -> b) -> a -> b
$ IO Timestamp
getCPUTime
  StableArray
key <- SingleType e -> ArrayData e -> m StableArray
forall (m :: * -> *) a.
MonadIO m =>
SingleType a -> ArrayData a -> m StableArray
Basic.makeStableArray SingleType e
tp ArrayData e
ad
  --
  let status :: Status
status = if Bool
frozen
                 then Status
Clean
                 else Status
Dirty
  --
  MVar (HashTable RealWorld StableArray (Used task))
-> (HashTable RealWorld StableArray (Used task) -> m Bool)
-> m Bool
forall (m :: * -> *) a b.
(MonadIO m, MonadMask m) =>
MVar a -> (a -> m b) -> m b
withMVar' MVar (HashTable RealWorld StableArray (Used task))
UseTable task
ref ((HashTable RealWorld StableArray (Used task) -> m Bool) -> m Bool)
-> (HashTable RealWorld StableArray (Used task) -> m Bool)
-> m Bool
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld StableArray (Used task)
utbl -> do
    Maybe (Used task)
mu <- IO (Maybe (Used task)) -> m (Maybe (Used task))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (Used task)) -> m (Maybe (Used task)))
-> IO (Maybe (Used task)) -> m (Maybe (Used task))
forall a b. (a -> b) -> a -> b
$ IOHashTable HashTable StableArray (Used task)
-> StableArray -> IO (Maybe (Used task))
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO (Maybe v)
HT.lookup HashTable RealWorld StableArray (Used task)
IOHashTable HashTable StableArray (Used task)
utbl StableArray
key
    if Maybe (Used task) -> Bool
forall a. Maybe a -> Bool
isNothing Maybe (Used task)
mu
      then do
        Weak (UniqueArray e)
weak_arr <- IO (Weak (UniqueArray e)) -> m (Weak (UniqueArray e))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Weak (UniqueArray e)) -> m (Weak (UniqueArray e)))
-> IO (Weak (UniqueArray e)) -> m (Weak (UniqueArray e))
forall a b. (a -> b) -> a -> b
$ SingleType e
-> ArrayData e
-> UniqueArray e
-> Maybe (IO ())
-> IO (Weak (UniqueArray e))
forall e c.
SingleType e -> ArrayData e -> c -> Maybe (IO ()) -> IO (Weak c)
makeWeakArrayData SingleType e
tp ArrayData e
ad UniqueArray e
ArrayData e
ad (IO () -> Maybe (IO ())
forall a. a -> Maybe a
Just (IO () -> Maybe (IO ())) -> IO () -> Maybe (IO ())
forall a b. (a -> b) -> a -> b
$ StableArray -> Weak (UseTable task) -> IO ()
forall task. StableArray -> Weak (UseTable task) -> IO ()
finalizer StableArray
key Weak (UseTable task)
weak_utbl)
        RemotePtr m e
_        <- MemoryTable (RemotePtr m)
-> IOHashTable HashTable StableArray (Used task)
-> SingleType e
-> ArrayData e
-> Used task
-> m (RemotePtr m (ScalarArrayDataR e))
forall e (m :: * -> *) task.
(HasCallStack, RemoteMemory m, MonadIO m, Task task,
 ArrayData e ~ ScalarArrayData e) =>
MemoryTable (RemotePtr m)
-> UT task
-> SingleType e
-> ArrayData e
-> Used task
-> m (RemotePtr m (ScalarArrayDataR e))
mallocWithUsage MemoryTable (RemotePtr m)
mt HashTable RealWorld StableArray (Used task)
IOHashTable HashTable StableArray (Used task)
utbl SingleType e
tp ArrayData e
ad (Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
forall e task.
(ArrayData e ~ ScalarArrayData e) =>
Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
Used Timestamp
ts Status
status Int
0 [] Int
n SingleType e
tp Weak (UniqueArray e)
Weak (ScalarArrayData e)
weak_arr)
        Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
      else
        Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

mallocWithUsage
    :: forall e m task. (HasCallStack, RemoteMemory m, MonadIO m, Task task, ArrayData e ~ ScalarArrayData e)
    => Basic.MemoryTable (RemotePtr m)
    -> UT task
    -> SingleType e
    -> ArrayData e
    -> Used task
    -> m (RemotePtr m (ScalarArrayDataR e))
mallocWithUsage :: MemoryTable (RemotePtr m)
-> UT task
-> SingleType e
-> ArrayData e
-> Used task
-> m (RemotePtr m (ScalarArrayDataR e))
mallocWithUsage !MemoryTable (RemotePtr m)
mt !UT task
utbl !SingleType e
tp !ArrayData e
ad !usage :: Used task
usage@(Used Timestamp
_ Status
_ Int
_ [task]
_ Int
n SingleType e
_ Weak (ScalarArrayData e)
_) = m (RemotePtr m (ScalarArrayDataR e))
HasCallStack => m (RemotePtr m (ScalarArrayDataR e))
malloc'
  where
    malloc' :: HasCallStack => m (RemotePtr m (ScalarArrayDataR e))
    malloc' :: m (RemotePtr m (ScalarArrayDataR e))
malloc' = do
      Maybe (RemotePtr m (ScalarArrayDataR e))
mp <- MemoryTable (RemotePtr m)
-> SingleType e
-> ArrayData e
-> Int
-> m (Maybe (RemotePtr m (ScalarArrayDataR e)))
forall a (m :: * -> *).
(HasCallStack, RemoteMemory m, MonadIO m) =>
MemoryTable (RemotePtr m)
-> SingleType a
-> ArrayData a
-> Int
-> m (Maybe (RemotePtr m (ScalarArrayDataR a)))
Basic.malloc @e @m MemoryTable (RemotePtr m)
mt SingleType e
tp ArrayData e
ad Int
n :: m (Maybe (RemotePtr m (ScalarArrayDataR e)))
      case Maybe (RemotePtr m (ScalarArrayDataR e))
mp of
        Maybe (RemotePtr m (ScalarArrayDataR e))
Nothing -> do
          Bool
success <- UT task -> MemoryTable (RemotePtr m) -> m Bool
forall (m :: * -> *) task.
(HasCallStack, RemoteMemory m, MonadIO m, Task task) =>
UT task -> MemoryTable (RemotePtr m) -> m Bool
evictLRU UT task
utbl MemoryTable (RemotePtr m)
mt
          if Bool
success then m (RemotePtr m (ScalarArrayDataR e))
HasCallStack => m (RemotePtr m (ScalarArrayDataR e))
malloc'
                     else String -> m (RemotePtr m (ScalarArrayDataR e))
forall a. HasCallStack => String -> a
internalError String
"Remote memory exhausted"
        Just RemotePtr m (ScalarArrayDataR e)
p -> IO (RemotePtr m (ScalarArrayDataR e))
-> m (RemotePtr m (ScalarArrayDataR e))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (RemotePtr m (ScalarArrayDataR e))
 -> m (RemotePtr m (ScalarArrayDataR e)))
-> IO (RemotePtr m (ScalarArrayDataR e))
-> m (RemotePtr m (ScalarArrayDataR e))
forall a b. (a -> b) -> a -> b
$ do
          StableArray
key <- SingleType e -> ArrayData e -> IO StableArray
forall (m :: * -> *) a.
MonadIO m =>
SingleType a -> ArrayData a -> m StableArray
Basic.makeStableArray SingleType e
tp ArrayData e
ad
          UT task -> StableArray -> Used task -> IO ()
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
HT.insert UT task
utbl StableArray
key Used task
usage
          RemotePtr m (ScalarArrayDataR e)
-> IO (RemotePtr m (ScalarArrayDataR e))
forall (m :: * -> *) a. Monad m => a -> m a
return RemotePtr m (ScalarArrayDataR e)
p

evictLRU
    :: forall m task. (HasCallStack, RemoteMemory m, MonadIO m, Task task)
    => UT task
    -> Basic.MemoryTable (RemotePtr m)
    -> m Bool
evictLRU :: UT task -> MemoryTable (RemotePtr m) -> m Bool
evictLRU !UT task
utbl !MemoryTable (RemotePtr m)
mt = String -> m Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
trace String
"evictLRU/evicting-eldest-array" (m Bool -> m Bool) -> m Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
  Maybe (StableArray, Used task)
mused <- IO (Maybe (StableArray, Used task))
-> m (Maybe (StableArray, Used task))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (StableArray, Used task))
 -> m (Maybe (StableArray, Used task)))
-> IO (Maybe (StableArray, Used task))
-> m (Maybe (StableArray, Used task))
forall a b. (a -> b) -> a -> b
$ (Maybe (StableArray, Used task)
 -> (StableArray, Used task) -> IO (Maybe (StableArray, Used task)))
-> Maybe (StableArray, Used task)
-> UT task
-> IO (Maybe (StableArray, Used task))
forall (h :: * -> * -> * -> *) a k v.
HashTable h =>
(a -> (k, v) -> IO a) -> a -> IOHashTable h k v -> IO a
HT.foldM Maybe (StableArray, Used task)
-> (StableArray, Used task) -> IO (Maybe (StableArray, Used task))
eldest Maybe (StableArray, Used task)
forall a. Maybe a
Nothing UT task
utbl
  case Maybe (StableArray, Used task)
mused of
    Just (StableArray
sa, Used Timestamp
ts Status
status Int
count [task]
tasks Int
n SingleType e
tp Weak (ScalarArrayData e)
weak_arr) -> do
      Maybe (ScalarArrayData e)
mad <- IO (Maybe (ScalarArrayData e)) -> m (Maybe (ScalarArrayData e))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (ScalarArrayData e)) -> m (Maybe (ScalarArrayData e)))
-> IO (Maybe (ScalarArrayData e)) -> m (Maybe (ScalarArrayData e))
forall a b. (a -> b) -> a -> b
$ Weak (ScalarArrayData e) -> IO (Maybe (ScalarArrayData e))
forall v. Weak v -> IO (Maybe v)
deRefWeak Weak (ScalarArrayData e)
weak_arr
      case Maybe (ScalarArrayData e)
mad of
        Maybe (ScalarArrayData e)
Nothing -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          -- This can only happen if our eviction process was interrupted by
          -- garbage collection. In which case, even though we didn't actually
          -- evict anything, we should return true, as we know some remote
          -- memory is now free.
          --
          -- Small caveat: Due to finalisers being delayed, it's a good idea
          -- to free the array here.
          MemoryTable (RemotePtr m) -> StableArray -> IO ()
forall (m :: * -> *).
RemoteMemory m =>
MemoryTable (RemotePtr m) -> StableArray -> IO ()
Basic.freeStable @m MemoryTable (RemotePtr m)
mt StableArray
sa
          UT task -> StableArray -> IO ()
forall task. UT task -> StableArray -> IO ()
delete UT task
utbl StableArray
sa
          String -> IO ()
forall (m :: * -> *). MonadIO m => String -> m ()
message String
"evictLRU/Accelerate GC interrupted by GHC GC"

        Just ScalarArrayData e
arr -> do
          String -> m ()
forall (m :: * -> *). MonadIO m => String -> m ()
message (String
"evictLRU/evicting " String -> String -> String
forall a. [a] -> [a] -> [a]
++ StableArray -> String
forall a. Show a => a -> String
show StableArray
sa)
          Status -> Int -> SingleType e -> ArrayData e -> m ()
forall e. Status -> Int -> SingleType e -> ArrayData e -> m ()
copyIfNecessary Status
status Int
n SingleType e
tp ScalarArrayData e
ArrayData e
arr
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int64 -> IO ()
D.didEvictBytes (SingleType e -> Int -> Int64
forall e. SingleType e -> Int -> Int64
remoteBytes SingleType e
tp Int
n)
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ MemoryTable (RemotePtr m) -> StableArray -> IO ()
forall (m :: * -> *).
RemoteMemory m =>
MemoryTable (RemotePtr m) -> StableArray -> IO ()
Basic.freeStable @m MemoryTable (RemotePtr m)
mt StableArray
sa
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ UT task -> StableArray -> Used task -> IO ()
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
HT.insert UT task
utbl StableArray
sa (Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
forall e task.
(ArrayData e ~ ScalarArrayData e) =>
Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
Used Timestamp
ts Status
Evicted Int
count [task]
tasks Int
n SingleType e
tp Weak (ScalarArrayData e)
weak_arr)
      Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    Maybe (StableArray, Used task)
_ -> String -> m Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
trace String
"evictLRU/All arrays in use, unable to evict" (m Bool -> m Bool) -> m Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  where
    -- Find the eldest, not currently in use, array.
    eldest :: (Maybe (StableArray, Used task)) -> (StableArray, Used task) -> IO (Maybe (StableArray, Used task))
    eldest :: Maybe (StableArray, Used task)
-> (StableArray, Used task) -> IO (Maybe (StableArray, Used task))
eldest Maybe (StableArray, Used task)
prev (StableArray
sa, used :: Used task
used@(Used Timestamp
ts Status
status Int
count [task]
tasks Int
n SingleType e
tp Weak (ScalarArrayData e)
weak_arr))
      | Int
count Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
      , Status -> Bool
evictable Status
status
      = do
          [task]
tasks' <- [task] -> IO [task]
forall task. Task task => [task] -> IO [task]
cleanUses [task]
tasks
          UT task -> StableArray -> Used task -> IO ()
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
HT.insert UT task
utbl StableArray
sa (Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
forall e task.
(ArrayData e ~ ScalarArrayData e) =>
Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
Used Timestamp
ts Status
status Int
count [task]
tasks' Int
n SingleType e
tp Weak (ScalarArrayData e)
weak_arr)
          case [task]
tasks' of
            [] | Just (StableArray
_, Used Timestamp
ts' Status
_ Int
_ [task]
_ Int
_ SingleType e
_ Weak (ScalarArrayData e)
_) <- Maybe (StableArray, Used task)
prev
               , Timestamp
ts Timestamp -> Timestamp -> Bool
forall a. Ord a => a -> a -> Bool
< Timestamp
ts'        -> Maybe (StableArray, Used task)
-> IO (Maybe (StableArray, Used task))
forall (m :: * -> *) a. Monad m => a -> m a
return ((StableArray, Used task) -> Maybe (StableArray, Used task)
forall a. a -> Maybe a
Just (StableArray
sa, Used task
used))
               | Maybe (StableArray, Used task)
Nothing <- Maybe (StableArray, Used task)
prev -> Maybe (StableArray, Used task)
-> IO (Maybe (StableArray, Used task))
forall (m :: * -> *) a. Monad m => a -> m a
return ((StableArray, Used task) -> Maybe (StableArray, Used task)
forall a. a -> Maybe a
Just (StableArray
sa, Used task
used))
            [task]
_  -> Maybe (StableArray, Used task)
-> IO (Maybe (StableArray, Used task))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (StableArray, Used task)
prev
    eldest Maybe (StableArray, Used task)
prev (StableArray, Used task)
_ = Maybe (StableArray, Used task)
-> IO (Maybe (StableArray, Used task))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (StableArray, Used task)
prev

    remoteBytes :: SingleType e -> Int -> Int64
    remoteBytes :: SingleType e -> Int -> Int64
remoteBytes SingleType e
tp Int
n = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (TypeR e -> Int
forall e. TypeR e -> Int
bytesElt (ScalarType e -> TypeR e
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType e -> ScalarType e
forall a. SingleType a -> ScalarType a
SingleScalarType SingleType e
tp))) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n

    evictable :: Status -> Bool
    evictable :: Status -> Bool
evictable Status
Clean     = Bool
True
    evictable Status
Dirty     = Bool
True
    evictable Status
Unmanaged = Bool
False
    evictable Status
Evicted   = Bool
False

    copyIfNecessary :: Status -> Int -> SingleType e -> ArrayData e -> m ()
    copyIfNecessary :: Status -> Int -> SingleType e -> ArrayData e -> m ()
copyIfNecessary Status
Clean     Int
_ SingleType e
_  ArrayData e
_  = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    copyIfNecessary Status
Unmanaged Int
_ SingleType e
_  ArrayData e
_  = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    copyIfNecessary Status
Evicted   Int
_ SingleType e
_  ArrayData e
_  = String -> m ()
forall a. HasCallStack => String -> a
internalError String
"Attempting to evict already evicted array"
    copyIfNecessary Status
Dirty     Int
n SingleType e
tp ArrayData e
ad = do
      Maybe (RemotePtr m (ScalarArrayDataR e))
mp <- IO (Maybe (RemotePtr m (ScalarArrayDataR e)))
-> m (Maybe (RemotePtr m (ScalarArrayDataR e)))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (RemotePtr m (ScalarArrayDataR e)))
 -> m (Maybe (RemotePtr m (ScalarArrayDataR e))))
-> IO (Maybe (RemotePtr m (ScalarArrayDataR e)))
-> m (Maybe (RemotePtr m (ScalarArrayDataR e)))
forall a b. (a -> b) -> a -> b
$ MemoryTable (RemotePtr m)
-> SingleType e
-> ArrayData e
-> IO (Maybe (RemotePtr m (ScalarArrayDataR e)))
forall (m :: * -> *) a.
(HasCallStack, RemoteMemory m) =>
MemoryTable (RemotePtr m)
-> SingleType a
-> ArrayData a
-> IO (Maybe (RemotePtr m (ScalarArrayDataR a)))
Basic.lookup @m MemoryTable (RemotePtr m)
mt SingleType e
tp ArrayData e
ad
      case Maybe (RemotePtr m (ScalarArrayDataR e))
mp of
        Maybe (RemotePtr m (ScalarArrayDataR e))
Nothing -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- RCE: I think this branch is actually impossible.
        Just RemotePtr m (ScalarArrayDataR e)
p  -> SingleType e
-> Int -> RemotePtr m (ScalarArrayDataR e) -> ArrayData e -> m ()
forall (m :: * -> *) e.
RemoteMemory m =>
SingleType e
-> Int -> RemotePtr m (ScalarArrayDataR e) -> ArrayData e -> m ()
peekRemote SingleType e
tp Int
n RemotePtr m (ScalarArrayDataR e)
p ArrayData e
ad

-- | Deallocate the device array associated with the given host-side array.
-- Typically this should only be called in very specific circumstances. This
-- operation is not thread-safe.
--
free :: forall m a task. (HasCallStack, RemoteMemory m)
     => MemoryTable (RemotePtr m) task
     -> SingleType a
     -> ArrayData a
     -> IO ()
free :: MemoryTable (RemotePtr m) task
-> SingleType a -> ArrayData a -> IO ()
free (MemoryTable !MemoryTable (RemotePtr m)
mt !UseTable task
ref Weak (UseTable task)
_) !SingleType a
tp !ArrayData a
arr
  = MVar (HashTable RealWorld StableArray (Used task))
-> (HashTable RealWorld StableArray (Used task) -> IO ()) -> IO ()
forall (m :: * -> *) a b.
(MonadIO m, MonadMask m) =>
MVar a -> (a -> m b) -> m b
withMVar' MVar (HashTable RealWorld StableArray (Used task))
UseTable task
ref
  ((HashTable RealWorld StableArray (Used task) -> IO ()) -> IO ())
-> (HashTable RealWorld StableArray (Used task) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld StableArray (Used task)
utbl -> do
      StableArray
key <- SingleType a -> ArrayData a -> IO StableArray
forall (m :: * -> *) a.
MonadIO m =>
SingleType a -> ArrayData a -> m StableArray
Basic.makeStableArray SingleType a
tp ArrayData a
arr
      UT task -> StableArray -> IO ()
forall task. UT task -> StableArray -> IO ()
delete HashTable RealWorld StableArray (Used task)
UT task
utbl StableArray
key
      MemoryTable (RemotePtr m) -> StableArray -> IO ()
forall (m :: * -> *).
RemoteMemory m =>
MemoryTable (RemotePtr m) -> StableArray -> IO ()
Basic.freeStable @m MemoryTable (RemotePtr m)
mt StableArray
key

-- | Record an association between a host-side array and a remote memory area
-- that was not allocated by accelerate. The remote memory will NOT be re-used
-- once the host-side array is garbage collected.
--
-- This typically only has use for backends that provide an FFI.
--
insertUnmanaged
    :: (HasCallStack, MonadIO m, RemoteMemory m)
    => MemoryTable (RemotePtr m) task
    -> SingleType e
    -> ArrayData e
    -> RemotePtr m (ScalarArrayDataR e)
    -> m ()
insertUnmanaged :: MemoryTable (RemotePtr m) task
-> SingleType e
-> ArrayData e
-> RemotePtr m (ScalarArrayDataR e)
-> m ()
insertUnmanaged (MemoryTable MemoryTable (RemotePtr m)
mt UseTable task
ref Weak (UseTable task)
weak_utbl) !SingleType e
tp !ArrayData e
arr !RemotePtr m (ScalarArrayDataR e)
ptr | SingleArrayDict e
SingleArrayDict <- SingleType e -> SingleArrayDict e
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType e
tp = do -- Gives evidence that ArrayData e ~ ScalarArrayData e
  StableArray
key <- SingleType e -> ArrayData e -> m StableArray
forall (m :: * -> *) a.
MonadIO m =>
SingleType a -> ArrayData a -> m StableArray
Basic.makeStableArray SingleType e
tp ArrayData e
arr
  ()  <- MemoryTable (RemotePtr m)
-> SingleType e
-> ArrayData e
-> RemotePtr m (ScalarArrayDataR e)
-> m ()
forall (m :: * -> *) a.
(MonadIO m, RemoteMemory m) =>
MemoryTable (RemotePtr m)
-> SingleType a
-> ArrayData a
-> RemotePtr m (ScalarArrayDataR a)
-> m ()
Basic.insertUnmanaged MemoryTable (RemotePtr m)
mt SingleType e
tp ArrayData e
arr RemotePtr m (ScalarArrayDataR e)
ptr
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
    (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ MVar (HashTable RealWorld StableArray (Used task))
-> (HashTable RealWorld StableArray (Used task) -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (HashTable RealWorld StableArray (Used task))
UseTable task
ref
    ((HashTable RealWorld StableArray (Used task) -> IO ()) -> IO ())
-> (HashTable RealWorld StableArray (Used task) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \HashTable RealWorld StableArray (Used task)
utbl -> do
      Timestamp
ts        <- IO Timestamp
getCPUTime
      Weak (UniqueArray e)
weak_arr  <- SingleType e
-> ArrayData e
-> UniqueArray e
-> Maybe (IO ())
-> IO (Weak (UniqueArray e))
forall e c.
SingleType e -> ArrayData e -> c -> Maybe (IO ()) -> IO (Weak c)
makeWeakArrayData SingleType e
tp ArrayData e
arr UniqueArray e
ArrayData e
arr (IO () -> Maybe (IO ())
forall a. a -> Maybe a
Just (IO () -> Maybe (IO ())) -> IO () -> Maybe (IO ())
forall a b. (a -> b) -> a -> b
$ StableArray -> Weak (UseTable task) -> IO ()
forall task. StableArray -> Weak (UseTable task) -> IO ()
finalizer StableArray
key Weak (UseTable task)
weak_utbl)
      IOHashTable HashTable StableArray (Used task)
-> StableArray -> Used task -> IO ()
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
HT.insert HashTable RealWorld StableArray (Used task)
IOHashTable HashTable StableArray (Used task)
utbl StableArray
key (Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
forall e task.
(ArrayData e ~ ScalarArrayData e) =>
Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
Used Timestamp
ts Status
Unmanaged Int
0 [] Int
0 SingleType e
tp Weak (UniqueArray e)
Weak (ScalarArrayData e)
weak_arr)


-- Removing entries
-- ----------------

finalizer :: StableArray -> Weak (UseTable task) -> IO ()
finalizer :: StableArray -> Weak (UseTable task) -> IO ()
finalizer !StableArray
key !Weak (UseTable task)
weak_utbl = do
  Maybe (MVar (HashTable RealWorld StableArray (Used task)))
mref <- Weak (MVar (HashTable RealWorld StableArray (Used task)))
-> IO (Maybe (MVar (HashTable RealWorld StableArray (Used task))))
forall v. Weak v -> IO (Maybe v)
deRefWeak Weak (MVar (HashTable RealWorld StableArray (Used task)))
Weak (UseTable task)
weak_utbl
  case Maybe (MVar (HashTable RealWorld StableArray (Used task)))
mref of
    Maybe (MVar (HashTable RealWorld StableArray (Used task)))
Nothing  -> String -> IO ()
forall (m :: * -> *). MonadIO m => String -> m ()
message String
"finalize cache/dead table"
    Just MVar (HashTable RealWorld StableArray (Used task))
ref -> String -> IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
trace  (String
"finalize cache: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ StableArray -> String
forall a. Show a => a -> String
show StableArray
key) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar (HashTable RealWorld StableArray (Used task))
-> (HashTable RealWorld StableArray (Used task) -> IO ()) -> IO ()
forall (m :: * -> *) a b.
(MonadIO m, MonadMask m) =>
MVar a -> (a -> m b) -> m b
withMVar' MVar (HashTable RealWorld StableArray (Used task))
ref (UT task -> StableArray -> IO ()
forall task. UT task -> StableArray -> IO ()
`delete` StableArray
key)

delete :: UT task -> StableArray -> IO ()
delete :: UT task -> StableArray -> IO ()
delete = UT task -> StableArray -> IO ()
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO ()
HT.delete


-- | Initiate garbage collection and `free` any remote arrays that no longer
-- have matching host-side equivalents.
--
reclaim
    :: forall m task. (HasCallStack, RemoteMemory m, MonadIO m)
    => MemoryTable (RemotePtr m) task
    -> m ()
reclaim :: MemoryTable (RemotePtr m) task -> m ()
reclaim (MemoryTable !MemoryTable (RemotePtr m)
mt UseTable task
_ Weak (UseTable task)
_) = MemoryTable (RemotePtr m) -> m ()
forall (m :: * -> *).
(RemoteMemory m, MonadIO m) =>
MemoryTable (RemotePtr m) -> m ()
Basic.reclaim MemoryTable (RemotePtr m)
mt

cache_finalizer :: UT task -> IO ()
cache_finalizer :: UT task -> IO ()
cache_finalizer !UT task
tbl
  = String -> IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
trace String
"cache finaliser"
  (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ((StableArray, Used task) -> IO ()) -> UT task -> IO ()
forall (h :: * -> * -> * -> *) k v a.
HashTable h =>
((k, v) -> IO a) -> IOHashTable h k v -> IO ()
HT.mapM_ (\(StableArray
_,Used task
u) -> Used task -> IO ()
forall task. Used task -> IO ()
f Used task
u) UT task
tbl
  where
    f :: Used task -> IO ()
    f :: Used task -> IO ()
f (Used Timestamp
_ Status
_ Int
_ [task]
_ Int
_ SingleType e
_ Weak (ScalarArrayData e)
w) = Weak (ScalarArrayData e) -> IO ()
forall v. Weak v -> IO ()
finalize Weak (ScalarArrayData e)
w

-- Miscellaneous
-- -------------

cleanUses :: Task task => [task] -> IO [task]
cleanUses :: [task] -> IO [task]
cleanUses = (task -> IO Bool) -> [task] -> IO [task]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM ((Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not (IO Bool -> IO Bool) -> (task -> IO Bool) -> task -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. task -> IO Bool
forall task. Task task => task -> IO Bool
completed)

incCount :: Used task -> Used task
incCount :: Used task -> Used task
incCount (Used Timestamp
ts Status
status Int
count [task]
uses Int
n SingleType e
tp Weak (ScalarArrayData e)
weak_arr) = Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
forall e task.
(ArrayData e ~ ScalarArrayData e) =>
Timestamp
-> Status
-> Int
-> [task]
-> Int
-> SingleType e
-> Weak (ScalarArrayData e)
-> Used task
Used Timestamp
ts Status
status (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [task]
uses Int
n SingleType e
tp Weak (ScalarArrayData e)
weak_arr

isEvicted :: Used task -> Bool
isEvicted :: Used task -> Bool
isEvicted (Used Timestamp
_ Status
status Int
_ [task]
_ Int
_ SingleType e
_ Weak (ScalarArrayData e)
_) = Status
status Status -> Status -> Bool
forall a. Eq a => a -> a -> Bool
== Status
Evicted

{-# INLINE withMVar' #-}
withMVar' :: (MonadIO m, MonadMask m) => MVar a -> (a -> m b) -> m b
withMVar' :: MVar a -> (a -> m b) -> m b
withMVar' MVar a
m a -> m b
f =
  ((forall a. m a -> m a) -> m b) -> m b
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m b) -> m b)
-> ((forall a. m a -> m a) -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
    a
a <- MVar a -> m a
forall (m :: * -> *) a. (MonadIO m, MonadMask m) => MVar a -> m a
takeMVar' MVar a
m
    b
b <- m b -> m b
forall a. m a -> m a
restore (a -> m b
f a
a) m b -> m () -> m b
forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` MVar a -> a -> m ()
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
MVar a -> a -> m ()
putMVar' MVar a
m a
a
    MVar a -> a -> m ()
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
MVar a -> a -> m ()
putMVar' MVar a
m a
a
    b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
b

{-# INLINE putMVar' #-}
putMVar' :: (MonadIO m, MonadMask m) => MVar a -> a -> m ()
putMVar' :: MVar a -> a -> m ()
putMVar' MVar a
m a
a = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar a
m a
a)

{-# INLINE takeMVar' #-}
takeMVar' :: (MonadIO m, MonadMask m) => MVar a -> m a
takeMVar' :: MVar a -> m a
takeMVar' MVar a
m = IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (MVar a -> IO a
forall a. MVar a -> IO a
takeMVar MVar a
m)


-- Debug
-- -----

{-# INLINE trace #-}
trace :: MonadIO m => String -> m a -> m a
trace :: String -> m a -> m a
trace String
msg m a
next = String -> m ()
forall (m :: * -> *). MonadIO m => String -> m ()
message String
msg m () -> m a -> m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m a
next

{-# INLINE message #-}
message :: MonadIO m => String -> m ()
message :: String -> m ()
message String
msg = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Flag -> String -> IO ()
D.traceIO Flag
D.dump_gc (String
"gc: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg)