module Harpy.CodeGenMonad(
    
          CodeGen,
          ErrMsg,
          RelocKind(..),
          Reloc,
          Label,
          FixupKind(..),
          CodeGenConfig(..),
          defaultCodeGenConfig,
    
    
          failCodeGen,
    
          getEntryPoint,
          getCodeOffset,
          getBasePtr,
          getCodeBufferList,
    
          setState,
          getState,
          getEnv,
          withEnv,
    
          newLabel,
          newNamedLabel,
          setLabel,
          defineLabel,
          (@@),
          emitFixup,
          labelAddress,
          emitRelocInfo,
    
          emit8,
          emit8At,
          peek8At,
          emit32,
          emit32At,
          checkBufferSize,
          ensureBufferSize,
    
          runCodeGen,
          runCodeGenWithConfig,
    
          callDecl,
    
          disassemble
    ) where
import qualified Harpy.X86Disassembler as Dis
import Control.Applicative
import Control.Monad
import Text.PrettyPrint.HughesPJ
import Numeric
import Data.List
import qualified Data.Map as Map
import Foreign
import Foreign.C.Types
import System.IO
import Control.Monad.Trans
import Language.Haskell.TH.Syntax
type ErrMsg = Doc
newtype CodeGen e s a = CodeGen ((e, CodeGenEnv) -> (s, CodeGenState) -> IO ((s, CodeGenState), Either ErrMsg a))
data CodeGenConfig = CodeGenConfig { 
      codeBufferSize   :: Int,                   
      customCodeBuffer :: Maybe (Ptr Word8, Int) 
    }
data CodeGenState = CodeGenState { 
      buffer        :: Ptr Word8,                    
      bufferList    :: [(Ptr Word8, Int)],           
      firstBuffer   :: Ptr Word8,                    
      bufferOfs     :: Int,                          
      bufferSize    :: Int,                          
      relocEntries  :: [Reloc],                      
      nextLabel     :: Int,                          
      definedLabels :: Map.Map Int (Ptr Word8, Int, String), 
      pendingFixups :: Map.Map Int [FixupEntry],     
      config        :: CodeGenConfig                 
    }
data FixupEntry = FixupEntry { 
      fueBuffer :: Ptr Word8,
      fueOfs    :: Int,
      fueKind   :: FixupKind 
    }
data FixupKind = Fixup8          
               | Fixup16         
               | Fixup32         
               | Fixup32Absolute 
               deriving (Show)
data CodeGenEnv = CodeGenEnv { tailContext :: Bool }
   deriving (Show)
data RelocKind = RelocPCRel    
               | RelocAbsolute 
   deriving (Show)
data Reloc = Reloc { offset :: Int, 
             
                     kind :: RelocKind,
             
                     address :: FunPtr () 
             
           }
   deriving (Show)
data Label = Label Int String
           deriving (Eq, Ord)
unCg :: CodeGen e s a -> ((e, CodeGenEnv) -> (s, CodeGenState) -> IO ((s, CodeGenState), Either ErrMsg a))
unCg (CodeGen a) = a
instance Functor (CodeGen e s) where
  fmap f m = CodeGen (\ env state -> do
                         r <- unCg m env state
                         case r of
                           (state', Left err) -> return (state', Left err)
                           (state', Right v) -> return (state', Right $ f v))
  
instance Applicative (CodeGen e s) where
  pure x = cgReturn x
  f <*> x = do
    f' <- f
    x' <- x
    return $ f' x'
  
instance Monad (CodeGen e s) where
    return x = cgReturn x
    fail err = cgFail err
    m >>= k = cgBind m k
