module Grin.SSimplify(simplify,explicitRecurse) where
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Maybe
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.DeriveTH
import Grin.Grin
import Grin.Noodle
import Stats(mtick)
import StringTable.Atom
import Support.CanType
import Support.FreeVars
import Support.Tickle
import Util.GMap
import Util.Gen
import Util.HasSize
import Util.RWS
import Util.SetLike
import qualified Stats
data SEnv = SEnv {
envSubst :: IM.IntMap Val,
envCSE :: Map.Map Exp (Atom,Exp),
envPapp :: IM.IntMap (Atom,[Val])
}
newtype SState = SState { usedVars :: IS.IntSet }
data SCol = SCol {
colStats :: Stats.Stat,
colFreeVars :: GSet Var
}
$(derive makeMonoid ''SCol)
newtype S a = S (RWS SEnv SCol SState a)
deriving(Monad,Functor,MonadWriter SCol, MonadReader SEnv,MonadState SState)
instance Stats.MonadStats S where
mtickStat s = S (tell mempty { colStats = s })
mticks' n a = S (tell mempty { colStats = Stats.singleStat n a })
tellFV v = tell mempty { colFreeVars = freeVars v }
simplify :: Grin -> IO Grin
simplify grin = do
let (fs,_,SCol { colStats = stats}) = runRWS fun mempty SState { usedVars = mempty }
S fun = simpFuncs (grinFunctions grin)
return grin { grinFunctions = fs, grinStats = grinStats grin `mappend` stats }
simpFuncs :: [FuncDef] -> S [FuncDef]
simpFuncs fd = do
let f fd@FuncDef { funcDefBody = body } = do
body' <- simpLam body
return $ updateFuncDefProps fd { funcDefBody = body' }
mapM f fd
simpLam :: Lam -> S Lam
simpLam (ps :-> e) = do
(ps,env') <- renamePattern ps
let f col = col { colFreeVars = colFreeVars col \\ freeVars ps }
(e,col) <- censor f $ listen $ local (env' `mappend`) $ simpExp e
ps <- mapM (zeroVars (`member` colFreeVars col)) ps
return (ps :-> e)
dstore x = BaseOp (StoreNode True) [x]
simpDone :: Exp -> S Exp
simpDone e = do
pmap <- asks envPapp
case e of
(BaseOp (Apply ty) (Var (V vn) _:fs)) | Just (tl,gs) <- IM.lookup vn pmap -> do
(cl,fn) <- tagUnfunction tl
mtick $ if cl == 1 then "Simplify.Apply.Papp.{" ++ show tl else ("Simplify.Apply.App.{" ++ show fn)
return $ if cl == 1
then App fn (gs ++ fs) ty
else dstore (NodeC (partialTag fn (cl 1)) (gs ++ fs))
(Case v ls) | isJust utypes -> ans where
utypes@(~(Just ts)) = unboxTypes ur
ur = foldr1 combineUnboxing [ getUnboxing e | _ :-> e <- ls ]
ans = do
mtick "Grin.Simplify.Unbox.case-return"
let vs = zipWith Var [v1 ..] ts
return (unboxModify ur (Case v ls) :>>= vs :-> (unboxRet ur vs))
(Case v1 ls) | [v1'] :-> Case v2 ls' <- last ls, v1' == v2 || v1 == v2 -> do
let f (p :-> e) = p :-> Return [v1] :>>= [v1'] :-> e
mtick "Grin.Simplify.case-merge"
return $ Case v1 (init ls ++ map f ls')
_ -> do
cmap <- asks envCSE
case Map.lookup e cmap of
Just (n,e') -> do mtick n; tellFV e'; return e'
Nothing -> return e
simpBind :: [Val] -> Exp -> S Exp -> S Exp
simpBind p e cont = f p e where
cse name xs = do
(z,col) <- listen $ local (\s -> s { envCSE = Map.fromList [ (x,(toAtom name,y)) | (x,y) <- xs] `Map.union` envCSE s }) cont
e <- simpDone e
if isOmittable e && isEmpty (freeVars p `intersection` colFreeVars col) then do
mtick "Simplify.Omit.Bind"
return z
else return $ e :>>= (p :-> z)
cse' name xs = cse name ((e,Return p):xs)
f p app@(BaseOp Eval [v]) = cse' "Simplify.CSE.eval" [(BaseOp Promote [v],Return p)]
f p (BaseOp Promote [v@Var {}]) = cse' "Simplify.CSE.promote" [(gEval v,Return p)]
f [p] (BaseOp Demote [v@Var {}]) = cse' "Simplify.CSE.demote" [(BaseOp Promote [p],Return [v]),(gEval p,Return [v])]
f [p@(Var (V vn) _)] (BaseOp (StoreNode isD) [v@(NodeC t vs)]) | not (isHoly v) = case (isD,tagUnfunction t,tagIsWHNF t) of
(True,Nothing,_) -> cse' "Simplify.CSE.return-node" []
(True,Just (n,fn),_) -> local (\s -> s { envPapp = IM.insert vn (t,vs) (envPapp s) }) $ cse' "Simplify.CSE.return-node-func" []
_ -> cse' "Simplify.CSE.store" []
f _ _ = cse "Simplify.CSE.NOT" []
extEnv :: Var -> Val -> SEnv -> SEnv
extEnv (V vn) v s = s { envSubst = IM.insert vn v (envSubst s) }
simpExp :: Exp -> S Exp
simpExp e = f e [] where
f (e :>>= p :-> Return p') rs | p == p' = do
mtick "Grin.Simplify.tail-return-omit"
f e rs
f (a :>>= (v :-> b)) xs = do
env <- ask
f a ((env,v,b):xs)
f (BaseOp Promote [Const x]) rs = do
mtick "Grin.Simplify.fetch-const"
f (Return [x]) rs
f (BaseOp Eval [Const n]) rs = do
mtick "Grin.Simplify.eval-const"
f (Return [n]) rs
f (Error s t) rs@(_:_) = do
mtick "Grin.Simplify.error-discard"
let (_,_,b) = last rs
f (Error s (getType b)) []
f (Return [v]) ((senv,[Var vn _],b):rs) | valIsConstant v = do
mtick "Grin.Simplify.Subst.const"
fbind vn v senv b rs
f (Return [v@ValUnknown {}]) ((senv,[Var vn _],b):rs) = do
mtick "Grin.Simplify.Subst.unknown"
fbind vn v senv b rs
f (Return [v@Var {}]) ((senv,[Var vn _],b):rs) = do
mtick "Grin.Simplify.Subst.var"
fbind vn v senv b rs
f (Return []) ((senv,[],b):rs) = do
mtick "Grin.Simplify.Assign.unit-unit"
dtup [] [] senv b rs
f a@(Return (xs@(_:_:_))) ((senv,ys,b):rs) = do
mtick "Grin.Simplify.Assign.tuple-tuple"
dtup xs ys senv b rs
f (Case v@Var {} [l]) rs = do
f (Return [v] :>>= l) rs
f a ((senv,p,b):xs) = do
a <- g a
(p,env') <- renamePattern p
let env'' = env' `mappend` senv
local (const env'') $ simpBind p a (f b xs)
f x [] = do
e <- g x
simpDone e
fbind vn v senv b rs = do
v' <- applySubst v
local (\_ -> extEnv vn v' senv) $ f b rs
dtup xs ys senv b rs | sameLength xs ys = do
xs <- mapM applySubst xs
(ys,env') <- renamePattern ys
let env'' = env' `mappend` senv
z <- local (const env'') $ f b rs
ts <- mapM (return . Just) [([y],Return [x]) | x <- xs | y <- ys ]
let h [] = z
h ((p,v):rs) = v :>>= p :-> h rs
return $ h [ (p,v) | Just (p,v) <- ts]
dtup _ _ _ _ _ = error "dtup: attempt to bind unequal lists"
g (Case v as) = do
v <- applySubst v
as <- mapM simpLam as
return $ Case v as
g lt@Let { expDefs = defs, expBody = body } = do
body <- f body []
defs <- simpFuncs defs
let dnames = fromList $ map funcDefName defs :: GSet Atom
isInvalid e = isEmpty (freeVars e `intersection` dnames)
case body of
e :>>= l :-> r | isInvalid e -> do
mtick "Simplify.simplify.let-shrink-head"
return $ e :>>= l :-> updateLetProps lt { expBody = r, expDefs = defs }
e :>>= l :-> r | isInvalid r -> do
mtick "Simplify.simplify.let-shrink-tail"
return (updateLetProps lt { expBody = e, expDefs = defs } :>>= l :-> r)
App f as ts | f `elem` map funcDefName defs, f `Set.notMember` freeVars (map funcDefBody defs) -> do
mtick "Simplify.simplify.let-inline-body"
let [fbody] = [ funcDefBody fd | fd <- defs, funcDefName fd == f]
return $ updateLetProps lt { expDefs = defs, expBody = Return as :>>= fbody }
_ -> return $ updateLetProps lt { expBody = body, expDefs = defs }
g x = applySubstE x
applySubstE :: Exp -> S Exp
applySubstE x = mapExpVal applySubst x
applySubst x = f x where
f var@(Var (V v) _) = do
env <- asks envSubst
case IM.lookup v env of
Just n -> tellFV n >> return n
Nothing -> tellFV var >> return var
f x = mapValVal f x
zeroVars fn x = f x where
f (Var v ty) | fn v || v == v0 = return (Var v ty)
| otherwise = do mtick $ "Simplify.ZeroVar.{" ++ show (Var v ty); return (Var v0 ty)
f x = mapValVal f x
renamePattern :: [Val] -> S ([Val],SEnv)
renamePattern x = runWriterT (mapM f x) where
f :: Val -> WriterT SEnv S Val
f (Var v@(V vn) t) = do
v' <- lift $ newVarName v
let nv = Var v' t
tell (mempty { envSubst = IM.singleton vn nv })
return nv
f x = mapValVal f x
newVarName :: Var -> S Var
newVarName (V 0) = return (V 0)
newVarName (V sv) = do
s <- gets usedVars
let nv = v sv
v n | n `IS.member` s = v (1 + n + IS.size s)
| otherwise = n
modify (\e -> e { usedVars = IS.insert nv s })
return (V nv)
isHoly (NodeC _ as) | any isValUnknown as = True
isHoly n = False
data UnboxingResult
= UnErr [Ty]
| UnStore !Bool !Atom [Unbox]
| UnDemote Unbox
| UnReturn [Unbox]
| UnTail (Set.Set Atom) [Ty] [Ty]
data Unbox = UnConst Val | UnUnknown Ty | UnBaseOp BaseOp [Unbox]
deriving(Eq,Ord)
isUnUnknown UnUnknown {} = True
isUnUnknown _ = False
instance CanType UnboxingResult where
type TypeOf UnboxingResult = [Ty]
getType (UnErr tys) = tys
getType (UnReturn us) = map getType us
getType (UnStore b _ _) = [bool b tyDNode tyINode]
getType (UnDemote _) = [tyINode]
getType (UnTail _ tys _) = tys
instance CanType Unbox where
type TypeOf Unbox = Ty
getType (UnConst v) = getType v
getType (UnUnknown t) = t
getType _ = error "getType: bad."
unboxRet :: UnboxingResult -> [Val] -> Exp
unboxRet ur vs = f ur vs where
f (UnReturn xs) vs = Return $ let (r,[]) = g xs vs in r
f (UnStore b c xs) vs = let (xs',[]) = g xs vs in BaseOp (StoreNode b) [NodeC c xs']
f (UnDemote u) vs = let ([u'],[]) = g [u] vs in BaseOp Demote [u']
f (UnTail a _ tys) vs | [f] <- Set.toList a = App f vs tys
f UnErr {} _ = Return []
f _ vs = Return vs
g [] vs = ([],vs)
g (UnUnknown _:xs) (v:vs) = let (r,y) = g xs vs in (v:r,y)
g (UnConst v:xs) vs = let (r,y) = g xs vs in (v:r,y)
g _ _ = error "SSimplify.unboxRet: bad."
unboxTypes :: UnboxingResult -> Maybe [Ty]
unboxTypes ur = f ur where
f (UnTail ts tys _) | Set.size ts == 1 = Just tys
f (UnTail {}) = Nothing
f (UnErr []) = Nothing
f (UnErr (_:_)) = Just []
f (UnReturn us) | all isUnUnknown us = Nothing
f (UnReturn xs) = Just $ concatMap h xs
f (UnStore _ _ ts) = Just $ concatMap h ts
f (UnDemote _) = Just [tyDNode]
h (UnUnknown t) = [t]
h (UnConst {}) = []
h _ = error "SSimplify.unboxTypes: bad."
unboxModify :: UnboxingResult -> Exp -> Exp
unboxModify ur = f ur where
Just nty = unboxTypes ur
f UnErr {} = id
f (UnTail a tys _) | [f] <- Set.toList a = runIdentity . editTail tys (mApp f)
f (UnReturn us) | all isUnUnknown us = id
f (UnReturn xs) = runIdentity . editTail nty (g xs)
f (UnStore _ _ us) =runIdentity . editTail nty (z us)
f (UnDemote _) =runIdentity . editTail nty y
f _ = error "SSimplify.unboxModify: bad1."
g xs (Return ys) = return $ Return (concat $ zipWith h xs ys)
g _ _ = error "SSimplify.unboxModify: bad2."
h (UnUnknown _) y = [y]
h (UnConst {}) _ = []
h _ _ = error "SSimplify.unboxModify: bad3."
z xs (BaseOp (StoreNode _) [NodeC _ ts]) = return . Return . concat $ zipWith h xs ts
z _ _ = error "SSimplify.unboxModify: bad4."
y (BaseOp Demote [x]) = return $ Return [x]
y (Return [Const v]) = return $ Return [v]
y _ = error "SSimplify.unboxModify: bad5."
mApp f (App f' as tys) | f == f' = return $ Return as
mApp f e = error $ "mApp: " ++ show (f,e)
combineUnboxing :: UnboxingResult -> UnboxingResult -> UnboxingResult
combineUnboxing ub1 ub2 = f ub1 ub2 where
f UnErr {} x = x
f x UnErr {} = x
f (UnTail t1 a1 u1) (UnTail t2 a2 u2) | u1 == u2, a1 == a2 = UnTail (t1 `union` t2) a1 u1
f (UnReturn xs) (UnReturn ys) = UnReturn (zipWith g xs ys)
f (UnStore b1 a1 xs1) (UnStore b2 a2 xs2) | a1 == a2 = UnStore b1 a1 (zipWith g xs1 xs2)
| otherwise = UnReturn [UnUnknown (bool b1 tyDNode tyINode)]
f (UnDemote u1) (UnDemote u2) = UnDemote (g u1 u2)
f (UnDemote u1) (UnReturn [UnConst (Const v)]) = UnDemote (UnUnknown tyDNode)
f (UnReturn [UnConst (Const v)]) (UnDemote u1) = UnDemote (UnUnknown tyDNode)
f x _ = UnReturn (map UnUnknown (getType x))
g (UnConst v1) (UnConst v2) | v1 == v2 = UnConst v1
| otherwise = UnUnknown (getType v1)
g x _ = UnUnknown (getType x)
getUnboxing :: Exp -> UnboxingResult
getUnboxing e = f e where
f (Return rs) = UnReturn (map g rs)
f (BaseOp (StoreNode b) [NodeC c xs]) = UnStore b c (map g xs)
f (BaseOp Demote [v]) = UnDemote (g v)
f (Error _ tys) = UnErr tys
f (App f vs ts) = UnTail (singleton f) (getType vs) ts
f (Case _ ls) = foldr1 combineUnboxing [ f e | _ :-> e <- ls ]
f Let { expDefs = defs, expBody = body, expIsNormal = False } = case f body of
UnTail fs _ ntys | not $ Set.null (fs `Set.intersection` (Set.fromList $ map funcDefName defs)) -> UnReturn (map UnUnknown ntys)
e -> e
f (_ :>>= _ :-> e) = f e
f e = UnReturn (map UnUnknown $ getType e)
g v | valIsConstant v = UnConst v
g v = UnUnknown (getType v)
editTail :: Monad m => [Ty] -> (Exp -> m Exp) -> Exp -> m Exp
editTail nty mt te = f (sempty :: GSet Atom) te where
f _ (Error s ty) = return $ Error s nty
f lf (Case x ls) = return (Case x) `ap` mapM (g lf) ls
f lf lt@Let {expIsNormal = False, expBody = body } = do
body <- f lf body
return $ updateLetProps lt { expBody = body }
f lf lt@Let {expDefs = defs, expIsNormal = True } = do
let nlf = lf `union` fromList (map funcDefName defs)
mapExpExp (f nlf) lt
f lf lt@MkCont {expLam = lam, expCont = cont } = do
a <- g lf lam
b <- g lf cont
return $ lt { expLam = a, expCont = b }
f lf (e1 :>>= p :-> e2) = do
e2 <- f lf e2
return $ e1 :>>= p :-> e2
f lf e@(App a as t) | a `member` lf = return $ App a as nty
f lf e = mt e
g lf (p :-> e) = do e <- f lf e; return $ p :-> e
bool b x y = if b then x else y
explicitRecurse
:: Grin
-> IO Grin
explicitRecurse grin = mapGrinFuncsM f grin where
f name lam | name `notMember` (freeVars lam :: GSet Atom) = return lam
f name (as :-> e) = do
let nname = toAtom $ "bR" ++ fromAtom name
g (App n rs t) | n == name = App nname rs t
g e = tickle g e
return $ as :-> grinLet [createFuncDef True nname (as :-> g e) ] (App nname as (getType e))
$(derive makeMonoid ''SEnv)