module Language.Syntactic.Sharing.Graph where
import Control.Arrow ((***))
import Control.Monad.Reader
import Data.Array
import Data.Function
import Data.List
import Data.Typeable
import Data.Hash
import Data.Proxy
import Language.Syntactic
import Language.Syntactic.Features.Binding
import Language.Syntactic.Sharing.Utils
newtype NodeId = NodeId { nodeInteger :: Integer }
deriving (Eq, Ord, Num, Real, Integral, Enum, Ix)
data Node ctx a
where
Node :: Sat ctx a => NodeId -> Node ctx (Full a)
instance Show NodeId
where
show (NodeId i) = show i
showNode :: NodeId -> String
showNode n = "node:" ++ show n
instance WitnessCons (Node ctx)
where
witnessCons (Node _) = ConsWit
instance Render (Node ctx)
where
render (Node a) = showNode a
instance ToTree (Node ctx)
prjNode :: (Node ctx :<: sup) => Proxy ctx -> sup a -> Maybe (Node ctx a)
prjNode _ = project
data SomeAST dom
where
SomeAST :: Typeable a => ASTF dom a -> SomeAST dom
data ASG ctx dom a = ASG
{ topExpression :: ASTF (Node ctx :+: dom) a
, graphNodes :: [(NodeId, SomeAST (Node ctx :+: dom))]
, numNodes :: NodeId
}
showASG :: ToTree dom => ASG ctx dom a -> String
showASG (ASG top nodes _) =
unlines ((line "top" ++ showAST top) : map showNode nodes)
where
line str = "---- " ++ str ++ " " ++ rest ++ "\n"
where
rest = take (40 length str) $ repeat '-'
showNode (n, SomeAST expr) = concat
[ line ("node:" ++ show n)
, showAST expr
]
drawASG :: ToTree dom => ASG ctx dom a -> IO ()
drawASG = putStrLn . showASG
reindexNodesAST ::
(NodeId -> NodeId) -> AST (Node ctx :+: dom) a -> AST (Node ctx :+: dom) a
reindexNodesAST reix (Symbol (InjectL (Node n))) =
Symbol (InjectL (Node $ reix n))
reindexNodesAST reix (f :$: a) =
reindexNodesAST reix f :$: reindexNodesAST reix a
reindexNodesAST reix a = a
reindexNodes :: (NodeId -> NodeId) -> ASG ctx dom a -> ASG ctx dom a
reindexNodes reix (ASG top nodes n) = ASG top' nodes' n
where
top' = reindexNodesAST reix top
nodes' =
[ (reix n, SomeAST $ reindexNodesAST reix a)
| (n, SomeAST a) <- nodes
]
reindexNodesFrom0 :: ASG ctx dom a -> ASG ctx dom a
reindexNodesFrom0 graph = reindexNodes reix graph
where
reix = reindex $ map fst $ graphNodes graph
nubNodes :: ASG ctx dom a -> ASG ctx dom a
nubNodes (ASG top nodes n) = ASG top nodes' n'
where
nodes' = nubBy ((==) `on` fst) nodes
n' = genericLength nodes'
liftSome2
:: (forall a b . ASTF (Node ctx :+: dom) a -> ASTF (Node ctx :+: dom) b -> c)
-> SomeAST (Node ctx :+: dom)
-> SomeAST (Node ctx :+: dom)
-> c
liftSome2 f (SomeAST a) (SomeAST b) = f a b
data SyntaxPF dom a
where
AppPF :: a -> a -> SyntaxPF dom a
NodePF :: NodeId -> a -> SyntaxPF dom a
DomPF :: dom b -> SyntaxPF dom a
instance Functor (SyntaxPF dom)
where
fmap f (AppPF g a) = AppPF (f g) (f a)
fmap f (NodePF n a) = NodePF n (f a)
fmap f (DomPF a) = DomPF a
foldGraph :: forall ctx dom a b
. (SyntaxPF dom b -> b)
-> ASG ctx dom a
-> (b, (Array NodeId b, [(NodeId,b)]))
foldGraph alg graph@(ASG top ns nn) = (g top, (arr,nodes))
where
nodes = [(n, g expr) | (n, SomeAST expr) <- ns]
arr = array (0, nn1) nodes
g :: ConsType c => AST (Node ctx :+: dom) c -> b
g (h :$: a) = alg $ AppPF (g h) (g a)
g (Symbol (InjectL (Node n)) ) = alg $ NodePF n (arr!n)
g (Symbol (InjectR a)) = alg $ DomPF a
inlineAll :: forall ctx dom a . Typeable a => ASG ctx dom a -> ASTF dom a
inlineAll (ASG top nodes n) = inline top
where
nodeMap = array (0, n1) nodes
inline :: forall b. (Typeable b, ConsType b) =>
AST (Node ctx :+: dom) b -> AST dom b
inline (f :$: a) = inline f :$: inline a
inline (Symbol (InjectL (Node n))) = case nodeMap ! n of
SomeAST a -> case gcast a of
Nothing -> error "inlineAll: type mismatch"
Just a -> inline a
inline (Symbol (InjectR a)) = Symbol a
nodeChildren :: ASG ctx dom a -> [(NodeId, [NodeId])]
nodeChildren = map (id *** fromDList) . snd . snd . foldGraph children
where
children :: SyntaxPF dom (DList NodeId) -> DList (NodeId)
children (AppPF ns1 ns2) = ns1 . ns2
children (NodePF n _) = single n
children _ = empty
occurrences :: ASG ctx dom a -> Array NodeId Int
occurrences graph
= count (0, numNodes graph 1)
$ concatMap snd
$ nodeChildren graph
inlineSingle :: forall ctx dom a . Typeable a => ASG ctx dom a -> ASG ctx dom a
inlineSingle graph@(ASG top nodes n) = ASG top' nodes' n'
where
nodeTab = array (0, n1) nodes
occs = occurrences graph
top' = inline top
nodes' = [(n, SomeAST (inline a)) | (n, SomeAST a) <- nodes, occs!n > 1]
n' = genericLength nodes'
inline :: forall b. (Typeable b, ConsType b) =>
AST (Node ctx :+: dom) b -> AST (Node ctx :+: dom) b
inline (f :$: a) = inline f :$: inline a
inline (Symbol (InjectL (Node n)))
| occs!n > 1 = Symbol (InjectL (Node n))
| otherwise = case nodeTab ! n of
SomeAST a -> case gcast a of
Nothing -> error "inlineSingle: type mismatch"
Just a -> inline a
inline (Symbol (InjectR a)) = Symbol (InjectR a)
hashNodes :: ExprEq dom =>
ASG ctx dom a -> (Array NodeId Hash, [(NodeId, Hash)])
hashNodes = snd . foldGraph hashNode
where
hashNode (AppPF h1 h2) = hashInt 0 `combine` h1 `combine` h2
hashNode (NodePF _ h) = h
hashNode (DomPF a) = hashInt 1 `combine` exprHash a
partitionNodes :: forall ctx dom a
. (Lambda ctx :<: dom, Variable ctx :<: dom, ExprEq dom)
=> ASG ctx dom a -> [[NodeId]]
partitionNodes graph = concatMap (fullPartition nodeEq) approxPartitioning
where
nTab = array (0, numNodes graph 1) (graphNodes graph)
(hTab,hashes) = hashNodes graph
approxPartitioning
= map (map fst)
$ groupBy ((==) `on` snd)
$ sortBy (compare `on` snd)
$ hashes
eqNode :: forall a b . ExprEq dom
=> AST (Node ctx :+: dom) a
-> AST (Node ctx :+: dom) b
-> Reader [(VarId,VarId)] Bool
eqNode (Symbol (InjectL (Node n1))) (Symbol (InjectL (Node n2)))
| n1 == n2 = return True
| hTab!n1 /= hTab!n2 = return False
| otherwise = case (nTab!n1, nTab!n2) of
(SomeAST a, SomeAST b) -> eqNodeAlpha a b
eqNode (Symbol (InjectR a)) (Symbol (InjectR b)) = return (exprEq a b)
eqNode _ _ = return False
eqNodeAlpha :: forall a b
. AST (Node ctx :+: dom) a
-> AST (Node ctx :+: dom) b
-> Reader [(VarId,VarId)] Bool
eqNodeAlpha a b = alphaEqM (Proxy::Proxy ctx) eqNode a b
nodeEq :: NodeId -> NodeId -> Bool
nodeEq n1 n2 = runReader (liftSome2 eqNodeAlpha (nTab!n1) (nTab!n2)) []
cse :: (Lambda ctx :<: dom, Variable ctx :<: dom, ExprEq dom) =>
ASG ctx dom a -> ASG ctx dom a
cse graph@(ASG top nodes n) = nubNodes $ reindexNodes (reixTab!) graph
where
parts = partitionNodes graph
reixTab = array (0,n1) [(n,p) | (part,p) <- parts `zip` [0..], n <- part]