{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.Embed (
Embed(..),
embedAfun, embedOpenAfun,
embedOpenAcc,
) where
import LLVM.AST.Type.Name
import Data.Array.Accelerate.AST ( liftIdx, liftTupleIdx, liftArrays, liftConst, liftSliceIndex, liftPrimConst, liftPrimFun )
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.LLVM.AST
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.Compile
import Data.Array.Accelerate.LLVM.Link
import Data.Typeable
import Data.ByteString.Short ( ShortByteString )
import GHC.Ptr ( Ptr(..) )
import Language.Haskell.TH ( Q, TExp )
import System.IO.Unsafe
import qualified Data.ByteString.Short.Internal as BS
import qualified Language.Haskell.TH as TH
import qualified Language.Haskell.TH.Syntax as TH
#if MIN_VERSION_containers(0,5,9)
import qualified Data.IntMap.Internal as IM
#elif MIN_VERSION_containers(0,5,8)
import qualified Data.IntMap.Base as IM
#else
import qualified Data.IntMap as IM
#endif
class Embed arch where
embedForTarget
:: arch
-> ObjectR arch
-> Q (TExp (ExecutableR arch))
{-# INLINEABLE embedAfun #-}
embedAfun
:: (Embed arch, Typeable arch)
=> arch
-> CompiledAfun arch f
-> Q (TExp (ExecAfun arch f))
embedAfun = embedOpenAfun
{-# INLINEABLE embedOpenAfun #-}
embedOpenAfun
:: (Embed arch, Typeable arch, Typeable aenv)
=> arch
-> CompiledOpenAfun arch aenv f
-> Q (TExp (ExecOpenAfun arch aenv f))
embedOpenAfun arch (Alam l) = [|| Alam $$(embedOpenAfun arch l) ||]
embedOpenAfun arch (Abody b) = [|| Abody $$(embedOpenAcc arch b) ||]
{-# INLINEABLE embedOpenAcc #-}
embedOpenAcc
:: forall arch aenv arrs. (Embed arch, Typeable arch, Typeable aenv, Typeable arrs)
=> arch
-> CompiledOpenAcc arch aenv arrs
-> Q (TExp (ExecOpenAcc arch aenv arrs))
embedOpenAcc arch = liftA
where
liftA :: (Typeable aenv', Typeable arrs') => CompiledOpenAcc arch aenv' arrs' -> Q (TExp (ExecOpenAcc arch aenv' arrs'))
liftA (PlainAcc pacc) = withSigE [|| EvalAcc $$(liftPreOpenAccCommand arch pacc) ||]
liftA (BuildAcc aenv obj pacc) = withSigE [|| ExecAcc $$(liftGamma aenv) $$(embedForTarget arch obj) $$(liftPreOpenAccSkeleton arch pacc) ||]
liftGamma :: Gamma aenv' -> Q (TExp (Gamma aenv'))
#if MIN_VERSION_containers(0,5,8)
liftGamma IM.Nil = [|| IM.Nil ||]
liftGamma (IM.Bin p m l r) = [|| IM.Bin p m $$(liftGamma l) $$(liftGamma r) ||]
liftGamma (IM.Tip k v) = [|| IM.Tip k $$(liftV v) ||]
#else
liftGamma aenv = [|| IM.fromAscList $$(liftIM (IM.toAscList aenv)) ||]
where
liftIM :: [(Int, (Label, Idx' aenv'))] -> Q (TExp [(Int, (Label, Idx' aenv'))])
liftIM im = do
im' <- mapM (\(k,v) -> TH.unTypeQ [|| (k, $$(liftV v)) ||]) im
TH.unsafeTExpCoerce (return $ TH.ListE im')
#endif
liftV :: (Label, Idx' aenv') -> Q (TExp (Label, Idx' aenv'))
liftV (Label n, Idx' ix) = [|| (Label $$(liftSBS n), Idx' $$(liftIdx ix)) ||]
liftSBS :: ShortByteString -> Q (TExp ShortByteString)
liftSBS bs =
let bytes = BS.unpack bs
len = BS.length bs
in
[|| unsafePerformIO $ BS.createFromPtr $$( TH.unsafeTExpCoerce [| Ptr $(TH.litE (TH.StringPrimL bytes)) |]) len ||]
{-# INLINEABLE liftPreOpenAfun #-}
liftPreOpenAfun
:: (Embed arch, Typeable arch, Typeable aenv)
=> arch
-> PreOpenAfun (CompiledOpenAcc arch) aenv t
-> Q (TExp (PreOpenAfun (ExecOpenAcc arch) aenv t))
liftPreOpenAfun arch (Alam f) = [|| Alam $$(liftPreOpenAfun arch f) ||]
liftPreOpenAfun arch (Abody b) = [|| Abody $$(embedOpenAcc arch b) ||]
{-# INLINEABLE liftPreOpenAccCommand #-}
liftPreOpenAccCommand
:: forall arch aenv a. (Embed arch, Typeable arch, Typeable aenv)
=> arch
-> PreOpenAccCommand (CompiledOpenAcc arch) aenv a
-> Q (TExp (PreOpenAccCommand (ExecOpenAcc arch) aenv a))
liftPreOpenAccCommand arch pacc =
let
liftA :: (Typeable aenv', Typeable arrs) => CompiledOpenAcc arch aenv' arrs -> Q (TExp (ExecOpenAcc arch aenv' arrs))
liftA = embedOpenAcc arch
liftE :: PreExp (CompiledOpenAcc arch) aenv t -> Q (TExp (PreExp (ExecOpenAcc arch) aenv t))
liftE = liftPreOpenExp arch
liftAF :: PreOpenAfun (CompiledOpenAcc arch) aenv f -> Q (TExp (PreOpenAfun (ExecOpenAcc arch) aenv f))
liftAF = liftPreOpenAfun arch
liftAtuple :: Atuple (CompiledOpenAcc arch aenv) t -> Q (TExp (Atuple (ExecOpenAcc arch aenv) t))
liftAtuple NilAtup = [|| NilAtup ||]
liftAtuple (SnocAtup tup a) = [|| SnocAtup $$(liftAtuple tup) $$(liftA a) ||]
in
case pacc of
Avar ix -> [|| Avar $$(liftIdx ix) ||]
Alet bnd body -> [|| Alet $$(liftA bnd) $$(liftA body) ||]
Use a -> [|| Use $$(liftArrays (arrays (undefined::a)) a) ||]
Unit e -> [|| Unit $$(liftE e) ||]
Atuple tup -> [|| Atuple $$(liftAtuple tup) ||]
Aprj tix a -> [|| Aprj $$(liftTupleIdx tix) $$(liftA a) ||]
Apply f a -> [|| Apply $$(liftAF f) $$(liftA a) ||]
Aforeign asm a -> [|| Aforeign $$(liftForeign asm) $$(liftA a) ||]
Acond p t e -> [|| Acond $$(liftE p) $$(liftA t) $$(liftA e) ||]
Awhile p f a -> [|| Awhile $$(liftAF p) $$(liftAF f) $$(liftA a) ||]
Reshape sh ix -> [|| Reshape $$(liftE sh) $$(liftIdx ix) ||]
Unzip tix ix -> [|| Unzip $$(liftTupleIdx tix) $$(liftIdx ix) ||]
{-# INLINEABLE liftPreOpenAccSkeleton #-}
liftPreOpenAccSkeleton
:: forall arch aenv a. (Embed arch, Typeable arch, Typeable aenv)
=> arch
-> PreOpenAccSkeleton (CompiledOpenAcc arch) aenv a
-> Q (TExp (PreOpenAccSkeleton (ExecOpenAcc arch) aenv a))
liftPreOpenAccSkeleton arch pacc =
let
liftA :: Typeable arrs => CompiledOpenAcc arch aenv arrs -> Q (TExp (ExecOpenAcc arch aenv arrs))
liftA = embedOpenAcc arch
liftE :: PreExp (CompiledOpenAcc arch) aenv t -> Q (TExp (PreExp (ExecOpenAcc arch) aenv t))
liftE = liftPreOpenExp arch
in
case pacc of
Map sh -> [|| Map $$(liftE sh) ||]
Generate sh -> [|| Generate $$(liftE sh) ||]
Transform sh -> [|| Transform $$(liftE sh) ||]
Backpermute sh -> [|| Backpermute $$(liftE sh) ||]
Fold sh -> [|| Fold $$(liftE sh) ||]
Fold1 sh -> [|| Fold1 $$(liftE sh) ||]
FoldSeg sa ss -> [|| FoldSeg $$(liftE sa) $$(liftE ss) ||]
Fold1Seg sa ss -> [|| Fold1Seg $$(liftE sa) $$(liftE ss) ||]
Scanl sh -> [|| Scanl $$(liftE sh) ||]
Scanl1 sh -> [|| Scanl1 $$(liftE sh) ||]
Scanl' sh -> [|| Scanl' $$(liftE sh) ||]
Scanr sh -> [|| Scanr $$(liftE sh) ||]
Scanr1 sh -> [|| Scanr1 $$(liftE sh) ||]
Scanr' sh -> [|| Scanr' $$(liftE sh) ||]
Permute sh a -> [|| Permute $$(liftE sh) $$(liftA a) ||]
Stencil a -> [|| Stencil $$(liftIdx a) ||]
Stencil2 a1 a2 -> [|| Stencil2 $$(liftIdx a1) $$(liftIdx a2) ||]
{-# INLINEABLE liftPreOpenFun #-}
liftPreOpenFun
:: (Embed arch, Typeable arch, Typeable env, Typeable aenv)
=> arch
-> PreOpenFun (CompiledOpenAcc arch) env aenv t
-> Q (TExp (PreOpenFun (ExecOpenAcc arch) env aenv t))
liftPreOpenFun arch (Lam f) = [|| Lam $$(liftPreOpenFun arch f) ||]
liftPreOpenFun arch (Body b) = [|| Body $$(liftPreOpenExp arch b) ||]
{-# INLINEABLE liftPreOpenExp #-}
liftPreOpenExp
:: forall arch env aenv t. (Embed arch, Typeable arch, Typeable env, Typeable aenv)
=> arch
-> PreOpenExp (CompiledOpenAcc arch) env aenv t
-> Q (TExp (PreOpenExp (ExecOpenAcc arch) env aenv t))
liftPreOpenExp arch pexp =
let
liftA :: Typeable arrs => CompiledOpenAcc arch aenv arrs -> Q (TExp (ExecOpenAcc arch aenv arrs))
liftA = embedOpenAcc arch
liftE :: PreOpenExp (CompiledOpenAcc arch) env aenv e -> Q (TExp (PreOpenExp (ExecOpenAcc arch) env aenv e))
liftE = liftPreOpenExp arch
liftF :: PreOpenFun (CompiledOpenAcc arch) env aenv f -> Q (TExp (PreOpenFun (ExecOpenAcc arch) env aenv f))
liftF = liftPreOpenFun arch
liftT :: Tuple (PreOpenExp (CompiledOpenAcc arch) env aenv) e -> Q (TExp (Tuple (PreOpenExp (ExecOpenAcc arch) env aenv) e))
liftT NilTup = [|| NilTup ||]
liftT (SnocTup tup e) = [|| SnocTup $$(liftT tup) $$(liftE e) ||]
in
case pexp of
Let bnd body -> [|| Let $$(liftPreOpenExp arch bnd) $$(liftPreOpenExp arch body) ||]
Var ix -> [|| Var $$(liftIdx ix) ||]
Foreign asm f x -> [|| Foreign $$(liftForeign asm) $$(liftPreOpenFun arch f) $$(liftE x) ||]
Const c -> [|| Const $$(liftConst (eltType (undefined::t)) c) ||]
Tuple tup -> [|| Tuple $$(liftT tup) ||]
Prj tix e -> [|| Prj $$(liftTupleIdx tix) $$(liftE e) ||]
IndexNil -> [|| IndexNil ||]
IndexCons sh sz -> [|| IndexCons $$(liftE sh) $$(liftE sz) ||]
IndexHead sh -> [|| IndexHead $$(liftE sh) ||]
IndexTail sh -> [|| IndexTail $$(liftE sh) ||]
IndexAny -> [|| IndexAny ||]
IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||]
IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||]
ToIndex sh ix -> [|| ToIndex $$(liftE sh) $$(liftE ix) ||]
FromIndex sh ix -> [|| FromIndex $$(liftE sh) $$(liftE ix) ||]
Cond p t e -> [|| Cond $$(liftE p) $$(liftE t) $$(liftE e) ||]
While p f x -> [|| While $$(liftF p) $$(liftF f) $$(liftE x) ||]
PrimConst t -> [|| PrimConst $$(liftPrimConst t) ||]
PrimApp f x -> [|| PrimApp $$(liftPrimFun f) $$(liftE x) ||]
Index a ix -> [|| Index $$(liftA a) $$(liftE ix) ||]
LinearIndex a ix -> [|| LinearIndex $$(liftA a) $$(liftE ix) ||]
Shape a -> [|| Shape $$(liftA a) ||]
ShapeSize ix -> [|| ShapeSize $$(liftE ix) ||]
Intersect sh1 sh2 -> [|| Intersect $$(liftE sh1) $$(liftE sh2) ||]
Union sh1 sh2 -> [|| Union $$(liftE sh1) $$(liftE sh2) ||]
withSigE :: forall e. Typeable e => Q (TExp e) -> Q (TExp e)
withSigE e = e `sigE` typeRepToType (typeOf (undefined::e))
sigE :: Q (TExp t) -> Q TH.Type -> Q (TExp t)
sigE e t = TH.unsafeTExpCoerce $ TH.sigE (TH.unTypeQ e) t
typeRepToType :: TypeRep -> Q TH.Type
typeRepToType trep = do
let (con, args) = splitTyConApp trep
name = TH.Name (TH.OccName (tyConName con)) (TH.NameG TH.TcClsName (TH.PkgName (tyConPackage con)) (TH.ModName (tyConModule con)))
appsT x [] = x
appsT x (y:xs) = appsT (TH.AppT x y) xs
resultArgs <- mapM typeRepToType args
return (appsT (TH.ConT name) resultArgs)