{-# LANGUAGE CPP               #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes        #-}
{-# LANGUAGE RecordWildCards   #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}

-- | A simple interpreter for a subset of Ivory.
module Ivory.Eval
  ( runEval
  , runEvalStartingFrom
  , Error
  , Eval
  , EvalState(EvalState)
  , Value(..)
  , initState
  , openModule
  , evalAssert
  , evalBlock
  , evalCond
  , evalDeref
  , evalExpr
  , evalInit
  , evalLit
  , evalOp
  , evalRequires
  , evalStmt
  ) where

import           Prelude                                 ()
import           Prelude.Compat                          hiding (and, div, mod,
                                                          negate, not, or)
import qualified Prelude.Compat                          as Prelude

import           Control.Monad                           (foldM, unless, void)
#if MIN_VERSION_base_compat(0,10,0)
import qualified Control.Monad.Fail.Compat as Fail
#endif
import           Data.Int
import qualified Data.Map                                as Map
import           Data.Maybe
import qualified Data.Sequence                           as Seq
import           Data.Word
import qualified Ivory.Language.Array                    as I
import qualified Ivory.Language.Syntax                   as I
import           Ivory.Language.Syntax.Concrete.Location
import           MonadLib                                (ExceptionM (..),
                                                          ExceptionT, Id,
                                                          StateM (..), StateT,
                                                          runExceptionT, runId,
                                                          runStateT, sets_)

-- XXX: DEBUG
-- import Debug.Trace

type Error  = String
type Eval a = StateT EvalState (ExceptionT Error Id) a

#if MIN_VERSION_base_compat(0,10,0)
instance {-# OVERLAPS #-} Fail.MonadFail (StateT EvalState (ExceptionT Error Id)) where
  fail = raise
#endif

runEval :: Eval a -> Either Error a
runEval doThis = fmap fst (runEvalStartingFrom (initState Map.empty) doThis)

runEvalStartingFrom :: EvalState -> Eval a -> Either Error (a, EvalState)
runEvalStartingFrom st doThis = runId (runExceptionT (runStateT st doThis))

data EvalState = EvalState
  { store   :: Map.Map I.Sym Value
  , loc     :: SrcLoc
  , structs :: Map.Map I.Sym I.Struct
  } deriving Show

initState :: Map.Map I.Sym Value -> EvalState
initState st = EvalState st NoLoc Map.empty

-- | Run an action inside the scope of a given module.
openModule :: I.Module -> Eval a -> Eval a
openModule (I.Module {..}) doThis = do
  oldStrs <- fmap structs get
  sets_ (\s -> s { structs = Map.union (structs s) newStructs })
  res <- doThis
  sets_ (\s -> s { structs = oldStrs })
  return res
  where
  newStructs = Map.fromList
               [ (sym, struct)
               | struct@(I.Struct sym _) <- I.public modStructs ++ I.private modStructs
               ]

data Value
  = Sint8  Int8
  | Sint16 Int16
  | Sint32 Int32
  | Sint64 Int64
  | Uint8  Word8
  | Uint16 Word16
  | Uint32 Word32
  | Uint64 Word64
  | Float  Float
  | Double Double
  | Char   Char
  | String String
  | Bool   Bool
  | Array  (Seq.Seq Value)
  | Struct (Map.Map I.Sym Value)
  | Ref    I.Sym
  deriving (Show, Eq, Ord)

eq :: Value -> Value -> Value
eq x y = Bool (x == y)

neq :: Value -> Value -> Value
neq x y = Bool (x /= y)

not :: Value -> Value
not (Bool x) = Bool (Prelude.not x)
not x        = error $ "invalid operands to `not`: " ++ show x

and :: Value -> Value -> Value
and (Bool x) (Bool y) = Bool (x && y)
and x        y        = error $ "invalid operands to `and`: " ++ show (x,y)

or :: Value -> Value -> Value
or (Bool x) (Bool y) = Bool (x || y)
or x        y        = error $ "invalid operands to `or`: " ++ show (x,y)

ordOp :: (forall a. Ord a => a -> a -> Bool) -> Value -> Value -> Value
ordOp op (Sint8  x) (Sint8  y) = Bool (op x y)
ordOp op (Sint16 x) (Sint16 y) = Bool (op x y)
ordOp op (Sint32 x) (Sint32 y) = Bool (op x y)
ordOp op (Sint64 x) (Sint64 y) = Bool (op x y)
ordOp op (Uint8  x) (Uint8  y) = Bool (op x y)
ordOp op (Uint16 x) (Uint16 y) = Bool (op x y)
ordOp op (Uint32 x) (Uint32 y) = Bool (op x y)
ordOp op (Uint64 x) (Uint64 y) = Bool (op x y)
ordOp op (Float  x) (Float  y) = Bool (op x y)
ordOp op (Double x) (Double y) = Bool (op x y)
ordOp op (Char   x) (Char   y) = Bool (op x y)
ordOp _ x y = error $ "invalid operands to `ordOp`: " ++ show (x,y)

gt  :: Value -> Value -> Value
gt  = ordOp (>)

gte :: Value -> Value -> Value
gte = ordOp (>=)

lt  :: Value -> Value -> Value
lt  = ordOp (<)

lte :: Value -> Value -> Value
lte = ordOp (<=)

numUnOp :: (forall a. Num a => a -> a) -> Value -> Value
numUnOp op (Sint8  x) = Sint8  (op x)
numUnOp op (Sint16 x) = Sint16 (op x)
numUnOp op (Sint32 x) = Sint32 (op x)
numUnOp op (Sint64 x) = Sint64 (op x)
numUnOp op (Uint8  x) = Uint8  (op x)
numUnOp op (Uint16 x) = Uint16 (op x)
numUnOp op (Uint32 x) = Uint32 (op x)
numUnOp op (Uint64 x) = Uint64 (op x)
numUnOp op (Float  x) = Float  (op x)
numUnOp op (Double x) = Double (op x)
numUnOp _ x = error $ "invalid operands to `numUnOp`: " ++ show x

negate :: Value -> Value
negate = numUnOp Prelude.negate

numBinOp :: (forall a. Num a => a -> a -> a) -> Value -> Value -> Value
numBinOp op (Sint8  x) (Sint8  y) = Sint8  (op x y)
numBinOp op (Sint16 x) (Sint16 y) = Sint16 (op x y)
numBinOp op (Sint32 x) (Sint32 y) = Sint32 (op x y)
numBinOp op (Sint64 x) (Sint64 y) = Sint64 (op x y)
numBinOp op (Uint8  x) (Uint8  y) = Uint8  (op x y)
numBinOp op (Uint16 x) (Uint16 y) = Uint16 (op x y)
numBinOp op (Uint32 x) (Uint32 y) = Uint32 (op x y)
numBinOp op (Uint64 x) (Uint64 y) = Uint64 (op x y)
numBinOp op (Float  x) (Float  y) = Float  (op x y)
numBinOp op (Double x) (Double y) = Double (op x y)
numBinOp _ x y = error $ "invalid operands to `numBinOp`: " ++ show (x,y)

add :: Value -> Value -> Value
add = numBinOp (+)

sub :: Value -> Value -> Value
sub = numBinOp (-)

mul :: Value -> Value -> Value
mul = numBinOp (*)

div :: Value -> Value -> Value
(Sint8  x) `div` (Sint8  y) = Sint8  (x `Prelude.div` y)
(Sint16 x) `div` (Sint16 y) = Sint16 (x `Prelude.div` y)
(Sint32 x) `div` (Sint32 y) = Sint32 (x `Prelude.div` y)
(Sint64 x) `div` (Sint64 y) = Sint64 (x `Prelude.div` y)
(Uint8  x) `div` (Uint8  y) = Uint8  (x `Prelude.div` y)
(Uint16 x) `div` (Uint16 y) = Uint16 (x `Prelude.div` y)
(Uint32 x) `div` (Uint32 y) = Uint32 (x `Prelude.div` y)
(Uint64 x) `div` (Uint64 y) = Uint64 (x `Prelude.div` y)
(Float  x) `div` (Float  y) = Float  (x / y)
(Double x) `div` (Double y) = Double (x / y)
x          `div` y          = error $ "invalid operands to `div`: " ++ show (x,y)

mod :: Value -> Value -> Value
(Sint8  x) `mod` (Sint8  y) = Sint8  (x `Prelude.mod` y)
(Sint16 x) `mod` (Sint16 y) = Sint16 (x `Prelude.mod` y)
(Sint32 x) `mod` (Sint32 y) = Sint32 (x `Prelude.mod` y)
(Sint64 x) `mod` (Sint64 y) = Sint64 (x `Prelude.mod` y)
(Uint8  x) `mod` (Uint8  y) = Uint8  (x `Prelude.mod` y)
(Uint16 x) `mod` (Uint16 y) = Uint16 (x `Prelude.mod` y)
(Uint32 x) `mod` (Uint32 y) = Uint32 (x `Prelude.mod` y)
(Uint64 x) `mod` (Uint64 y) = Uint64 (x `Prelude.mod` y)
-- (Float  x) `mod` (Float  y) = Float  (x `Prelude.mod` y)
-- (Double x) `mod` (Double y) = Double (x `Prelude.mod` y)
x          `mod` y          = error $ "invalid operands to `mod`: " ++ show (x,y)

readStore :: I.Sym -> Eval Value
readStore sym = do
  st <- fmap store get
  case Map.lookup sym st of
    Nothing -> raise $ "Unbound variable: `" ++ sym ++ "'!"
    Just v  -> return v

writeStore :: I.Sym -> Value -> Eval ()
writeStore sym val = sets_ (\s -> s { store = Map.insert sym val (store s) })

modifyStore :: I.Sym -> (Value -> Value) -> Eval ()
modifyStore sym f = sets_ (\s -> s { store = Map.update (Just . f) sym (store s) })

updateLoc :: SrcLoc -> Eval ()
updateLoc loc = sets_ (\ s -> s { loc = loc })

lookupStruct :: String -> Eval I.Struct
lookupStruct str = do
  structs <- fmap structs get
  case Map.lookup str structs of
    Nothing  -> raise $ "Couldn't find struct: " ++ str
    Just str -> return str

----------------------------------------------------------------------
-- | Main Evaluator
----------------------------------------------------------------------
evalBlock :: I.Block -> Eval ()
evalBlock = mapM_ evalStmt

evalRequires :: [I.Require] -> Eval Bool
evalRequires = fmap Prelude.and . mapM (evalCond . I.getRequire)

evalCond :: I.Cond -> Eval Bool
evalCond cond = case cond of
  I.CondBool expr -> do
    val <- evalExpr I.TyBool expr
    case val of
      Bool True  -> return True
      Bool False -> return False
      _          -> raise $ "evalCond: expected boolean, got: " ++ show val
  I.CondDeref ty expr var cond -> do
    evalStmt (I.Deref ty var expr)
    evalCond cond

evalStmt :: I.Stmt -> Eval ()
evalStmt stmt = case stmt of
  I.Comment (I.SourcePos loc)
    -> updateLoc loc
  I.Assert expr
    -> evalAssert expr
  I.CompilerAssert expr
    -> evalAssert expr
  I.IfTE cond true false
    -> do b <- evalExpr I.TyBool cond
          case b of
            Bool True  -> evalBlock true
            Bool False -> evalBlock false
            _          -> raise $ "evalStmt: IfTE: expected true or false, got: " ++ show b
  I.Deref _ty var expr
    -> do val <- evalDeref _ty expr
          case val of
            Ref ref -> do
              val <- readStore ref
              writeStore (varSym var) val
            _ -> writeStore (varSym var) val
  I.Assign _ty var expr
    -> do val <- evalExpr _ty expr
          writeStore (varSym var) val
  I.Local ty var init
    -> do val <- evalInit ty init
          writeStore (varSym var) val
  I.AllocRef _ty var ref
    -> writeStore (varSym var) (Ref $ nameSym ref)
  I.Loop _ var expr incr body
    -> do val <- evalExpr I.ixRep expr
          writeStore (varSym var) val
          let (checkDone, stepFn, doneExpr) = case incr of
                I.IncrTo expr -> ((>=), (`add` Sint32 1), expr) -- XXX: don't hard-code ixrep
                I.DecrTo expr -> ((<=), (`sub` Sint32 1), expr)
          let step = modifyStore (varSym var) stepFn
          let done = do curVal  <- readStore (varSym var)
                        doneVal <- evalExpr I.ixRep doneExpr
                        return (checkDone curVal doneVal)
          untilM done (evalBlock body >> step)
  I.Return (I.Typed  _ty expr)
    -> void $ evalExpr _ty expr
  I.Store _ty (I.ExpAddrOfGlobal dst) expr
    -> do val <- evalExpr _ty expr
          writeStore dst val
  I.Store _ty (I.ExpVar dst) expr
    -> do val <- evalExpr _ty expr
          Ref ref <- readStore (varSym dst)
          writeStore ref val
  I.RefCopy _ty (I.ExpAddrOfGlobal dst) expr
    -> do val <- evalExpr _ty expr
          writeStore dst val
  _ -> error $ show stmt

evalDeref :: I.Type -> I.Expr -> Eval Value
evalDeref _ty expr = case expr of
  I.ExpSym sym -> readRef sym
  I.ExpVar var -> readRef (varSym var)
  I.ExpAddrOfGlobal sym -> readStore sym
  I.ExpIndex tarr arr tidx idx
    -> do Array arr  <- evalDeref tarr arr
          Sint32 idx <- evalExpr tidx idx
          return (arr `Seq.index` fromIntegral idx)
  I.ExpLabel tstr str lab
    -> do Struct str <- evalDeref tstr str
          return (fromJust $ Map.lookup lab str)
  _ -> error $ show expr

readRef :: I.Sym -> Eval Value
readRef sym = do
  val <- readStore sym
  case val of
    Ref ref -> readStore ref
    _       -> raise $ "Expected Ref, got: " ++ show val

evalAssert :: I.Expr -> Eval ()
evalAssert asrt = do
  b <- evalExpr I.TyBool asrt
  case b of
    Bool True -> return ()
    _         -> raise $ "Assertion failed: " ++ show (b,asrt)

evalExpr :: I.Type -> I.Expr -> Eval Value
evalExpr ty expr = case expr of
  I.ExpSym sym -> readStore sym
  I.ExpVar var -> readStore (varSym var)
  I.ExpLit lit -> evalLit ty lit
  I.ExpOp op exprs
    -> do let opTy = case op of
                I.ExpEq t -> t
                I.ExpNeq t -> t
                I.ExpGt _ t -> t
                I.ExpLt _ t -> t
                I.ExpIsNan t -> t
                I.ExpIsInf t -> t
                _ -> ty
          vals <- mapM (evalExpr opTy) exprs
          evalOp op vals
  I.ExpSafeCast fromTy expr
    -> do val <- evalExpr fromTy expr
          return (cast ty fromTy val)
  I.ExpToIx expr max
    -> fmap (`mod` Sint32 (fromInteger max)) (evalExpr I.ixRep expr)
  _ -> error $ show expr

cast :: I.Type -> I.Type -> Value -> Value
cast toTy _fromTy val = mkVal toTy integer
  where
  integer = case val of
    Sint8  i -> toInteger i
    Sint16 i -> toInteger i
    Sint32 i -> toInteger i
    Sint64 i -> toInteger i
    Uint8  i -> toInteger i
    Uint16 i -> toInteger i
    Uint32 i -> toInteger i
    Uint64 i -> toInteger i
    _        -> error $ "Expected number, got: " ++ show val

evalLit :: I.Type -> I.Literal -> Eval Value
evalLit ty lit = case lit of
  I.LitInteger i -> return (mkVal ty i)
  I.LitFloat c   -> return (Float c)
  I.LitDouble c  -> return (Double c)
  I.LitChar c    -> return (Char c)
  I.LitBool b    -> return (Bool b)
  I.LitString s  -> return (String s)
  _              -> raise $ "evalLit: can't handle: " ++ show lit

mkVal :: I.Type -> Integer -> Value
mkVal ty = case ty of
  I.TyInt  I.Int8   -> Sint8  . fromInteger
  I.TyInt  I.Int16  -> Sint16 . fromInteger
  I.TyInt  I.Int32  -> Sint32 . fromInteger
  I.TyInt  I.Int64  -> Sint64 . fromInteger
  I.TyWord I.Word8  -> Uint8  . fromInteger
  I.TyWord I.Word16 -> Uint16 . fromInteger
  I.TyWord I.Word32 -> Uint32 . fromInteger
  I.TyWord I.Word64 -> Uint64 . fromInteger
  I.TyIndex _       -> Sint32 . fromInteger -- XXX: don't hard-code index rep
  I.TyFloat         -> Float  . fromInteger
  I.TyDouble        -> Double . fromInteger
  I.TyBool          -> Bool   . toEnum . fromInteger
  I.TyChar          -> Char   . toEnum . fromInteger
  _ -> error $ "mkVal: " ++ show ty

evalOp :: I.ExpOp -> [Value] -> Eval Value
evalOp (I.ExpEq _)       [x, y] = return (x `eq`  y)
evalOp (I.ExpNeq _)      [x, y] = return (x `neq` y)
evalOp (I.ExpGt True _)  [x, y] = return (x `gte` y)
evalOp (I.ExpGt False _) [x, y] = return (x `gt`  y)
evalOp (I.ExpLt True _)  [x, y] = return (x `lte` y)
evalOp (I.ExpLt False _) [x, y] = return (x `lt`  y)
evalOp I.ExpAdd          [x, y] = return (x `add` y)
evalOp I.ExpSub          [x, y] = return (x `sub` y)
evalOp I.ExpMul          [x, y] = return (x `mul` y)
evalOp I.ExpDiv          [x, y] = return (x `div` y)
evalOp I.ExpMod          [x, y] = return (x `mod` y)
evalOp I.ExpNegate       [x]    = return (negate x)
evalOp I.ExpNot          [x]    = return (not x)
evalOp I.ExpAnd          [x, y] = return (x `and` y)
evalOp I.ExpOr           [x, y] = return (x `or`  y)
evalOp I.ExpCond [cond, true, false] =
  case cond of
    Bool True  -> return true
    Bool False -> return false
    _          -> raise $ "evalOp: ExpCond: expected true or false, got: " ++ show cond
evalOp op args = raise $ "evalOp: can't handle: " ++ show (op, args)

evalInit :: I.Type -> I.Init -> Eval Value
evalInit ty init = case init of
  I.InitZero
    -> case ty of
         I.TyArr _ _  -> evalInit ty (I.InitArray [] True)
         I.TyStruct _ -> evalInit ty (I.InitStruct [])
         _            -> return (mkVal ty 0)
  I.InitExpr ty expr
    -> evalExpr ty expr
  I.InitArray inits _
    -> case ty of
         I.TyArr len ty
           -> Array . Seq.fromList
              <$> mapM (evalInit ty) (take len $ inits ++ repeat I.InitZero)
         _ -> raise $ "evalInit: InitArray: unexpected type: " ++ show ty
  I.InitStruct inits
    -> case ty of
         I.TyStruct str -> do
           I.Struct _ fields <- lookupStruct str
           zstr <- foldM (\ str (I.Typed ty fld) -> do
                             val <- evalInit ty I.InitZero
                             return (Map.insert fld val str))
                         Map.empty fields
           str <- foldM (\ str (fld, init) -> do
                            val <- evalInit (lookupTyped fld fields) init
                            return (Map.insert fld val str))
                        zstr inits
           return $ Struct str
         _ -> raise $ "evalInit: InitStruct: unexpected type: " ++ show ty

lookupTyped :: (Show a, Eq a) => a -> [I.Typed a] -> I.Type
lookupTyped a [] = error $ "lookupTyped: couldn't find: " ++ show a
lookupTyped a (I.Typed t x : xs)
  | a == x    = t
  | otherwise = lookupTyped a xs

varSym :: I.Var -> I.Sym
varSym (I.VarName sym)     = sym
varSym (I.VarInternal sym) = sym
varSym (I.VarLitName sym)  = sym

nameSym :: I.Name -> I.Sym
nameSym (I.NameSym sym) = sym
nameSym (I.NameVar var) = varSym var

----------------------------------------------------------------------
-- | Utilities
----------------------------------------------------------------------
untilM :: Monad m => m Bool -> m () -> m ()
untilM done doThis = do
  b <- done
  unless b (doThis >> untilM done doThis)