{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Ivory.Opts.CSE (cseFold) where
import Prelude ()
import Prelude.Compat
import Control.Applicative (liftA2)
import qualified Data.DList as D
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.List (sort)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Reify
import Ivory.Language.Array (ixRep)
import qualified Ivory.Language.Syntax as AST
import MonadLib (Id, StateT, WriterT, collect, get, lift,
put, runM, set, sets, sets_)
import System.IO.Unsafe (unsafePerformIO)
cseFold :: AST.Proc -> AST.Proc
cseFold def = def
{ AST.procBody = reconstruct $ unsafePerformIO $ reifyGraph $ AST.procBody def }
data Bindings = Bindings
{ availableBindings :: (Map (Unique, AST.Type) Int)
, unusedBindings :: IntSet
, totalBindings :: Int
}
type BlockM a = StateT Bindings (WriterT (D.DList AST.Stmt) Id) a
data CSE t
= CSEExpr (ExprF t)
| CSEBlock (BlockF t)
deriving (Show, Eq, Ord, Functor)
data ExprF t
= ExpSimpleF AST.Expr
| ExpLabelF AST.Type t String
| ExpIndexF AST.Type t AST.Type t
| ExpToIxF t Integer
| ExpSafeCastF AST.Type t
| ExpOpF AST.ExpOp [t]
deriving (Show, Eq, Ord, Foldable, Functor, Traversable)
instance MuRef AST.Expr where
type DeRef AST.Expr = CSE
mapDeRef child e = CSEExpr <$> case e of
AST.ExpSym{} -> pure $ ExpSimpleF e
AST.ExpExtern{} -> pure $ ExpSimpleF e
AST.ExpVar{} -> pure $ ExpSimpleF e
AST.ExpLit{} -> pure $ ExpSimpleF e
AST.ExpLabel ty ex nm -> ExpLabelF <$> pure ty <*> child ex <*> pure nm
AST.ExpIndex ty1 ex1 ty2 ex2 -> ExpIndexF <$> pure ty1 <*> child ex1 <*> pure ty2 <*> child ex2
AST.ExpToIx ex bound -> ExpToIxF <$> child ex <*> pure bound
AST.ExpSafeCast ty ex -> ExpSafeCastF ty <$> child ex
AST.ExpOp op args -> ExpOpF op <$> traverse child args
AST.ExpAddrOfGlobal{} -> pure $ ExpSimpleF e
AST.ExpMaxMin{} -> pure $ ExpSimpleF e
AST.ExpSizeOf{} -> pure $ ExpSimpleF e
toExpr :: ExprF AST.Expr -> AST.Expr
toExpr (ExpSimpleF ex) = ex
toExpr (ExpLabelF ty ex nm) = AST.ExpLabel ty ex nm
toExpr (ExpIndexF ty1 ex1 ty2 ex2) = AST.ExpIndex ty1 ex1 ty2 ex2
toExpr (ExpToIxF ex bound) = AST.ExpToIx ex bound
toExpr (ExpSafeCastF ty ex) = AST.ExpSafeCast ty ex
toExpr (ExpOpF op args) = AST.ExpOp op args
copyConst :: AST.Type -> AST.Type -> AST.Type
copyConst (AST.TyRef _) ty = AST.TyRef ty
copyConst (AST.TyConstRef _) ty = AST.TyConstRef ty
copyConst ty _ = error $ "Ivory.Opts.CSE.copyConst: expected a Ref type but got " ++ show ty
labelTypes :: AST.Type -> ExprF k -> ExprF (k, AST.Type)
labelTypes _ (ExpSimpleF e) = ExpSimpleF e
labelTypes resty (ExpLabelF ty ex nm) = ExpLabelF ty (ex, copyConst resty ty) nm
labelTypes resty (ExpIndexF ty1 ex1 ty2 ex2) = ExpIndexF ty1 (ex1, copyConst resty ty1) ty2 (ex2, ty2)
labelTypes _ (ExpToIxF ex bd) = ExpToIxF (ex, ixRep) bd
labelTypes _ (ExpSafeCastF ty ex) = ExpSafeCastF ty (ex, ty)
labelTypes ty (ExpOpF op args) = ExpOpF op $ case op of
AST.ExpEq t -> map (`atType` t) args
AST.ExpNeq t -> map (`atType` t) args
AST.ExpCond -> let (cond, rest) = splitAt 1 args in map (`atType` AST.TyBool) cond ++ map (`atType` ty) rest
AST.ExpGt _ t -> map (`atType` t) args
AST.ExpLt _ t -> map (`atType` t) args
AST.ExpIsNan t -> map (`atType` t) args
AST.ExpIsInf t -> map (`atType` t) args
_ -> map (`atType` ty) args
where
atType = (,)
data BlockF t
= StmtSimple AST.Stmt
| StmtIfTE t t t
| StmtAssert t
| StmtCompilerAssert t
| StmtAssume t
| StmtReturn (AST.Typed t)
| StmtDeref AST.Type AST.Var t
| StmtStore AST.Type t t
| StmtAssign AST.Type AST.Var t
| StmtCall AST.Type (Maybe AST.Var) AST.Name [AST.Typed t]
| StmtLocal AST.Type AST.Var (InitF t)
| StmtRefCopy AST.Type t t
| StmtRefZero AST.Type t
| StmtLoop Integer AST.Var t (LoopIncrF t) t
| StmtForever t
| Block [t]
deriving (Show, Eq, Ord, Functor)
data LoopIncrF t
= IncrTo t
| DecrTo t
deriving (Show, Eq, Ord, Functor)
data InitF t
= InitZero
| InitExpr AST.Type t
| InitStruct [(String, InitF t)]
| InitArray [InitF t] Bool
deriving (Show, Eq, Ord, Functor)
instance MuRef AST.Stmt where
type DeRef AST.Stmt = CSE
mapDeRef child stmt = CSEBlock <$> case stmt of
AST.IfTE cond tb fb -> StmtIfTE <$> child cond <*> child tb <*> child fb
AST.Assert cond -> StmtAssert <$> child cond
AST.CompilerAssert cond -> StmtCompilerAssert <$> child cond
AST.Assume cond -> StmtAssume <$> child cond
AST.Return (AST.Typed ty ex) -> StmtReturn <$> (AST.Typed ty <$> child ex)
AST.Deref ty var ex -> StmtDeref ty var <$> child ex
AST.Store ty lhs rhs -> StmtStore ty <$> child lhs <*> child rhs
AST.Assign ty var ex -> StmtAssign ty var <$> child ex
AST.Call ty mv nm args -> StmtCall ty mv nm <$> traverse (\ (AST.Typed argTy argEx) -> AST.Typed argTy <$> child argEx) args
AST.Local ty var initex -> StmtLocal ty var <$> mapInit initex
AST.RefCopy ty dst src -> StmtRefCopy ty <$> child dst <*> child src
AST.RefZero ty dst -> StmtRefZero ty <$> child dst
AST.Loop m var ex incr lb -> StmtLoop m var <$> child ex <*> mapIncr incr <*> child lb
AST.Forever lb -> StmtForever <$> child lb
AST.ReturnVoid -> pure $ StmtSimple stmt
AST.AllocRef{} -> pure $ StmtSimple stmt
AST.Break -> pure $ StmtSimple stmt
AST.Comment{} -> pure $ StmtSimple stmt
where
mapInit AST.InitZero = pure InitZero
mapInit (AST.InitExpr ty ex) = InitExpr ty <$> child ex
mapInit (AST.InitStruct fields) = InitStruct <$> traverse (\ (nm, i) -> (,) nm <$> mapInit i) fields
mapInit (AST.InitArray elements b) = liftA2 InitArray (traverse mapInit elements) (pure b)
mapIncr (AST.IncrTo ex) = IncrTo <$> child ex
mapIncr (AST.DecrTo ex) = DecrTo <$> child ex
instance (MuRef a, DeRef [a] ~ DeRef a) => MuRef [a] where
type DeRef [a] = CSE
mapDeRef child xs = CSEBlock <$> Block <$> traverse child xs
toBlock :: (k -> AST.Type -> BlockM AST.Expr) -> (k -> BlockM ()) -> BlockF k -> BlockM ()
toBlock expr block b = case b of
StmtSimple s -> stmt $ return s
StmtIfTE ex tb fb -> stmt $ AST.IfTE <$> expr ex AST.TyBool <*> genBlock (block tb) <*> genBlock (block fb)
StmtAssert cond -> stmt $ AST.Assert <$> expr cond AST.TyBool
StmtCompilerAssert cond -> stmt $ AST.CompilerAssert <$> expr cond AST.TyBool
StmtAssume cond -> stmt $ AST.Assume <$> expr cond AST.TyBool
StmtReturn (AST.Typed ty ex) -> stmt $ AST.Return <$> (AST.Typed ty <$> expr ex ty)
StmtDeref ty var ex -> stmt $ AST.Deref ty var <$> expr ex (AST.TyConstRef ty)
StmtStore ty lhs rhs -> stmt $ AST.Store ty <$> expr lhs (AST.TyRef ty) <*> expr rhs ty
StmtAssign ty var ex -> stmt $ AST.Assign ty var <$> expr ex ty
StmtCall ty mv nm args -> stmt $ AST.Call ty mv nm <$> mapM (\ (AST.Typed argTy argEx) -> AST.Typed argTy <$> expr argEx argTy) args
StmtLocal ty var initex -> stmt $ AST.Local ty var <$> toInit initex
StmtRefCopy ty dst src -> stmt $ AST.RefCopy ty <$> expr dst (AST.TyRef ty) <*> expr src (AST.TyConstRef ty)
StmtRefZero ty dst -> stmt $ AST.RefZero ty <$> expr dst (AST.TyRef ty)
StmtLoop m var ex incr lb -> stmt $ AST.Loop m var <$> expr ex ixRep <*> toIncr incr <*> genBlock (block lb)
StmtForever lb -> stmt $ AST.Forever <$> genBlock (block lb)
Block stmts -> mapM_ block stmts
where
stmt stmtM = fmap D.singleton stmtM >>= put
toInit InitZero = pure AST.InitZero
toInit (InitExpr ty ex) = AST.InitExpr ty <$> expr ex ty
toInit (InitStruct fields) = AST.InitStruct <$> traverse (\ (nm, i) -> (,) nm <$> toInit i) fields
toInit (InitArray elements b') = liftA2 AST.InitArray (traverse toInit elements) (pure b')
toIncr (IncrTo ex) = AST.IncrTo <$> expr ex ixRep
toIncr (DecrTo ex) = AST.DecrTo <$> expr ex ixRep
genBlock :: BlockM () -> BlockM AST.Block
genBlock gen = do
oldBindings <- get
((), stmts) <- collect gen
sets_ $ \ newBindings -> newBindings { availableBindings = availableBindings oldBindings }
return $ D.toList stmts
type Facts = (IntMap (AST.Type -> BlockM AST.Expr), IntMap (BlockM ()))
getFact :: IntMap v -> Unique -> v
getFact m k = case IntMap.lookup k m of
Nothing -> error "IvoryCSE: cycle detected in expression graph"
Just v -> v
updateFacts :: IntSet -> (Unique, CSE Unique) -> Facts -> Facts
updateFacts _ (ident, CSEBlock block) (exprFacts, blockFacts) = (exprFacts, IntMap.insert ident (toBlock (getFact exprFacts) (getFact blockFacts) block) blockFacts)
updateFacts usedOnce (ident, CSEExpr expr) (exprFacts, blockFacts) = (IntMap.insert ident fact exprFacts, blockFacts)
where
nameOf var = AST.VarName $ "cse" ++ show var
fact = case expr of
ExpSimpleF e -> const $ return e
ex -> \ ty -> do
bindings <- get
case Map.lookup (ident, ty) $ availableBindings bindings of
Just var -> do
set $ bindings { unusedBindings = IntSet.delete var $ unusedBindings bindings }
return $ AST.ExpVar $ nameOf var
Nothing -> do
ex' <- fmap toExpr $ mapM (uncurry $ getFact exprFacts) $ labelTypes ty ex
var <- sets $ \ (Bindings { availableBindings = avail, unusedBindings = unused, totalBindings = maxId}) ->
(maxId, Bindings
{ availableBindings = Map.insert (ident, ty) maxId avail
, unusedBindings = IntSet.insert maxId unused
, totalBindings = maxId + 1
})
lift $ if var `IntSet.member` usedOnce
then return ex'
else do
put $ D.singleton $ AST.Assign ty (nameOf var) ex'
return $ AST.ExpVar $ nameOf var
data Constant
= ConstFalse
| ConstTrue
| ConstZero
| ConstTwo
deriving (Bounded, Enum)
constExpr :: Constant -> CSE Unique
constExpr ConstFalse = CSEExpr $ ExpSimpleF $ AST.ExpLit $ AST.LitBool False
constExpr ConstTrue = CSEExpr $ ExpSimpleF $ AST.ExpLit $ AST.LitBool True
constExpr ConstZero = CSEExpr $ ExpSimpleF $ AST.ExpLit $ AST.LitInteger 0
constExpr ConstTwo = CSEExpr $ ExpSimpleF $ AST.ExpLit $ AST.LitInteger 2
constUnique :: Constant -> Unique
constUnique c = negate $ 1 + fromEnum c
type Dupes = (Map (CSE Unique) Unique, IntMap Unique, Facts)
dedup :: IntSet -> (Unique, CSE Unique) -> Dupes -> Dupes
dedup usedOnce (ident, expr) (seen, remap, facts) = case expr' of
CSEExpr (ExpOpF (AST.ExpEq ty) [a, b]) | not (isFloat ty) && a == b -> remapTo $ constUnique ConstTrue
CSEExpr (ExpOpF (AST.ExpNeq ty) [a, b]) | not (isFloat ty) && a == b -> remapTo $ constUnique ConstFalse
CSEExpr (ExpOpF (AST.ExpGt isEq ty) [a, b]) | not (isFloat ty) && a == b -> remapTo $ if isEq then constUnique ConstTrue else constUnique ConstFalse
CSEExpr (ExpOpF (AST.ExpLt isEq ty) [a, b]) | not (isFloat ty) && a == b -> remapTo $ if isEq then constUnique ConstTrue else constUnique ConstFalse
CSEExpr (ExpOpF AST.ExpBitXor [a, b]) | a == b -> remapTo $ constUnique ConstZero
CSEExpr (ExpOpF AST.ExpAnd [a, b]) | a == b -> remapTo a
CSEExpr (ExpOpF AST.ExpOr [a, b]) | a == b -> remapTo a
CSEExpr (ExpOpF AST.ExpBitAnd [a, b]) | a == b -> remapTo a
CSEExpr (ExpOpF AST.ExpBitOr [a, b]) | a == b -> remapTo a
CSEExpr (ExpOpF AST.ExpCond [_, t, f]) | t == f -> remapTo t
CSEBlock (StmtIfTE _ t f) | t == f -> remapTo t
CSEBlock (Block [s]) -> remapTo s
_ -> case Map.lookup expr' seen of
Just ident' -> remapTo ident'
Nothing -> (Map.insert expr' ident seen, remap, updateFacts usedOnce (ident, expr') facts)
where
remapTo ident' = (seen, IntMap.insert ident ident' remap, facts)
expr' = case fmap (\ k -> IntMap.findWithDefault k k remap) expr of
CSEExpr (ExpOpF AST.ExpAdd [a, b]) | a == b -> CSEExpr $ ExpOpF AST.ExpMul $ sort [constUnique ConstTwo, a]
CSEExpr (ExpOpF op args) | isCommutative op -> CSEExpr $ ExpOpF op $ sort args
asis -> asis
isFloat AST.TyFloat = True
isFloat AST.TyDouble = True
isFloat _ = False
isCommutative (AST.ExpEq _) = True
isCommutative (AST.ExpNeq _) = True
isCommutative AST.ExpMul = True
isCommutative AST.ExpAdd = True
isCommutative AST.ExpBitAnd = True
isCommutative AST.ExpBitOr = True
isCommutative AST.ExpBitXor = True
isCommutative _ = False
reconstruct :: Graph CSE -> AST.Block
reconstruct (Graph subexprs root) = D.toList rootBlock
where
(_, remap, (_, blockFacts)) =
foldr (dedup usedOnce) mempty
$ subexprs ++ [ (constUnique c, constExpr c) | c <- [minBound..maxBound] ]
Just rootGen = IntMap.lookup (IntMap.findWithDefault root root remap) blockFacts
(((), Bindings { unusedBindings = usedOnce }), rootBlock) =
runM rootGen $ Bindings Map.empty IntSet.empty 0