{-# LANGUAGE GeneralizedNewtypeDeriving #-} -- | This module provides a monadic interface to rewrites on algebra DAGs. module Database.Algebra.Rewrite.DagRewrite ( -- ** The Rewrite monad Rewrite , runRewrite , initRewriteState , Log , logGeneral , logRewrite , reachableNodesFrom , parents , topsort , operator , operatorSafe , rootNodes , exposeDag , getExtras , updateExtras , condRewrite , insert , insertNoShare , replaceChild , replace , replaceWithNew , replaceRoot , infer , collect ) where import Control.Applicative import Control.Monad.State import Control.Monad.Writer import qualified Data.IntMap as IM import qualified Data.Sequence as Seq import qualified Data.Set as S import Debug.Trace import qualified Database.Algebra.Dag as Dag import Database.Algebra.Dag.Common -- | Cache some topological information about the DAG. data Cache = Cache { cachedTopOrdering :: Maybe [AlgNode] } emptyCache :: Cache emptyCache = Cache Nothing data RewriteState o e = RewriteState { dag :: Dag.AlgebraDag o -- ^ The DAG itself , cache :: Cache -- ^ Cache of some topological information , extras :: e -- ^ Polymorphic container for whatever needs to be provided additionally. , debugFlag :: Bool -- ^ Wether to output log messages via Debug.Trace.trace , collectNodes :: S.Set AlgNode -- ^ List of nodes which must be checked during garbage collection } -- | A Monad for DAG rewrites, parameterized over the type of algebra operators. newtype Rewrite o e a = R (WriterT Log (State (RewriteState o e)) a) deriving (Monad, Functor, Applicative) -- FIXME Map.findMax might call error initRewriteState :: (Ord o, Dag.Operator o) => Dag.AlgebraDag o -> e -> Bool -> RewriteState o e initRewriteState d e debug = RewriteState { dag = d , cache = emptyCache , extras = e , debugFlag = debug , collectNodes = S.empty } -- | Run a rewrite action on the supplied graph. Returns the rewritten node map, the potentially -- modified list of root nodes, the result of the rewrite and the rewrite log. runRewrite :: Dag.Operator o => Rewrite o e r -> Dag.AlgebraDag o -> e -> Bool -> (Dag.AlgebraDag o, e, r, Log) runRewrite (R m) d e debug = (dag s, extras s, res, rewriteLog) where ((res, rewriteLog), s) = runState (runWriterT m) (initRewriteState d e debug) -- | The log from a sequence of rewrite actions. type Log = Seq.Seq String -- FIXME unwrapR should not be necessary: just provide a type alias for the monad stack unwrapR :: Rewrite o e a -> WriterT Log (State (RewriteState o e)) a unwrapR (R m) = m invalidateCacheM :: Rewrite o e () invalidateCacheM = R $ do s <- get put $ s { cache = emptyCache } -- internal helper function putDag :: Dag.AlgebraDag o -> Rewrite o e () putDag d = R $ do s <- get put $ s { dag = d } putCache :: Cache -> Rewrite o e () putCache c = R $ do s <- get put $ s { cache = c } -- | Log a general message logGeneral :: String -> Rewrite o e () logGeneral msg = do d <- R $ gets debugFlag if d then trace msg $ R $ tell $ Seq.singleton msg else R $ tell $ Seq.singleton msg -- | Log a rewrite logRewrite :: String -> AlgNode -> Rewrite o e () logRewrite rewrite node = logGeneral $ "Triggering rewrite " ++ rewrite ++ " at node " ++ (show node) -- | Return the set of nodes that are reachable from the specified node. reachableNodesFrom :: AlgNode -> Rewrite o e (S.Set AlgNode) reachableNodesFrom n = R $ do d <- gets dag return $ Dag.reachableNodesFrom n d -- | Return the parents of a node parents :: AlgNode -> Rewrite o e [AlgNode] parents n = R $ gets ((Dag.parents n) . dag) -- | Return a topological ordering of all reachable nodes in the DAG. topsort :: Dag.Operator o => Rewrite o e [AlgNode] topsort = R $ do s <- get let c = cache s case cachedTopOrdering c of Just o -> return o Nothing -> do let d = dag s ordering = Dag.topsort d unwrapR $ putCache $ c { cachedTopOrdering = Just ordering } return ordering -- | Return the operator for a node id. operator :: Dag.Operator o => AlgNode -> Rewrite o e o operator n = R $ do d <- gets dag return $ Dag.operator n d operatorSafe :: AlgNode -> Rewrite o e (Maybe o) operatorSafe n = R $ do d <- gets dag return $ IM.lookup n (Dag.nodeMap d) -- | Returns the root nodes of the DAG. rootNodes :: Rewrite o e [AlgNode] rootNodes = R $ liftM Dag.rootNodes $ liftM dag $ get -- | Exposes the current state of the DAG exposeDag :: Rewrite o e (Dag.AlgebraDag o) exposeDag = R $ gets dag getExtras :: Rewrite o e e getExtras = R $ gets extras -- | Preserve the effects of a rewrite only if the rewrite signals -- success by returning True. Otherwise, the state before the rewrite -- is put in place again. condRewrite :: Rewrite o e Bool -> Rewrite o e Bool condRewrite r = R $ do s <- get success <- unwrapR r if success then return success else trace "Rollback" $ put s >> return success updateExtras :: e -> Rewrite o e () updateExtras e = R $ do s <- get put $ s { extras = e } -- | Insert an operator into the DAG and return its node id. If the operator is already -- present (same op, same children), reuse it. insert :: (Dag.Operator o, Show o) => o -> Rewrite o e AlgNode insert op = R $ do d <- gets dag unwrapR invalidateCacheM let (n, d') = Dag.insert op d unwrapR $ putDag d' return n -- | Insert an operator into the DAG and return its node id WITHOUT reusing an -- operator if it is already present. insertNoShare :: Dag.Operator o => o -> Rewrite o e AlgNode insertNoShare op = R $ do d <- gets dag unwrapR invalidateCacheM let (n, d') = Dag.insertNoShare op d unwrapR $ putDag d' return n -- | replaceChildM n old new replaces all links from node n to node old with links -- to node new replaceChild :: Dag.Operator o => AlgNode -> AlgNode -> AlgNode -> Rewrite o e () replaceChild n old new = R $ do s <- get unwrapR invalidateCacheM unwrapR $ putDag $ Dag.replaceChild n old new $ dag s -- | replace old new replaces _all_ links to old with links to new replace :: Dag.Operator o => AlgNode -> AlgNode -> Rewrite o e () replace old new = do ps <- parents old forM_ ps $ (\p -> replaceChild p old new) addCollectNode old R $ do s <- get unwrapR $ putDag $ Dag.replaceRoot (dag s) old new -- | Creates a new node from the operator and replaces the old node with it -- by rewiring all links to the old node. replaceWithNew :: (Dag.Operator o, Show o) => AlgNode -> o -> Rewrite o e AlgNode replaceWithNew oldNode newOp = do newNode <- insert newOp replace oldNode newNode return newNode -- | Apply a pure function to the DAG. infer :: (Dag.AlgebraDag o -> b) -> Rewrite o e b infer f = R $ liftM f $ gets dag addCollectNode :: AlgNode -> Rewrite o e () addCollectNode n = R $ do s <- get put $ s { collectNodes = S.insert n $ collectNodes s } collect :: (Show o, Dag.Operator o) => Rewrite o e () collect = R $ do s <- get let d' = Dag.collect (collectNodes s) (dag s) put s { dag = d', collectNodes = S.empty } replaceRoot :: Dag.Operator o => AlgNode -> AlgNode -> Rewrite o e () replaceRoot oldRoot newRoot = R $ do s <- get if not $ IM.member newRoot $ Dag.nodeMap $ dag s then error "replaceRootM: new root node is not present in the DAG" else unwrapR $ putDag $ Dag.replaceRoot (dag s) oldRoot newRoot