module Data.SBV.Plugin.Env (buildFunEnv, buildTCEnv, buildSpecialEnv) where
import GhcPlugins
import GHC.Types
import qualified Data.Map as M
import qualified Language.Haskell.TH as TH
import Data.Int
import Data.Word
import Data.Bits
import Data.Maybe (fromMaybe)
import Data.Ratio
import qualified Data.SBV as S hiding (proveWith, proveWithAny)
import qualified Data.SBV.Dynamic as S
import Data.SBV.Plugin.Common
buildTCEnv :: Int -> CoreM (M.Map (TyCon, [TyCon]) S.Kind)
buildTCEnv isz = do xs <- mapM grabTyCon basics
ys <- mapM grabTyApp apps
return $ M.fromList $ xs ++ ys
where grab x = do Just fn <- thNameToGhcName x
lookupTyCon fn
grabTyCon (x, k) = grabTyApp (x, [], k)
grabTyApp (x, as, k) = do fn <- grab x
args <- mapM grab as
return ((fn, args), k)
basics = [ (''Bool, S.KBool)
, (''Integer, S.KUnbounded)
, (''Float, S.KFloat)
, (''Double, S.KDouble)
, (''Int, S.KBounded True isz)
, (''Int8, S.KBounded True 8)
, (''Int16, S.KBounded True 16)
, (''Int32, S.KBounded True 32)
, (''Int64, S.KBounded True 64)
, (''Word8, S.KBounded False 8)
, (''Word16, S.KBounded False 16)
, (''Word32, S.KBounded False 32)
, (''Word64, S.KBounded False 64)
]
apps = [ (''Ratio, [''Integer], S.KReal) ]
buildFunEnv :: CoreM (M.Map (Id, S.Kind) Val)
buildFunEnv = M.fromList `fmap` mapM grabVar symFuncs
where grabVar (n, k, sfn) = do Just fn <- thNameToGhcName n
f <- lookupId fn
return ((f, k), sfn)
buildSpecialEnv :: Int -> CoreM (M.Map Id Val)
buildSpecialEnv wsz = M.fromList `fmap` mapM grabVar basics
where grabVar (n, sfn) = do Just fn <- thNameToGhcName n
f <- lookupId fn
return (f, sfn)
basics = [ ('F#, Func (S.KFloat, Nothing) (return . Base))
, ('D#, Func (S.KDouble, Nothing) (return . Base))
, ('I#, Func (S.KBounded True wsz, Nothing) (return . Base))
, ('W#, Func (S.KBounded False wsz, Nothing) (return . Base))
, ('True, Base S.svTrue)
, ('False, Base S.svFalse)
, ('(&&), lift2 S.KBool S.svAnd)
, ('(||), lift2 S.KBool S.svOr)
, ('not, lift1 S.KBool S.svNot)
]
symFuncs :: [(TH.Name, S.Kind, Val)]
symFuncs =
[(op, k, lift2 k sOp) | k <- allKinds, (op, sOp) <- [('(==), S.svEqual), ('(/=), S.svNotEqual)]]
++ [(op, k, lift1 k sOp) | k <- arithKinds, (op, sOp) <- unaryOps]
++ [(op, k, lift2 k sOp) | k <- arithKinds, (op, sOp) <- binaryOps]
++ [(op, k, lift1Int sOp) | k <- integerKinds, (op, sOp) <- [('fromInteger, S.svInteger k)]]
++ [(op, k, lift2 k sOp) | k <- arithKinds, (op, sOp) <- compOps ]
++ [(op, k, lift2 k sOp) | k <- integralKinds, (op, sOp) <- [('div, S.svDivide), ('quot, S.svQuot), ('rem, S.svRem)]]
++ [ (op, k, lift2 k sOp) | k <- bvKinds, (op, sOp) <- bvBinOps]
where
bvKinds = [S.KBounded s sz | s <- [False, True], sz <- [8, 16, 32, 64]]
integralKinds = S.KUnbounded : bvKinds
integerKinds = S.KReal : integralKinds
floatKinds = [S.KFloat, S.KDouble]
arithKinds = floatKinds ++ integerKinds
allKinds = S.KBool : arithKinds
unaryOps = [ ('abs, S.svAbs)
, ('negate, S.svUNeg)
]
binaryOps = [ ('(+), S.svPlus)
, ('(), S.svMinus)
, ('(*), S.svTimes)
, ('(/), S.svDivide)
]
compOps = [ ('(<), S.svLessThan)
, ('(>), S.svGreaterThan)
, ('(<=), S.svLessEq)
, ('(>=), S.svGreaterEq)
]
bvBinOps = [ ('(.&.), S.svAnd)
, ('(.|.), S.svOr)
, ('xor, S.svXOr)
]
lift1 :: S.Kind -> (S.SVal -> S.SVal) -> Val
lift1 k f = Func (k, Nothing) $ return . Base . f
lift1Int :: (Integer -> S.SVal) -> Val
lift1Int f = Func (S.KUnbounded, Nothing) $ \i -> return $ Base (f (fromMaybe (error ("Cannot extract an integer from value: " ++ show i)) (S.svAsInteger i)))
lift2 :: S.Kind -> (S.SVal -> S.SVal -> S.SVal) -> Val
lift2 k f = Func (k, Nothing) $ \a -> return $ Func (k, Nothing) $ \b -> return (Base (f a b))