module Data.SBV.Plugin.Env (buildTCEnv, buildFunEnv, buildDests, buildSpecials, uninterestingTypes) where
import GhcPlugins
import GHC.Prim
import GHC.Types
import qualified Data.Map as M
import qualified Language.Haskell.TH as TH
import Data.Int
import Data.Word
import Data.Bits
import Data.Maybe (fromMaybe)
import Data.Ratio
import qualified Data.SBV as S hiding (proveWith, proveWithAny)
import qualified Data.SBV.Dynamic as S
import Data.SBV.Plugin.Common
supportTupleSizes :: [Int]
supportTupleSizes = [2 .. 15]
buildTCEnv :: Int -> CoreM (M.Map (TyCon, [TyCon]) S.Kind)
buildTCEnv wsz = do xs <- mapM grabTyCon basics
ys <- mapM grabTyApp apps
return $ M.fromList $ xs ++ ys
where grab = grabTH lookupTyCon
grabTyCon (x, k) = grabTyApp (x, [], k)
grabTyApp (x, as, k) = do fn <- grab x
args <- mapM grab as
return ((fn, args), k)
basics = concat [ [(t, S.KBool) | t <- [''Bool ]]
, [(t, S.KUnbounded) | t <- [''Integer ]]
, [(t, S.KFloat) | t <- [''Float, ''Float# ]]
, [(t, S.KDouble) | t <- [''Double, ''Double#]]
, [(t, S.KBounded True wsz) | t <- [''Int, ''Int# ]]
, [(t, S.KBounded True 8) | t <- [''Int8 ]]
, [(t, S.KBounded True 16) | t <- [''Int16 ]]
, [(t, S.KBounded True 32) | t <- [''Int32, ''Int32# ]]
, [(t, S.KBounded True 64) | t <- [''Int64, ''Int64# ]]
, [(t, S.KBounded False wsz) | t <- [''Word, ''Word# ]]
, [(t, S.KBounded False 8) | t <- [''Word8 ]]
, [(t, S.KBounded False 16) | t <- [''Word16 ]]
, [(t, S.KBounded False 32) | t <- [''Word32, ''Word32#]]
, [(t, S.KBounded False 64) | t <- [''Word64, ''Word64#]]
]
apps = [ (''Ratio, [''Integer], S.KReal) ]
buildFunEnv :: Int -> CoreM (M.Map (Id, SKind) Val)
buildFunEnv wsz = M.fromList `fmap` mapM thToGHC (basicFuncs wsz ++ symFuncs wsz)
basicFuncs :: Int -> [(TH.Name, SKind, Val)]
basicFuncs wsz = [ ('F#, tlift1 (KBase S.KFloat), Func Nothing return)
, ('D#, tlift1 (KBase S.KDouble), Func Nothing return)
, ('I#, tlift1 (KBase (S.KBounded True wsz)), Func Nothing return)
, ('W#, tlift1 (KBase (S.KBounded False wsz)), Func Nothing return)
, ('True, KBase S.KBool, Base S.svTrue)
, ('False, KBase S.KBool, Base S.svFalse)
, ('(&&), tlift2 (KBase S.KBool), lift2 S.svAnd)
, ('(||), tlift2 (KBase S.KBool), lift2 S.svOr)
, ('not, tlift1 (KBase S.KBool), lift1 S.svNot)
]
symFuncs :: Int -> [(TH.Name, SKind, Val)]
symFuncs wsz =
[(op, tlift2Bool (KBase k), lift2 sOp) | k <- allKinds, (op, sOp) <- [('(==), S.svEqual), ('(/=), S.svNotEqual)]]
++ [(op, tlift1 (KBase k), lift1 sOp) | k <- arithKinds, (op, sOp) <- unaryOps]
++ [(op, tlift2 (KBase k), lift2 sOp) | k <- arithKinds, (op, sOp) <- binaryOps]
++ [(op, KFun (KBase S.KUnbounded) (KBase k), lift1Int sOp) | k <- integerKinds, (op, sOp) <- [('fromInteger, S.svInteger k)]]
++ [(op, tlift2Bool (KBase k), lift2 sOp) | k <- arithKinds, (op, sOp) <- compOps ]
++ [(op, tlift2 (KBase k), lift2 sOp) | k <- integralKinds, (op, sOp) <- [('div, S.svDivide), ('quot, S.svQuot), ('rem, S.svRem)]]
++ [ (op, tlift2 (KBase k), lift2 sOp) | k <- bvKinds, (op, sOp) <- bvBinOps ]
++ [ (op, tlift2ShRot wsz (KBase k), lift2 sOp) | k <- bvKinds, (op, sOp) <- bvShiftRots]
where
bvKinds = [S.KBounded s sz | s <- [False, True], sz <- [8, 16, 32, 64]]
integralKinds = S.KUnbounded : bvKinds
integerKinds = S.KReal : integralKinds
floatKinds = [S.KFloat, S.KDouble]
arithKinds = floatKinds ++ integerKinds
allKinds = S.KBool : arithKinds
unaryOps = [ ('abs, S.svAbs)
, ('negate, S.svUNeg)
, ('complement, S.svNot)
]
binaryOps = [ ('(+), S.svPlus)
, ('(), S.svMinus)
, ('(*), S.svTimes)
, ('(/), S.svDivide)
, ('quot, S.svQuot)
, ('rem, S.svRem)
]
compOps = [ ('(<), S.svLessThan)
, ('(>), S.svGreaterThan)
, ('(<=), S.svLessEq)
, ('(>=), S.svGreaterEq)
]
bvBinOps = [ ('(.&.), S.svAnd)
, ('(.|.), S.svOr)
, ('xor, S.svXOr)
]
bvShiftRots = [ ('shiftL, S.svShiftLeft)
, ('shiftR, S.svShiftRight)
, ('rotateL, S.svRotateLeft)
, ('rotateR, S.svRotateRight)
]
buildDests :: CoreM (M.Map Var (Val -> [(Var, SKind)] -> (S.SVal, [((Var, SKind), Val)])))
buildDests = do simple <- mapM mkSingle dests
tups <- mapM mkTuple supportTupleSizes
nil <- mkNil
cons <- mkCons
return $ M.fromList (simple ++ tups ++ [nil, cons])
where
dests = [ ('W#, dest1)
, ('I#, dest1)
, ('F#, dest1)
, ('D#, dest1)
]
dest1 a [bk] = (S.svTrue, [(bk, a)])
dest1 a bs = error $ "Impossible happened: Mistmatched arity case-binder for: " ++ showSDocUnsafe (ppr a) ++ ". Expected 1, got: " ++ show (length bs) ++ " arguments."
mkSingle :: (TH.Name, b) -> CoreM (Id, b)
mkSingle (n, sfn) = do f <- grabTH lookupId n
return (f, sfn)
mkTuple n = do d <- grabTH lookupId (TH.tupleDataName n)
let dest (Tup xs) bs
| length xs == n && length bs == n
= (S.svTrue, zip bs xs)
dest a b = error $ "Impossible: Tuple-case mismatch: " ++ showSDocUnsafe (ppr (n, a, b))
return (d, dest)
mkNil = do d <- lookupId nilDataConName
let dest (Lst []) [] = (S.svTrue, [])
dest (Lst _) _ = (S.svFalse, [])
dest a b = error $ "Impossible: []-case mismatch: " ++ showSDocUnsafe (ppr (a, b))
return (d, dest)
mkCons = do d <- lookupId consDataConName
let dest (Lst []) _ = (S.svFalse, [])
dest (Lst (x:xs)) [h, t] = (S.svTrue, [(h, x), (t, Lst xs)])
dest a b = error $ "Impossible: (:)-case mismatch: " ++ showSDocUnsafe (ppr (a, b))
return (d, dest)
uninterestingTypes :: CoreM [Type]
uninterestingTypes = map varType `fmap` mapM (grabTH lookupId) ['void#]
buildSpecials :: CoreM Specials
buildSpecials = do isEq <- do eq <- grabTH lookupId '(==)
neq <- grabTH lookupId '(/=)
let choose = [(eq, liftEq S.svEqual), (neq, liftEq S.svNotEqual)]
return (`lookup` choose)
isTup <- do let mkTup n = Func Nothing g
where g (Typ _) = return $ Func Nothing g
g v = h (n1) [v]
h 0 sofar = return $ Tup (reverse sofar)
h i sofar = return $ Func Nothing $ \v -> h (i1) (v:sofar)
ts <- mapM (grabTH lookupId . TH.tupleDataName) supportTupleSizes
let choose = zip ts (map mkTup supportTupleSizes)
return (`lookup` choose)
isLst <- do nil <- lookupId nilDataConName
cons <- lookupId consDataConName
let snil = Lst []
scons = Func Nothing g
where g (Typ _) = return $ Func Nothing g
g v = return $ Func Nothing (k v)
k v (Lst xs) = return (Lst (v:xs))
k v a = error $ "Impossible: (:) received incompatible arguments: " ++ showSDocUnsafe (ppr (v, a))
choose = [(nil, snil), (cons, scons)]
return (`lookup` choose)
return Specials{ isEquality = isEq
, isTuple = isTup
, isList = isLst
}
tlift2Bool :: SKind -> SKind
tlift2Bool k = KFun k (KFun k (KBase S.KBool))
tlift1 :: SKind -> SKind
tlift1 k = KFun k k
tlift2 :: SKind -> SKind
tlift2 k = KFun k (tlift1 k)
tlift2ShRot :: Int -> SKind -> SKind
tlift2ShRot wsz k = KFun k (KFun (KBase (S.KBounded True wsz)) k)
lift1Int :: (Integer -> S.SVal) -> Val
lift1Int f = Func Nothing g
where g (Base i) = return $ Base $ f (fromMaybe (error ("Cannot extract an integer from value: " ++ show i)) (S.svAsInteger i))
g _ = error "Impossible happened: lift1Int received non-base argument!"
lift1 :: (S.SVal -> S.SVal) -> Val
lift1 f = Func Nothing g
where g (Typ _) = return $ Func Nothing h
g v = h v
h (Base a) = return $ Base $ f a
h v = error $ "Impossible happened: lift1 received non-base argument: " ++ showSDocUnsafe (ppr v)
lift2 :: (S.SVal -> S.SVal -> S.SVal) -> Val
lift2 f = Func Nothing g
where g (Typ _) = return $ Func Nothing h
g v = h v
h (Base a) = return $ Func Nothing (k a)
h v = error $ "Impossible happened: lift2 received non-base argument (h): " ++ showSDocUnsafe (ppr v)
k a (Base b) = return $ Base $ f a b
k _ v = error $ "Impossible happened: lift2 received non-base argument (k): " ++ showSDocUnsafe (ppr v)
liftEq :: (S.SVal -> S.SVal -> S.SVal) -> Val
liftEq baseEq = Func Nothing g
where g (Typ _) = return $ Func Nothing g
g v1 = return $ Func Nothing $ \v2 -> return $ Base $ liftEqVal baseEq v1 v2
thToGHC :: (TH.Name, a, b) -> CoreM ((Id, a), b)
thToGHC (n, k, sfn) = do f <- grabTH lookupId n
return ((f, k), sfn)
grabTH :: (Name -> CoreM b) -> TH.Name -> CoreM b
grabTH f n = do mbN <- thNameToGhcName n
case mbN of
Just gn -> f gn
Nothing -> error $ "[SBV] Impossible happened, while trying to locate GHC name for: " ++ show n