module Language.SequentCore.Syntax (
Term(..), Cont(..), Command(..), Bind(..), Alt(..), AltCon(..), Expr(..), Program, ContId,
SeqCoreTerm, SeqCoreCont, SeqCoreCommand, SeqCoreBind, SeqCoreBndr,
SeqCoreAlt, SeqCoreExpr, SeqCoreProgram,
mkCommand, mkCompute, addLets, addNonRec,
lambdas, collectArgs, collectTypeArgs, collectTypeAndOtherArgs, collectArgsUpTo,
partitionTypes, isLambda,
isValueArg, isTypeTerm, isCoTerm, isErasedTerm, isRuntimeTerm, isProperTerm,
isTrivial, isTrivialTerm, isTrivialCont, isReturnCont,
commandAsSaturatedCall, asSaturatedCall, asValueCommand,
flattenBind, flattenBinds, bindersOf, bindersOfBinds,
termArity, termType, contType,
termIsBottom, commandIsBottom,
needsCaseBinding,
termOkForSpeculation, commandOkForSpeculation, contOkForSpeculation,
termOkForSideEffects, commandOkForSideEffects, contOkForSideEffects,
termIsCheap, contIsCheap, commandIsCheap,
termIsExpandable, contIsExpandable, commandIsExpandable,
isContId, asContId, Language.SequentCore.WiredIn.mkContTy, contTyArg,
(=~=), AlphaEq(..), AlphaEnv, HasId(..)
) where
import Language.SequentCore.Pretty ()
import Language.SequentCore.WiredIn
import Coercion ( Coercion, coercionType, coercionKind )
import CoreSyn ( AltCon(..), Tickish, tickishCounts, isRuntimeVar
, isEvaldUnfolding )
import DataCon ( DataCon, dataConRepType, dataConTyCon )
import Id ( Id, isDataConWorkId, isDataConWorkId_maybe, isConLikeId
, idArity, idType, idDetails, idUnfolding
, setIdType, isBottomingId )
import IdInfo ( IdDetails(..) )
import Literal ( Literal, isZeroLit, litIsTrivial, literalType )
import Maybes ( orElse )
import Outputable
import Pair
import PrimOp ( PrimOp(..), primOpOkForSpeculation, primOpOkForSideEffects
, primOpIsCheap )
import TyCon
import Type ( Type, KindOrType )
import qualified Type
import TysPrim
import Var ( Var, isId )
import VarEnv
import Control.Exception ( assert )
import Data.Maybe
data Term b = Lit Literal
| Var Id
| Lam [b] b (Command b)
| Cons DataCon [Term b]
| Compute b (Command b)
| Type Type
| Coercion Coercion
| Cont (Cont b)
data Cont b = App (Term b) (Cont b)
| Case b [Alt b]
| Cast Coercion (Cont b)
| Tick (Tickish Id) (Cont b)
| Return ContId
type ContId = Id
data Command b = Command {
cmdLet :: [Bind b]
, cmdTerm :: Term b
, cmdCont :: Cont b
}
data Bind b = NonRec b (Term b)
| Rec [(b, Term b)]
data Alt b = Alt AltCon [b] (Command b)
data Expr a = T { unT :: Term a }
| C { unC :: Command a }
| K { unK :: Cont a }
type Program a = [Bind a]
type SeqCoreBndr = Var
type SeqCoreTerm = Term Var
type SeqCoreCont = Cont Var
type SeqCoreCommand = Command Var
type SeqCoreBind = Bind Var
type SeqCoreAlt = Alt Var
type SeqCoreExpr = Expr Var
type SeqCoreProgram = Program Var
mkCommand :: HasId b => [Bind b] -> Term b -> Cont b -> Command b
mkCommand binds (Var f) cont
| Just ctor <- isDataConWorkId_maybe f
, Just (args, cont') <- ctorCall
= mkCommand binds (Cons ctor args) cont'
where
(tyVars, monoTy) = Type.splitForAllTys (idType f)
(argTys, _) = Type.splitFunTys monoTy
argsNeeded = length tyVars + length argTys
ctorCall
| let (args, cont') = collectArgsUpTo argsNeeded cont
, length args == argsNeeded
= Just (args, cont')
| otherwise
= Nothing
mkCommand binds (Compute kbndr (Command { cmdLet = binds'
, cmdTerm = term'
, cmdCont = Return kid })) cont
| identifier kbndr == kid
= mkCommand (binds ++ binds') term' cont
mkCommand binds term cont
= Command { cmdLet = binds, cmdTerm = term, cmdCont = cont }
mkCompute :: HasId b => b -> Command b -> Term b
mkCompute k comm
| Just val <- asValueCommand kid comm
= val
| otherwise
= Compute k comm
where
kid = identifier k
addLets :: [Bind b] -> Command b -> Command b
addLets [] c = c
addLets bs c = c { cmdLet = bs ++ cmdLet c }
addNonRec :: HasId b => b -> Term b -> Command b -> Command b
addNonRec bndr rhs comm
| needsCaseBinding (idType (identifier bndr)) rhs
= mkCommand [] rhs (Case bndr [Alt DEFAULT [] comm])
| otherwise
= addLets [NonRec bndr rhs] comm
lambdas :: Term b -> ([b], Maybe (b, Command b))
lambdas (Lam xs k body) = (xs, Just (k, body))
lambdas _ = ([], Nothing)
collectArgs :: Cont b -> ([Term b], Cont b)
collectArgs (App v k)
= (v : vs, k')
where (vs, k') = collectArgs k
collectArgs k
= ([], k)
collectTypeArgs :: Cont b -> ([KindOrType], Cont b)
collectTypeArgs (App (Type ty) k)
= (ty : tys, k')
where (tys, k') = collectTypeArgs k
collectTypeArgs k
= ([], k)
collectTypeAndOtherArgs :: Cont b -> ([KindOrType], [Term b], Cont b)
collectTypeAndOtherArgs k
= let (tys, k') = collectTypeArgs k
(vs, k'') = collectArgs k'
in (tys, vs, k'')
collectArgsUpTo :: Int -> Cont b -> ([Term b], Cont b)
collectArgsUpTo 0 k
= ([], k)
collectArgsUpTo n (App v k)
= (v : vs, k')
where (vs, k') = collectArgsUpTo (n 1) k
collectArgsUpTo _ k
= ([], k)
partitionTypes :: [Term b] -> ([KindOrType], [Term b])
partitionTypes (Type ty : vs) = (ty : tys, vs')
where (tys, vs') = partitionTypes vs
partitionTypes vs = ([], vs)
isLambda :: Command b -> Bool
isLambda (Command { cmdLet = [], cmdCont = Return {}, cmdTerm = Lam {} })
= True
isLambda _
= False
isValueArg :: Term b -> Bool
isValueArg (Type _) = False
isValueArg _ = True
isTypeTerm :: Term b -> Bool
isTypeTerm (Type _) = True
isTypeTerm _ = False
isCoTerm :: Term b -> Bool
isCoTerm (Coercion _) = True
isCoTerm _ = False
isErasedTerm :: Term b -> Bool
isErasedTerm (Type _) = True
isErasedTerm (Coercion _) = True
isErasedTerm _ = False
isRuntimeTerm :: Term b -> Bool
isRuntimeTerm v = not (isErasedTerm v)
isProperTerm :: Term b -> Bool
isProperTerm (Type _) = False
isProperTerm (Coercion _) = False
isProperTerm (Cont _) = False
isProperTerm _ = True
isTrivial :: HasId b => Command b -> Bool
isTrivial c
= null (cmdLet c) &&
isTrivialCont (cmdCont c) &&
isTrivialTerm (cmdTerm c)
isTrivialTerm :: HasId b => Term b -> Bool
isTrivialTerm (Lit l) = litIsTrivial l
isTrivialTerm (Lam xs _ c)= not (any (isRuntimeVar . identifier) xs) && isTrivial c
isTrivialTerm (Cons _ as) = all isErasedTerm as
isTrivialTerm (Compute _ c) = isTrivial c
isTrivialTerm (Cont cont) = isTrivialCont cont
isTrivialTerm _ = True
isTrivialCont :: Cont b -> Bool
isTrivialCont (Return _) = True
isTrivialCont (Cast _ k) = isTrivialCont k
isTrivialCont (App v k) = isErasedTerm v && isTrivialCont k
isTrivialCont _ = False
isReturnCont :: Cont b -> Bool
isReturnCont (Return _) = True
isReturnCont _ = False
commandAsSaturatedCall :: HasId b =>
Command b -> Maybe (Term b, [Term b], Cont b)
commandAsSaturatedCall c
= do
let term = cmdTerm c
(args, cont) <- asSaturatedCall term (cmdCont c)
return $ (term, args, cont)
asSaturatedCall :: HasId b => Term b -> Cont b -> Maybe ([Term b], Cont b)
asSaturatedCall term cont
| 0 < arity, arity <= length args
= Just (args, others)
| otherwise
= Nothing
where
arity = termArity term
(args, others) = collectArgs cont
asValueCommand :: ContId -> Command b -> Maybe (Term b)
asValueCommand k (Command { cmdLet = [], cmdTerm = v, cmdCont = Return k' })
| k == k'
= Just v
asValueCommand _ _
= Nothing
flattenBind :: Bind b -> [(b, Term b)]
flattenBind (NonRec bndr rhs) = [(bndr, rhs)]
flattenBind (Rec pairs) = pairs
flattenBinds :: [Bind b] -> [(b, Term b)]
flattenBinds = concatMap flattenBind
bindersOf :: Bind b -> [b]
bindersOf (NonRec bndr _) = [bndr]
bindersOf (Rec pairs) = map fst pairs
bindersOfBinds :: [Bind b] -> [b]
bindersOfBinds = concatMap bindersOf
termType :: HasId b => Term b -> Type
termType (Lit l) = literalType l
termType (Var x) = idType x
termType (Lam xs k _) = Type.mkPiTypes (map identifier xs) (contTyArg (idType (identifier k)))
termType (Cons con as) = res_ty
where
(tys, _) = partitionTypes as
(_, res_ty) = Type.splitFunTys (dataConRepType con `Type.applyTys` tys)
termType (Compute k _) = contTyArg (idType (identifier k))
termType _other = alphaTy
contType :: HasId b => Cont b -> Type
contType (Return k) = contTyArg (idType k)
contType (App arg k) = Type.mkFunTy (termType arg) (contType k)
contType (Cast co k) = let Pair fromTy toTy = coercionKind co
in assert (toTy `Type.eqType` contType k) fromTy
contType (Case b _) = idType (identifier b)
contType (Tick _ k) = contType k
termArity :: HasId b => Term b -> Int
termArity (Var x)
| isId x = idArity x
termArity (Lam bndrs _kbndr _)
= length bndrs
termArity _
= 0
termIsBottom :: Term b -> Bool
termIsBottom (Var x) = isBottomingId x && idArity x == 0
termIsBottom (Compute _ c) = commandIsBottom c
termIsBottom _ = False
commandIsBottom :: Command b -> Bool
commandIsBottom (Command { cmdTerm = Var x, cmdCont = cont })
| isBottomingId x
= go 0 cont
where
go n (App arg cont') | isTypeTerm arg = go n cont'
| otherwise = (go $! (n+1)) cont'
go n (Tick _ cont') = go n cont'
go n (Cast _ cont') = go n cont'
go n _ = n >= idArity x
commandIsBottom _ = False
needsCaseBinding :: Type -> Term b -> Bool
needsCaseBinding ty rhs
= Type.isUnLiftedType ty && not (termOkForSpeculation rhs)
termOkForSpeculation, termOkForSideEffects :: Term b -> Bool
commandOkForSpeculation, commandOkForSideEffects :: Command b -> Bool
contOkForSpeculation, contOkForSideEffects :: Cont b -> Bool
termOkForSpeculation = termOk primOpOkForSpeculation
termOkForSideEffects = termOk primOpOkForSideEffects
commandOkForSpeculation = commOk primOpOkForSpeculation
commandOkForSideEffects = commOk primOpOkForSideEffects
contOkForSpeculation = contOk primOpOkForSpeculation
contOkForSideEffects = contOk primOpOkForSideEffects
termOk :: (PrimOp -> Bool) -> Term b -> Bool
termOk primOpOk (Var id) = appOk primOpOk id []
termOk primOpOk (Compute _ comm) = commOk primOpOk comm
termOk _ _ = True
commOk :: (PrimOp -> Bool) -> Command b -> Bool
commOk primOpOk (Command { cmdLet = binds, cmdTerm = term, cmdCont = cont })
= null binds && cutOk primOpOk term cont
cutOk :: (PrimOp -> Bool) -> Term b -> Cont b -> Bool
cutOk primOpOk (Var fid) cont
| (args, cont') <- collectArgs cont
= appOk primOpOk fid args && contOk primOpOk cont'
cutOk primOpOk term cont
= termOk primOpOk term && contOk primOpOk cont
contOk :: (PrimOp -> Bool) -> Cont b -> Bool
contOk _ (Return _)= False
contOk _ (App _ _) = False
contOk primOpOk (Case _bndr alts)
= all (\(Alt _ _ rhs) -> commOk primOpOk rhs) alts
&& altsAreExhaustive
where
altsAreExhaustive
| (Alt con1 _ _ : _) <- alts
= case con1 of
DEFAULT -> True
LitAlt {} -> False
DataAlt dc -> 1 + length alts == tyConFamilySize (dataConTyCon dc)
| otherwise
= False
contOk primOpOk (Tick ti cont)
= not (tickishCounts ti) && contOk primOpOk cont
contOk primOpOk (Cast _ cont)
= contOk primOpOk cont
appOk :: (PrimOp -> Bool) -> Id -> [Term b] -> Bool
appOk primOpOk fid args
= case idDetails fid of
DFunId _ newType -> not newType
DataConWorkId {} -> True
PrimOpId op | isDivOp op
, [arg1, Lit lit] <- args
-> not (isZeroLit lit) && termOk primOpOk arg1
| DataToTagOp <- op
-> True
| otherwise
-> primOpOk op && all (termOk primOpOk) args
_ -> Type.isUnLiftedType (idType fid)
|| idArity fid > nValArgs
|| nValArgs == 0 && isEvaldUnfolding (idUnfolding fid)
where
nValArgs = length (filter isValueArg args)
where
isDivOp IntQuotOp = True
isDivOp IntRemOp = True
isDivOp WordQuotOp = True
isDivOp WordRemOp = True
isDivOp FloatDivOp = True
isDivOp DoubleDivOp = True
isDivOp _ = False
termIsCheap, termIsExpandable :: HasId b => Term b -> Bool
termIsCheap = termCheap isCheapApp
termIsExpandable = termCheap isExpandableApp
contIsCheap, contIsExpandable :: HasId b => Cont b -> Bool
contIsCheap = contCheap isCheapApp
contIsExpandable = contCheap isExpandableApp
commandIsCheap, commandIsExpandable :: HasId b => Command b -> Bool
commandIsCheap = commCheap isCheapApp
commandIsExpandable = commCheap isExpandableApp
type CheapMeasure = Id -> Int -> Bool
termCheap :: HasId b => CheapMeasure -> Term b -> Bool
termCheap _ (Lit _) = True
termCheap _ (Var _) = True
termCheap _ (Type _) = True
termCheap _ (Coercion _) = True
termCheap _ (Cons _ _) = True
termCheap appCheap (Lam xs _ c) = any (isRuntimeVar . identifier) xs
|| commCheap appCheap c
termCheap appCheap (Compute _ c)= commCheap appCheap c
termCheap appCheap (Cont cont) = contCheap appCheap cont
contCheap :: HasId b => CheapMeasure -> Cont b -> Bool
contCheap _ (Return _) = True
contCheap appCheap (Case _ alts) = all (\(Alt _ _ rhs) -> commCheap appCheap rhs)
alts
contCheap appCheap (Cast _ cont) = contCheap appCheap cont
contCheap appCheap (Tick ti cont) = not (tickishCounts ti)
&& contCheap appCheap cont
contCheap appCheap (App arg cont) = isErasedTerm arg
&& contCheap appCheap cont
commCheap :: HasId b => CheapMeasure -> Command b -> Bool
commCheap appCheap (Command { cmdLet = binds, cmdTerm = term, cmdCont = cont})
= all (termCheap appCheap . snd) (flattenBinds binds)
&& cutCheap appCheap term cont
cutCheap :: HasId b => CheapMeasure -> Term b -> Cont b -> Bool
cutCheap appCheap term (Cast _ cont) = cutCheap appCheap term cont
cutCheap appCheap (Var fid) cont@(App {})
= case collectTypeAndOtherArgs cont of
(_, [], cont') -> contCheap appCheap cont'
(_, args, cont')
| appCheap fid (length args)
-> papCheap args && contCheap appCheap cont'
| otherwise
-> case idDetails fid of
RecSelId {} -> selCheap args
ClassOpId {} -> selCheap args
PrimOpId op -> primOpCheap op args
_ | isBottomingId fid -> True
| otherwise -> False
where
papCheap args = all (termCheap appCheap) args
selCheap [arg] = termCheap appCheap arg
selCheap _ = False
primOpCheap op args = primOpIsCheap op && all (termCheap appCheap) args
cutCheap _ _ _ = False
isCheapApp, isExpandableApp :: CheapMeasure
isCheapApp fid valArgCount = isDataConWorkId fid
|| valArgCount == 0
|| valArgCount < idArity fid
isExpandableApp fid valArgCount = isConLikeId fid
|| valArgCount < idArity fid
|| allPreds valArgCount (idType fid)
where
allPreds 0 _ = True
allPreds n ty
| Just (_, ty') <- Type.splitForAllTy_maybe ty = allPreds n ty'
| Just (argTy, ty') <- Type.splitFunTy_maybe ty
, Type.isPredTy argTy = allPreds (n1) ty'
| otherwise = False
isContId :: Id -> Bool
isContId x = isContTy (idType x)
asContId :: Id -> ContId
asContId x | isContId x = x
| otherwise = x `setIdType` mkContTy (idType x)
contTyArg :: Type -> Type
contTyArg ty = isContTy_maybe ty `orElse` pprPanic "contTyArg" (ppr ty)
class HasId a where
identifier :: a -> Id
instance HasId Var where
identifier x = x
type AlphaEnv = RnEnv2
infix 4 =~=, `aeq`
class AlphaEq a where
aeq :: a -> a -> Bool
aeqIn :: AlphaEnv -> a -> a -> Bool
aeq = aeqIn emptyAlphaEnv
emptyAlphaEnv :: AlphaEnv
emptyAlphaEnv = mkRnEnv2 emptyInScopeSet
(=~=) :: AlphaEq a => a -> a -> Bool
(=~=) = aeq
instance HasId b => AlphaEq (Term b) where
aeqIn _ (Lit l1) (Lit l2)
= l1 == l2
aeqIn env (Lam bs1 k1 c1) (Lam bs2 k2 c2)
= aeqIn (rnBndrs2 env' (map identifier bs1) (map identifier bs2)) c1 c2
where env' = rnBndr2 env (identifier k1) (identifier k2)
aeqIn env (Type t1) (Type t2)
= aeqIn env t1 t2
aeqIn env (Coercion co1) (Coercion co2)
= aeqIn env co1 co2
aeqIn env (Var x1) (Var x2)
= env `rnOccL` x1 == env `rnOccR` x2
aeqIn env (Compute k1 c1) (Compute k2 c2)
= aeqIn (rnBndr2 env (identifier k1) (identifier k2)) c1 c2
aeqIn env (Cont k1) (Cont k2)
= aeqIn env k1 k2
aeqIn _ _ _
= False
instance HasId b => AlphaEq (Cont b) where
aeqIn env (App c1 k1) (App c2 k2)
= aeqIn env c1 c2 && aeqIn env k1 k2
aeqIn env (Case x1 as1) (Case x2 as2)
= aeqIn env' as1 as2
where env' = rnBndr2 env (identifier x1) (identifier x2)
aeqIn env (Cast co1 k1) (Cast co2 k2)
= aeqIn env co1 co2 && aeqIn env k1 k2
aeqIn env (Tick ti1 k1) (Tick ti2 k2)
= ti1 == ti2 && aeqIn env k1 k2
aeqIn env (Return x1) (Return x2)
= env `rnOccL` x1 == env `rnOccR` x2
aeqIn _ _ _
= False
instance HasId b => AlphaEq (Command b) where
aeqIn env
(Command { cmdLet = bs1, cmdTerm = v1, cmdCont = c1 })
(Command { cmdLet = bs2, cmdTerm = v2, cmdCont = c2 })
| Just env' <- aeqBindsIn env bs1 bs2
= aeqIn env' v1 v2 && aeqIn env' c1 c2
| otherwise
= False
aeqBindsIn :: HasId b => AlphaEnv -> [Bind b] -> [Bind b] -> Maybe AlphaEnv
aeqBindsIn env [] []
= Just env
aeqBindsIn env (b1:bs1) (b2:bs2)
= aeqBindIn env b1 b2 >>= \env' -> aeqBindsIn env' bs1 bs2
aeqBindsIn _ _ _
= Nothing
aeqBindIn :: HasId b => AlphaEnv -> Bind b -> Bind b -> Maybe AlphaEnv
aeqBindIn env (NonRec x1 c1) (NonRec x2 c2)
= if aeqIn env' c1 c2 then Just env' else Nothing
where env' = rnBndr2 env (identifier x1) (identifier x2)
aeqBindIn env (Rec bs1) (Rec bs2)
= if and $ zipWith alpha bs1 bs2 then Just env' else Nothing
where
alpha :: HasId b => (b, Term b) -> (b, Term b) -> Bool
alpha (_, c1) (_, c2)
= aeqIn env' c1 c2
env'
= rnBndrs2 env (map (identifier . fst) bs1) (map (identifier . fst) bs2)
aeqBindIn _ _ _
= Nothing
instance HasId b => AlphaEq (Alt b) where
aeqIn env (Alt a1 xs1 c1) (Alt a2 xs2 c2)
= a1 == a2 && aeqIn env' c1 c2
where
env' = rnBndrs2 env (map identifier xs1) (map identifier xs2)
instance AlphaEq Type where
aeqIn env t1 t2
| Just x1 <- Type.getTyVar_maybe t1
, Just x2 <- Type.getTyVar_maybe t2
= env `rnOccL` x1 == env `rnOccR` x2
| Just (f1, a1) <- Type.splitAppTy_maybe t1
, Just (f2, a2) <- Type.splitAppTy_maybe t2
= f1 `alpha` f2 && a1 `alpha` a2
| Just n1 <- Type.isNumLitTy t1
, Just n2 <- Type.isNumLitTy t2
= n1 == n2
| Just s1 <- Type.isStrLitTy t1
, Just s2 <- Type.isStrLitTy t2
= s1 == s2
| Just (a1, r1) <- Type.splitFunTy_maybe t1
, Just (a2, r2) <- Type.splitFunTy_maybe t2
= a1 `alpha` a2 && r1 `alpha` r2
| Just (c1, as1) <- Type.splitTyConApp_maybe t1
, Just (c2, as2) <- Type.splitTyConApp_maybe t2
= c1 == c2 && as1 `alpha` as2
| Just (x1, t1') <- Type.splitForAllTy_maybe t1
, Just (x2, t2') <- Type.splitForAllTy_maybe t2
= aeqIn (rnBndr2 env x1 x2) t1' t2'
| otherwise
= False
where
alpha a1 a2 = aeqIn env a1 a2
instance AlphaEq Coercion where
aeqIn env co1 co2 = aeqIn env (coercionType co1) (coercionType co2)
instance AlphaEq a => AlphaEq [a] where
aeqIn env xs ys = and $ zipWith (aeqIn env) xs ys
instance HasId b => AlphaEq (Bind b) where
aeqIn env b1 b2 = isJust $ aeqBindIn env b1 b2