module Data.Comp.Dag.PAG
( runPAG
, module I
) where
import Control.Monad.ST
import Data.Comp.Dag
import Data.Comp.Dag.Internal
import Data.Comp.Mapping as I
import Data.Comp.Multi.Projection as I
import Data.Comp.PAG.Internal
import qualified Data.Comp.PAG.Internal as I hiding (explicit)
import Data.Comp.Term
import qualified Data.IntMap as IntMap
import Data.IntMap (IntMap)
import Data.Vector (MVector)
import Data.Maybe
import Data.STRef
import qualified Data.Traversable as Traversable
import qualified Data.Vector as Vec
import qualified Data.Vector.Generic.Mutable as MVec
import Control.Monad.State
runPAG :: forall f d u g . (Traversable f, Traversable d, Traversable g, Traversable u)
=> (forall a . d a -> d a -> d a)
-> Syn' f (u :*: d) u g
-> Inh' f (u :*: d) d g
-> (forall a . u a -> d (Context g a))
-> Dag f
-> u (Dag g)
runPAG res syn inh dinit Dag {edges,root,nodeCount} = result where
(uFin, result) = runST runM
runM :: forall s . ST s (u Node, u (Dag g))
runM = mdo
dmap <- MVec.new nodeCount
MVec.set dmap Nothing
umap <- MVec.new nodeCount
count <- newSTRef 0
nextNode <- newSTRef 0
newEdges <- newSTRef (IntMap.empty :: IntMap (g (Context g Node)))
let
iter (node,s) = do
let d = fromJust $ dmapFin Vec.! node
u <- run d s
MVec.unsafeWrite umap node u
run :: d Node -> f (Context f Node) -> ST s (u Node)
run d t = mdo
e <- readSTRef newEdges
n <- readSTRef nextNode
let mkFresh = liftM2 (,) (Traversable.mapM freshNode $ explicit syn (u :*: d) unNumbered result)
(Traversable.mapM (Traversable.mapM freshNode) $ explicit inh (u :*: d) unNumbered result)
((u,m),(Fresh n' e')) = runState mkFresh (Fresh n e)
writeSTRef newEdges e'
writeSTRef nextNode n'
let run' :: Context f Node -> ST s (Numbered ((u :*: d) Node))
run' s = do i <- readSTRef count
writeSTRef count $! (i+1)
let d' = case lookupNumMap' i m of
Nothing -> d
Just d' -> d'
u' <- runF d' s
return (Numbered i (u' :*: d'))
writeSTRef count 0
result <- Traversable.mapM run' t
return u
runF d (Term t) = run d t
runF d (Hole x) = do
old <- MVec.unsafeRead dmap x
let new = case old of
Just o -> res o d
_ -> d
MVec.unsafeWrite dmap x (Just new)
return (umapFin Vec.! x)
e <- readSTRef newEdges
n <- readSTRef nextNode
let (dFin,Fresh n' e') = runState (Traversable.mapM freshNode $ dinit uFin) (Fresh n e)
writeSTRef newEdges e'
writeSTRef nextNode n'
u <- run dFin root
mapM_ iter $ IntMap.toList edges
dmapFin <- Vec.unsafeFreeze dmap
umapFin <- Vec.unsafeFreeze umap
newEdgesFin <- readSTRef newEdges
newEdgesCount <- readSTRef nextNode
let relabel n = relabelNodes n newEdgesFin newEdgesCount
return (u, fmap relabel u)
data Fresh f = Fresh {nextFreshNode :: Int, freshEdges :: IntMap (f (Context f Node))}
freshNode :: Context g Node -> State (Fresh g) Node
freshNode (Hole n) = return n
freshNode (Term t) = do
s <- get
let n = nextFreshNode s
e = freshEdges s
put (s {freshEdges= IntMap.insert n t e, nextFreshNode = n+1 })
return n
relabelNodes :: forall f . Traversable f
=> Node
-> IntMap (f (Context f Node))
-> Int
-> Dag f
relabelNodes root edges nodeCount = runST run where
run :: ST s (Dag f)
run = do
curNode <- newSTRef 0
newEdges <- newSTRef IntMap.empty
newNodes :: MVector s (Maybe Int) <- MVec.new nodeCount
MVec.set newNodes Nothing
let
build :: Node -> ST s Node
build node = do
mnewNode <- MVec.unsafeRead newNodes node
case mnewNode of
Just newNode -> return newNode
Nothing -> do
newNode <- readSTRef curNode
writeSTRef curNode $! (newNode+1)
MVec.unsafeWrite newNodes node (Just newNode)
f' <- Traversable.mapM (Traversable.mapM build) (edges IntMap.! node)
modifySTRef newEdges (IntMap.insert newNode f')
return newNode
root' <- Traversable.mapM (Traversable.mapM build) (edges IntMap.! root)
edges' <- readSTRef newEdges
nodeCount' <- readSTRef curNode
return Dag {edges = edges', root = root', nodeCount = nodeCount'}