{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.LLVM.PTX.Execute.Async (
module Data.Array.Accelerate.LLVM.Execute.Async,
module Data.Array.Accelerate.LLVM.PTX.Execute.Async,
) where
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.Execute.Async
import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.LLVM.PTX.Execute.Event ( Event )
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream ( Stream )
import Data.Array.Accelerate.LLVM.PTX.Link.Object ( FunctionTable )
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event as Event
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Stream as Stream
import Control.Monad.State
import Control.Monad.Reader
import Data.IORef
{-# INLINE evalPar #-}
evalPar :: Par PTX a -> LLVM PTX a
evalPar :: Par PTX a -> LLVM PTX a
evalPar Par PTX a
p = do
Stream
s <- LLVM PTX Stream
Stream.create
a
r <- ReaderT ParState (LLVM PTX) a -> ParState -> LLVM PTX a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Par PTX a -> ReaderT ParState (LLVM PTX) a
forall a. Par PTX a -> ReaderT ParState (LLVM PTX) a
runPar Par PTX a
p) (Stream
s, Maybe (Lifetime FunctionTable)
forall a. Maybe a
Nothing)
a -> LLVM PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
type ParState = (Stream, Maybe (Lifetime FunctionTable))
ptxStream :: ParState -> Stream
ptxStream :: ParState -> Stream
ptxStream = ParState -> Stream
forall a b. (a, b) -> a
fst
ptxKernel :: ParState -> Maybe (Lifetime FunctionTable)
ptxKernel :: ParState -> Maybe (Lifetime FunctionTable)
ptxKernel = ParState -> Maybe (Lifetime FunctionTable)
forall a b. (a, b) -> b
snd
data Future a = Future {-# UNPACK #-} !(IORef (IVar a))
data IVar a
= Full !a
| Pending {-# UNPACK #-} !Event !(Maybe (Lifetime FunctionTable)) !a
| Empty
instance Async PTX where
type FutureR PTX = Future
newtype Par PTX a = Par { Par PTX a -> ReaderT ParState (LLVM PTX) a
runPar :: ReaderT ParState (LLVM PTX) a }
deriving ( a -> Par PTX b -> Par PTX a
(a -> b) -> Par PTX a -> Par PTX b
(forall a b. (a -> b) -> Par PTX a -> Par PTX b)
-> (forall a b. a -> Par PTX b -> Par PTX a) -> Functor (Par PTX)
forall a b. a -> Par PTX b -> Par PTX a
forall a b. (a -> b) -> Par PTX a -> Par PTX b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Par PTX b -> Par PTX a
$c<$ :: forall a b. a -> Par PTX b -> Par PTX a
fmap :: (a -> b) -> Par PTX a -> Par PTX b
$cfmap :: forall a b. (a -> b) -> Par PTX a -> Par PTX b
Functor, Functor (Par PTX)
a -> Par PTX a
Functor (Par PTX)
-> (forall a. a -> Par PTX a)
-> (forall a b. Par PTX (a -> b) -> Par PTX a -> Par PTX b)
-> (forall a b c.
(a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c)
-> (forall a b. Par PTX a -> Par PTX b -> Par PTX b)
-> (forall a b. Par PTX a -> Par PTX b -> Par PTX a)
-> Applicative (Par PTX)
Par PTX a -> Par PTX b -> Par PTX b
Par PTX a -> Par PTX b -> Par PTX a
Par PTX (a -> b) -> Par PTX a -> Par PTX b
(a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c
forall a. a -> Par PTX a
forall a b. Par PTX a -> Par PTX b -> Par PTX a
forall a b. Par PTX a -> Par PTX b -> Par PTX b
forall a b. Par PTX (a -> b) -> Par PTX a -> Par PTX b
forall a b c. (a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: Par PTX a -> Par PTX b -> Par PTX a
$c<* :: forall a b. Par PTX a -> Par PTX b -> Par PTX a
*> :: Par PTX a -> Par PTX b -> Par PTX b
$c*> :: forall a b. Par PTX a -> Par PTX b -> Par PTX b
liftA2 :: (a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c
$cliftA2 :: forall a b c. (a -> b -> c) -> Par PTX a -> Par PTX b -> Par PTX c
<*> :: Par PTX (a -> b) -> Par PTX a -> Par PTX b
$c<*> :: forall a b. Par PTX (a -> b) -> Par PTX a -> Par PTX b
pure :: a -> Par PTX a
$cpure :: forall a. a -> Par PTX a
$cp1Applicative :: Functor (Par PTX)
Applicative, Applicative (Par PTX)
a -> Par PTX a
Applicative (Par PTX)
-> (forall a b. Par PTX a -> (a -> Par PTX b) -> Par PTX b)
-> (forall a b. Par PTX a -> Par PTX b -> Par PTX b)
-> (forall a. a -> Par PTX a)
-> Monad (Par PTX)
Par PTX a -> (a -> Par PTX b) -> Par PTX b
Par PTX a -> Par PTX b -> Par PTX b
forall a. a -> Par PTX a
forall a b. Par PTX a -> Par PTX b -> Par PTX b
forall a b. Par PTX a -> (a -> Par PTX b) -> Par PTX b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> Par PTX a
$creturn :: forall a. a -> Par PTX a
>> :: Par PTX a -> Par PTX b -> Par PTX b
$c>> :: forall a b. Par PTX a -> Par PTX b -> Par PTX b
>>= :: Par PTX a -> (a -> Par PTX b) -> Par PTX b
$c>>= :: forall a b. Par PTX a -> (a -> Par PTX b) -> Par PTX b
$cp1Monad :: Applicative (Par PTX)
Monad, Monad (Par PTX)
Monad (Par PTX)
-> (forall a. IO a -> Par PTX a) -> MonadIO (Par PTX)
IO a -> Par PTX a
forall a. IO a -> Par PTX a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
liftIO :: IO a -> Par PTX a
$cliftIO :: forall a. IO a -> Par PTX a
$cp1MonadIO :: Monad (Par PTX)
MonadIO, MonadReader ParState, MonadState PTX )
{-# INLINEABLE new #-}
{-# INLINEABLE newFull #-}
new :: Par PTX (FutureR PTX a)
new = IORef (IVar a) -> Future a
forall a. IORef (IVar a) -> Future a
Future (IORef (IVar a) -> Future a)
-> Par PTX (IORef (IVar a)) -> Par PTX (Future a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (IORef (IVar a)) -> Par PTX (IORef (IVar a))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IVar a -> IO (IORef (IVar a))
forall a. a -> IO (IORef a)
newIORef IVar a
forall a. IVar a
Empty)
newFull :: a -> Par PTX (FutureR PTX a)
newFull a
v = IORef (IVar a) -> Future a
forall a. IORef (IVar a) -> Future a
Future (IORef (IVar a) -> Future a)
-> Par PTX (IORef (IVar a)) -> Par PTX (Future a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (IORef (IVar a)) -> Par PTX (IORef (IVar a))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IVar a -> IO (IORef (IVar a))
forall a. a -> IO (IORef a)
newIORef (a -> IVar a
forall a. a -> IVar a
Full a
v))
{-# INLINEABLE spawn #-}
spawn :: Par PTX a -> Par PTX a
spawn Par PTX a
m = do
Stream
s' <- LLVM PTX Stream -> Par PTX Stream
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar LLVM PTX Stream
Stream.create
a
r <- (ParState -> ParState) -> Par PTX a -> Par PTX a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (ParState -> ParState -> ParState
forall a b. a -> b -> a
const (Stream
s', Maybe (Lifetime FunctionTable)
forall a. Maybe a
Nothing)) Par PTX a
m
IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Stream -> IO ()
Stream.destroy Stream
s')
a -> Par PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
{-# INLINEABLE fork #-}
fork :: Par PTX () -> Par PTX ()
fork Par PTX ()
m = do
Stream
s' <- LLVM PTX Stream -> Par PTX Stream
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (LLVM PTX Stream
Stream.create)
() <- (ParState -> ParState) -> Par PTX () -> Par PTX ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (ParState -> ParState -> ParState
forall a b. a -> b -> a
const (Stream
s', Maybe (Lifetime FunctionTable)
forall a. Maybe a
Nothing)) Par PTX ()
m
IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Stream -> IO ()
Stream.destroy Stream
s')
{-# INLINEABLE put #-}
put :: FutureR PTX a -> a -> Par PTX ()
put (Future ref) a
v = do
Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
Maybe (Lifetime FunctionTable)
kernel <- (ParState -> Maybe (Lifetime FunctionTable))
-> Par PTX (Maybe (Lifetime FunctionTable))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Maybe (Lifetime FunctionTable)
ptxKernel
Event
event <- LLVM PTX Event -> Par PTX Event
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (Stream -> LLVM PTX Event
Event.waypoint Stream
stream)
Bool
ready <- IO Bool -> Par PTX Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Event -> IO Bool
Event.query Event
event)
IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Par PTX ())
-> ((IVar a -> IVar a) -> IO ())
-> (IVar a -> IVar a)
-> Par PTX ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IORef (IVar a) -> (IVar a -> IVar a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (IVar a)
ref ((IVar a -> IVar a) -> Par PTX ())
-> (IVar a -> IVar a) -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ \case
IVar a
Empty -> if Bool
ready then a -> IVar a
forall a. a -> IVar a
Full a
v
else Event -> Maybe (Lifetime FunctionTable) -> a -> IVar a
forall a. Event -> Maybe (Lifetime FunctionTable) -> a -> IVar a
Pending Event
event Maybe (Lifetime FunctionTable)
kernel a
v
IVar a
_ -> String -> IVar a
forall a. HasCallStack => String -> a
internalError String
"multiple put"
{-# INLINEABLE get #-}
get :: FutureR PTX a -> Par PTX a
get (Future ref) = do
Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
IO a -> Par PTX a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> Par PTX a) -> IO a -> Par PTX a
forall a b. (a -> b) -> a -> b
$ do
IVar a
ivar <- IORef (IVar a) -> IO (IVar a)
forall a. IORef a -> IO a
readIORef IORef (IVar a)
ref
case IVar a
ivar of
Full a
v -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
Pending Event
event Maybe (Lifetime FunctionTable)
k a
v -> do
Bool
ready <- Event -> IO Bool
Event.query Event
event
if Bool
ready
then do
IORef (IVar a) -> IVar a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IVar a)
ref (a -> IVar a
forall a. a -> IVar a
Full a
v)
case Maybe (Lifetime FunctionTable)
k of
Just Lifetime FunctionTable
f -> Lifetime FunctionTable -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime Lifetime FunctionTable
f
Maybe (Lifetime FunctionTable)
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else
Event -> Stream -> IO ()
Event.after Event
event Stream
stream
a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
IVar a
Empty -> String -> IO a
forall a. HasCallStack => String -> a
internalError String
"blocked on an IVar"
{-# INLINEABLE block #-}
block :: FutureR PTX a -> Par PTX a
block = IO a -> Par PTX a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> Par PTX a) -> (Future a -> IO a) -> Future a -> Par PTX a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Future a -> IO a
forall a. Future a -> IO a
wait
{-# INLINE liftPar #-}
liftPar :: LLVM PTX a -> Par PTX a
liftPar = ReaderT ParState (LLVM PTX) a -> Par PTX a
forall a. ReaderT ParState (LLVM PTX) a -> Par PTX a
Par (ReaderT ParState (LLVM PTX) a -> Par PTX a)
-> (LLVM PTX a -> ReaderT ParState (LLVM PTX) a)
-> LLVM PTX a
-> Par PTX a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LLVM PTX a -> ReaderT ParState (LLVM PTX) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
{-# INLINEABLE wait #-}
wait :: Future a -> IO a
wait :: Future a -> IO a
wait (Future IORef (IVar a)
ref) = do
IVar a
ivar <- IORef (IVar a) -> IO (IVar a)
forall a. IORef a -> IO a
readIORef IORef (IVar a)
ref
case IVar a
ivar of
Full a
v -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
Pending Event
event Maybe (Lifetime FunctionTable)
k a
v -> do
Event -> IO ()
Event.block Event
event
IORef (IVar a) -> IVar a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IVar a)
ref (a -> IVar a
forall a. a -> IVar a
Full a
v)
case Maybe (Lifetime FunctionTable)
k of
Just Lifetime FunctionTable
f -> Lifetime FunctionTable -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime Lifetime FunctionTable
f
Maybe (Lifetime FunctionTable)
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
IVar a
Empty -> String -> IO a
forall a. HasCallStack => String -> a
internalError String
"blocked on an IVar"