{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
module Ivory.Eval
( runEval
, runEvalStartingFrom
, Error
, Eval
, EvalState(EvalState)
, Value(..)
, initState
, openModule
, evalAssert
, evalBlock
, evalCond
, evalDeref
, evalExpr
, evalInit
, evalLit
, evalOp
, evalRequires
, evalStmt
) where
import Prelude ()
import Prelude.Compat hiding (and, div, mod,
negate, not, or)
import qualified Prelude.Compat as Prelude
import Control.Monad (foldM, unless, void)
#if MIN_VERSION_base_compat(0,10,0)
import qualified Control.Monad.Fail.Compat as Fail
#endif
import Data.Int
import qualified Data.Map as Map
import Data.Maybe
import qualified Data.Sequence as Seq
import Data.Word
import qualified Ivory.Language.Array as I
import qualified Ivory.Language.Syntax as I
import Ivory.Language.Syntax.Concrete.Location
import MonadLib (ExceptionM (..),
ExceptionT, Id,
StateM (..), StateT,
runExceptionT, runId,
runStateT, sets_)
type Error = String
type Eval a = StateT EvalState (ExceptionT Error Id) a
#if MIN_VERSION_base_compat(0,10,0)
instance {-# OVERLAPS #-} Fail.MonadFail (StateT EvalState (ExceptionT Error Id)) where
fail = raise
#endif
runEval :: Eval a -> Either Error a
runEval doThis = fmap fst (runEvalStartingFrom (initState Map.empty) doThis)
runEvalStartingFrom :: EvalState -> Eval a -> Either Error (a, EvalState)
runEvalStartingFrom st doThis = runId (runExceptionT (runStateT st doThis))
data EvalState = EvalState
{ store :: Map.Map I.Sym Value
, loc :: SrcLoc
, structs :: Map.Map I.Sym I.Struct
} deriving Show
initState :: Map.Map I.Sym Value -> EvalState
initState st = EvalState st NoLoc Map.empty
openModule :: I.Module -> Eval a -> Eval a
openModule (I.Module {..}) doThis = do
oldStrs <- fmap structs get
sets_ (\s -> s { structs = Map.union (structs s) newStructs })
res <- doThis
sets_ (\s -> s { structs = oldStrs })
return res
where
newStructs = Map.fromList
[ (sym, struct)
| struct@(I.Struct sym _) <- I.public modStructs ++ I.private modStructs
]
data Value
= Sint8 Int8
| Sint16 Int16
| Sint32 Int32
| Sint64 Int64
| Uint8 Word8
| Uint16 Word16
| Uint32 Word32
| Uint64 Word64
| Float Float
| Double Double
| Char Char
| String String
| Bool Bool
| Array (Seq.Seq Value)
| Struct (Map.Map I.Sym Value)
| Ref I.Sym
deriving (Show, Eq, Ord)
eq :: Value -> Value -> Value
eq x y = Bool (x == y)
neq :: Value -> Value -> Value
neq x y = Bool (x /= y)
not :: Value -> Value
not (Bool x) = Bool (Prelude.not x)
not x = error $ "invalid operands to `not`: " ++ show x
and :: Value -> Value -> Value
and (Bool x) (Bool y) = Bool (x && y)
and x y = error $ "invalid operands to `and`: " ++ show (x,y)
or :: Value -> Value -> Value
or (Bool x) (Bool y) = Bool (x || y)
or x y = error $ "invalid operands to `or`: " ++ show (x,y)
ordOp :: (forall a. Ord a => a -> a -> Bool) -> Value -> Value -> Value
ordOp op (Sint8 x) (Sint8 y) = Bool (op x y)
ordOp op (Sint16 x) (Sint16 y) = Bool (op x y)
ordOp op (Sint32 x) (Sint32 y) = Bool (op x y)
ordOp op (Sint64 x) (Sint64 y) = Bool (op x y)
ordOp op (Uint8 x) (Uint8 y) = Bool (op x y)
ordOp op (Uint16 x) (Uint16 y) = Bool (op x y)
ordOp op (Uint32 x) (Uint32 y) = Bool (op x y)
ordOp op (Uint64 x) (Uint64 y) = Bool (op x y)
ordOp op (Float x) (Float y) = Bool (op x y)
ordOp op (Double x) (Double y) = Bool (op x y)
ordOp op (Char x) (Char y) = Bool (op x y)
ordOp _ x y = error $ "invalid operands to `ordOp`: " ++ show (x,y)
gt :: Value -> Value -> Value
gt = ordOp (>)
gte :: Value -> Value -> Value
gte = ordOp (>=)
lt :: Value -> Value -> Value
lt = ordOp (<)
lte :: Value -> Value -> Value
lte = ordOp (<=)
numUnOp :: (forall a. Num a => a -> a) -> Value -> Value
numUnOp op (Sint8 x) = Sint8 (op x)
numUnOp op (Sint16 x) = Sint16 (op x)
numUnOp op (Sint32 x) = Sint32 (op x)
numUnOp op (Sint64 x) = Sint64 (op x)
numUnOp op (Uint8 x) = Uint8 (op x)
numUnOp op (Uint16 x) = Uint16 (op x)
numUnOp op (Uint32 x) = Uint32 (op x)
numUnOp op (Uint64 x) = Uint64 (op x)
numUnOp op (Float x) = Float (op x)
numUnOp op (Double x) = Double (op x)
numUnOp _ x = error $ "invalid operands to `numUnOp`: " ++ show x
negate :: Value -> Value
negate = numUnOp Prelude.negate
numBinOp :: (forall a. Num a => a -> a -> a) -> Value -> Value -> Value
numBinOp op (Sint8 x) (Sint8 y) = Sint8 (op x y)
numBinOp op (Sint16 x) (Sint16 y) = Sint16 (op x y)
numBinOp op (Sint32 x) (Sint32 y) = Sint32 (op x y)
numBinOp op (Sint64 x) (Sint64 y) = Sint64 (op x y)
numBinOp op (Uint8 x) (Uint8 y) = Uint8 (op x y)
numBinOp op (Uint16 x) (Uint16 y) = Uint16 (op x y)
numBinOp op (Uint32 x) (Uint32 y) = Uint32 (op x y)
numBinOp op (Uint64 x) (Uint64 y) = Uint64 (op x y)
numBinOp op (Float x) (Float y) = Float (op x y)
numBinOp op (Double x) (Double y) = Double (op x y)
numBinOp _ x y = error $ "invalid operands to `numBinOp`: " ++ show (x,y)
add :: Value -> Value -> Value
add = numBinOp (+)
sub :: Value -> Value -> Value
sub = numBinOp (-)
mul :: Value -> Value -> Value
mul = numBinOp (*)
div :: Value -> Value -> Value
(Sint8 x) `div` (Sint8 y) = Sint8 (x `Prelude.div` y)
(Sint16 x) `div` (Sint16 y) = Sint16 (x `Prelude.div` y)
(Sint32 x) `div` (Sint32 y) = Sint32 (x `Prelude.div` y)
(Sint64 x) `div` (Sint64 y) = Sint64 (x `Prelude.div` y)
(Uint8 x) `div` (Uint8 y) = Uint8 (x `Prelude.div` y)
(Uint16 x) `div` (Uint16 y) = Uint16 (x `Prelude.div` y)
(Uint32 x) `div` (Uint32 y) = Uint32 (x `Prelude.div` y)
(Uint64 x) `div` (Uint64 y) = Uint64 (x `Prelude.div` y)
(Float x) `div` (Float y) = Float (x / y)
(Double x) `div` (Double y) = Double (x / y)
x `div` y = error $ "invalid operands to `div`: " ++ show (x,y)
mod :: Value -> Value -> Value
(Sint8 x) `mod` (Sint8 y) = Sint8 (x `Prelude.mod` y)
(Sint16 x) `mod` (Sint16 y) = Sint16 (x `Prelude.mod` y)
(Sint32 x) `mod` (Sint32 y) = Sint32 (x `Prelude.mod` y)
(Sint64 x) `mod` (Sint64 y) = Sint64 (x `Prelude.mod` y)
(Uint8 x) `mod` (Uint8 y) = Uint8 (x `Prelude.mod` y)
(Uint16 x) `mod` (Uint16 y) = Uint16 (x `Prelude.mod` y)
(Uint32 x) `mod` (Uint32 y) = Uint32 (x `Prelude.mod` y)
(Uint64 x) `mod` (Uint64 y) = Uint64 (x `Prelude.mod` y)
x `mod` y = error $ "invalid operands to `mod`: " ++ show (x,y)
readStore :: I.Sym -> Eval Value
readStore sym = do
st <- fmap store get
case Map.lookup sym st of
Nothing -> raise $ "Unbound variable: `" ++ sym ++ "'!"
Just v -> return v
writeStore :: I.Sym -> Value -> Eval ()
writeStore sym val = sets_ (\s -> s { store = Map.insert sym val (store s) })
modifyStore :: I.Sym -> (Value -> Value) -> Eval ()
modifyStore sym f = sets_ (\s -> s { store = Map.update (Just . f) sym (store s) })
updateLoc :: SrcLoc -> Eval ()
updateLoc loc = sets_ (\ s -> s { loc = loc })
lookupStruct :: String -> Eval I.Struct
lookupStruct str = do
structs <- fmap structs get
case Map.lookup str structs of
Nothing -> raise $ "Couldn't find struct: " ++ str
Just str -> return str
evalBlock :: I.Block -> Eval ()
evalBlock = mapM_ evalStmt
evalRequires :: [I.Require] -> Eval Bool
evalRequires = fmap Prelude.and . mapM (evalCond . I.getRequire)
evalCond :: I.Cond -> Eval Bool
evalCond cond = case cond of
I.CondBool expr -> do
val <- evalExpr I.TyBool expr
case val of
Bool True -> return True
Bool False -> return False
_ -> raise $ "evalCond: expected boolean, got: " ++ show val
I.CondDeref ty expr var cond -> do
evalStmt (I.Deref ty var expr)
evalCond cond
evalStmt :: I.Stmt -> Eval ()
evalStmt stmt = case stmt of
I.Comment (I.SourcePos loc)
-> updateLoc loc
I.Assert expr
-> evalAssert expr
I.CompilerAssert expr
-> evalAssert expr
I.IfTE cond true false
-> do b <- evalExpr I.TyBool cond
case b of
Bool True -> evalBlock true
Bool False -> evalBlock false
_ -> raise $ "evalStmt: IfTE: expected true or false, got: " ++ show b
I.Deref _ty var expr
-> do val <- evalDeref _ty expr
case val of
Ref ref -> do
val <- readStore ref
writeStore (varSym var) val
_ -> writeStore (varSym var) val
I.Assign _ty var expr
-> do val <- evalExpr _ty expr
writeStore (varSym var) val
I.Local ty var init
-> do val <- evalInit ty init
writeStore (varSym var) val
I.AllocRef _ty var ref
-> writeStore (varSym var) (Ref $ nameSym ref)
I.Loop _ var expr incr body
-> do val <- evalExpr I.ixRep expr
writeStore (varSym var) val
let (checkDone, stepFn, doneExpr) = case incr of
I.IncrTo expr -> ((>=), (`add` Sint32 1), expr)
I.DecrTo expr -> ((<=), (`sub` Sint32 1), expr)
let step = modifyStore (varSym var) stepFn
let done = do curVal <- readStore (varSym var)
doneVal <- evalExpr I.ixRep doneExpr
return (checkDone curVal doneVal)
untilM done (evalBlock body >> step)
I.Return (I.Typed _ty expr)
-> void $ evalExpr _ty expr
I.Store _ty (I.ExpAddrOfGlobal dst) expr
-> do val <- evalExpr _ty expr
writeStore dst val
I.Store _ty (I.ExpVar dst) expr
-> do val <- evalExpr _ty expr
Ref ref <- readStore (varSym dst)
writeStore ref val
I.RefCopy _ty (I.ExpAddrOfGlobal dst) expr
-> do val <- evalExpr _ty expr
writeStore dst val
_ -> error $ show stmt
evalDeref :: I.Type -> I.Expr -> Eval Value
evalDeref _ty expr = case expr of
I.ExpSym sym -> readRef sym
I.ExpVar var -> readRef (varSym var)
I.ExpAddrOfGlobal sym -> readStore sym
I.ExpIndex tarr arr tidx idx
-> do Array arr <- evalDeref tarr arr
Sint32 idx <- evalExpr tidx idx
return (arr `Seq.index` fromIntegral idx)
I.ExpLabel tstr str lab
-> do Struct str <- evalDeref tstr str
return (fromJust $ Map.lookup lab str)
_ -> error $ show expr
readRef :: I.Sym -> Eval Value
readRef sym = do
val <- readStore sym
case val of
Ref ref -> readStore ref
_ -> raise $ "Expected Ref, got: " ++ show val
evalAssert :: I.Expr -> Eval ()
evalAssert asrt = do
b <- evalExpr I.TyBool asrt
case b of
Bool True -> return ()
_ -> raise $ "Assertion failed: " ++ show (b,asrt)
evalExpr :: I.Type -> I.Expr -> Eval Value
evalExpr ty expr = case expr of
I.ExpSym sym -> readStore sym
I.ExpVar var -> readStore (varSym var)
I.ExpLit lit -> evalLit ty lit
I.ExpOp op exprs
-> do let opTy = case op of
I.ExpEq t -> t
I.ExpNeq t -> t
I.ExpGt _ t -> t
I.ExpLt _ t -> t
I.ExpIsNan t -> t
I.ExpIsInf t -> t
_ -> ty
vals <- mapM (evalExpr opTy) exprs
evalOp op vals
I.ExpSafeCast fromTy expr
-> do val <- evalExpr fromTy expr
return (cast ty fromTy val)
I.ExpToIx expr max
-> fmap (`mod` Sint32 (fromInteger max)) (evalExpr I.ixRep expr)
_ -> error $ show expr
cast :: I.Type -> I.Type -> Value -> Value
cast toTy _fromTy val = mkVal toTy integer
where
integer = case val of
Sint8 i -> toInteger i
Sint16 i -> toInteger i
Sint32 i -> toInteger i
Sint64 i -> toInteger i
Uint8 i -> toInteger i
Uint16 i -> toInteger i
Uint32 i -> toInteger i
Uint64 i -> toInteger i
_ -> error $ "Expected number, got: " ++ show val
evalLit :: I.Type -> I.Literal -> Eval Value
evalLit ty lit = case lit of
I.LitInteger i -> return (mkVal ty i)
I.LitFloat c -> return (Float c)
I.LitDouble c -> return (Double c)
I.LitChar c -> return (Char c)
I.LitBool b -> return (Bool b)
I.LitString s -> return (String s)
_ -> raise $ "evalLit: can't handle: " ++ show lit
mkVal :: I.Type -> Integer -> Value
mkVal ty = case ty of
I.TyInt I.Int8 -> Sint8 . fromInteger
I.TyInt I.Int16 -> Sint16 . fromInteger
I.TyInt I.Int32 -> Sint32 . fromInteger
I.TyInt I.Int64 -> Sint64 . fromInteger
I.TyWord I.Word8 -> Uint8 . fromInteger
I.TyWord I.Word16 -> Uint16 . fromInteger
I.TyWord I.Word32 -> Uint32 . fromInteger
I.TyWord I.Word64 -> Uint64 . fromInteger
I.TyIndex _ -> Sint32 . fromInteger
I.TyFloat -> Float . fromInteger
I.TyDouble -> Double . fromInteger
I.TyBool -> Bool . toEnum . fromInteger
I.TyChar -> Char . toEnum . fromInteger
_ -> error $ "mkVal: " ++ show ty
evalOp :: I.ExpOp -> [Value] -> Eval Value
evalOp (I.ExpEq _) [x, y] = return (x `eq` y)
evalOp (I.ExpNeq _) [x, y] = return (x `neq` y)
evalOp (I.ExpGt True _) [x, y] = return (x `gte` y)
evalOp (I.ExpGt False _) [x, y] = return (x `gt` y)
evalOp (I.ExpLt True _) [x, y] = return (x `lte` y)
evalOp (I.ExpLt False _) [x, y] = return (x `lt` y)
evalOp I.ExpAdd [x, y] = return (x `add` y)
evalOp I.ExpSub [x, y] = return (x `sub` y)
evalOp I.ExpMul [x, y] = return (x `mul` y)
evalOp I.ExpDiv [x, y] = return (x `div` y)
evalOp I.ExpMod [x, y] = return (x `mod` y)
evalOp I.ExpNegate [x] = return (negate x)
evalOp I.ExpNot [x] = return (not x)
evalOp I.ExpAnd [x, y] = return (x `and` y)
evalOp I.ExpOr [x, y] = return (x `or` y)
evalOp I.ExpCond [cond, true, false] =
case cond of
Bool True -> return true
Bool False -> return false
_ -> raise $ "evalOp: ExpCond: expected true or false, got: " ++ show cond
evalOp op args = raise $ "evalOp: can't handle: " ++ show (op, args)
evalInit :: I.Type -> I.Init -> Eval Value
evalInit ty init = case init of
I.InitZero
-> case ty of
I.TyArr _ _ -> evalInit ty (I.InitArray [] True)
I.TyStruct _ -> evalInit ty (I.InitStruct [])
_ -> return (mkVal ty 0)
I.InitExpr ty expr
-> evalExpr ty expr
I.InitArray inits _
-> case ty of
I.TyArr len ty
-> Array . Seq.fromList
<$> mapM (evalInit ty) (take len $ inits ++ repeat I.InitZero)
_ -> raise $ "evalInit: InitArray: unexpected type: " ++ show ty
I.InitStruct inits
-> case ty of
I.TyStruct str -> do
I.Struct _ fields <- lookupStruct str
zstr <- foldM (\ str (I.Typed ty fld) -> do
val <- evalInit ty I.InitZero
return (Map.insert fld val str))
Map.empty fields
str <- foldM (\ str (fld, init) -> do
val <- evalInit (lookupTyped fld fields) init
return (Map.insert fld val str))
zstr inits
return $ Struct str
_ -> raise $ "evalInit: InitStruct: unexpected type: " ++ show ty
lookupTyped :: (Show a, Eq a) => a -> [I.Typed a] -> I.Type
lookupTyped a [] = error $ "lookupTyped: couldn't find: " ++ show a
lookupTyped a (I.Typed t x : xs)
| a == x = t
| otherwise = lookupTyped a xs
varSym :: I.Var -> I.Sym
varSym (I.VarName sym) = sym
varSym (I.VarInternal sym) = sym
varSym (I.VarLitName sym) = sym
nameSym :: I.Name -> I.Sym
nameSym (I.NameSym sym) = sym
nameSym (I.NameVar var) = varSym var
untilM :: Monad m => m Bool -> m () -> m ()
untilM done doThis = do
b <- done
unless b (doThis >> untilM done doThis)