{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.CodeGen.Monad (
CodeGen,
runLLVM,
fresh, freshName,
declare,
intrinsic,
Block,
newBlock, setBlock, beginBlock, createBlocks,
instr, instr', do_, return_, retval_, br, cbr, phi, phi',
instr_,
addMetadata,
) where
import Control.Applicative
import Control.Monad.State
import Data.ByteString.Short ( ShortByteString )
import Data.Function
import Data.HashMap.Strict ( HashMap )
import Data.Map ( Map )
import Data.Sequence ( Seq )
import Data.String
import Data.Word
import Prelude
import Text.Printf
import qualified Data.Foldable as F
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Map as Map
import qualified Data.Sequence as Seq
import qualified Data.ByteString.Short as B
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Array.Sugar ( Elt, eltType )
import qualified Data.Array.Accelerate.Debug as Debug
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Metadata
import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import LLVM.AST.Type.Terminator
import Data.Array.Accelerate.LLVM.Target
import Data.Array.Accelerate.LLVM.CodeGen.Downcast
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Intrinsic
import Data.Array.Accelerate.LLVM.CodeGen.Module
import Data.Array.Accelerate.LLVM.CodeGen.Type
import Data.Array.Accelerate.LLVM.CodeGen.Sugar ( IROpenAcc(..) )
import qualified LLVM.AST as LLVM
import qualified LLVM.AST.Global as LLVM
data CodeGenState = CodeGenState
{ blockChain :: Seq Block
, symbolTable :: Map Label LLVM.Global
, metadataTable :: HashMap ShortByteString (Seq [Maybe Metadata])
, intrinsicTable :: HashMap ShortByteString Label
, next :: {-# UNPACK #-} !Word
}
data Block = Block
{ blockLabel :: {-# UNPACK #-} !Label
, instructions :: Seq (LLVM.Named LLVM.Instruction)
, terminator :: LLVM.Terminator
}
newtype CodeGen a = CodeGen { runCodeGen :: State CodeGenState a }
deriving (Functor, Applicative, Monad, MonadState CodeGenState)
{-# INLINEABLE runLLVM #-}
runLLVM
:: forall arch aenv a. (Target arch, Intrinsic arch)
=> CodeGen (IROpenAcc arch aenv a)
-> Module arch aenv a
runLLVM ll =
let
initialState = CodeGenState
{ blockChain = initBlockChain
, symbolTable = Map.empty
, metadataTable = HashMap.empty
, intrinsicTable = intrinsicForTarget (undefined::arch)
, next = 0
}
(kernels, md, st) = case runState (runCodeGen ll) initialState of
(IROpenAcc ks, s) -> let (fs, as) = unzip [ (f , (LLVM.name f, a)) | Kernel f a <- ks ]
in (fs, Map.fromList as, s)
definitions = map LLVM.GlobalDefinition (kernels ++ Map.elems (symbolTable st))
++ createMetadata (metadataTable st)
name | x:_ <- kernels
, f@LLVM.Function{} <- x
, LLVM.Name s <- LLVM.name f = s
| otherwise = "<undefined>"
in
Module { moduleMetadata = md
, unModule = LLVM.Module
{ LLVM.moduleName = name
, LLVM.moduleSourceFileName = B.empty
, LLVM.moduleDataLayout = targetDataLayout (undefined::arch)
, LLVM.moduleTargetTriple = targetTriple (undefined::arch)
, LLVM.moduleDefinitions = definitions
}
}
initBlockChain :: Seq Block
initBlockChain
= Seq.singleton
$ Block "entry" Seq.empty ($internalError "entry" "block has no terminator")
newBlock :: String -> CodeGen Block
newBlock nm =
state $ \s ->
let idx = Seq.length (blockChain s)
label = let (h,t) = break (== '.') nm in (h ++ shows idx t)
next = Block (fromString label) Seq.empty err
err = $internalError label "Block has no terminator"
in
( next, s )
setBlock :: Block -> CodeGen ()
setBlock next =
modify $ \s -> s { blockChain = blockChain s Seq.|> next }
beginBlock :: String -> CodeGen Block
beginBlock nm = do
next <- newBlock nm
_ <- br next
setBlock next
return next
createBlocks :: CodeGen [LLVM.BasicBlock]
createBlocks
= state
$ \s -> let s' = s { blockChain = initBlockChain, next = 0 }
blocks = makeBlock `fmap` blockChain s
m = Seq.length (blockChain s)
n = F.foldl' (\i b -> i + Seq.length (instructions b)) 0 (blockChain s)
in
trace (printf "generated %d instructions in %d blocks" (n+m) m) ( F.toList blocks , s' )
where
makeBlock Block{..} =
LLVM.BasicBlock (downcast blockLabel) (F.toList instructions) (LLVM.Do terminator)
fresh :: forall a. Elt a => CodeGen (IR a)
fresh = IR <$> go (eltType (undefined::a))
where
go :: TupleType t -> CodeGen (Operands t)
go UnitTuple = return OP_Unit
go (PairTuple t2 t1) = OP_Pair <$> go t2 <*> go t1
go (SingleTuple t) = ir' t . LocalReference (PrimType (ScalarPrimType t)) <$> freshName
freshName :: CodeGen (Name a)
freshName = state $ \s@CodeGenState{..} -> ( UnName next, s { next = next + 1 } )
instr :: Instruction a -> CodeGen (IR a)
instr ins = ir (typeOf ins) <$> instr' ins
instr' :: Instruction a -> CodeGen (Operand a)
instr' ins =
case typeOf ins of
VoidType -> do
do_ ins
return $ LocalReference VoidType (Name B.empty)
ty -> do
name <- freshName
instr_ $ downcast (name := ins)
return $ LocalReference ty name
do_ :: Instruction () -> CodeGen ()
do_ ins = instr_ $ downcast (Do ins)
instr_ :: LLVM.Named LLVM.Instruction -> CodeGen ()
instr_ ins =
modify $ \s ->
case Seq.viewr (blockChain s) of
Seq.EmptyR -> $internalError "instr_" "empty block chain"
bs Seq.:> b -> s { blockChain = bs Seq.|> b { instructions = instructions b Seq.|> ins } }
return_ :: CodeGen ()
return_ = void $ terminate Ret
retval_ :: Operand a -> CodeGen ()
retval_ x = void $ terminate (RetVal x)
br :: Block -> CodeGen Block
br target = terminate $ Br (blockLabel target)
cbr :: IR Bool -> Block -> Block -> CodeGen Block
cbr cond t f = terminate $ CondBr (op scalarType cond) (blockLabel t) (blockLabel f)
phi :: forall a. Elt a => [(IR a, Block)] -> CodeGen (IR a)
phi incoming = do
crit <- fresh
block <- state $ \s -> case Seq.viewr (blockChain s) of
Seq.EmptyR -> $internalError "phi" "empty block chain"
_ Seq.:> b -> ( b, s )
phi' block crit incoming
phi' :: forall a. Elt a => Block -> IR a -> [(IR a, Block)] -> CodeGen (IR a)
phi' target (IR crit) incoming = IR <$> go (eltType (undefined::a)) crit [ (o,b) | (IR o, b) <- incoming ]
where
go :: TupleType t -> Operands t -> [(Operands t, Block)] -> CodeGen (Operands t)
go UnitTuple OP_Unit _
= return OP_Unit
go (PairTuple t2 t1) (OP_Pair n2 n1) inc
= OP_Pair <$> go t2 n2 [ (x, b) | (OP_Pair x _, b) <- inc ]
<*> go t1 n1 [ (y, b) | (OP_Pair _ y, b) <- inc ]
go (SingleTuple t) tup inc
| LocalReference _ v <- op' t tup = ir' t <$> phi1 target v [ (op' t x, b) | (x, b) <- inc ]
| otherwise = $internalError "phi" "expected critical variable to be local reference"
phi1 :: Block -> Name a -> [(Operand a, Block)] -> CodeGen (Operand a)
phi1 target crit incoming =
let cmp = (==) `on` blockLabel
update b = b { instructions = downcast (crit := Phi t [ (p,blockLabel) | (p,Block{..}) <- incoming ]) Seq.<| instructions b }
t = case incoming of
[] -> $internalError "phi" "no incoming values specified"
(o,_):_ -> case typeOf o of
VoidType -> $internalError "phi" "operand has void type"
PrimType x -> x
in
state $ \s ->
case Seq.findIndexR (cmp target) (blockChain s) of
Nothing -> $internalError "phi" "unknown basic block"
Just i -> ( LocalReference (PrimType t) crit
, s { blockChain = Seq.adjust update i (blockChain s) } )
terminate :: Terminator a -> CodeGen Block
terminate term =
state $ \s ->
case Seq.viewr (blockChain s) of
Seq.EmptyR -> $internalError "terminate" "empty block chain"
bs Seq.:> b -> ( b, s { blockChain = bs Seq.|> b { terminator = downcast term } } )
declare :: LLVM.Global -> CodeGen ()
declare g =
let unique (Just q) | g /= q = $internalError "global" "duplicate symbol"
| otherwise = Just g
unique _ = Just g
name = case LLVM.name g of
LLVM.Name n -> Label n
LLVM.UnName n -> Label (fromString (show n))
in
modify (\s -> s { symbolTable = Map.alter unique name (symbolTable s) })
intrinsic :: ShortByteString -> CodeGen Label
intrinsic key =
state $ \s ->
let name = HashMap.lookupDefault (Label key) key (intrinsicTable s)
in (name, s)
addMetadata :: ShortByteString -> [Maybe Metadata] -> CodeGen ()
addMetadata key val =
modify $ \s ->
s { metadataTable = HashMap.insertWith (flip (Seq.><)) key (Seq.singleton val) (metadataTable s) }
createMetadata :: HashMap ShortByteString (Seq [Maybe Metadata]) -> [LLVM.Definition]
createMetadata md = build (HashMap.toList md) (Seq.empty, Seq.empty)
where
build :: [(ShortByteString, Seq [Maybe Metadata])]
-> (Seq LLVM.Definition, Seq LLVM.Definition)
-> [LLVM.Definition]
build [] (k,d) = F.toList (k Seq.>< d)
build (x:xs) (k,d) =
let (k',d') = meta (Seq.length d) x
in build xs (k Seq.|> k', d Seq.>< d')
meta :: Int
-> (ShortByteString, Seq [Maybe Metadata])
-> (LLVM.Definition, Seq LLVM.Definition)
meta n (key, vals)
= let node i = LLVM.MetadataNodeID (fromIntegral (i+n))
nodes = Seq.mapWithIndex (\i x -> LLVM.MetadataNodeDefinition (node i) (downcast (F.toList x))) vals
name = LLVM.NamedMetadataDefinition key [ node i | i <- [0 .. Seq.length vals - 1] ]
in
(name, nodes)
{-# INLINE trace #-}
trace :: String -> a -> a
trace msg = Debug.trace Debug.dump_cc ("llvm: " ++ msg)