{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TemplateHaskell            #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.CodeGen.Monad
-- Copyright   : [2015..2017] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.CodeGen.Monad (

  CodeGen,
  runLLVM,

  -- declarations
  fresh, freshName,
  declare,
  intrinsic,

  -- basic blocks
  Block,
  newBlock, setBlock, beginBlock, createBlocks,

  -- instructions
  instr, instr', do_, return_, retval_, br, cbr, phi, phi',
  instr_,

  -- metadata
  addMetadata,

) where

-- standard library
import Control.Applicative
import Control.Monad.State
import Data.ByteString.Short                                        ( ShortByteString )
import Data.Function
import Data.HashMap.Strict                                          ( HashMap )
import Data.Map                                                     ( Map )
import Data.Sequence                                                ( Seq )
import Data.String
import Data.Word
import Prelude
import Text.Printf
import qualified Data.Foldable                                      as F
import qualified Data.HashMap.Strict                                as HashMap
import qualified Data.Map                                           as Map
import qualified Data.Sequence                                      as Seq
import qualified Data.ByteString.Short                              as B

-- accelerate
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Array.Sugar                            ( Elt, eltType )
import qualified Data.Array.Accelerate.Debug                        as Debug

-- accelerate-llvm
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Metadata
import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import LLVM.AST.Type.Terminator

import Data.Array.Accelerate.LLVM.Target
import Data.Array.Accelerate.LLVM.CodeGen.Downcast
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Intrinsic
import Data.Array.Accelerate.LLVM.CodeGen.Module
import Data.Array.Accelerate.LLVM.CodeGen.Type

import Data.Array.Accelerate.LLVM.CodeGen.Sugar                     ( IROpenAcc(..) )

-- llvm-hs
import qualified LLVM.AST                                           as LLVM
import qualified LLVM.AST.Global                                    as LLVM


-- Code generation
-- ===============

-- | The code generation state for scalar functions and expressions.
--
-- We use two records: one to hold all the code generation state as it walks the
-- AST, and one for each of the basic blocks that are generated during the walk.
--
data CodeGenState = CodeGenState
  { blockChain          :: Seq Block                                      -- blocks for this function
  , symbolTable         :: Map Label LLVM.Global                          -- global (external) function declarations
  , metadataTable       :: HashMap ShortByteString (Seq [Maybe Metadata]) -- module metadata to be collected
  , intrinsicTable      :: HashMap ShortByteString Label                  -- standard math intrinsic functions
  , next                :: {-# UNPACK #-} !Word                           -- a name supply
  }

data Block = Block
  { blockLabel          :: {-# UNPACK #-} !Label                          -- block label
  , instructions        :: Seq (LLVM.Named LLVM.Instruction)              -- stack of instructions
  , terminator          :: LLVM.Terminator                                -- block terminator
  }

newtype CodeGen a = CodeGen { runCodeGen :: State CodeGenState a }
  deriving (Functor, Applicative, Monad, MonadState CodeGenState)


{-# INLINEABLE runLLVM #-}
runLLVM
    :: forall arch aenv a. (Target arch, Intrinsic arch)
    => CodeGen (IROpenAcc arch aenv a)
    -> Module arch aenv a
runLLVM ll =
  let
      initialState      = CodeGenState
                            { blockChain        = initBlockChain
                            , symbolTable       = Map.empty
                            , metadataTable     = HashMap.empty
                            , intrinsicTable    = intrinsicForTarget (undefined::arch)
                            , next              = 0
                            }

      (kernels, md, st) = case runState (runCodeGen ll) initialState of
                            (IROpenAcc ks, s) -> let (fs, as) = unzip [ (f , (LLVM.name f, a)) | Kernel f a <- ks ]
                                                 in  (fs, Map.fromList as, s)

      definitions       = map LLVM.GlobalDefinition (kernels ++ Map.elems (symbolTable st))
                       ++ createMetadata (metadataTable st)

      name | x:_               <- kernels
           , f@LLVM.Function{} <- x
           , LLVM.Name s       <- LLVM.name f = s
           | otherwise                        = "<undefined>"

  in
  Module { moduleMetadata = md
         , unModule       = LLVM.Module
                          { LLVM.moduleName           = name
                          , LLVM.moduleSourceFileName = B.empty
                          , LLVM.moduleDataLayout     = targetDataLayout (undefined::arch)
                          , LLVM.moduleTargetTriple   = targetTriple (undefined::arch)
                          , LLVM.moduleDefinitions    = definitions
                          }
         }


-- Basic Blocks
-- ============

-- | An initial block chain
--
initBlockChain :: Seq Block
initBlockChain
  = Seq.singleton
  $ Block "entry" Seq.empty ($internalError "entry" "block has no terminator")


-- | Create a new basic block, but don't yet add it to the block chain. You need
-- to call 'setBlock' to append it to the chain, so that subsequent instructions
-- are added to this block.
--
-- Note: [Basic blocks]
--
-- The names of basic blocks are generated based on the base name provided to
-- the 'newBlock' function, as well as the current state (length) of the block
-- stream. By not immediately adding new blocks to the stream, we have the
-- advantage that:
--
--   1. Instructions are generated "in order", and are always appended to the
--      stream. There is no need to search the stream for a block of the right
--      name.
--
--   2. Blocks are named in groups, which helps readability. For example, the
--      blocks for the then and else branches of a conditional, created at the
--      same time, will be named similarly: 'if4.then' and 'if4.else', etc.
--
-- However, this leads to a slight awkwardness when walking the AST. Since a new
-- naming group scheme is only applied *after* a call to 'setBlock',
-- encountering (say) nested conditionals in the walk will generate logically
-- distinct blocks that happen to have the same name. This means that
-- instructions might be added to the wrong blocks, or the first set of blocks
-- will be emitted empty and/or without a terminator.
--
newBlock :: String -> CodeGen Block
newBlock nm =
  state $ \s ->
    let idx     = Seq.length (blockChain s)
        label   = let (h,t) = break (== '.') nm in (h ++ shows idx t)
        next    = Block (fromString label) Seq.empty err
        err     = $internalError label "Block has no terminator"
    in
    ( next, s )


-- | Add this block to the block stream. Any instructions pushed onto the stream
-- by 'instr' and friends will now apply to this block.
--
setBlock :: Block -> CodeGen ()
setBlock next =
  modify $ \s -> s { blockChain = blockChain s Seq.|> next }


-- | Generate a new block and branch unconditionally to it.
--
beginBlock :: String -> CodeGen Block
beginBlock nm = do
  next <- newBlock nm
  _    <- br next
  setBlock next
  return next


-- | Extract the block state and construct the basic blocks that form a function
-- body. The block stream is re-initialised, but module-level state such as the
-- global symbol table is left intact.
--
createBlocks :: CodeGen [LLVM.BasicBlock]
createBlocks
  = state
  $ \s -> let s'     = s { blockChain = initBlockChain, next = 0 }
              blocks = makeBlock `fmap` blockChain s
              m      = Seq.length (blockChain s)
              n      = F.foldl' (\i b -> i + Seq.length (instructions b)) 0 (blockChain s)
          in
          trace (printf "generated %d instructions in %d blocks" (n+m) m) ( F.toList blocks , s' )
  where
    makeBlock Block{..} =
      LLVM.BasicBlock (downcast blockLabel) (F.toList instructions) (LLVM.Do terminator)


-- Instructions
-- ------------

-- | Generate a fresh local reference
--
fresh :: forall a. Elt a => CodeGen (IR a)
fresh = IR <$> go (eltType (undefined::a))
  where
    go :: TupleType t -> CodeGen (Operands t)
    go UnitTuple         = return OP_Unit
    go (PairTuple t2 t1) = OP_Pair <$> go t2 <*> go t1
    go (SingleTuple t)   = ir' t . LocalReference (PrimType (ScalarPrimType t)) <$> freshName

-- | Generate a fresh (un)name.
--
freshName :: CodeGen (Name a)
freshName = state $ \s@CodeGenState{..} -> ( UnName next, s { next = next + 1 } )


-- | Add an instruction to the state of the currently active block so that it is
-- computed, and return the operand (LocalReference) that can be used to later
-- refer to it.
--
instr :: Instruction a -> CodeGen (IR a)
instr ins = ir (typeOf ins) <$> instr' ins

instr' :: Instruction a -> CodeGen (Operand a)
instr' ins =
  -- LLVM-5 does not allow instructions of type void to have a name.
  case typeOf ins of
    VoidType -> do
      do_ ins
      return $ LocalReference VoidType (Name B.empty)
    --
    ty -> do
      name <- freshName
      instr_ $ downcast (name := ins)
      return $ LocalReference ty name

-- | Execute an unnamed instruction
--
do_ :: Instruction () -> CodeGen ()
do_ ins = instr_ $ downcast (Do ins)

-- | Add raw assembly instructions to the execution stream
--
instr_ :: LLVM.Named LLVM.Instruction -> CodeGen ()
instr_ ins =
  modify $ \s ->
    case Seq.viewr (blockChain s) of
      Seq.EmptyR  -> $internalError "instr_" "empty block chain"
      bs Seq.:> b -> s { blockChain = bs Seq.|> b { instructions = instructions b Seq.|> ins } }


-- | Return void from a basic block
--
return_ :: CodeGen ()
return_ = void $ terminate Ret

-- | Return a value from a basic block
--
retval_ :: Operand a -> CodeGen ()
retval_ x = void $ terminate (RetVal x)


-- | Unconditional branch. Return the name of the block that was branched from.
--
br :: Block -> CodeGen Block
br target = terminate $ Br (blockLabel target)


-- | Conditional branch. Return the name of the block that was branched from.
--
cbr :: IR Bool -> Block -> Block -> CodeGen Block
cbr cond t f = terminate $ CondBr (op scalarType cond) (blockLabel t) (blockLabel f)


-- | Add a phi node to the top of the current block
--
phi :: forall a. Elt a => [(IR a, Block)] -> CodeGen (IR a)
phi incoming = do
  crit  <- fresh
  block <- state $ \s -> case Seq.viewr (blockChain s) of
                           Seq.EmptyR -> $internalError "phi" "empty block chain"
                           _ Seq.:> b -> ( b, s )
  phi' block crit incoming

phi' :: forall a. Elt a => Block -> IR a -> [(IR a, Block)] -> CodeGen (IR a)
phi' target (IR crit) incoming = IR <$> go (eltType (undefined::a)) crit [ (o,b) | (IR o, b) <- incoming ]
  where
    go :: TupleType t -> Operands t -> [(Operands t, Block)] -> CodeGen (Operands t)
    go UnitTuple OP_Unit _
      = return OP_Unit
    go (PairTuple t2 t1) (OP_Pair n2 n1) inc
      = OP_Pair <$> go t2 n2 [ (x, b) | (OP_Pair x _, b) <- inc ]
                <*> go t1 n1 [ (y, b) | (OP_Pair _ y, b) <- inc ]
    go (SingleTuple t) tup inc
      | LocalReference _ v <- op' t tup = ir' t <$> phi1 target v [ (op' t x, b) | (x, b) <- inc ]
      | otherwise                       = $internalError "phi" "expected critical variable to be local reference"


phi1 :: Block -> Name a -> [(Operand a, Block)] -> CodeGen (Operand a)
phi1 target crit incoming =
  let cmp       = (==) `on` blockLabel
      update b  = b { instructions = downcast (crit := Phi t [ (p,blockLabel) | (p,Block{..}) <- incoming ]) Seq.<| instructions b }
      t         = case incoming of
                    []        -> $internalError "phi" "no incoming values specified"
                    (o,_):_   -> case typeOf o of
                                   VoidType    -> $internalError "phi" "operand has void type"
                                   PrimType x  -> x
  in
  state $ \s ->
    case Seq.findIndexR (cmp target) (blockChain s) of
      Nothing -> $internalError "phi" "unknown basic block"
      Just i  -> ( LocalReference (PrimType t) crit
                 , s { blockChain = Seq.adjust update i (blockChain s) } )


-- | Add a termination condition to the current instruction stream. Also return
-- the block that was just terminated.
--
terminate :: Terminator a -> CodeGen Block
terminate term =
  state $ \s ->
    case Seq.viewr (blockChain s) of
      Seq.EmptyR  -> $internalError "terminate" "empty block chain"
      bs Seq.:> b -> ( b, s { blockChain = bs Seq.|> b { terminator = downcast term } } )


-- | Add a global declaration to the symbol table
--
declare :: LLVM.Global -> CodeGen ()
declare g =
  let unique (Just q) | g /= q    = $internalError "global" "duplicate symbol"
                      | otherwise = Just g
      unique _                    = Just g

      name = case LLVM.name g of
               LLVM.Name n      -> Label n
               LLVM.UnName n    -> Label (fromString (show n))
  in
  modify (\s -> s { symbolTable = Map.alter unique name (symbolTable s) })


-- | Get name of the corresponding intrinsic function implementing a given C
-- function. If there is no mapping, the C function name is used.
--
intrinsic :: ShortByteString -> CodeGen Label
intrinsic key =
  state $ \s ->
    let name = HashMap.lookupDefault (Label key) key (intrinsicTable s)
    in  (name, s)



-- Metadata
-- ========

-- | Insert a metadata key/value pair into the current module.
--
addMetadata :: ShortByteString -> [Maybe Metadata] -> CodeGen ()
addMetadata key val =
  modify $ \s ->
    s { metadataTable = HashMap.insertWith (flip (Seq.><)) key (Seq.singleton val) (metadataTable s) }


-- | Generate the metadata definitions for the file. Every key in the map
-- represents a named metadata definition. The values associated with that key
-- represent the metadata node definitions that will be attached to that
-- definition.
--
createMetadata :: HashMap ShortByteString (Seq [Maybe Metadata]) -> [LLVM.Definition]
createMetadata md = build (HashMap.toList md) (Seq.empty, Seq.empty)
  where
    build :: [(ShortByteString, Seq [Maybe Metadata])]
          -> (Seq LLVM.Definition, Seq LLVM.Definition) -- accumulator of (names, metadata)
          -> [LLVM.Definition]
    build []     (k,d) = F.toList (k Seq.>< d)
    build (x:xs) (k,d) =
      let (k',d') = meta (Seq.length d) x
      in  build xs (k Seq.|> k', d Seq.>< d')

    meta :: Int                                         -- number of metadata node definitions so far
         -> (ShortByteString, Seq [Maybe Metadata])     -- current assoc of the metadata map
         -> (LLVM.Definition, Seq LLVM.Definition)
    meta n (key, vals)
      = let node i      = LLVM.MetadataNodeID (fromIntegral (i+n))
            nodes       = Seq.mapWithIndex (\i x -> LLVM.MetadataNodeDefinition (node i) (downcast (F.toList x))) vals
            name        = LLVM.NamedMetadataDefinition key [ node i | i <- [0 .. Seq.length vals - 1] ]
        in
        (name, nodes)


-- Debug
-- =====

{-# INLINE trace #-}
trace :: String -> a -> a
trace msg = Debug.trace Debug.dump_cc ("llvm: " ++ msg)