{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Comp.Dag.AG
( runAG
, runRewrite
, module I
) where
import Control.Monad.ST
import Control.Monad.State
import Data.Comp.AG.Internal
import qualified Data.Comp.AG.Internal as I hiding (explicit)
import Data.Comp.Dag
import Data.Comp.Dag.Internal
import Data.Comp.Mapping as I
import Data.Comp.Projection as I
import Data.Comp.Term
import qualified Data.IntMap as IntMap
import Data.Maybe
import Data.STRef
import qualified Data.Traversable as Traversable
import Data.Vector (Vector,MVector)
import qualified Data.Vector as Vec
import qualified Data.Vector.Generic.Mutable as MVec
runAG :: forall f d u .Traversable f
=> (d -> d -> d)
-> Syn' f (u,d) u
-> Inh' f (u,d) d
-> (u -> d)
-> Dag f
-> u
runAG res syn inh dinit Dag {edges,root,nodeCount} = uFin where
uFin = runST runM
dFin = dinit uFin
runM :: forall s . ST s u
runM = mdo
dmap <- MVec.new nodeCount
MVec.set dmap Nothing
umap <- MVec.new nodeCount
count <- newSTRef 0
let
run :: d -> f (Context f Node) -> ST s u
run d t = mdo
let u = explicit syn (u,d) unNumbered result
m = explicit inh (u,d) unNumbered result
run' :: Context f Node -> ST s (Numbered (u,d))
run' s = do i <- readSTRef count
writeSTRef count $! (i+1)
let d' = lookupNumMap d i m
u' <- runF d' s
return (Numbered i (u',d'))
result <- Traversable.mapM run' t
return u
runF :: d -> Context f Node -> ST s u
runF d (Hole x) = do
old <- MVec.unsafeRead dmap x
let new = case old of
Just o -> res o d
_ -> d
MVec.unsafeWrite dmap x (Just new)
return (umapFin Vec.! x)
runF d (Term t) = run d t
iter (n, t) = do
writeSTRef count 0
u <- run (fromJust $ dmapFin Vec.! n) t
MVec.unsafeWrite umap n u
u <- run dFin root
mapM_ iter (IntMap.toList edges)
dmapFin <- Vec.unsafeFreeze dmap
umapFin <- Vec.unsafeFreeze umap
return u
runRewrite :: forall f g d u .(Traversable f, Traversable g)
=> (d -> d -> d)
-> Syn' f (u,d) u
-> Inh' f (u,d) d
-> Rewrite f (u, d) g
-> (u -> d)
-> Dag f
-> (u, Dag g)
runRewrite res syn inh rewr dinit Dag {edges,root,nodeCount} = result where
result@(uFin,_) = runST runM
dFin = dinit uFin
runM :: forall s . ST s (u, Dag g)
runM = mdo
dmap <- MVec.new nodeCount
MVec.set dmap Nothing
umap <- MVec.new nodeCount
count <- newSTRef 0
allEdges <- MVec.new nodeCount
let
iter (node,s) = do
let d = fromJust $ dmapFin Vec.! node
writeSTRef count 0
(u,t) <- run d s
MVec.unsafeWrite umap node u
MVec.unsafeWrite allEdges node t
run :: d -> f (Context f Node) -> ST s (u, Context g Node)
run d t = mdo
let u = explicit syn (u,d) (fst . unNumbered) result
m = explicit inh (u,d) (fst . unNumbered) result
run' :: Context f Node -> ST s (Numbered ((u,d), Context g Node))
run' s = do i <- readSTRef count
writeSTRef count $! (i+1)
let d' = lookupNumMap d i m
(u',t) <- runF d' s
return (Numbered i ((u',d'), t))
result <- Traversable.mapM run' t
let t' = join $ fmap (snd . unNumbered) $ explicit rewr (u,d) (fst . unNumbered) result
return (u, t')
runF d (Term t) = run d t
runF d (Hole x) = do
old <- MVec.unsafeRead dmap x
let new = case old of
Just o -> res o d
_ -> d
MVec.unsafeWrite dmap x (Just new)
return (umapFin Vec.! x, Hole x)
(u,interRoot) <- run dFin root
mapM_ iter $ IntMap.toList edges
dmapFin <- Vec.unsafeFreeze dmap
umapFin <- Vec.unsafeFreeze umap
allEdgesFin <- Vec.unsafeFreeze allEdges
return (u, relabelNodes interRoot allEdgesFin nodeCount)
relabelNodes :: forall f . Traversable f
=> Context f Node
-> Vector (Cxt Hole f Int)
-> Int
-> Dag f
relabelNodes root edges nodeCount = runST run where
run :: ST s (Dag f)
run = do
curNode <- newSTRef 0
newEdges <- newSTRef IntMap.empty
newNodes :: MVector s (Maybe Int) <- MVec.new nodeCount
MVec.set newNodes Nothing
let
build :: Node -> ST s Node
build node = do
mnewNode <- MVec.unsafeRead newNodes node
case mnewNode of
Just newNode -> return newNode
Nothing ->
case edges Vec.! node of
Hole n -> do
newNode <- build n
MVec.unsafeWrite newNodes node (Just newNode)
return newNode
Term f -> do
newNode <- readSTRef curNode
writeSTRef curNode $! (newNode+1)
MVec.unsafeWrite newNodes node (Just newNode)
f' <- Traversable.mapM (Traversable.mapM build) f
modifySTRef newEdges (IntMap.insert newNode f')
return newNode
build' :: Context f Node -> ST s (f (Context f Node))
build' (Hole n) = do
n' <- build n
e <- readSTRef newEdges
return (e IntMap.! n')
build' (Term f) = Traversable.mapM (Traversable.mapM build) f
root' <- build' root
edges' <- readSTRef newEdges
nodeCount' <- readSTRef curNode
return Dag {edges = edges', root = root', nodeCount = nodeCount'}