{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase                 #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE TypeSynonymInstances       #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Execute.Async
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Execute.Async (

  module Data.Array.Accelerate.LLVM.Execute.Async,
  module Data.Array.Accelerate.LLVM.PTX.Execute.Async,

) where

-- accelerate
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Lifetime

import Data.Array.Accelerate.LLVM.Execute.Async
import Data.Array.Accelerate.LLVM.State

import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.LLVM.PTX.Execute.Event                 ( Event )
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream                ( Stream )
import Data.Array.Accelerate.LLVM.PTX.Link.Object                   ( FunctionTable )
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event       as Event
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Stream      as Stream

-- standard library
import Control.Monad.State
import Control.Monad.Reader
import Data.IORef


-- | Evaluate a parallel computation
--
{-# INLINE evalPar #-}
evalPar :: Par PTX a -> LLVM PTX a
evalPar :: Par PTX a -> LLVM PTX a
evalPar Par PTX a
p = do
  Stream
s <- LLVM PTX Stream
Stream.create
  a
r <- ReaderT ParState (LLVM PTX) a -> ParState -> LLVM PTX a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Par PTX a -> ReaderT ParState (LLVM PTX) a
forall a. Par PTX a -> ReaderT ParState (LLVM PTX) a
runPar Par PTX a
p) (Stream
s, Maybe (Lifetime FunctionTable)
forall a. Maybe a
Nothing)
  a -> LLVM PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r


type ParState = (Stream, Maybe (Lifetime FunctionTable))

ptxStream :: ParState -> Stream
ptxStream :: ParState -> Stream
ptxStream = ParState -> Stream
forall a b. (a, b) -> a
fst

ptxKernel :: ParState -> Maybe (Lifetime FunctionTable)
ptxKernel :: ParState -> Maybe (Lifetime FunctionTable)
ptxKernel = ParState -> Maybe (Lifetime FunctionTable)
forall a b. (a, b) -> b
snd


-- Implementation
-- --------------

data Future a = Future {-# UNPACK #-} !(IORef (IVar a))

data IVar a
    = Full !a
    | Pending {-# UNPACK #-} !Event !(Maybe (Lifetime FunctionTable)) !a
    | Empty


instance Async PTX where
  type FutureR PTX = Future

  newtype Par PTX a = Par { Par PTX a -> ReaderT ParState (LLVM PTX) a
runPar :: ReaderT ParState (LLVM PTX) a }
    deriving ( a -> Par PTX b -> Par PTX a
(a -> b) -> Par PTX a -> Par PTX b
(forall a b. (a -> b) -> Par PTX a -> Par PTX b)
-> (forall a b. a -> Par PTX b -> Par PTX a) -> Functor (Par PTX)
forall a b. a -> Par PTX b -> Par PTX a
forall a b. (a -> b) -> Par PTX a -> Par PTX b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Par PTX b -> Par PTX a
$c<$ :: forall a b. a -> Par PTX b -> Par PTX a
fmap :: (a -> b) -> Par PTX a -> Par PTX b
$cfmap :: forall a b. (a -> b) -> Par PTX a -> Par PTX b
Functor, Functor (Par PTX)
a -> Par PTX a
Functor (Par PTX)
-> (forall a. a -> Par PTX a)
-> (forall a b. Par PTX (a -> b) -> Par PTX a -> Par PTX b)
-> (forall a b c.
    (a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c)
-> (forall a b. Par PTX a -> Par PTX b -> Par PTX b)
-> (forall a b. Par PTX a -> Par PTX b -> Par PTX a)
-> Applicative (Par PTX)
Par PTX a -> Par PTX b -> Par PTX b
Par PTX a -> Par PTX b -> Par PTX a
Par PTX (a -> b) -> Par PTX a -> Par PTX b
(a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c
forall a. a -> Par PTX a
forall a b. Par PTX a -> Par PTX b -> Par PTX a
forall a b. Par PTX a -> Par PTX b -> Par PTX b
forall a b. Par PTX (a -> b) -> Par PTX a -> Par PTX b
forall a b c. (a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: Par PTX a -> Par PTX b -> Par PTX a
$c<* :: forall a b. Par PTX a -> Par PTX b -> Par PTX a
*> :: Par PTX a -> Par PTX b -> Par PTX b
$c*> :: forall a b. Par PTX a -> Par PTX b -> Par PTX b
liftA2 :: (a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c
$cliftA2 :: forall a b c. (a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c
<*> :: Par PTX (a -> b) -> Par PTX a -> Par PTX b
$c<*> :: forall a b. Par PTX (a -> b) -> Par PTX a -> Par PTX b
pure :: a -> Par PTX a
$cpure :: forall a. a -> Par PTX a
$cp1Applicative :: Functor (Par PTX)
Applicative, Applicative (Par PTX)
a -> Par PTX a
Applicative (Par PTX)
-> (forall a b. Par PTX a -> (a -> Par PTX b) -> Par PTX b)
-> (forall a b. Par PTX a -> Par PTX b -> Par PTX b)
-> (forall a. a -> Par PTX a)
-> Monad (Par PTX)
Par PTX a -> (a -> Par PTX b) -> Par PTX b
Par PTX a -> Par PTX b -> Par PTX b
forall a. a -> Par PTX a
forall a b. Par PTX a -> Par PTX b -> Par PTX b
forall a b. Par PTX a -> (a -> Par PTX b) -> Par PTX b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> Par PTX a
$creturn :: forall a. a -> Par PTX a
>> :: Par PTX a -> Par PTX b -> Par PTX b
$c>> :: forall a b. Par PTX a -> Par PTX b -> Par PTX b
>>= :: Par PTX a -> (a -> Par PTX b) -> Par PTX b
$c>>= :: forall a b. Par PTX a -> (a -> Par PTX b) -> Par PTX b
$cp1Monad :: Applicative (Par PTX)
Monad, Monad (Par PTX)
Monad (Par PTX)
-> (forall a. IO a -> Par PTX a) -> MonadIO (Par PTX)
IO a -> Par PTX a
forall a. IO a -> Par PTX a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
liftIO :: IO a -> Par PTX a
$cliftIO :: forall a. IO a -> Par PTX a
$cp1MonadIO :: Monad (Par PTX)
MonadIO, MonadReader ParState, MonadState PTX )

  {-# INLINEABLE new     #-}
  {-# INLINEABLE newFull #-}
  new :: Par PTX (FutureR PTX a)
new       = IORef (IVar a) -> Future a
forall a. IORef (IVar a) -> Future a
Future (IORef (IVar a) -> Future a)
-> Par PTX (IORef (IVar a)) -> Par PTX (Future a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (IORef (IVar a)) -> Par PTX (IORef (IVar a))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IVar a -> IO (IORef (IVar a))
forall a. a -> IO (IORef a)
newIORef IVar a
forall a. IVar a
Empty)
  newFull :: a -> Par PTX (FutureR PTX a)
newFull a
v = IORef (IVar a) -> Future a
forall a. IORef (IVar a) -> Future a
Future (IORef (IVar a) -> Future a)
-> Par PTX (IORef (IVar a)) -> Par PTX (Future a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (IORef (IVar a)) -> Par PTX (IORef (IVar a))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IVar a -> IO (IORef (IVar a))
forall a. a -> IO (IORef a)
newIORef (a -> IVar a
forall a. a -> IVar a
Full a
v))

  {-# INLINEABLE spawn #-}
  spawn :: Par PTX a -> Par PTX a
spawn Par PTX a
m = do
    Stream
s' <- LLVM PTX Stream -> Par PTX Stream
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar LLVM PTX Stream
Stream.create
    a
r  <- (ParState -> ParState) -> Par PTX a -> Par PTX a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (ParState -> ParState -> ParState
forall a b. a -> b -> a
const (Stream
s', Maybe (Lifetime FunctionTable)
forall a. Maybe a
Nothing)) Par PTX a
m
    IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Stream -> IO ()
Stream.destroy Stream
s')
    a -> Par PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

  {-# INLINEABLE fork #-}
  fork :: Par PTX () -> Par PTX ()
fork Par PTX ()
m = do
    Stream
s' <- LLVM PTX Stream -> Par PTX Stream
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (LLVM PTX Stream
Stream.create)
    () <- (ParState -> ParState) -> Par PTX () -> Par PTX ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (ParState -> ParState -> ParState
forall a b. a -> b -> a
const (Stream
s', Maybe (Lifetime FunctionTable)
forall a. Maybe a
Nothing)) Par PTX ()
m
    IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Stream -> IO ()
Stream.destroy Stream
s')

  -- When we call 'put' the actual work may not have been evaluated yet; get
  -- a new event in the current execution stream and once that is filled we can
  -- transition the IVar to Full.
  --
  {-# INLINEABLE put #-}
  put :: FutureR PTX a -> a -> Par PTX ()
put (Future ref) a
v = do
    Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
    Maybe (Lifetime FunctionTable)
kernel <- (ParState -> Maybe (Lifetime FunctionTable))
-> Par PTX (Maybe (Lifetime FunctionTable))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Maybe (Lifetime FunctionTable)
ptxKernel
    Event
event  <- LLVM PTX Event -> Par PTX Event
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (Stream -> LLVM PTX Event
Event.waypoint Stream
stream)
    Bool
ready  <- IO Bool -> Par PTX Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO  (Event -> IO Bool
Event.query Event
event)
    IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Par PTX ())
-> ((IVar a -> IVar a) -> IO ())
-> (IVar a -> IVar a)
-> Par PTX ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IORef (IVar a) -> (IVar a -> IVar a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (IVar a)
ref ((IVar a -> IVar a) -> Par PTX ())
-> (IVar a -> IVar a) -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ \case
      IVar a
Empty -> if Bool
ready then a -> IVar a
forall a. a -> IVar a
Full a
v
                        else Event -> Maybe (Lifetime FunctionTable) -> a -> IVar a
forall a. Event -> Maybe (Lifetime FunctionTable) -> a -> IVar a
Pending Event
event Maybe (Lifetime FunctionTable)
kernel a
v
      IVar a
_     -> String -> IVar a
forall a. HasCallStack => String -> a
internalError String
"multiple put"

  -- Get the value of Future. Since the actual cross-stream synchronisation
  -- happens on the device, we should never have to block/reschedule the main
  -- thread waiting on a value; if we get an empty IVar at this point, something
  -- has gone wrong.
  --
  {-# INLINEABLE get #-}
  get :: FutureR PTX a -> Par PTX a
get (Future ref) = do
    Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
    IO a -> Par PTX a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO  (IO a -> Par PTX a) -> IO a -> Par PTX a
forall a b. (a -> b) -> a -> b
$ do
      IVar a
ivar <- IORef (IVar a) -> IO (IVar a)
forall a. IORef a -> IO a
readIORef IORef (IVar a)
ref
      case IVar a
ivar of
        Full a
v            -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
        Pending Event
event Maybe (Lifetime FunctionTable)
k a
v -> do
          Bool
ready <- Event -> IO Bool
Event.query Event
event
          if Bool
ready
            then do
              IORef (IVar a) -> IVar a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IVar a)
ref (a -> IVar a
forall a. a -> IVar a
Full a
v)
              case Maybe (Lifetime FunctionTable)
k of
                Just Lifetime FunctionTable
f  -> Lifetime FunctionTable -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime Lifetime FunctionTable
f
                Maybe (Lifetime FunctionTable)
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            else
              Event -> Stream -> IO ()
Event.after Event
event Stream
stream
          a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
        IVar a
Empty           -> String -> IO a
forall a. HasCallStack => String -> a
internalError String
"blocked on an IVar"

  {-# INLINEABLE block #-}
  block :: FutureR PTX a -> Par PTX a
block = IO a -> Par PTX a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> Par PTX a) -> (Future a -> IO a) -> Future a -> Par PTX a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Future a -> IO a
forall a. Future a -> IO a
wait

  {-# INLINE liftPar #-}
  liftPar :: LLVM PTX a -> Par PTX a
liftPar = ReaderT ParState (LLVM PTX) a -> Par PTX a
forall a. ReaderT ParState (LLVM PTX) a -> Par PTX a
Par (ReaderT ParState (LLVM PTX) a -> Par PTX a)
-> (LLVM PTX a -> ReaderT ParState (LLVM PTX) a)
-> LLVM PTX a
-> Par PTX a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LLVM PTX a -> ReaderT ParState (LLVM PTX) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift


-- | Block the calling _host_ thread until the value offered by the future is
-- available.
--
{-# INLINEABLE wait #-}
wait :: Future a -> IO a
wait :: Future a -> IO a
wait (Future IORef (IVar a)
ref) = do
  IVar a
ivar <- IORef (IVar a) -> IO (IVar a)
forall a. IORef a -> IO a
readIORef IORef (IVar a)
ref
  case IVar a
ivar of
    Full a
v            -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
    Pending Event
event Maybe (Lifetime FunctionTable)
k a
v -> do
      Event -> IO ()
Event.block Event
event
      IORef (IVar a) -> IVar a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IVar a)
ref (a -> IVar a
forall a. a -> IVar a
Full a
v)
      case Maybe (Lifetime FunctionTable)
k of
        Just Lifetime FunctionTable
f  -> Lifetime FunctionTable -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime Lifetime FunctionTable
f
        Maybe (Lifetime FunctionTable)
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
    IVar a
Empty           -> String -> IO a
forall a. HasCallStack => String -> a
internalError String
"blocked on an IVar"