{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.CodeGen.Base (
Name(..),
local, global,
irArray,
mutableArray,
call,
scalarParameter, ptrParameter,
envParam,
arrayParam,
) where
import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Constant
import LLVM.AST.Type.Global
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.LLVM.CodeGen.Downcast
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import qualified LLVM.AST.Global as LLVM
import qualified Data.IntMap as IM
local :: ScalarType a -> Name a -> IR a
local t x = ir t (LocalReference (PrimType (ScalarPrimType t)) x)
global :: ScalarType a -> Name a -> IR a
global t x = ir t (ConstantOperand (GlobalReference (PrimType (ScalarPrimType t)) x))
arrayName :: Name (Array sh e) -> Int -> Name e'
arrayName (Name n) i = Name (n ++ ".ad" ++ show i)
arrayName (UnName n) i = arrayName (Name (show n)) i
shapeName :: Name (Array sh e) -> Int -> Name sh'
shapeName (Name n) i = Name (n ++ ".sh" ++ show i)
shapeName (UnName n) i = shapeName (Name (show n)) i
{-# INLINEABLE irArray #-}
irArray
:: forall sh e. (Shape sh, Elt e)
=> Name (Array sh e)
-> IRArray (Array sh e)
irArray n
= IRArray (travTypeToIR (undefined::sh) (\t i -> LocalReference (PrimType (ScalarPrimType t)) (shapeName n i)))
(travTypeToIR (undefined::e) (\t i -> LocalReference (PrimType (ScalarPrimType t)) (arrayName n i)))
defaultAddrSpace
NonVolatile
{-# INLINEABLE mutableArray #-}
mutableArray
:: forall sh e. (Shape sh, Elt e)
=> Name (Array sh e)
-> (IRArray (Array sh e), [LLVM.Parameter])
mutableArray name =
( irArray name
, arrayParam name )
{-# INLINEABLE travTypeToList #-}
travTypeToList
:: forall t a. Elt t
=> t
-> (forall s. ScalarType s -> Int -> a)
-> [a]
travTypeToList t f = snd $ go (eltType t) 0
where
go :: TupleType s -> Int -> (Int, [a])
go UnitTuple i = (i, [])
go (SingleTuple t') i = (i+1, [f t' i])
go (PairTuple t2 t1) i = let (i1, r1) = go t1 i
(i2, r2) = go t2 i1
in
(i2, r2 ++ r1)
travTypeToIR
:: Elt t
=> t
-> (forall s. ScalarType s -> Int -> Operand s)
-> IR t
travTypeToIR t f = IR . snd $ go (eltType t) 0
where
go :: TupleType s -> Int -> (Int, Operands s)
go UnitTuple i = (i, OP_Unit)
go (SingleTuple t') i = (i+1, ir' t' $ f t' i)
go (PairTuple t2 t1) i = let (i1, r1) = go t1 i
(i2, r2) = go t2 i1
in
(i2, OP_Pair r2 r1)
call :: GlobalFunction args t -> [FunctionAttribute] -> CodeGen (IR t)
call f attrs = do
let decl = (downcast f) { LLVM.functionAttributes = downcast attrs' }
attrs' = map Right attrs
declare decl
instr (Call f attrs')
scalarParameter :: ScalarType t -> Name t -> LLVM.Parameter
scalarParameter t x = downcast (Parameter (ScalarPrimType t) x)
ptrParameter :: ScalarType t -> Name (Ptr t) -> LLVM.Parameter
ptrParameter t x = downcast (Parameter (PtrPrimType (ScalarPrimType t) defaultAddrSpace) x)
envParam :: forall aenv. Gamma aenv -> [LLVM.Parameter]
envParam aenv = concatMap (\(Label n, Idx' v) -> toParam v (Name n)) (IM.elems aenv)
where
toParam :: forall sh e. (Shape sh, Elt e) => Idx aenv (Array sh e) -> Name (Array sh e) -> [LLVM.Parameter]
toParam _ name = arrayParam name
{-# INLINEABLE arrayParam #-}
arrayParam
:: forall sh e. (Shape sh, Elt e)
=> Name (Array sh e)
-> [LLVM.Parameter]
arrayParam name = ad ++ sh
where
ad = travTypeToList (undefined :: e) (\t i -> ptrParameter t (arrayName name i))
sh = travTypeToList (undefined :: sh) (\t i -> scalarParameter t (shapeName name i))