cgReturn :: a -> CodeGen e s a
cgReturn x = CodeGen (\_env state -> return (state, Right x))
cgFail :: String -> CodeGen e s a
cgFail err = CodeGen (\_env state -> return (state, Left (text err)))
cgBind :: CodeGen e s a -> (a -> CodeGen e s a1) -> CodeGen e s a1
cgBind m k = CodeGen (\env state -> 
               do r1 <- unCg m env state
                  case r1 of
                    (state', Left err) -> return (state', Left err)
                    (state', Right v) -> unCg (k v) env state')
failCodeGen :: Doc -> CodeGen e s a
failCodeGen d = CodeGen (\_env state -> return (state, Left d))
instance MonadIO (CodeGen e s) where
  liftIO st = CodeGen (\_env state -> do { r <- st; return (state, Right r) })
emptyCodeGenState :: CodeGenState
emptyCodeGenState = CodeGenState { buffer = undefined,
                                   bufferList = [],
                                   firstBuffer = undefined,
                                   bufferOfs = 0,
                                   bufferSize = 0,
                                   relocEntries = [], 
                                   nextLabel = 0,
                                   definedLabels = Map.empty,
                                   pendingFixups = Map.empty,
                                   config = defaultCodeGenConfig}
defaultCodeGenConfig :: CodeGenConfig
defaultCodeGenConfig = CodeGenConfig { codeBufferSize = defaultCodeBufferSize,
                                       customCodeBuffer = Nothing }
defaultCodeBufferSize :: Int
defaultCodeBufferSize = 4096
runCodeGen :: CodeGen e s a -> e -> s -> IO (s, Either ErrMsg a)
runCodeGen cg uenv ustate =
    runCodeGenWithConfig cg uenv ustate defaultCodeGenConfig
foreign import ccall "static stdlib.h"
  memalign :: CUInt -> CUInt -> IO (Ptr a)
foreign import ccall "static sys/mman.h"
  mprotect :: CUInt -> CUInt -> Int -> IO Int
runCodeGenWithConfig :: CodeGen e s a -> e -> s -> CodeGenConfig -> IO (s, Either ErrMsg a)
runCodeGenWithConfig (CodeGen cg) uenv ustate conf =
    do (buf, sze) <- case customCodeBuffer conf of
                       Nothing -> do let initSize = codeBufferSize conf
                                     let size = fromIntegral initSize
                                     arr <- memalign 0x1000 size
                                     
                                     _ <- mprotect (fromIntegral $ ptrToIntPtr arr) size 0x7
                                     return (arr, initSize)
                       Just (buf, sze) -> return (buf, sze)
       let env = CodeGenEnv {tailContext = True}
       let state = emptyCodeGenState{buffer = buf,
                                     bufferList = [],
                                     firstBuffer = buf,
                                     bufferSize = sze,
                                     config = conf}
       ((ustate', _), res) <- cg (uenv, env) (ustate, state)
       return (ustate', res)
checkBufferSize :: Int -> CodeGen e s ()
checkBufferSize needed =
    do state <- getInternalState
       unless (bufferOfs state + needed <= bufferSize state)
              (failCodeGen (text "code generation buffer overflow: needed additional" <+> 
                            int needed <+> text "bytes (offset =" <+> 
                            int (bufferOfs state) <> 
                            text ", buffer size =" <+> 
                            int (bufferSize state) <> text ")"))
ensureBufferSize :: Int -> CodeGen e s ()
ensureBufferSize needed =
    do state <- getInternalState
       case (customCodeBuffer (config state)) of
         Nothing ->
             unless (bufferOfs state + needed + 5 <= bufferSize state)
                        (do let incrSize = max (needed + 16) (codeBufferSize (config state))
                            arr <- liftIO $ mallocBytes incrSize
                            ofs <- getCodeOffset
                            let buf = buffer state
                                disp :: Int
                                disp = arr `minusPtr` (buf `plusPtr` ofs)  5
                            emit8 0xe9    
                            emit32 (fromIntegral disp)
                            st <- getInternalState
                            setInternalState st{buffer = arr, bufferList = bufferList st ++ [(buffer st, bufferOfs st)], bufferOfs = 0})
         Just (_, _) -> checkBufferSize needed
getEntryPoint :: CodeGen e s (Ptr Word8)
getEntryPoint =
    CodeGen (\ _ (ustate, state) -> 
      return $ ((ustate, state), Right (firstBuffer state)))
getCodeOffset :: CodeGen e s Int
getCodeOffset =
    CodeGen (\ _ (ustate, state) -> 
      return $ ((ustate, state), Right (bufferOfs state)))
setState :: s -> CodeGen e s ()
setState st =
    CodeGen (\ _ (_, state) -> 
      return $ ((st, state), Right ()))
getState :: CodeGen e s s
getState =
    CodeGen (\ _ (ustate, state) -> 
      return $ ((ustate, state), Right (ustate)))
getEnv :: CodeGen e s e
getEnv =
    CodeGen (\ (uenv, _) state -> 
      return $ (state, Right uenv))
withEnv :: e -> CodeGen e s r -> CodeGen e s r
withEnv e (CodeGen cg) =
    CodeGen (\ (_, env) state ->
      cg (e, env) state)
setInternalState :: CodeGenState -> CodeGen e s ()
setInternalState st =
    CodeGen (\ _ (ustate, _) -> 
      return $ ((ustate, st), Right ()))
getInternalState :: CodeGen e s CodeGenState
getInternalState =
    CodeGen (\ _ (ustate, state) -> 
      return $ ((ustate, state), Right (state)))
getBasePtr :: CodeGen e s (Ptr Word8)
getBasePtr =
    CodeGen (\ _ (ustate, state) -> 
      return $ ((ustate, state), Right (buffer state)))
getCodeBufferList :: CodeGen e s [(Ptr Word8, Int)]
getCodeBufferList = do st <- getInternalState
                       return $ bufferList st ++ [(buffer st, bufferOfs st)]
newLabel :: CodeGen e s Label
newLabel =
    do state <- getInternalState
       let lab = nextLabel state
       setInternalState state{nextLabel = lab + 1}
       return (Label lab "")
newNamedLabel :: String -> CodeGen e s Label
newNamedLabel name =
    do state <- getInternalState
       let lab = nextLabel state
       setInternalState state{nextLabel = lab + 1}
       return (Label lab name)
setLabel :: CodeGen e s Label
setLabel =
    do l <- newLabel
       defineLabel l
       return l
emitRelocInfo :: Int -> RelocKind -> FunPtr a -> CodeGen e s ()
emitRelocInfo ofs knd addr = 
    do state <- getInternalState
       setInternalState state{relocEntries =
                              Reloc{offset = ofs, 
                                    kind = knd,
                                    address = castFunPtr addr} : 
                              (relocEntries state)}
emit8 :: Word8 -> CodeGen e s ()
emit8 op = 
    CodeGen (\ _ (ustate, state) -> 
      do let buf = buffer state
             ptr = bufferOfs state
         pokeByteOff buf ptr op
         return $ ((ustate, state{bufferOfs = ptr + 1}), Right ()))
emit8At :: Int -> Word8 -> CodeGen e s ()
emit8At pos op = 
    CodeGen (\ _ (ustate, state) -> 
      do let buf = buffer state
         pokeByteOff buf pos op
         return $ ((ustate, state), Right ()))
peek8At :: Int -> CodeGen e s Word8
peek8At pos =
    CodeGen (\ _ (ustate, state) -> 
      do let buf = buffer state
         b <- peekByteOff buf pos
         return $ ((ustate, state), Right b))
emit32 :: Word32 -> CodeGen e s ()
emit32 op = 
    CodeGen (\ _ (ustate, state) -> 
      do let buf = buffer state
             ptr = bufferOfs state
         pokeByteOff buf ptr op
         return $ ((ustate, state{bufferOfs = ptr + 4}), Right ()))
emit32At :: Int -> Word32 -> CodeGen e s ()
emit32At pos op = 
    CodeGen (\ _ (ustate, state) -> 
      do let buf = buffer state
         pokeByteOff buf pos op
         return $ ((ustate, state), Right ()))
defineLabel :: Label -> CodeGen e s ()
defineLabel (Label lab name) = 
    do state <- getInternalState
       case Map.lookup lab (definedLabels state) of
         Just _ -> failCodeGen $ text "duplicate definition of label" <+> 
                     int lab
         _ -> return ()
       case Map.lookup lab (pendingFixups state) of
         Just fixups -> do mapM_ (performFixup (buffer state) (bufferOfs state)) fixups
                           setInternalState state{pendingFixups = Map.delete lab (pendingFixups state)}
         Nothing -> return ()
       state1 <- getInternalState
       setInternalState state1{definedLabels = Map.insert lab (buffer state1, bufferOfs state1, name) (definedLabels state1)}
performFixup :: Ptr Word8 -> Int -> FixupEntry -> CodeGen e s ()
performFixup labBuf labOfs (FixupEntry{fueBuffer = buf, fueOfs = ofs, fueKind = knd}) =
    do let diff = (labBuf `plusPtr` labOfs) `minusPtr` (buf `plusPtr` ofs)
       liftIO $ case knd of
                  Fixup8  -> pokeByteOff buf ofs (fromIntegral diff  1 :: Word8)
                  Fixup16 -> pokeByteOff buf ofs (fromIntegral diff  2 :: Word16)
                  Fixup32 -> pokeByteOff buf ofs (fromIntegral diff  4 :: Word32)
                  Fixup32Absolute -> pokeByteOff buf ofs (fromIntegral (ptrToWordPtr (labBuf `plusPtr` labOfs)) :: Word32)
       return ()
(@@) :: Label -> CodeGen e s a -> CodeGen e s a
(@@) lab gen = do defineLabel lab
                  gen
emitFixup :: Label -> Int -> FixupKind -> CodeGen e s ()
emitFixup (Label lab _) ofs knd = 
    do state <- getInternalState 
       let base = buffer state
           ptr = bufferOfs state
           fue = FixupEntry{fueBuffer = base,
                            fueOfs = ptr + ofs,
                            fueKind = knd}
       case Map.lookup lab (definedLabels state) of
         Just (labBuf, labOfs, _) -> performFixup labBuf labOfs fue
         Nothing -> setInternalState state{pendingFixups = Map.insertWith (++) lab [fue] (pendingFixups state)}
labelAddress :: Label -> CodeGen e s (Ptr a)
labelAddress (Label lab name) = do
  state <- getInternalState
  case Map.lookup lab (definedLabels state) of
    Just (labBuf, labOfs, _) -> return $ plusPtr labBuf labOfs
    Nothing -> fail $ "Label " ++ show lab ++ "(" ++ name ++ ") not yet defined"
disassemble :: CodeGen e s [Dis.Instruction]
disassemble = do
  s <- getInternalState
  let buffers = bufferList s
  r <- mapM (\ (buff, len) -> do
             r <- liftIO $ Dis.disassembleBlock buff len
             case r of
                    Left err -> cgFail $ show err
                    Right instr -> return instr
            ) $ buffers ++ [(buffer s, bufferOfs s)]
  r' <- insertLabels (concat r)
  return r'
 where insertLabels :: [Dis.Instruction] -> CodeGen e s [Dis.Instruction]
       insertLabels = liftM concat . mapM ins
       ins :: Dis.Instruction -> CodeGen e s [Dis.Instruction]
       ins i@(Dis.BadInstruction{}) = return [i]
       ins i@(Dis.PseudoInstruction{}) = return [i]
       ins i@(Dis.Instruction{Dis.address = addr}) =
           do state <- getInternalState
              let allLabs = Map.toList (definedLabels state)
                  labs = filter (\ (_, (buf, ofs, _)) -> fromIntegral (ptrToWordPtr (buf `plusPtr` ofs)) == addr) allLabs
                  createLabel (l, (buf, ofs, name)) = Dis.PseudoInstruction addr
                                                        (case name of
                                                           "" ->
                                                               "label " ++ show l ++ 
                                                                " [" ++ 
                                                                hex32 (fromIntegral (ptrToWordPtr (buf `plusPtr` ofs))) ++ 
                                                                "]"
                                                           _ -> name ++ ": [" ++ 
                                                                  hex32 (fromIntegral (ptrToWordPtr (buf `plusPtr` ofs))) ++ 
                                                                  "]")
              return $ fmap createLabel labs ++ [i]
       hex32 :: Int -> String
       hex32 i =
              let w :: Word32
                  w = fromIntegral i
                  s = showHex w ""
              in take (8  length s) (repeat '0') ++ s
#ifndef __HADDOCK__
callDecl :: String -> Q Type -> Q [Dec]
callDecl ns qt =  do
    t0 <- qt
    let (tvars, cxt, t) = case t0 of
                         ForallT vs c t' -> (vs, c, t')
                         _ -> ([], [], t0)
    let name = mkName ns
    let funptr = AppT (ConT $ mkName "FunPtr") t
    let ioresult = t 
    let ty = AppT (AppT ArrowT funptr) ioresult
    dynName <- newName "conv"
    let dyn = ForeignD $ ImportF CCall Safe "dynamic" dynName $ ForallT tvars cxt ty
    vs <- mkArgs t
    cbody <- [| CodeGen (\env (ustate, state) ->
                        do let code = firstBuffer state
                           res <- liftIO $ $(do
                                             c <- newName "c"
                                             cast <- [|castPtrToFunPtr|]
                                             let f = AppE (VarE dynName)
                                                          (AppE cast
                                                                (VarE c))
                                             return $ LamE [VarP c] $ foldl AppE f $ map VarE vs
                                            ) code
                           return $ ((ustate, state), Right res))|]
    let call = ValD (VarP name) (NormalB $ LamE (map VarP vs) cbody) []
    return [ dyn, call ]
mkArgs (AppT (AppT ArrowT _from) to) = do
  v  <- newName "v"
  vs <- mkArgs to
  return $ v : vs
mkArgs _ = return []
addIO (AppT t@(AppT ArrowT _from) to) = AppT t $ addIO to
addIO t = AppT (ConT $ mkName "IO") t
#else
callDecl :: String -> Q Type -> Q [Dec]
#endif