{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.Execute (
Execute(..), Gamma,
executeAcc, executeAfun1,
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) )
import Data.Array.Accelerate.Array.Sugar hiding ( Foreign )
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Interpreter ( evalPrim, evalPrimConst, evalPrj )
import qualified Data.Array.Accelerate.Array.Sugar as S
import qualified Data.Array.Accelerate.Array.Representation as R
import Data.Array.Accelerate.LLVM.Array.Data
import Data.Array.Accelerate.LLVM.Compile
import Data.Array.Accelerate.LLVM.Foreign
import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.CodeGen.Environment ( Gamma )
import Data.Array.Accelerate.LLVM.Execute.Async hiding ( join )
import Data.Array.Accelerate.LLVM.Execute.Environment
import Control.Monad
import Control.Applicative hiding ( Const )
import Prelude hiding ( exp, map, unzip, scanl, scanr, scanl1, scanr1 )
class (Remote arch, Foreign arch) => Execute arch where
map :: (Shape sh, Elt b)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh
-> LLVM arch (Array sh b)
generate :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh
-> LLVM arch (Array sh e)
transform :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh
-> LLVM arch (Array sh e)
backpermute :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh
-> LLVM arch (Array sh e)
fold :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> LLVM arch (Array sh e)
fold1 :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> LLVM arch (Array sh e)
foldSeg :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> DIM1
-> LLVM arch (Array (sh:.Int) e)
fold1Seg :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> DIM1
-> LLVM arch (Array (sh:.Int) e)
scanl :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> LLVM arch (Array (sh:.Int) e)
scanl1 :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> LLVM arch (Array (sh:.Int) e)
scanl' :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> LLVM arch (Array (sh:.Int) e, Array sh e)
scanr :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> LLVM arch (Array (sh:.Int) e)
scanr1 :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> LLVM arch (Array (sh:.Int) e)
scanr' :: (Shape sh, Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> sh :. Int
-> LLVM arch (Array (sh:.Int) e, Array sh e)
permute :: (Shape sh, Shape sh', Elt e)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> Bool
-> sh
-> Array sh' e
-> LLVM arch (Array sh' e)
stencil1 :: (Shape sh, Elt a, Elt b)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> Array sh a
-> LLVM arch (Array sh b)
stencil2 :: (Shape sh, Elt a, Elt b, Elt c)
=> ExecutableR arch
-> Gamma aenv
-> AvalR arch aenv
-> StreamR arch
-> Array sh a
-> Array sh b
-> LLVM arch (Array sh c)
{-# INLINEABLE executeAcc #-}
executeAcc
:: forall arch a. Execute arch
=> ExecAcc arch a
-> LLVM arch a
executeAcc acc =
get =<< async (executeOpenAcc acc Aempty)
{-# INLINEABLE executeAfun1 #-}
executeAfun1
:: forall arch a b. (Execute arch, Arrays a)
=> ExecAfun arch (a -> b)
-> a
-> LLVM arch b
executeAfun1 afun arrs = do
AsyncR _ a <- async (useRemoteAsync arrs)
executeOpenAfun1 afun Aempty a
{-# INLINEABLE executeOpenAfun1 #-}
executeOpenAfun1
:: Execute arch
=> ExecOpenAfun arch aenv (a -> b)
-> AvalR arch aenv
-> AsyncR arch a
-> LLVM arch b
executeOpenAfun1 (Alam (Abody f)) aenv a = get =<< async (executeOpenAcc f (aenv `Apush` a))
executeOpenAfun1 _ _ _ = error "boop!"
{-# INLINEABLE executeOpenAcc #-}
executeOpenAcc
:: forall arch aenv arrs. Execute arch
=> ExecOpenAcc arch aenv arrs
-> AvalR arch aenv
-> StreamR arch
-> LLVM arch arrs
executeOpenAcc EmbedAcc{} _ _ =
$internalError "execute" "unexpected delayed array"
executeOpenAcc (ExecAcc kernel gamma pacc) aenv stream =
case pacc of
Use arr -> return (toArr arr)
Unit x -> newRemote Z . const =<< travE x
Avar ix -> do let AsyncR event arr = aprj ix aenv
after stream event
return arr
Alet bnd body -> do bnd' <- async (executeOpenAcc bnd aenv)
body' <- executeOpenAcc body (aenv `Apush` bnd') stream
return body'
Apply f a -> executeOpenAfun1 f aenv =<< async (executeOpenAcc a aenv)
Atuple tup -> toAtuple <$> travT tup
Aprj ix tup -> evalPrj ix . fromAtuple <$> travA tup
Acond p t e -> acond t e =<< travE p
Awhile p f a -> awhile p f =<< travA a
Aforeign asm _ a -> foreignA asm =<< travA a
Map _ a -> map kernel gamma aenv stream =<< extent a
Generate sh _ -> generate kernel gamma aenv stream =<< travE sh
Transform sh _ _ _ -> transform kernel gamma aenv stream =<< travE sh
Backpermute sh _ _ -> backpermute kernel gamma aenv stream =<< travE sh
Reshape sh a -> reshape <$> travE sh <*> travA a
Fold _ _ a -> fold kernel gamma aenv stream =<< extent a
Fold1 _ a -> fold1 kernel gamma aenv stream =<< extent a
FoldSeg _ _ a s -> join $ foldSeg kernel gamma aenv stream <$> extent a <*> extent s
Fold1Seg _ a s -> join $ fold1Seg kernel gamma aenv stream <$> extent a <*> extent s
Scanl _ _ a -> scanl kernel gamma aenv stream =<< extent a
Scanr _ _ a -> scanr kernel gamma aenv stream =<< extent a
Scanl1 _ a -> scanl1 kernel gamma aenv stream =<< extent a
Scanr1 _ a -> scanr1 kernel gamma aenv stream =<< extent a
Scanl' _ _ a -> scanl' kernel gamma aenv stream =<< extent a
Scanr' _ _ a -> scanr' kernel gamma aenv stream =<< extent a
Permute _ d _ a -> join $ permute kernel gamma aenv stream (inplace d) <$> extent a <*> travA d
Stencil _ _ a -> stencil1 kernel gamma aenv stream =<< travA a
Stencil2 _ _ a _ b -> join $ stencil2 kernel gamma aenv stream <$> travA a <*> travA b
Replicate{} -> fusionError
Slice{} -> fusionError
ZipWith{} -> fusionError
where
fusionError :: error
fusionError = $internalError "execute" $ "unexpected fusible material: " ++ showPreAccOp pacc
travA :: ExecOpenAcc arch aenv a -> LLVM arch a
travA acc = executeOpenAcc acc aenv stream
travE :: ExecExp arch aenv t -> LLVM arch t
travE exp = executeExp exp aenv stream
travT :: Atuple (ExecOpenAcc arch aenv) t -> LLVM arch t
travT NilAtup = return ()
travT (SnocAtup t a) = (,) <$> travT t <*> travA a
extent :: Shape sh => ExecOpenAcc arch aenv (Array sh e) -> LLVM arch sh
extent ExecAcc{} = $internalError "executeOpenAcc" "expected delayed array"
extent (EmbedAcc sh) = travE sh
extent (UnzipAcc _ ix) = let AsyncR _ a = aprj ix aenv
in return $ shape a
inplace :: ExecOpenAcc arch aenv a -> Bool
inplace (ExecAcc _ _ Avar{}) = False
inplace _ = True
reshape :: Shape sh => sh -> Array sh' e -> Array sh e
reshape sh (Array sh' adata)
= $boundsCheck "reshape" "shape mismatch" (size sh == R.size sh')
$ Array (fromElt sh) adata
acond :: ExecOpenAcc arch aenv a -> ExecOpenAcc arch aenv a -> Bool -> LLVM arch a
acond yes _ True = travA yes
acond _ no False = travA no
awhile :: ExecOpenAfun arch aenv (a -> Scalar Bool)
-> ExecOpenAfun arch aenv (a -> a)
-> a
-> LLVM arch a
awhile p f a = do
e <- checkpoint stream
r <- executeOpenAfun1 p aenv (AsyncR e a)
ok <- indexRemote r 0
if ok then awhile p f =<< executeOpenAfun1 f aenv (AsyncR e a)
else return a
foreignA :: (Arrays a, Arrays b, Foreign arch, S.Foreign asm)
=> asm (a -> b)
-> a
-> LLVM arch b
foreignA asm a =
case foreignAcc (undefined :: arch) asm of
Just f -> f stream a
Nothing -> $internalError "foreignA" "failed to recover foreign function the second time"
executeOpenAcc (UnzipAcc tup v) aenv stream = do
let AsyncR event arr = aprj v aenv
after stream event
return $ unzip tup arr
where
unzip :: forall t sh e. (Elt t, Elt e) => TupleIdx (TupleRepr t) e -> Array sh t -> Array sh e
unzip tix (Array sh adata) = Array sh $ go tix (eltType (undefined::t)) adata
where
go :: TupleIdx v e -> TupleType t' -> ArrayData t' -> ArrayData (EltRepr e)
go (SuccTupIdx ix) (PairTuple t _) (AD_Pair x _) = go ix t x
go ZeroTupIdx (PairTuple _ t) (AD_Pair _ x)
| Just Refl <- matchTupleType t (eltType (undefined::e)) = x
go _ _ _ = $internalError "unzip" "inconsistent valuation"
{-# INLINEABLE executeExp #-}
executeExp
:: Execute arch
=> ExecExp arch aenv t
-> AvalR arch aenv
-> StreamR arch
-> LLVM arch t
executeExp exp aenv stream = executeOpenExp exp Empty aenv stream
{-# INLINEABLE executeOpenExp #-}
executeOpenExp
:: forall arch env aenv exp. Execute arch
=> ExecOpenExp arch env aenv exp
-> Val env
-> AvalR arch aenv
-> StreamR arch
-> LLVM arch exp
executeOpenExp rootExp env aenv stream = travE rootExp
where
travE :: ExecOpenExp arch env aenv t -> LLVM arch t
travE exp = case exp of
Var ix -> return (prj ix env)
Let bnd body -> travE bnd >>= \x -> executeOpenExp body (env `Push` x) aenv stream
Const c -> return (toElt c)
PrimConst c -> return (evalPrimConst c)
PrimApp f x -> evalPrim f <$> travE x
Tuple t -> toTuple <$> travT t
Prj ix e -> evalPrj ix . fromTuple <$> travE e
Cond p t e -> travE p >>= \x -> if x then travE t else travE e
While p f x -> while p f =<< travE x
IndexAny -> return Any
IndexNil -> return Z
IndexCons sh sz -> (:.) <$> travE sh <*> travE sz
IndexHead sh -> (\(_ :. ix) -> ix) <$> travE sh
IndexTail sh -> (\(ix :. _) -> ix) <$> travE sh
IndexSlice ix slix sh -> indexSlice ix <$> travE slix <*> travE sh
IndexFull ix slix sl -> indexFull ix <$> travE slix <*> travE sl
ToIndex sh ix -> toIndex <$> travE sh <*> travE ix
FromIndex sh ix -> fromIndex <$> travE sh <*> travE ix
Intersect sh1 sh2 -> intersect <$> travE sh1 <*> travE sh2
Union sh1 sh2 -> union <$> travE sh1 <*> travE sh2
ShapeSize sh -> size <$> travE sh
Shape acc -> shape <$> travA acc
Index acc ix -> join $ index <$> travA acc <*> travE ix
LinearIndex acc ix -> join $ indexRemote <$> travA acc <*> travE ix
Foreign _ f x -> foreignE f x
travT :: Tuple (ExecOpenExp arch env aenv) t -> LLVM arch t
travT tup = case tup of
NilTup -> return ()
SnocTup t e -> (,) <$> travT t <*> travE e
travA :: ExecOpenAcc arch aenv a -> LLVM arch a
travA acc = executeOpenAcc acc aenv stream
foreignE :: ExecFun arch () (a -> b) -> ExecOpenExp arch env aenv a -> LLVM arch b
foreignE (Lam (Body f)) x = travE x >>= \e -> executeOpenExp f (Empty `Push` e) Aempty stream
foreignE _ _ = error "I bless the rains down in Africa"
travF1 :: ExecOpenFun arch env aenv (a -> b) -> a -> LLVM arch b
travF1 (Lam (Body f)) x = executeOpenExp f (env `Push` x) aenv stream
travF1 _ _ = error "LANAAAAAAAA!"
while :: ExecOpenFun arch env aenv (a -> Bool) -> ExecOpenFun arch env aenv (a -> a) -> a -> LLVM arch a
while p f x = do
ok <- travF1 p x
if ok then while p f =<< travF1 f x
else return x
indexSlice :: (Elt slix, Elt sh, Elt sl)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> slix
-> sh
-> sl
indexSlice ix slix sh = toElt $ restrict ix (fromElt slix) (fromElt sh)
where
restrict :: SliceIndex slix sl co sh -> slix -> sh -> sl
restrict SliceNil () () = ()
restrict (SliceAll sliceIdx) (slx, ()) (sl, sz) = (restrict sliceIdx slx sl, sz)
restrict (SliceFixed sliceIdx) (slx, _) (sl, _) = restrict sliceIdx slx sl
indexFull :: (Elt slix, Elt sh, Elt sl)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> slix
-> sl
-> sh
indexFull ix slix sl = toElt $ extend ix (fromElt slix) (fromElt sl)
where
extend :: SliceIndex slix sl co sh -> slix -> sl -> sh
extend SliceNil () () = ()
extend (SliceAll sliceIdx) (slx, ()) (sh, sz) = (extend sliceIdx slx sh, sz)
extend (SliceFixed sliceIdx) (slx, sz) sh = (extend sliceIdx slx sh, sz)
index :: Shape sh => Array sh e -> sh -> LLVM arch e
index arr ix = linearIndex arr (toIndex (shape arr) ix)
linearIndex :: Array sh e -> Int -> LLVM arch e
linearIndex arr ix = do
block =<< checkpoint stream
indexRemote arr ix