{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Comp.Dag
( Dag
, termTree
, reifyDag
, unravel
, bisim
, iso
, strongIso
) where
import Control.Exception.Base
import Control.Monad.State
import Data.Comp.Dag.Internal
import Data.Comp.Equality
import Data.Comp.Term
import Data.Foldable (Foldable)
import qualified Data.HashMap.Lazy as HashMap
import Data.IntMap
import qualified Data.IntMap as IntMap
import Data.IORef
import Data.Traversable (Traversable)
import qualified Data.Traversable as Traversable
import Data.Typeable
import System.Mem.StableName
import Control.Monad.ST
import Data.Comp.Show
import Data.List
import Data.STRef
import qualified Data.Vector as Vec
import qualified Data.Vector.Generic.Mutable as MVec
instance (ShowF f, Functor f) => Show (Dag f)
where
show (Dag r es _) = unwords
[ "mkDag"
, show (Term r)
, showLst ["(" ++ show n ++ "," ++ show (Term f) ++ ")" | (n,f) <- IntMap.toList es ]
]
where
showLst ss = "[" ++ intercalate "," ss ++ "]"
termTree :: Functor f => Term f -> Dag f
termTree (Term t) = Dag (fmap toCxt t) IntMap.empty 0
data CyclicException = CyclicException
deriving (Show, Typeable)
instance Exception CyclicException
reifyDag :: Traversable f => Term f -> IO (Dag f)
reifyDag m = do
tabRef <- newIORef HashMap.empty
let findNodes (Term !j) = do
st <- liftIO $ makeStableName j
tab <- readIORef tabRef
case HashMap.lookup st tab of
Just (single,f) | single -> writeIORef tabRef (HashMap.insert st (False,f) tab)
>> return st
| otherwise -> return st
Nothing -> do res <- Traversable.mapM findNodes j
tab <- readIORef tabRef
if HashMap.member st tab
then throwIO CyclicException
else writeIORef tabRef (HashMap.insert st (True,res) tab)
>> return st
st <- findNodes m
tab <- readIORef tabRef
counterRef <- newIORef 0
edgesRef <- newIORef IntMap.empty
nodesRef <- newIORef HashMap.empty
let run st = do
let (single,f) = tab HashMap.! st
if single then Term <$> Traversable.mapM run f
else do
nodes <- readIORef nodesRef
case HashMap.lookup st nodes of
Just n -> return (Hole n)
Nothing -> do
n <- readIORef counterRef
writeIORef counterRef $! (n+1)
writeIORef nodesRef (HashMap.insert st n nodes)
f' <- Traversable.mapM run f
modifyIORef edgesRef (IntMap.insert n f')
return (Hole n)
Term root <- run st
edges <- readIORef edgesRef
count <- readIORef counterRef
return (Dag root edges count)
unravel :: forall f. Functor f => Dag f -> Term f
unravel Dag {edges, root} = Term $ build <$> root
where build :: Context f Node -> Term f
build (Term t) = Term $ build <$> t
build (Hole n) = Term $ build <$> edges IntMap.! n
bisim :: forall f . (EqF f, Functor f, Foldable f) => Dag f -> Dag f -> Bool
bisim Dag {root=r1,edges=e1} Dag {root=r2,edges=e2} = runF r1 r2
where run :: (Context f Node, Context f Node) -> Bool
run (t1, t2) = runF (step e1 t1) (step e2 t2)
step :: Edges f -> Context f Node -> f (Context f Node)
step e (Hole n) = e IntMap.! n
step _ (Term t) = t
runF :: f (Context f Node) -> f (Context f Node) -> Bool
runF f1 f2 = case eqMod f1 f2 of
Nothing -> False
Just l -> all run l
iso :: (Traversable f, Foldable f, EqF f) => Dag f -> Dag f -> Bool
iso g1 g2 = checkIso eqMod (flatten g1) (flatten g2)
strongIso :: (Functor f, Foldable f, EqF f) => Dag f -> Dag f -> Bool
strongIso Dag {root=r1,edges=e1,nodeCount=nx1}
Dag {root=r2,edges=e2,nodeCount=nx2}
= checkIso checkEq (r1,e1,nx1) (r2,e2,nx2)
where checkEq t1 t2 = eqMod (Term t1) (Term t2)
flatten :: forall f . Traversable f => Dag f -> (f Node, IntMap (f Node), Int)
flatten Dag {root,edges,nodeCount} = runST run where
run :: forall s . ST s (f Node, IntMap (f Node), Int)
run = do
count <- newSTRef 0
nMap :: Vec.MVector s (Maybe Node) <- MVec.new nodeCount
MVec.set nMap Nothing
newEdges <- newSTRef IntMap.empty
let build :: Context f Node -> ST s Node
build (Hole n) = mkNode n
build (Term t) = do
n' <- readSTRef count
writeSTRef count $! (n'+1)
t' <- Traversable.mapM build t
modifySTRef newEdges (IntMap.insert n' t')
return n'
mkNode n = do
mn' <- MVec.unsafeRead nMap n
case mn' of
Just n' -> return n'
Nothing -> do n' <- readSTRef count
writeSTRef count $! (n'+1)
MVec.unsafeWrite nMap n (Just n')
return n'
buildF (n,t) = do
n' <- mkNode n
t' <- Traversable.mapM build t
modifySTRef newEdges (IntMap.insert n' t')
root' <- Traversable.mapM build root
mapM_ buildF $ IntMap.toList edges
edges' <- readSTRef newEdges
nodeCount' <- readSTRef count
return (root', edges', nodeCount')
checkIso :: (e -> e -> Maybe [(Node,Node)])
-> (e, IntMap e, Int)
-> (e, IntMap e, Int) -> Bool
checkIso checkEq (r1,e1,nx1) (r2,e2,nx2) = runST run where
run :: ST s Bool
run = do
nMap :: Vec.MVector s (Maybe Node) <- MVec.new nx1
MVec.set nMap Nothing
nSet :: Vec.MVector s Bool <- MVec.new nx2
MVec.set nSet False
let checkT t1 t2 = case checkEq t1 t2 of
Nothing -> return False
Just l -> liftM and $ mapM checkN l
checkN (n1,n2) = do
nm' <- MVec.unsafeRead nMap n1
case nm' of
Just n' -> return (n2 == n')
_ -> do
b <- MVec.unsafeRead nSet n2
if b
then return False
else do
MVec.unsafeWrite nMap n1 (Just n2)
MVec.unsafeWrite nSet n2 True
checkT (e1 IntMap.! n1) (e2 IntMap.! n2)
checkT r1 r2