module Database.Algebra.Rewrite.DagRewrite
(
Rewrite
, runRewrite
, initRewriteState
, Log
, logGeneral
, logRewrite
, reachableNodesFrom
, parents
, topsort
, operator
, operatorSafe
, rootNodes
, exposeDag
, getExtras
, updateExtras
, condRewrite
, insert
, insertNoShare
, replaceChild
, replace
, replaceWithNew
, replaceRoot
, infer
, collect
) where
import Control.Monad.State
import Control.Monad.Writer
import qualified Data.IntMap as IM
import qualified Data.Sequence as Seq
import qualified Data.Set as S
import Debug.Trace
import qualified Database.Algebra.Dag as Dag
import Database.Algebra.Dag.Common
data Cache = Cache { cachedTopOrdering :: Maybe [AlgNode] }
emptyCache :: Cache
emptyCache = Cache Nothing
data RewriteState o e = RewriteState
{ dag :: Dag.AlgebraDag o
, cache :: Cache
, extras :: e
, debugFlag :: Bool
, collectNodes :: S.Set AlgNode
}
newtype Rewrite o e a = R (WriterT Log (State (RewriteState o e)) a) deriving (Monad, Functor, Applicative)
initRewriteState :: (Ord o, Dag.Operator o) => Dag.AlgebraDag o -> e -> Bool -> RewriteState o e
initRewriteState d e debug =
RewriteState { dag = d
, cache = emptyCache
, extras = e
, debugFlag = debug
, collectNodes = S.empty
}
runRewrite :: Dag.Operator o => Rewrite o e r -> Dag.AlgebraDag o -> e -> Bool -> (Dag.AlgebraDag o, e, r, Log)
runRewrite (R m) d e debug = (dag s, extras s, res, rewriteLog)
where ((res, rewriteLog), s) = runState (runWriterT m) (initRewriteState d e debug)
type Log = Seq.Seq String
unwrapR :: Rewrite o e a -> WriterT Log (State (RewriteState o e)) a
unwrapR (R m) = m
invalidateCacheM :: Rewrite o e ()
invalidateCacheM =
R $ do
s <- get
put $ s { cache = emptyCache }
putDag :: Dag.AlgebraDag o -> Rewrite o e ()
putDag d =
R $ do
s <- get
put $ s { dag = d }
putCache :: Cache -> Rewrite o e ()
putCache c =
R $ do
s <- get
put $ s { cache = c }
logGeneral :: String -> Rewrite o e ()
logGeneral msg = do
d <- R $ gets debugFlag
if d
then trace msg $ R $ tell $ Seq.singleton msg
else R $ tell $ Seq.singleton msg
logRewrite :: String -> AlgNode -> Rewrite o e ()
logRewrite rewrite node =
logGeneral $ "Triggering rewrite " ++ rewrite ++ " at node " ++ (show node)
reachableNodesFrom :: AlgNode -> Rewrite o e (S.Set AlgNode)
reachableNodesFrom n =
R $ do
d <- gets dag
return $ Dag.reachableNodesFrom n d
parents :: AlgNode -> Rewrite o e [AlgNode]
parents n = R $ gets ((Dag.parents n) . dag)
topsort :: Dag.Operator o => Rewrite o e [AlgNode]
topsort =
R $ do
s <- get
let c = cache s
case cachedTopOrdering c of
Just o -> return o
Nothing -> do
let d = dag s
ordering = Dag.topsort d
unwrapR $ putCache $ c { cachedTopOrdering = Just ordering }
return ordering
operator :: Dag.Operator o => AlgNode -> Rewrite o e o
operator n =
R $ do
d <- gets dag
return $ Dag.operator n d
operatorSafe :: AlgNode -> Rewrite o e (Maybe o)
operatorSafe n =
R $ do
d <- gets dag
return $ IM.lookup n (Dag.nodeMap d)
rootNodes :: Rewrite o e [AlgNode]
rootNodes = R $ liftM Dag.rootNodes $ liftM dag $ get
exposeDag :: Rewrite o e (Dag.AlgebraDag o)
exposeDag = R $ gets dag
getExtras :: Rewrite o e e
getExtras = R $ gets extras
condRewrite :: Rewrite o e Bool -> Rewrite o e Bool
condRewrite r =
R $ do
s <- get
success <- unwrapR r
if success
then return success
else trace "Rollback" $ put s >> return success
updateExtras :: e -> Rewrite o e ()
updateExtras e =
R $ do
s <- get
put $ s { extras = e }
insert :: (Dag.Operator o, Show o) => o -> Rewrite o e AlgNode
insert op =
R $ do
d <- gets dag
unwrapR invalidateCacheM
let (n, d') = Dag.insert op d
unwrapR $ putDag d'
return n
insertNoShare :: Dag.Operator o => o -> Rewrite o e AlgNode
insertNoShare op =
R $ do
d <- gets dag
unwrapR invalidateCacheM
let (n, d') = Dag.insertNoShare op d
unwrapR $ putDag d'
return n
replaceChild :: Dag.Operator o => AlgNode -> AlgNode -> AlgNode -> Rewrite o e ()
replaceChild n old new =
R $ do
s <- get
unwrapR invalidateCacheM
unwrapR $ putDag $ Dag.replaceChild n old new $ dag s
replace :: Dag.Operator o => AlgNode -> AlgNode -> Rewrite o e ()
replace old new = do
ps <- parents old
forM_ ps $ (\p -> replaceChild p old new)
addCollectNode old
R $ do s <- get
unwrapR $ putDag $ Dag.replaceRoot (dag s) old new
replaceWithNew :: (Dag.Operator o, Show o) => AlgNode -> o -> Rewrite o e AlgNode
replaceWithNew oldNode newOp = do
newNode <- insert newOp
replace oldNode newNode
return newNode
infer :: (Dag.AlgebraDag o -> b) -> Rewrite o e b
infer f = R $ liftM f $ gets dag
addCollectNode :: AlgNode -> Rewrite o e ()
addCollectNode n =
R $ do
s <- get
put $ s { collectNodes = S.insert n $ collectNodes s }
collect :: (Show o, Dag.Operator o) => Rewrite o e ()
collect =
R $ do
s <- get
let d' = Dag.collect (collectNodes s) (dag s)
put s { dag = d', collectNodes = S.empty }
replaceRoot :: Dag.Operator o => AlgNode -> AlgNode -> Rewrite o e ()
replaceRoot oldRoot newRoot =
R $ do
s <- get
if not $ IM.member newRoot $ Dag.nodeMap $ dag s
then error "replaceRootM: new root node is not present in the DAG"
else unwrapR $ putDag $ Dag.replaceRoot (dag s) oldRoot newRoot