{-# OPTIONS -cpp #-}

--------------------------------------------------------------------------
-- |
-- Module      :  CodeGenMonad
-- Copyright   :  (c) 2006-2007 Martin Grabmueller and Dirk Kleeblatt
-- License     :  GPL
-- 
-- Maintainer  :  {magr,klee}@cs.tu-berlin.de
-- Stability   :  provisional
-- Portability :  portable (but generated code non-portable)
--
-- Monad for generating x86 machine code at runtime.
--
-- This is a combined reader-state-exception monad which handles all
-- the details of handling code buffers, emitting binary data,
-- relocation etc.
--
-- All the code generation functions in module "Harpy.X86CodeGen"
-- live in this monad and use its error reporting facilities as well
-- as the internal state maintained by the monad.  The user state is
-- independent from the internal state and may be used by
-- higher-level code generation libraries to maintain their own
-- state across code generation operations.
--
--------------------------------------------------------------------------

module Harpy.CodeGenMonad(
    -- * Types
          CodeGen,
          RelocKind(..),
          ErrMsg,
          Reloc,
          Label,
          FixupKind(..),
          CodeGenConfig(..),
          defaultCodeGenConfig,
    -- * Functions
    -- ** General code generator monad operations
          failCodeGen,
    -- ** Accessing code generation internals
          getEntryPoint,
          getCodeOffset,
          getBasePtr,
          getCodeBufferList,
    -- ** Access to user state and environment
          setState,
          getState,
          getEnv,
          withEnv,
    -- ** Label management
          newLabel,
          setLabel,
          defineLabel,
          (@@),
          emitFixup,
          labelAddress,
          emitRelocInfo,
    -- ** Code emission
          emit8,
          emit8At,
          peek8At,
          emit32,
          emit32At,
          checkBufferSize,
          ensureBufferSize,
    -- ** Executing code generation
          runCodeGen,
          runCodeGenWithConfig,
    -- ** Calling generated functions
          callDecl,
    -- ** Interface to disassembler
          disassemble
    ) where

import qualified Harpy.X86Disassembler as Dis

import Control.Monad

import Text.PrettyPrint.HughesPJ
import Text.Printf

import Data.Word
import qualified Data.Map as Map
import Foreign
import System.Cmd
import System.IO

import Control.Monad.Trans

import Language.Haskell.TH.Syntax


-- | An error message produced by a code generation operation.
type ErrMsg = Doc

-- | The code generation monad, a combined reader-state-exception
-- monad.
newtype CodeGen e s a = CodeGen ((e, CodeGenEnv) -> (s, CodeGenState) -> IO ((s, CodeGenState), Either ErrMsg a))

-- | Configuration of the code generator.
data CodeGenConfig = CodeGenConfig { codeBufferSize :: Int -- ^ Size of individual code buffer blocks. 
                                     }

-- | Internal state of the code generator
data CodeGenState = CodeGenState { buffer :: Ptr Word8,
                                   bufferList :: [(Ptr Word8, Int)],
                                   firstBuffer :: Ptr Word8,
                                   bufferOfs :: Int,
                                   bufferSize :: Int,
                                   relocEntries :: [Reloc],
                                   nextLabel :: Int,
                                   definedLabels :: Map.Map Int (Ptr Word8, Int),
                                   pendingFixups :: Map.Map Int [FixupEntry],
                                   config :: CodeGenConfig}

data FixupKind = Fixup8
               | Fixup16
               | Fixup32
               | Fixup32Absolute
               deriving (Show)

data FixupEntry = FixupEntry { fueBuffer :: Ptr Word8,
                               fueOfs :: Int,
                               fueKind :: FixupKind }

data CodeGenEnv = CodeGenEnv { tailContext :: Bool }
   deriving (Show)

-- | Kind of relocation, for example PC-relative
data RelocKind = RelocPCRel -- ^ PC-relative relocation
               | RelocAbsolute -- ^ Absolute address
   deriving (Show)

-- | Relocation entry
data Reloc = Reloc { offset :: Int, 
             -- ^ offset in code block which needs relocation
                     kind :: RelocKind,
             -- ^ kind of relocation
                     address :: FunPtr () 
             -- ^ target address
           }
   deriving (Show)

-- | Label
data Label = Label Int

unCg :: CodeGen e s a -> ((e, CodeGenEnv) -> (s, CodeGenState) -> IO ((s, CodeGenState), Either ErrMsg a))
unCg (CodeGen a) = a

instance Monad (CodeGen e s) where
    return x = cgReturn x
    fail err = cgFail err
    m >>= k = cgBind m k

-- {-# INLINE cgReturn #-}
cgReturn x = CodeGen (\_env state -> return (state, Right x))
-- {-# INLINE cgFail #-}
cgFail err = CodeGen (\_env state -> return (state, Left (text err)))
-- {-# INLINE cgBind #-}
cgBind m k = CodeGen (\env state -> 
               do r1 <- unCg m env state
                  case r1 of
                    (state', Left err) -> return (state', Left err)
                    (state', Right v) -> unCg (k v) env state')

-- | Abort code generation with the given error message.
failCodeGen :: Doc -> CodeGen e s a
failCodeGen d = CodeGen (\_env state -> return (state, Left d))

instance MonadIO (CodeGen e s) where
  liftIO st = CodeGen (\_env state -> do { r <- st; return (state, Right r) })

emptyCodeGenState :: CodeGenState
emptyCodeGenState = CodeGenState { buffer = undefined,
                                   bufferList = [],
                                   firstBuffer = undefined,
                                   bufferOfs = 0,
                                   bufferSize = 0,
                                   relocEntries = [], 
                                   nextLabel = 0,
                                   definedLabels = Map.empty,
                                   pendingFixups = Map.empty,
                                   config = defaultCodeGenConfig}

defaultCodeGenConfig :: CodeGenConfig
defaultCodeGenConfig = CodeGenConfig { codeBufferSize = defaultCodeBufferSize }

defaultCodeBufferSize :: Int
defaultCodeBufferSize = 128

-- | Execute code generation, given a user environment and state.
-- The result is a tuple of the resulting user state and either an
-- error message (when code generation failed) or the result of the
-- code generation.
runCodeGen :: CodeGen e s a -> e -> s -> IO (s, Either ErrMsg a)
runCodeGen cg uenv ustate =
    runCodeGenWithConfig cg uenv ustate defaultCodeGenConfig

runCodeGenWithConfig :: CodeGen e s a -> e -> s -> CodeGenConfig -> IO (s, Either ErrMsg a)
runCodeGenWithConfig (CodeGen cg) uenv ustate conf =
    do let initSize = codeBufferSize conf
       arr <- mallocBytes initSize
       let env = CodeGenEnv {tailContext = True}
       let state = emptyCodeGenState{buffer = arr,
                                     bufferList = [],
                                     firstBuffer = arr,
                                     bufferSize = initSize,
                                     config = conf}
       ((ustate', _), res) <- cg (uenv, env) (ustate, state)
       return (ustate', res)

-- | Check whether the code buffer has room for at least the given
-- number of bytes.  This should be called by code generators
-- whenever it cannot be guaranteed that the code buffer is large
-- enough to hold all the generated code.  Lets the code generation
-- monad fail when the buffer overflows.
checkBufferSize :: Int -> CodeGen e s ()
checkBufferSize needed =
    do state <- getInternalState
       unless (bufferOfs state + needed <= bufferSize state)
              (failCodeGen (text "code generation buffer overflow: needed additional" <+> 
                            int needed <+> text "bytes (offset =" <+> 
                            int (bufferOfs state) <> 
                            text ", buffer size =" <+> 
                            int (bufferSize state) <> text ")"))

-- | Make sure that the code buffer has room for at least the given
-- number of bytes.  This should be called by code generators
-- whenever it cannot be guaranteed that the code buffer is large
-- enough to hold all the generated code.  Creates a new buffer and
-- places a jump to the new buffer when there is not sufficient space
-- available
ensureBufferSize :: Int -> CodeGen e s ()
ensureBufferSize needed =
    do state <- getInternalState
       unless (bufferOfs state + needed + 5 <= bufferSize state)
              (do let incrSize = max (needed + 16) 
                                     (codeBufferSize (config state))
                  arr <- liftIO $ mallocBytes incrSize
                  ofs <- getCodeOffset
                  let buf = buffer state
                      disp :: Int
                      disp = arr `minusPtr` (buf `plusPtr` ofs) - 5
                  emit8 0xe9
                  emit32 (fromIntegral disp)
                  st <- getInternalState
                  setInternalState st{buffer = arr, bufferList = bufferList st ++ [(buffer st, bufferOfs st)], bufferOfs = 0})

-- | Return a pointer to the beginning of the first code buffer, which
-- is normally the entry point to the generated code.
getEntryPoint :: CodeGen e s (Ptr Word8)
getEntryPoint =
    CodeGen (\ env (ustate, state) -> 
      return $ ((ustate, state), Right (firstBuffer state)))

-- | Return the current offset in the code buffer, e.g. the offset
-- at which the next instruction will be emitted.
getCodeOffset :: CodeGen e s Int
getCodeOffset =
    CodeGen (\ env (ustate, state) -> 
      return $ ((ustate, state), Right (bufferOfs state)))

-- | Set the user state to the given value. 
setState :: s -> CodeGen e s ()
setState st =
    CodeGen (\ env (_, state) -> 
      return $ ((st, state), Right ()))

-- | Return the current user state.
getState :: CodeGen e s s
getState =
    CodeGen (\ env (ustate, state) -> 
      return $ ((ustate, state), Right (ustate)))

-- | Return the current user environment.
getEnv :: CodeGen e s e
getEnv =
    CodeGen (\ (uenv, env) state -> 
      return $ (state, Right uenv))

-- | Set the environment to the given value and execute the given
-- code generation in this environment.
withEnv :: e -> CodeGen e s r -> CodeGen e s r
withEnv e (CodeGen cg) =
    CodeGen (\ (_, env) state ->
      cg (e, env) state)

-- | Set the user state to the given value. 
setInternalState :: CodeGenState -> CodeGen e s ()
setInternalState st =
    CodeGen (\ env (ustate, _) -> 
      return $ ((ustate, st), Right ()))

-- | Return the current user state.
getInternalState :: CodeGen e s CodeGenState
getInternalState =
    CodeGen (\ env (ustate, state) -> 
      return $ ((ustate, state), Right (state)))

-- | Return the pointer to the start of the code buffer.
-- {-# INLINE getBasePtr #-}
getBasePtr :: CodeGen e s (Ptr Word8)
getBasePtr =
    CodeGen (\ env (ustate, state) -> 
      return $ ((ustate, state), Right (buffer state)))

-- | Return a list of all code buffers and their respective size 
-- (i.e., actually used space for code, not allocated size).
getCodeBufferList :: CodeGen e s [(Ptr Word8, Int)]
getCodeBufferList = do st <- getInternalState
                       return $ bufferList st ++ [(buffer st, bufferOfs st)]

-- | Generate a new label to be used with the label operations
-- 'emitRelocInfo', 'emitFixup' and 'defineLabel'.
newLabel :: CodeGen e s Label
newLabel =
    do state <- getInternalState
       let lab = nextLabel state
       setInternalState state{nextLabel = lab + 1}
       return (Label lab)

-- | Generate a new label and define it at once
setLabel :: CodeGen e s Label
setLabel =
    do l <- newLabel
       defineLabel l
       return l

-- | Emit a relocation entry for the given offset, relocation kind 
-- and target address.
emitRelocInfo :: Int -> RelocKind -> FunPtr a -> CodeGen e s ()
emitRelocInfo ofs kind addr = 
    CodeGen (\ env (ustate, state) -> 
      do let newState = state{relocEntries =
                              Reloc{offset = ofs, 
                                    kind = kind,
                                    address = castFunPtr addr} : 
                              (relocEntries state)}
         return $ ((ustate, newState), Right ()))

-- | Emit a byte value to the code buffer. 
-- {-# INLINE emit8 #-}
emit8 :: Word8 -> CodeGen e s ()
emit8 op = 
    CodeGen (\ env (ustate, state) -> 
      do let buf = buffer state
             ptr = bufferOfs state
         pokeByteOff buf ptr op
         return $ ((ustate, state{bufferOfs = ptr + 1}), Right ()))

-- | Store a byte value at the given offset into the code buffer.
-- {-# INLINE emit8At #-}
emit8At :: Int -> Word8 -> CodeGen e s ()
emit8At pos op = 
    CodeGen (\ env (ustate, state) -> 
      do let buf = buffer state
         pokeByteOff buf pos op
         return $ ((ustate, state), Right ()))

-- | Return the byte value at the given offset in the code buffer.
-- {-# INLINE peek8At #-}
peek8At :: Int -> CodeGen e s Word8
peek8At pos =
    CodeGen (\ env (ustate, state) -> 
      do let buf = buffer state
         b <- peekByteOff buf pos
         return $ ((ustate, state), Right b))

-- | Like 'emit8', but for a 32-bit value.
-- {-# INLINE emit32 #-}
emit32 :: Word32 -> CodeGen e s ()
emit32 op = 
    CodeGen (\ env (ustate, state) -> 
      do let buf = buffer state
             ptr = bufferOfs state
         pokeByteOff buf ptr op
         return $ ((ustate, state{bufferOfs = ptr + 4}), Right ()))

-- | Like 'emit8At', but for a 32-bit value.
-- {-# INLINE emit32At #-}
emit32At :: Int -> Word32 -> CodeGen e s ()
emit32At pos op = 
    CodeGen (\ env (ustate, state) -> 
      do let buf = buffer state
         pokeByteOff buf pos op
         return $ ((ustate, state), Right ()))

-- | Emit a label at the current offset in the code buffer.  All
-- references to the label will be relocated to this offset.
defineLabel :: Label -> CodeGen e s ()
defineLabel (Label lab) = 
    do state <- getInternalState
       case Map.lookup lab (pendingFixups state) of
         Just fixups -> do mapM_ (performFixup (buffer state) (bufferOfs state)) fixups
                           setInternalState state{pendingFixups = Map.delete lab (pendingFixups state)}
         Nothing -> return ()
       state <- getInternalState
       setInternalState state{definedLabels = Map.insert lab (buffer state, bufferOfs state) (definedLabels state)}

performFixup :: Ptr Word8 -> Int -> FixupEntry -> CodeGen e s ()
performFixup labBuf labOfs (FixupEntry{fueBuffer = buf, fueOfs = ofs, fueKind = kind}) =
    do let diff = (labBuf `plusPtr` labOfs) `minusPtr` (buf `plusPtr` (ofs + 4))
       liftIO $ case kind of
                  Fixup8 -> pokeByteOff buf ofs (fromIntegral diff :: Word8)
                  Fixup16 -> pokeByteOff buf ofs (fromIntegral diff :: Word16)
                  Fixup32 -> pokeByteOff buf ofs (fromIntegral diff :: Word32)
                  Fixup32Absolute -> pokeByteOff buf ofs (fromIntegral (ptrToWordPtr (labBuf `plusPtr` labOfs)) :: Word32)
       return ()


(@@) :: Label -> CodeGen e s a -> CodeGen e s a
(@@) lab gen = do defineLabel lab
                  gen

-- | Emit a fixup entry for the given label at the current offset in
-- the code buffer (unless the label is already defined).
-- The instruction at this offset will
-- be patched to target the address associated with this label when
-- it is defined later.
emitFixup :: Label -> Int -> FixupKind -> CodeGen e s ()
emitFixup (Label lab) ofs kind = 
    do state <- getInternalState 
       let base = buffer state
           ptr = bufferOfs state
       case Map.lookup lab (definedLabels state) of
         Just (labBuf, labOfs) -> performFixup labBuf labOfs (FixupEntry{fueBuffer = base,
                                                                         fueOfs = ptr + ofs,
                                                                         fueKind = kind})
         Nothing -> setInternalState state{pendingFixups = Map.insertWith (++) lab [FixupEntry{fueBuffer = base,
                                                                                          fueOfs = ptr + ofs,
                                                                                          fueKind = kind}]
                                                                                          (pendingFixups state)}

-- | Return the address of a label, fail if the label is not yet defined.
labelAddress :: Label -> CodeGen e s (Ptr a)
labelAddress (Label lab) = do
  state <- getInternalState
  case Map.lookup lab (definedLabels state) of
    Just (labBuf, labOfs) -> return $ plusPtr labBuf labOfs
    Nothing -> fail $ "Label " ++ show lab ++ " not yet defined"


-- | Disassemble all code buffers.
disassemble :: CodeGen e s [Dis.Instruction]
disassemble = do
  s <- getInternalState
  let buffers = bufferList s
  r <- mapM (\ (buffer, length) -> do
             r <- liftIO $ Dis.disassembleBlock buffer length
             case r of
                    Left err -> cgFail $ show err
                    Right instr -> return instr
            ) $ buffers ++ [(buffer s, bufferOfs s)]
  return $ concat r

#ifndef __HADDOCK__

callDecl :: String -> Q Type -> Q [Dec]
callDecl ns qt =  do
    t0 <- qt
    let (tvars, cxt, t) = case t0 of
                         ForallT vs c t -> (vs, c, t)
                         _ -> ([], [], t0)
    let name = mkName ns
    let funptr = AppT (ConT $ mkName "FunPtr") t
    let ioresult = addIO t
    let ty = AppT (AppT ArrowT funptr) ioresult
    dynName <- newName "conv"
    let dyn = ForeignD $ ImportF CCall Safe "dynamic" dynName $ ForallT tvars cxt ty
    vs <- mkArgs t
    cbody <- [| CodeGen (\env (ustate, state) ->
                        do let code = firstBuffer state
                           res <- liftIO $ $(do
                                             c <- newName "c"
                                             cast <- [|castPtrToFunPtr|]
                                             let f = AppE (VarE dynName)
                                                          (AppE cast
                                                                (VarE c))
                                             return $ LamE [VarP c] $ foldl AppE f $ map VarE vs
                                            ) code
                           return $ ((ustate, state), Right res))|]
    let call = ValD (VarP name) (NormalB $ LamE (map VarP vs) cbody) []
    return [ dyn, call ]

mkArgs (AppT (AppT ArrowT _from) to) = do
  v  <- newName "v"
  vs <- mkArgs to
  return $ v : vs
mkArgs _ = return []

addIO (AppT t@(AppT ArrowT _from) to) = AppT t $ addIO to
addIO t = AppT (ConT $ mkName "IO") t

#else

-- | Declare a stub function to call the code buffer. Arguments are the name
-- of the generated function, and the type the code buffer is supposed to have.
-- The type argument can be given using the [t| ... |] notation of Template Haskell.
-- Allowed types are the legal types for FFI functions.
callDecl :: String -> Q Type -> Q [Dec]

#endif