{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Data.Array.Accelerate.LLVM.PTX (
Acc, Arrays,
Afunction, AfunctionR,
run, runWith,
run1, run1With,
runN, runNWith,
stream, streamWith,
Async,
wait, poll, cancel,
runAsync, runAsyncWith,
run1Async, run1AsyncWith,
runNAsync, runNAsyncWith,
runQ, runQWith,
runQAsync, runQAsyncWith,
PTX, createTargetForDevice, createTargetFromContext,
registerPinnedAllocatorWith,
) where
import Data.Array.Accelerate.AST ( PreOpenAfun(..), arraysR, liftALeftHandSide )
import Data.Array.Accelerate.AST.LeftHandSide ( lhsToTupR )
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Async ( Async, asyncBound, wait, poll, cancel )
import Data.Array.Accelerate.Debug
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array ( liftArraysR )
import Data.Array.Accelerate.Smart ( Acc )
import Data.Array.Accelerate.Sugar.Array ( Arrays, toArr, fromArr, ArraysR )
import Data.Array.Accelerate.Trafo
import Data.Array.Accelerate.Trafo.Delayed
import Data.Array.Accelerate.Trafo.Sharing ( Afunction(..), AfunctionRepr(..), afunctionRepr )
import qualified Data.Array.Accelerate.Sugar.Array as Sugar
import Data.Array.Accelerate.LLVM.PTX.Array.Data
import Data.Array.Accelerate.LLVM.PTX.Compile
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Embed
import Data.Array.Accelerate.LLVM.PTX.Execute
import Data.Array.Accelerate.LLVM.PTX.Execute.Async ( Par, evalPar, getArrays )
import Data.Array.Accelerate.LLVM.PTX.Execute.Environment
import Data.Array.Accelerate.LLVM.PTX.Link
import Data.Array.Accelerate.LLVM.PTX.State
import Data.Array.Accelerate.LLVM.PTX.Target
import Foreign.CUDA.Driver as CUDA ( CUDAException, mallocHostForeignPtr )
import Control.Exception
import Control.Monad.Trans
import Data.Maybe
import System.IO.Unsafe
import Text.Printf
import qualified Language.Haskell.TH as TH
import qualified Language.Haskell.TH.Syntax as TH
run :: Arrays a => Acc a -> a
run :: Acc a -> a
run Acc a
a = IO a -> a
forall a. IO a -> a
unsafePerformIO (Acc a -> IO a
forall a. Arrays a => Acc a -> IO a
runIO Acc a
a)
runWith :: Arrays a => PTX -> Acc a -> a
runWith :: PTX -> Acc a -> a
runWith PTX
target Acc a
a = IO a -> a
forall a. IO a -> a
unsafePerformIO (PTX -> Acc a -> IO a
forall a. Arrays a => PTX -> Acc a -> IO a
runWithIO PTX
target Acc a
a)
runAsync :: Arrays a => Acc a -> IO (Async a)
runAsync :: Acc a -> IO (Async a)
runAsync Acc a
a = IO a -> IO (Async a)
forall a. IO a -> IO (Async a)
asyncBound (Acc a -> IO a
forall a. Arrays a => Acc a -> IO a
runIO Acc a
a)
runAsyncWith :: Arrays a => PTX -> Acc a -> IO (Async a)
runAsyncWith :: PTX -> Acc a -> IO (Async a)
runAsyncWith PTX
target Acc a
a = IO a -> IO (Async a)
forall a. IO a -> IO (Async a)
asyncBound (PTX -> Acc a -> IO a
forall a. Arrays a => PTX -> Acc a -> IO a
runWithIO PTX
target Acc a
a)
runIO :: Arrays a => Acc a -> IO a
runIO :: Acc a -> IO a
runIO Acc a
a = Pool PTX -> (PTX -> IO a) -> IO a
forall a b. Pool a -> (a -> IO b) -> IO b
withPool Pool PTX
defaultTargetPool (\PTX
target -> PTX -> Acc a -> IO a
forall a. Arrays a => PTX -> Acc a -> IO a
runWithIO PTX
target Acc a
a)
runWithIO :: forall a. Arrays a => PTX -> Acc a -> IO a
runWithIO :: PTX -> Acc a -> IO a
runWithIO PTX
target Acc a
a = IO a
execute
where
!acc :: DelayedAcc (ArraysR a)
acc = Acc a -> DelayedAcc (ArraysR a)
forall arrs. Acc arrs -> DelayedAcc (ArraysR arrs)
convertAcc Acc a
a
execute :: IO a
execute = do
DelayedAcc (ArraysR a) -> IO ()
forall (m :: * -> *) g. (MonadIO m, PrettyGraph g) => g -> m ()
dumpGraph DelayedAcc (ArraysR a)
acc
PTX -> LLVM PTX a -> IO a
forall a. PTX -> LLVM PTX a -> IO a
evalPTX PTX
target (LLVM PTX a -> IO a) -> LLVM PTX a -> IO a
forall a b. (a -> b) -> a -> b
$ do
CompiledAcc PTX (ArraysR a)
build <- String
-> LLVM PTX (CompiledAcc PTX (ArraysR a))
-> LLVM PTX (CompiledAcc PTX (ArraysR a))
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"compile" (DelayedAcc (ArraysR a) -> LLVM PTX (CompiledAcc PTX (ArraysR a))
forall arch a.
(HasCallStack, Compile arch) =>
DelayedAcc a -> LLVM arch (CompiledAcc arch a)
compileAcc DelayedAcc (ArraysR a)
acc) LLVM PTX (CompiledAcc PTX (ArraysR a))
-> (CompiledAcc PTX (ArraysR a)
-> LLVM PTX (CompiledAcc PTX (ArraysR a)))
-> LLVM PTX (CompiledAcc PTX (ArraysR a))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CompiledAcc PTX (ArraysR a)
-> LLVM PTX (CompiledAcc PTX (ArraysR a))
forall (m :: * -> *) a. MonadIO m => a -> m a
dumpStats
ExecAcc PTX (ArraysR a)
exec <- String
-> LLVM PTX (ExecAcc PTX (ArraysR a))
-> LLVM PTX (ExecAcc PTX (ArraysR a))
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"link" (CompiledAcc PTX (ArraysR a) -> LLVM PTX (ExecAcc PTX (ArraysR a))
forall arch a.
Link arch =>
CompiledAcc arch a -> LLVM arch (ExecAcc arch a)
linkAcc CompiledAcc PTX (ArraysR a)
build)
ArraysR a
res <- String -> LLVM PTX (ArraysR a) -> LLVM PTX (ArraysR a)
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"execute" (Par PTX (ArraysR a) -> LLVM PTX (ArraysR a)
forall a. Par PTX a -> LLVM PTX a
evalPar (ExecAcc PTX (ArraysR a) -> Par PTX (FutureArraysR PTX (ArraysR a))
forall arch a.
Execute arch =>
ExecAcc arch a -> Par arch (FutureArraysR arch a)
executeAcc ExecAcc PTX (ArraysR a)
exec Par PTX (FutureArraysR PTX (ArraysR a))
-> (FutureArraysR PTX (ArraysR a) -> Par PTX (ArraysR a))
-> Par PTX (ArraysR a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ArraysR (ArraysR a)
-> FutureArraysR PTX (ArraysR a) -> Par PTX (ArraysR a)
forall arrs.
HasCallStack =>
ArraysR arrs -> FutureArraysR PTX arrs -> Par PTX arrs
copyToHostLazy (Arrays a => ArraysR (ArraysR a)
forall a. Arrays a => ArraysR (ArraysR a)
Sugar.arraysR @a)))
a -> LLVM PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> LLVM PTX a) -> a -> LLVM PTX a
forall a b. (a -> b) -> a -> b
$ ArraysR a -> a
forall a. Arrays a => ArraysR a -> a
toArr ArraysR a
res
run1 :: (Arrays a, Arrays b) => (Acc a -> Acc b) -> a -> b
run1 :: (Acc a -> Acc b) -> a -> b
run1 = (Acc a -> Acc b) -> a -> b
forall f. Afunction f => f -> AfunctionR f
runN
run1With :: (Arrays a, Arrays b) => PTX -> (Acc a -> Acc b) -> a -> b
run1With :: PTX -> (Acc a -> Acc b) -> a -> b
run1With = PTX -> (Acc a -> Acc b) -> a -> b
forall f. Afunction f => PTX -> f -> AfunctionR f
runNWith
runN :: forall f. Afunction f => f -> AfunctionR f
runN :: f -> AfunctionR f
runN f
f = AfunctionR f
exec
where
!acc :: DelayedAfun (ArraysFunctionR f)
acc = f -> DelayedAfun (ArraysFunctionR f)
forall f. Afunction f => f -> DelayedAfun (ArraysFunctionR f)
convertAfun f
f
!exec :: AfunctionR f
exec = Pool PTX -> (PTX -> AfunctionR f) -> AfunctionR f
forall a b. Pool a -> (a -> b) -> b
unsafeWithPool Pool PTX
defaultTargetPool
((PTX -> AfunctionR f) -> AfunctionR f)
-> (PTX -> AfunctionR f) -> AfunctionR f
forall a b. (a -> b) -> a -> b
$ \PTX
target -> Maybe (AfunctionR f) -> AfunctionR f
forall a. HasCallStack => Maybe a -> a
fromJust (Context -> [(Context, AfunctionR f)] -> Maybe (AfunctionR f)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (PTX -> Context
ptxContext PTX
target) [(Context, AfunctionR f)]
afun)
!afun :: [(Context, AfunctionR f)]
afun = ((PTX -> (Context, AfunctionR f))
-> [PTX] -> [(Context, AfunctionR f)])
-> [PTX]
-> (PTX -> (Context, AfunctionR f))
-> [(Context, AfunctionR f)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (PTX -> (Context, AfunctionR f))
-> [PTX] -> [(Context, AfunctionR f)]
forall a b. (a -> b) -> [a] -> [b]
map (Pool PTX -> [PTX]
forall a. Pool a -> [a]
unmanaged Pool PTX
defaultTargetPool)
((PTX -> (Context, AfunctionR f)) -> [(Context, AfunctionR f)])
-> (PTX -> (Context, AfunctionR f)) -> [(Context, AfunctionR f)]
forall a b. (a -> b) -> a -> b
$ \PTX
target -> (PTX -> Context
ptxContext PTX
target, PTX -> DelayedAfun (ArraysFunctionR f) -> AfunctionR f
forall f.
Afunction f =>
PTX -> DelayedAfun (ArraysFunctionR f) -> AfunctionR f
runNWith' @f PTX
target DelayedAfun (ArraysFunctionR f)
acc)
runNWith :: forall f. Afunction f => PTX -> f -> AfunctionR f
runNWith :: PTX -> f -> AfunctionR f
runNWith PTX
target f
f = AfunctionR f
exec
where
!acc :: DelayedAfun (ArraysFunctionR f)
acc = f -> DelayedAfun (ArraysFunctionR f)
forall f. Afunction f => f -> DelayedAfun (ArraysFunctionR f)
convertAfun f
f
!exec :: AfunctionR f
exec = PTX -> DelayedAfun (ArraysFunctionR f) -> AfunctionR f
forall f.
Afunction f =>
PTX -> DelayedAfun (ArraysFunctionR f) -> AfunctionR f
runNWith' @f PTX
target DelayedAfun (ArraysFunctionR f)
acc
runNWith' :: forall f. Afunction f => PTX -> DelayedAfun (ArraysFunctionR f) -> AfunctionR f
runNWith' :: PTX -> DelayedAfun (ArraysFunctionR f) -> AfunctionR f
runNWith' PTX
target DelayedAfun (ArraysFunctionR f)
acc = AfunctionRepr f (AfunctionR f) (ArraysFunctionR f)
-> ExecOpenAfun PTX () (ArraysFunctionR f)
-> Par PTX (Val ())
-> AfunctionR f
forall aenv t r trepr.
AfunctionRepr t r trepr
-> ExecOpenAfun PTX aenv trepr -> Par PTX (Val aenv) -> r
go ((Afunction f, HasCallStack) =>
AfunctionRepr f (AfunctionR f) (ArraysFunctionR f)
forall f.
(Afunction f, HasCallStack) =>
AfunctionRepr f (AfunctionR f) (ArraysFunctionR f)
afunctionRepr @f) ExecOpenAfun PTX () (ArraysFunctionR f)
afun (Val () -> Par PTX (Val ())
forall (m :: * -> *) a. Monad m => a -> m a
return Val ()
forall arch. ValR arch ()
Empty)
where
!afun :: ExecOpenAfun PTX () (ArraysFunctionR f)
afun = IO (ExecOpenAfun PTX () (ArraysFunctionR f))
-> ExecOpenAfun PTX () (ArraysFunctionR f)
forall a. IO a -> a
unsafePerformIO (IO (ExecOpenAfun PTX () (ArraysFunctionR f))
-> ExecOpenAfun PTX () (ArraysFunctionR f))
-> IO (ExecOpenAfun PTX () (ArraysFunctionR f))
-> ExecOpenAfun PTX () (ArraysFunctionR f)
forall a b. (a -> b) -> a -> b
$ do
DelayedAfun (ArraysFunctionR f) -> IO ()
forall (m :: * -> *) g. (MonadIO m, PrettyGraph g) => g -> m ()
dumpGraph DelayedAfun (ArraysFunctionR f)
acc
PTX
-> LLVM PTX (ExecOpenAfun PTX () (ArraysFunctionR f))
-> IO (ExecOpenAfun PTX () (ArraysFunctionR f))
forall a. PTX -> LLVM PTX a -> IO a
evalPTX PTX
target (LLVM PTX (ExecOpenAfun PTX () (ArraysFunctionR f))
-> IO (ExecOpenAfun PTX () (ArraysFunctionR f)))
-> LLVM PTX (ExecOpenAfun PTX () (ArraysFunctionR f))
-> IO (ExecOpenAfun PTX () (ArraysFunctionR f))
forall a b. (a -> b) -> a -> b
$ do
CompiledAfun PTX (ArraysFunctionR f)
build <- String
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"compile" (DelayedAfun (ArraysFunctionR f)
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
forall arch f.
(HasCallStack, Compile arch) =>
DelayedAfun f -> LLVM arch (CompiledAfun arch f)
compileAfun DelayedAfun (ArraysFunctionR f)
acc) LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
-> (CompiledAfun PTX (ArraysFunctionR f)
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f)))
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CompiledAfun PTX (ArraysFunctionR f)
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
forall (m :: * -> *) a. MonadIO m => a -> m a
dumpStats
ExecOpenAfun PTX () (ArraysFunctionR f)
link <- String
-> LLVM PTX (ExecOpenAfun PTX () (ArraysFunctionR f))
-> LLVM PTX (ExecOpenAfun PTX () (ArraysFunctionR f))
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"link" (CompiledAfun PTX (ArraysFunctionR f)
-> LLVM PTX (ExecOpenAfun PTX () (ArraysFunctionR f))
forall arch f.
Link arch =>
CompiledAfun arch f -> LLVM arch (ExecAfun arch f)
linkAfun CompiledAfun PTX (ArraysFunctionR f)
build)
ExecOpenAfun PTX () (ArraysFunctionR f)
-> LLVM PTX (ExecOpenAfun PTX () (ArraysFunctionR f))
forall (m :: * -> *) a. Monad m => a -> m a
return ExecOpenAfun PTX () (ArraysFunctionR f)
link
go :: forall aenv t r trepr.
AfunctionRepr t r trepr
-> ExecOpenAfun PTX aenv trepr
-> Par PTX (Val aenv)
-> r
go :: AfunctionRepr t r trepr
-> ExecOpenAfun PTX aenv trepr -> Par PTX (Val aenv) -> r
go (AfunctionReprLam AfunctionRepr b br breprr
repr) (Alam ALeftHandSide a aenv aenv'
lhs PreOpenAfun (ExecOpenAcc PTX) aenv' t1
l) Par PTX (Val aenv)
k = \ !a
arrs ->
let k' :: Par PTX (ValR PTX aenv')
k' = do Val aenv
aenv <- Par PTX (Val aenv)
k
FutureArraysR PTX (ArraysR a)
a <- ArraysR a -> a -> Par PTX (FutureArraysR PTX a)
forall arch arrs.
Remote arch =>
ArraysR arrs -> arrs -> Par arch (FutureArraysR arch arrs)
useRemoteAsync (ALeftHandSide a aenv aenv' -> ArraysR a
forall (s :: * -> *) v env env'.
LeftHandSide s v env env' -> TupR s v
lhsToTupR ALeftHandSide a aenv aenv'
lhs) (a -> Par PTX (FutureArraysR PTX a))
-> a -> Par PTX (FutureArraysR PTX a)
forall a b. (a -> b) -> a -> b
$ a -> ArraysR a
forall a. Arrays a => a -> ArraysR a
fromArr a
arrs
ValR PTX aenv' -> Par PTX (ValR PTX aenv')
forall (m :: * -> *) a. Monad m => a -> m a
return (Val aenv
aenv Val aenv
-> (ALeftHandSide a aenv aenv', FutureArraysR PTX a)
-> ValR PTX aenv'
forall arch env t env'.
ValR arch env
-> (ALeftHandSide t env env', FutureArraysR arch t)
-> ValR arch env'
`push` (ALeftHandSide a aenv aenv'
lhs, FutureArraysR PTX a
FutureArraysR PTX (ArraysR a)
a))
in AfunctionRepr b br breprr
-> ExecOpenAfun PTX aenv' breprr -> Par PTX (ValR PTX aenv') -> br
forall aenv t r trepr.
AfunctionRepr t r trepr
-> ExecOpenAfun PTX aenv trepr -> Par PTX (Val aenv) -> r
go AfunctionRepr b br breprr
repr ExecOpenAfun PTX aenv' breprr
PreOpenAfun (ExecOpenAcc PTX) aenv' t1
l Par PTX (ValR PTX aenv')
k'
go AfunctionRepr t r trepr
AfunctionReprBody (Abody ExecOpenAcc PTX aenv trepr
b) Par PTX (Val aenv)
k = IO r -> r
forall a. IO a -> a
unsafePerformIO (IO r -> r) -> (Par PTX r -> IO r) -> Par PTX r -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO r -> IO r
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"execute" (IO r -> IO r) -> (Par PTX r -> IO r) -> Par PTX r -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PTX -> LLVM PTX r -> IO r
forall a. PTX -> LLVM PTX a -> IO a
evalPTX PTX
target (LLVM PTX r -> IO r)
-> (Par PTX r -> LLVM PTX r) -> Par PTX r -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Par PTX r -> LLVM PTX r
forall a. Par PTX a -> LLVM PTX a
evalPar (Par PTX r -> r) -> Par PTX r -> r
forall a b. (a -> b) -> a -> b
$ do
Val aenv
aenv <- Par PTX (Val aenv)
k
FutureArraysR PTX trepr
fut <- ExecOpenAcc PTX aenv trepr
-> Val aenv -> Par PTX (FutureArraysR PTX trepr)
forall arch aenv arrs.
Execute arch =>
ExecOpenAcc arch aenv arrs
-> ValR arch aenv -> Par arch (FutureArraysR arch arrs)
executeOpenAcc ExecOpenAcc PTX aenv trepr
b Val aenv
aenv
trepr -> r
forall a. Arrays a => ArraysR a -> a
toArr (trepr -> r) -> Par PTX trepr -> Par PTX r
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArraysR trepr -> FutureArraysR PTX trepr -> Par PTX trepr
forall arrs.
HasCallStack =>
ArraysR arrs -> FutureArraysR PTX arrs -> Par PTX arrs
copyToHostLazy (Arrays r => ArraysR (ArraysR r)
forall a. Arrays a => ArraysR (ArraysR a)
Sugar.arraysR @r) FutureArraysR PTX trepr
fut
go AfunctionRepr t r trepr
_ ExecOpenAfun PTX aenv trepr
_ Par PTX (Val aenv)
_ = String -> r
forall a. HasCallStack => String -> a
error String
"But that's not right, oh, no, what's the story?"
run1Async :: (Arrays a, Arrays b) => (Acc a -> Acc b) -> a -> IO (Async b)
run1Async :: (Acc a -> Acc b) -> a -> IO (Async b)
run1Async = (Acc a -> Acc b) -> a -> IO (Async b)
forall f r.
(Afunction f, RunAsync r, ArraysFunctionR f ~ RunAsyncR r) =>
f -> r
runNAsync
run1AsyncWith :: (Arrays a, Arrays b) => PTX -> (Acc a -> Acc b) -> a -> IO (Async b)
run1AsyncWith :: PTX -> (Acc a -> Acc b) -> a -> IO (Async b)
run1AsyncWith = PTX -> (Acc a -> Acc b) -> a -> IO (Async b)
forall f r.
(Afunction f, RunAsync r, ArraysFunctionR f ~ RunAsyncR r) =>
PTX -> f -> r
runNAsyncWith
runNAsync :: (Afunction f, RunAsync r, ArraysFunctionR f ~ RunAsyncR r) => f -> r
runNAsync :: f -> r
runNAsync f
f = r
exec
where
!acc :: DelayedAfun (ArraysFunctionR f)
acc = f -> DelayedAfun (ArraysFunctionR f)
forall f. Afunction f => f -> DelayedAfun (ArraysFunctionR f)
convertAfun f
f
!exec :: r
exec = Pool PTX -> (PTX -> r) -> r
forall a b. Pool a -> (a -> b) -> b
unsafeWithPool Pool PTX
defaultTargetPool
((PTX -> r) -> r) -> (PTX -> r) -> r
forall a b. (a -> b) -> a -> b
$ \PTX
target -> Maybe r -> r
forall a. HasCallStack => Maybe a -> a
fromJust (Context -> [(Context, r)] -> Maybe r
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (PTX -> Context
ptxContext PTX
target) [(Context, r)]
afun)
!afun :: [(Context, r)]
afun = ((PTX -> (Context, r)) -> [PTX] -> [(Context, r)])
-> [PTX] -> (PTX -> (Context, r)) -> [(Context, r)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (PTX -> (Context, r)) -> [PTX] -> [(Context, r)]
forall a b. (a -> b) -> [a] -> [b]
map (Pool PTX -> [PTX]
forall a. Pool a -> [a]
unmanaged Pool PTX
defaultTargetPool)
((PTX -> (Context, r)) -> [(Context, r)])
-> (PTX -> (Context, r)) -> [(Context, r)]
forall a b. (a -> b) -> a -> b
$ \PTX
target -> (PTX -> Context
ptxContext PTX
target, PTX -> DelayedAfun (RunAsyncR r) -> r
forall f. RunAsync f => PTX -> DelayedAfun (RunAsyncR f) -> f
runNAsyncWith' PTX
target DelayedAfun (ArraysFunctionR f)
DelayedAfun (RunAsyncR r)
acc)
runNAsyncWith :: (Afunction f, RunAsync r, ArraysFunctionR f ~ RunAsyncR r) => PTX -> f -> r
runNAsyncWith :: PTX -> f -> r
runNAsyncWith PTX
target f
f = r
exec
where
!acc :: DelayedAfun (ArraysFunctionR f)
acc = f -> DelayedAfun (ArraysFunctionR f)
forall f. Afunction f => f -> DelayedAfun (ArraysFunctionR f)
convertAfun f
f
!exec :: r
exec = PTX -> DelayedAfun (RunAsyncR r) -> r
forall f. RunAsync f => PTX -> DelayedAfun (RunAsyncR f) -> f
runNAsyncWith' PTX
target DelayedAfun (ArraysFunctionR f)
DelayedAfun (RunAsyncR r)
acc
runNAsyncWith' :: RunAsync f => PTX -> DelayedAfun (RunAsyncR f) -> f
runNAsyncWith' :: PTX -> DelayedAfun (RunAsyncR f) -> f
runNAsyncWith' PTX
target DelayedAfun (RunAsyncR f)
acc = f
exec
where
!afun :: ExecAfun PTX (RunAsyncR f)
afun = IO (ExecAfun PTX (RunAsyncR f)) -> ExecAfun PTX (RunAsyncR f)
forall a. IO a -> a
unsafePerformIO (IO (ExecAfun PTX (RunAsyncR f)) -> ExecAfun PTX (RunAsyncR f))
-> IO (ExecAfun PTX (RunAsyncR f)) -> ExecAfun PTX (RunAsyncR f)
forall a b. (a -> b) -> a -> b
$ do
DelayedAfun (RunAsyncR f) -> IO ()
forall (m :: * -> *) g. (MonadIO m, PrettyGraph g) => g -> m ()
dumpGraph DelayedAfun (RunAsyncR f)
acc
PTX
-> LLVM PTX (ExecAfun PTX (RunAsyncR f))
-> IO (ExecAfun PTX (RunAsyncR f))
forall a. PTX -> LLVM PTX a -> IO a
evalPTX PTX
target (LLVM PTX (ExecAfun PTX (RunAsyncR f))
-> IO (ExecAfun PTX (RunAsyncR f)))
-> LLVM PTX (ExecAfun PTX (RunAsyncR f))
-> IO (ExecAfun PTX (RunAsyncR f))
forall a b. (a -> b) -> a -> b
$ do
CompiledAfun PTX (RunAsyncR f)
build <- String
-> LLVM PTX (CompiledAfun PTX (RunAsyncR f))
-> LLVM PTX (CompiledAfun PTX (RunAsyncR f))
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"compile" (DelayedAfun (RunAsyncR f)
-> LLVM PTX (CompiledAfun PTX (RunAsyncR f))
forall arch f.
(HasCallStack, Compile arch) =>
DelayedAfun f -> LLVM arch (CompiledAfun arch f)
compileAfun DelayedAfun (RunAsyncR f)
acc) LLVM PTX (CompiledAfun PTX (RunAsyncR f))
-> (CompiledAfun PTX (RunAsyncR f)
-> LLVM PTX (CompiledAfun PTX (RunAsyncR f)))
-> LLVM PTX (CompiledAfun PTX (RunAsyncR f))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CompiledAfun PTX (RunAsyncR f)
-> LLVM PTX (CompiledAfun PTX (RunAsyncR f))
forall (m :: * -> *) a. MonadIO m => a -> m a
dumpStats
ExecAfun PTX (RunAsyncR f)
link <- String
-> LLVM PTX (ExecAfun PTX (RunAsyncR f))
-> LLVM PTX (ExecAfun PTX (RunAsyncR f))
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"link" (CompiledAfun PTX (RunAsyncR f)
-> LLVM PTX (ExecAfun PTX (RunAsyncR f))
forall arch f.
Link arch =>
CompiledAfun arch f -> LLVM arch (ExecAfun arch f)
linkAfun CompiledAfun PTX (RunAsyncR f)
build)
ExecAfun PTX (RunAsyncR f) -> LLVM PTX (ExecAfun PTX (RunAsyncR f))
forall (m :: * -> *) a. Monad m => a -> m a
return ExecAfun PTX (RunAsyncR f)
link
!exec :: f
exec = PTX -> ExecAfun PTX (RunAsyncR f) -> Par PTX (Val ()) -> f
forall f aenv.
RunAsync f =>
PTX
-> ExecOpenAfun PTX aenv (RunAsyncR f) -> Par PTX (Val aenv) -> f
runAsync' PTX
target ExecAfun PTX (RunAsyncR f)
afun (Val () -> Par PTX (Val ())
forall (m :: * -> *) a. Monad m => a -> m a
return Val ()
forall arch. ValR arch ()
Empty)
class RunAsync f where
type RunAsyncR f
runAsync' :: PTX -> ExecOpenAfun PTX aenv (RunAsyncR f) -> Par PTX (Val aenv) -> f
instance (Arrays a, RunAsync b) => RunAsync (a -> b) where
type RunAsyncR (a -> b) = ArraysR a -> RunAsyncR b
runAsync' :: PTX
-> ExecOpenAfun PTX aenv (RunAsyncR (a -> b))
-> Par PTX (Val aenv)
-> a
-> b
runAsync' PTX
_ Abody{} Par PTX (Val aenv)
_ a
_ = String -> b
forall a. HasCallStack => String -> a
error String
"runAsync: function oversaturated"
runAsync' PTX
target (Alam ALeftHandSide a aenv aenv'
lhs PreOpenAfun (ExecOpenAcc PTX) aenv' t1
l) Par PTX (Val aenv)
k !a
arrs =
let k' :: Par PTX (ValR PTX aenv')
k' = do Val aenv
aenv <- Par PTX (Val aenv)
k
FutureArraysR PTX a
a <- ArraysR a -> a -> Par PTX (FutureArraysR PTX a)
forall arch arrs.
Remote arch =>
ArraysR arrs -> arrs -> Par arch (FutureArraysR arch arrs)
useRemoteAsync (Arrays a => ArraysR (ArraysR a)
forall a. Arrays a => ArraysR (ArraysR a)
Sugar.arraysR @a) (a -> Par PTX (FutureArraysR PTX a))
-> a -> Par PTX (FutureArraysR PTX a)
forall a b. (a -> b) -> a -> b
$ a -> ArraysR a
forall a. Arrays a => a -> ArraysR a
fromArr a
arrs
ValR PTX aenv' -> Par PTX (ValR PTX aenv')
forall (m :: * -> *) a. Monad m => a -> m a
return (Val aenv
aenv Val aenv
-> (ALeftHandSide a aenv aenv', FutureArraysR PTX a)
-> ValR PTX aenv'
forall arch env t env'.
ValR arch env
-> (ALeftHandSide t env env', FutureArraysR arch t)
-> ValR arch env'
`push` (ALeftHandSide a aenv aenv'
lhs, FutureArraysR PTX a
a))
in PTX
-> ExecOpenAfun PTX aenv' (RunAsyncR b)
-> Par PTX (ValR PTX aenv')
-> b
forall f aenv.
RunAsync f =>
PTX
-> ExecOpenAfun PTX aenv (RunAsyncR f) -> Par PTX (Val aenv) -> f
runAsync' PTX
target PreOpenAfun (ExecOpenAcc PTX) aenv' t1
ExecOpenAfun PTX aenv' (RunAsyncR b)
l Par PTX (ValR PTX aenv')
k'
instance Arrays b => RunAsync (IO (Async b)) where
type RunAsyncR (IO (Async b)) = ArraysR b
runAsync' :: PTX
-> ExecOpenAfun PTX aenv (RunAsyncR (IO (Async b)))
-> Par PTX (Val aenv)
-> IO (Async b)
runAsync' PTX
_ Alam{} Par PTX (Val aenv)
_ = String -> IO (Async b)
forall a. HasCallStack => String -> a
error String
"runAsync: function not fully applied"
runAsync' PTX
target (Abody ExecOpenAcc PTX aenv (RunAsyncR (IO (Async b)))
b) Par PTX (Val aenv)
k = IO b -> IO (Async b)
forall a. IO a -> IO (Async a)
asyncBound (IO b -> IO (Async b))
-> (Par PTX b -> IO b) -> Par PTX b -> IO (Async b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO b -> IO b
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"execute" (IO b -> IO b) -> (Par PTX b -> IO b) -> Par PTX b -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PTX -> LLVM PTX b -> IO b
forall a. PTX -> LLVM PTX a -> IO a
evalPTX PTX
target (LLVM PTX b -> IO b)
-> (Par PTX b -> LLVM PTX b) -> Par PTX b -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Par PTX b -> LLVM PTX b
forall a. Par PTX a -> LLVM PTX a
evalPar (Par PTX b -> IO (Async b)) -> Par PTX b -> IO (Async b)
forall a b. (a -> b) -> a -> b
$ do
Val aenv
aenv <- Par PTX (Val aenv)
k
FutureArraysR PTX (ArraysR b)
ans <- ExecOpenAcc PTX aenv (ArraysR b)
-> Val aenv -> Par PTX (FutureArraysR PTX (ArraysR b))
forall arch aenv arrs.
Execute arch =>
ExecOpenAcc arch aenv arrs
-> ValR arch aenv -> Par arch (FutureArraysR arch arrs)
executeOpenAcc ExecOpenAcc PTX aenv (ArraysR b)
ExecOpenAcc PTX aenv (RunAsyncR (IO (Async b)))
b Val aenv
aenv
ArraysR b
arrs <- ArraysR (ArraysR b)
-> FutureArraysR PTX (ArraysR b) -> Par PTX (ArraysR b)
forall arch a.
Async arch =>
ArraysR a -> FutureArraysR arch a -> Par arch a
getArrays (ExecOpenAcc PTX aenv (ArraysR b) -> ArraysR (ArraysR b)
forall (f :: * -> * -> *) aenv a.
HasArraysR f =>
f aenv a -> ArraysR a
arraysR ExecOpenAcc PTX aenv (ArraysR b)
ExecOpenAcc PTX aenv (RunAsyncR (IO (Async b)))
b) FutureArraysR PTX (ArraysR b)
ans
b -> Par PTX b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Par PTX b) -> b -> Par PTX b
forall a b. (a -> b) -> a -> b
$ ArraysR b -> b
forall a. Arrays a => ArraysR a -> a
toArr ArraysR b
arrs
stream :: (Arrays a, Arrays b) => (Acc a -> Acc b) -> [a] -> [b]
stream :: (Acc a -> Acc b) -> [a] -> [b]
stream Acc a -> Acc b
f [a]
arrs = (a -> b) -> [a] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map a -> b
go [a]
arrs
where
!go :: a -> b
go = (Acc a -> Acc b) -> a -> b
forall a b. (Arrays a, Arrays b) => (Acc a -> Acc b) -> a -> b
run1 Acc a -> Acc b
f
streamWith :: (Arrays a, Arrays b) => PTX -> (Acc a -> Acc b) -> [a] -> [b]
streamWith :: PTX -> (Acc a -> Acc b) -> [a] -> [b]
streamWith PTX
target Acc a -> Acc b
f [a]
arrs = (a -> b) -> [a] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map a -> b
go [a]
arrs
where
!go :: a -> b
go = PTX -> (Acc a -> Acc b) -> a -> b
forall a b.
(Arrays a, Arrays b) =>
PTX -> (Acc a -> Acc b) -> a -> b
run1With PTX
target Acc a -> Acc b
f
runQ :: Afunction f => f -> TH.ExpQ
runQ :: f -> ExpQ
runQ = ExpQ -> f -> ExpQ
forall f. Afunction f => ExpQ -> f -> ExpQ
runQ' [| unsafePerformIO |]
runQWith :: Afunction f => f -> TH.ExpQ
runQWith :: f -> ExpQ
runQWith f
f = do
Name
target <- String -> Q Name
TH.newName String
"target"
[PatQ] -> ExpQ -> ExpQ
TH.lamE [Name -> PatQ
TH.varP Name
target] (ExpQ -> ExpQ -> f -> ExpQ
forall f. Afunction f => ExpQ -> ExpQ -> f -> ExpQ
runQWith' [| unsafePerformIO |] (Name -> ExpQ
TH.varE Name
target) f
f)
runQAsync :: Afunction f => f -> TH.ExpQ
runQAsync :: f -> ExpQ
runQAsync = ExpQ -> f -> ExpQ
forall f. Afunction f => ExpQ -> f -> ExpQ
runQ' [| asyncBound |]
runQAsyncWith :: Afunction f => f -> TH.ExpQ
runQAsyncWith :: f -> ExpQ
runQAsyncWith f
f = do
Name
target <- String -> Q Name
TH.newName String
"target"
[PatQ] -> ExpQ -> ExpQ
TH.lamE [Name -> PatQ
TH.varP Name
target] (ExpQ -> ExpQ -> f -> ExpQ
forall f. Afunction f => ExpQ -> ExpQ -> f -> ExpQ
runQWith' [| asyncBound |] (Name -> ExpQ
TH.varE Name
target) f
f)
runQ' :: Afunction f => TH.ExpQ -> f -> TH.ExpQ
runQ' :: ExpQ -> f -> ExpQ
runQ' ExpQ
using = ExpQ -> (ExpQ -> ExpQ) -> f -> ExpQ
forall f. Afunction f => ExpQ -> (ExpQ -> ExpQ) -> f -> ExpQ
runQ'_ ExpQ
using (\ExpQ
go -> [| withPool defaultTargetPool (\target -> evalPTX target (evalPar $go)) |])
runQWith' :: Afunction f => TH.ExpQ -> TH.ExpQ -> f -> TH.ExpQ
runQWith' :: ExpQ -> ExpQ -> f -> ExpQ
runQWith' ExpQ
using ExpQ
target = ExpQ -> (ExpQ -> ExpQ) -> f -> ExpQ
forall f. Afunction f => ExpQ -> (ExpQ -> ExpQ) -> f -> ExpQ
runQ'_ ExpQ
using (\ExpQ
go -> [| evalPTX $target (evalPar $go) |])
runQ'_ :: Afunction f => TH.ExpQ -> (TH.ExpQ -> TH.ExpQ) -> f -> TH.ExpQ
runQ'_ :: ExpQ -> (ExpQ -> ExpQ) -> f -> ExpQ
runQ'_ ExpQ
using ExpQ -> ExpQ
k f
f = do
CompiledAfun PTX (ArraysFunctionR f)
afun <- let acc :: DelayedAfun (ArraysFunctionR f)
acc = f -> DelayedAfun (ArraysFunctionR f)
forall f. Afunction f => f -> DelayedAfun (ArraysFunctionR f)
convertAfun f
f
in IO (CompiledAfun PTX (ArraysFunctionR f))
-> Q (CompiledAfun PTX (ArraysFunctionR f))
forall a. IO a -> Q a
TH.runIO (IO (CompiledAfun PTX (ArraysFunctionR f))
-> Q (CompiledAfun PTX (ArraysFunctionR f)))
-> IO (CompiledAfun PTX (ArraysFunctionR f))
-> Q (CompiledAfun PTX (ArraysFunctionR f))
forall a b. (a -> b) -> a -> b
$ do
DelayedAfun (ArraysFunctionR f) -> IO ()
forall (m :: * -> *) g. (MonadIO m, PrettyGraph g) => g -> m ()
dumpGraph DelayedAfun (ArraysFunctionR f)
acc
PTX
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
-> IO (CompiledAfun PTX (ArraysFunctionR f))
forall a. PTX -> LLVM PTX a -> IO a
evalPTX PTX
defaultTarget (LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
-> IO (CompiledAfun PTX (ArraysFunctionR f)))
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
-> IO (CompiledAfun PTX (ArraysFunctionR f))
forall a b. (a -> b) -> a -> b
$
String
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
forall (m :: * -> *) a. MonadIO m => String -> m a -> m a
phase String
"compile" (DelayedAfun (ArraysFunctionR f)
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
forall arch f.
(HasCallStack, Compile arch) =>
DelayedAfun f -> LLVM arch (CompiledAfun arch f)
compileAfun DelayedAfun (ArraysFunctionR f)
acc) LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
-> (CompiledAfun PTX (ArraysFunctionR f)
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f)))
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CompiledAfun PTX (ArraysFunctionR f)
-> LLVM PTX (CompiledAfun PTX (ArraysFunctionR f))
forall (m :: * -> *) a. MonadIO m => a -> m a
dumpStats
let
go :: CompiledOpenAfun PTX aenv t -> [TH.PatQ] -> [TH.ExpQ] -> [TH.StmtQ] -> TH.ExpQ
go :: CompiledOpenAfun PTX aenv t -> [PatQ] -> [ExpQ] -> [StmtQ] -> ExpQ
go (Alam ALeftHandSide a aenv aenv'
lhs PreOpenAfun (CompiledOpenAcc PTX) aenv' t1
l) [PatQ]
xs [ExpQ]
as [StmtQ]
stmts = do
Name
x <- String -> Q Name
TH.newName String
"x"
Name
a <- String -> Q Name
TH.newName String
"a"
Stmt
s <- PatQ -> ExpQ -> StmtQ
TH.bindS (Name -> PatQ
TH.varP Name
a) [| useRemoteAsync $(TH.unTypeQ $ liftArraysR (lhsToTupR lhs)) (fromArr $(TH.varE x)) |]
PreOpenAfun (CompiledOpenAcc PTX) aenv' t1
-> [PatQ] -> [ExpQ] -> [StmtQ] -> ExpQ
forall aenv t.
CompiledOpenAfun PTX aenv t -> [PatQ] -> [ExpQ] -> [StmtQ] -> ExpQ
go PreOpenAfun (CompiledOpenAcc PTX) aenv' t1
l (PatQ -> PatQ
TH.bangP (Name -> PatQ
TH.varP Name
x) PatQ -> [PatQ] -> [PatQ]
forall a. a -> [a] -> [a]
: [PatQ]
xs) ([| ($(TH.unTypeQ $ liftALeftHandSide lhs), $(TH.varE a)) |] ExpQ -> [ExpQ] -> [ExpQ]
forall a. a -> [a] -> [a]
: [ExpQ]
as) (Stmt -> StmtQ
forall (m :: * -> *) a. Monad m => a -> m a
return Stmt
s StmtQ -> [StmtQ] -> [StmtQ]
forall a. a -> [a] -> [a]
: [StmtQ]
stmts)
go (Abody CompiledOpenAcc PTX aenv t
b) [PatQ]
xs [ExpQ]
as [StmtQ]
stmts = do
Name
r <- String -> Q Name
TH.newName String
"r"
Name
s <- String -> Q Name
TH.newName String
"s"
let
aenv :: ExpQ
aenv = (ExpQ -> ExpQ -> ExpQ) -> ExpQ -> [ExpQ] -> ExpQ
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\ExpQ
a ExpQ
gamma -> [| $gamma `push` $a |] ) [| Empty |] [ExpQ]
as
body :: Q (TExp (ExecOpenAcc PTX aenv t))
body = PTX
-> CompiledOpenAcc PTX aenv t -> Q (TExp (ExecOpenAcc PTX aenv t))
forall arch aenv arrs.
(HasCallStack, Embed arch) =>
arch
-> CompiledOpenAcc arch aenv arrs
-> Q (TExp (ExecOpenAcc arch aenv arrs))
embedOpenAcc PTX
defaultTarget CompiledOpenAcc PTX aenv t
b
[PatQ] -> ExpQ -> ExpQ
TH.lamE ([PatQ] -> [PatQ]
forall a. [a] -> [a]
reverse [PatQ]
xs)
[| $using (phase "execute" $(k (
TH.doE ( reverse stmts ++
[ TH.bindS (TH.varP r) [| executeOpenAcc $(TH.unTypeQ body) $aenv |]
, TH.bindS (TH.varP s) [| copyToHostLazy $(TH.unTypeQ (liftArraysR (arraysR b))) $(TH.varE r) |]
, TH.noBindS [| return $ toArr $(TH.varE s) |]
]))))
|]
CompiledAfun PTX (ArraysFunctionR f)
-> [PatQ] -> [ExpQ] -> [StmtQ] -> ExpQ
forall aenv t.
CompiledOpenAfun PTX aenv t -> [PatQ] -> [ExpQ] -> [StmtQ] -> ExpQ
go CompiledAfun PTX (ArraysFunctionR f)
afun [] [] []
registerPinnedAllocatorWith :: HasCallStack => PTX -> IO ()
registerPinnedAllocatorWith :: PTX -> IO ()
registerPinnedAllocatorWith PTX
target =
(Int -> IO (ForeignPtr Word8)) -> IO ()
registerForeignPtrAllocator ((Int -> IO (ForeignPtr Word8)) -> IO ())
-> (Int -> IO (ForeignPtr Word8)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
bytes ->
Context -> IO (ForeignPtr Word8) -> IO (ForeignPtr Word8)
forall a. Context -> IO a -> IO a
withContext (PTX -> Context
ptxContext PTX
target) ([AllocFlag] -> Int -> IO (ForeignPtr Word8)
forall a. Storable a => [AllocFlag] -> Int -> IO (ForeignPtr a)
CUDA.mallocHostForeignPtr [] Int
bytes)
IO (ForeignPtr Word8)
-> (CUDAException -> IO (ForeignPtr Word8))
-> IO (ForeignPtr Word8)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
\CUDAException
e -> String -> IO (ForeignPtr Word8)
forall a. HasCallStack => String -> a
internalError (CUDAException -> String
forall a. Show a => a -> String
show (CUDAException
e :: CUDAException))
dumpStats :: MonadIO m => a -> m a
dumpStats :: a -> m a
dumpStats a
x = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ()
dumpSimplStats m () -> m a -> m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
phase :: MonadIO m => String -> m a -> m a
phase :: String -> m a -> m a
phase String
n m a
go = Flag -> (Double -> Double -> String) -> m a -> m a
forall (m :: * -> *) a.
MonadIO m =>
Flag -> (Double -> Double -> String) -> m a -> m a
timed Flag
dump_phases (\Double
wall Double
cpu -> String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"phase %s: %s" String
n (Double -> Double -> String
elapsed Double
wall Double
cpu)) m a
go