module Language.Syntactic.Sharing.CodeMotion2
( codeMotion2
, reifySmart2
) where
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Set as Set
import qualified Data.Map as Map
import Data.Array
import Data.List
import Data.Maybe (fromJust,fromMaybe)
import Data.Function
import Data.Hash
import Data.Typeable
import Language.Syntactic
import Language.Syntactic.Constructs.Binding
import Language.Syntactic.Constructs.Binding.HigherOrder
import Language.Syntactic.Sharing.SimpleCodeMotion
isVariable :: PrjDict dom -> ASTF (NodeDomain dom) a -> Bool
isVariable pd (Sym (C' (InjR (prjVariable pd -> Just _)))) = True
isVariable pd _ = False
newtype NodeId = NodeId { nodeInteger :: Integer }
deriving (Eq, Ord, Num, Real, Integral, Enum, Ix)
instance Show NodeId
where
show (NodeId i) = show i
showNode :: NodeId -> String
showNode n = "node:" ++ show n
instance AlphaEq dom dom dom env => AlphaEq Node Node dom env
where
alphaEqSym (Node n1) _ (Node n2) _ = return (n1 == n2)
instance Constrained Node
where
type Sat Node = Top
exprDict _ = Dict
instance Equality Node
where
equal (Node n1) (Node n2) = error "can't compare nodes for equality"
exprHash (Node n) = hash (nodeInteger n)
data Node a
where
Node :: NodeId -> Node (Full a)
type NodeDomain dom = (Node :+: dom) :|| Sat dom
data Gathered dom = Gathered
{ geExpr :: ASTSAT (NodeDomain dom)
, geNodeId :: NodeId
, geInfo :: [(NodeId, GatherInfo)]
}
data GatherInfo = GatherInfo
{ giCount :: Int
, giScopes :: Set.Set VarId
}
deriving Show
newtype GatherSet dom = GatherSet {unGatherSet :: Map.Map Hash [Gathered dom]}
lookupGS :: forall dom a
. ( AlphaEq dom dom (NodeDomain dom) [(VarId,VarId)]
, Equality dom)
=> GatherSet dom
-> ASTF (NodeDomain dom) a
-> Maybe (Gathered dom)
lookupGS (GatherSet m) e = Map.lookup (exprHash e) m >>= look
where
look :: [Gathered dom] -> Maybe (Gathered dom)
look [] = Nothing
look (g:gs) | ASTB ge <- geExpr g
, alphaEq ge e
= Just g
look (g:gs) = look gs
updateGS :: forall dom
. ( AlphaEq dom dom (NodeDomain dom) [(VarId,VarId)]
, Equality dom)
=> GatherSet dom
-> Gathered dom
-> GatherSet dom
updateGS (GatherSet m) g
| ASTB ge <- geExpr g
= GatherSet $ Map.alter alt (exprHash ge) m
where
alt :: Maybe [Gathered dom] -> Maybe [Gathered dom]
alt (Just gs) = Just $ ins gs
alt Nothing = Just [g]
ins :: [Gathered dom] -> [Gathered dom]
ins [] = [g]
ins (x:xs) | ASTB xe <- geExpr x
, ASTB ge <- geExpr g
, alphaEq xe ge
= g : xs
ins (x:xs) = x : ins xs
emptyGS = GatherSet $ Map.empty
toListGS (GatherSet m) = concatMap snd (Map.toList m)
type RebuildEnv dom =
( Map.Map NodeId (ASTSAT dom)
, Set.Set VarId
, [NodeId]
)
type RebuildMonad dom a = ReaderT (RebuildEnv dom) (State VarId) a
runRebuild :: RebuildMonad dom a -> State VarId a
runRebuild m = runReaderT m (Map.empty, Set.empty, [0])
addBoundVar :: VarId -> RebuildMonad dom a -> RebuildMonad dom a
addBoundVar v = local (\(nm,vs,sn) -> (nm, Set.insert v vs, sn))
getBoundVars :: RebuildMonad dom (Set.Set VarId)
getBoundVars = do
(_,bv,_) <- ask
return bv
addNodeExpr :: NodeId -> ASTSAT dom -> RebuildMonad dom a -> RebuildMonad dom a
addNodeExpr n a = local (\(nm,vs,sn) -> (Map.insert n a nm, vs, sn))
getNodeExprMap :: RebuildMonad dom (Map.Map NodeId (ASTSAT dom))
getNodeExprMap = do
(nm,_,_) <- ask
return nm
addSeenNode :: NodeId -> RebuildMonad dom a -> RebuildMonad dom a
addSeenNode n = local (\(nm,vs,sn) -> (nm, vs, n:sn))
getSeenNodes :: RebuildMonad dom [NodeId]
getSeenNodes = do
(_,_,sn) <- ask
return sn
codeMotion2 :: forall dom a
. ( ConstrainedBy dom Typeable
, AlphaEq dom dom (NodeDomain dom) [(VarId,VarId)]
, Equality dom
)
=> (forall c. ASTF dom c -> Bool)
-> PrjDict dom
-> MkInjDict dom
-> ASTF dom a
-> State VarId (ASTF dom a)
codeMotion2 hoistOver pd mkId a = do
let (gm, a') = gather hoistOver pd a
rebuild pd mkId (toListGS gm) a'
type ShareInfo dom = (NodeId, ASTSAT (NodeDomain dom), GatherInfo)
data ShareMaybe dom a
where
Share :: Sat dom b => VarId -> InjDict dom b a -> ASTF dom b -> ShareMaybe dom a
Not :: Sat dom b => ASTF dom b -> ShareMaybe dom a
rebuild :: forall dom a
. ( ConstrainedBy dom Typeable
, AlphaEq dom dom (NodeDomain dom) [(VarId,VarId)]
, Equality dom
)
=> PrjDict dom
-> MkInjDict dom
-> [Gathered dom]
-> ASTF (NodeDomain dom) a
-> State VarId (ASTF dom a)
rebuild pd mkId gs (Sym (C' (InjL _))) = error ""
rebuild pd mkId gs a = runRebuild $ rebuild' 0 a
where
nodes :: Array NodeId (Gathered dom)
nodes = array
(1, maximum (0:(Prelude.map geNodeId gs)))
(zip (Prelude.map geNodeId gs) gs)
nodeExpr :: NodeId -> ASTSAT (NodeDomain dom)
nodeExpr n = geExpr (nodes ! n)
freeVars :: Array NodeId (Set.Set VarId)
freeVars = nodesFreeVars pd nodes
nodeDeps :: Array NodeId (Set.Set NodeId)
nodeDeps = nodeDepsArray
where
nodeDepsArray = array (0,snd (bounds nodes)) [(n, nodeDepsNode n) | n <- 0 : indices nodes]
nodeDepsNode :: NodeId -> Set.Set NodeId
nodeDepsNode 0 = nodeDepsExp a
nodeDepsNode n = liftASTB nodeDepsExp $ geExpr (nodes ! n)
nodeDepsExp :: AST (NodeDomain dom) b -> Set.Set NodeId
nodeDepsExp (Sym (C' (InjR _))) = Set.empty
nodeDepsExp (Sym (C' (InjL (Node n)))) = Set.insert n (nodeDepsArray ! n)
nodeDepsExp (s :$ b) = Set.union (nodeDepsExp s) (nodeDepsExp b)
nodesToConsider :: (NodeId -> Bool) -> Set.Set VarId -> [NodeId] -> [ShareInfo dom]
nodesToConsider f bv seenNodes = concatMap mkShareInfo (assocs nodes)
where
maximumBy' f [] = []
maximumBy' f xs = [maximumBy f xs]
mkShareInfo (n,g) = map snd $ maximumBy' (compare `on` fst) $ map (\(Just i,x) -> (i,x)) $ filter ((/=Nothing) . fst)
[ (elemIndex il seenNodes, (n, geExpr g, gi))
| (il,gi) <- geInfo g
, Set.null (freeVars ! n `Set.difference` bv)
, f n
]
rebuild' :: forall b
. NodeId
-> ASTF (NodeDomain dom) b
-> RebuildMonad dom (ASTF dom b)
rebuild' n (Sym (C' (InjR lam)) :$ ns@(Sym (C' (InjL (Node nb)))))
| Just v <- prjLambda pd lam
= case geExpr (nodes ! nb) of
ASTB a
| Dict <- exprDictSub pTypeable ns
, Dict <- exprDictSub pTypeable a
-> case gcast a of
Nothing -> error "rebuild: type mistmatch"
Just a -> do
a' <- addBoundVar v $ addSeenNode n $ rebuild' nb a
return (Sym lam :$ a')
rebuild' n (Sym (C' (InjR s))) = return $ Sym s
rebuild' n a = addSeenNode n $ shareExprsIn n a
shareExprsIn :: forall b
. NodeId
-> ASTF (NodeDomain dom) b
-> RebuildMonad dom (ASTF dom b)
shareExprsIn n a = do
bv <- getBoundVars
seenNodes <- getSeenNodes
nodeMap <- getNodeExprMap
let considered = nodesToConsider (\n' -> n' /= n && not (Map.member n' nodeMap) && Set.member n' (nodeDeps ! n)) bv seenNodes
let sorted = sortBy (compare `on` (\(n,_,_) -> n)) considered
shareEm sorted a
shareEm
:: [ShareInfo dom]
-> ASTF (NodeDomain dom) b
-> RebuildMonad dom (ASTF dom b)
shareEm [] a = fixNodeExprSub a
shareEm ((n, be@(ASTB b), gi) : sis) a = do
b' <- rebuild' n b
bv <- getBoundVars
case mkId b' (inlineAll nodeExpr a) of
Just id | heuristic bv gi a -> do
v <- get; put (v+1)
a' <- addNodeExpr n (ASTB (Sym (injVariable id v))) $ shareEm sis a
return $ Sym (injLet id) :$ b' :$ (Sym (injLambda id v) :$ a')
_ -> do
a' <- addNodeExpr n (ASTB b') $ shareEm sis a
return a'
fixNodeExprSub :: forall b
. ( ConstrainedBy dom Typeable
, AlphaEq dom dom (NodeDomain dom) [(VarId,VarId)]
, Equality dom
)
=> AST (NodeDomain dom) b
-> RebuildMonad dom (AST dom b)
fixNodeExprSub (Sym (C' (InjR s))) = return (Sym s)
fixNodeExprSub (s :$ b) = do
b' <- fixNodeExpr b
s' <- fixNodeExprSub s
return (s' :$ b')
fixNodeExpr :: forall b . ASTF (NodeDomain dom) b -> RebuildMonad dom (ASTF dom b)
fixNodeExpr (ns@(Sym (C' (InjL (Node n))))) = do
nodeMap <- getNodeExprMap
let a = lookNode nodeMap
return a
where
lookNode nodeMap = case Map.lookup n nodeMap of
Just (ASTB a)
| Dict <- exprDictSub pTypeable ns
, Dict <- exprDictSub pTypeable a
-> case gcast a of
Nothing -> error "rebuild: type mismatch"
Just a -> a
Nothing -> error "rebuild: lost node"
heuristic :: Set.Set VarId -> GatherInfo -> ASTF (NodeDomain dom) b -> Bool
heuristic bv gi b = not (isVariable pd b) && (giCount gi > 1 || not (Set.null (giScopes gi `Set.difference` bv)))
nodesFreeVars :: forall dom
. PrjDict dom
-> Array NodeId (Gathered dom)
-> Array NodeId (Set.Set VarId)
nodesFreeVars pd nodes = freeVars
where
freeVars = array (bounds nodes) [(n, freeVarsNode n) | n <- indices nodes]
freeVarsNode :: NodeId -> Set.Set VarId
freeVarsNode n = liftASTB freeVarsExp $ geExpr (nodes ! n)
freeVarsExp :: AST (NodeDomain dom) a -> Set.Set VarId
freeVarsExp (Sym (C' (InjR var))) | Just v <- prjVariable pd var = Set.singleton v
freeVarsExp (Sym (C' (InjR lam)) :$ b) | Just v <- prjLambda pd lam = Set.delete v (freeVarsExp b)
freeVarsExp (Sym (C' (InjR _))) = Set.empty
freeVarsExp (Sym (C' (InjL (Node n)))) = freeVars ! n
freeVarsExp (s :$ b) = Set.union (freeVarsExp s) (freeVarsExp b)
inlineAll :: forall dom a
. ConstrainedBy dom Typeable
=> (NodeId -> ASTSAT (NodeDomain dom))
-> ASTF (NodeDomain dom) a
-> ASTF dom a
inlineAll nodes a = go a
where
go :: AST (NodeDomain dom) sig -> AST dom sig
go (s :$ a) = go s :$ go a
go (Sym (C' (InjR s))) = Sym s
go s@(Sym (C' (InjL (Node n)))) = case nodes n of
ASTB a
| Dict <- exprDictSub pTypeable s
, Dict <- exprDictSub pTypeable a
-> case gcast a of
Nothing -> error "inlineAll: type mismatch"
Just a -> go a
type GatherEnv =
( [NodeId]
, Set.Set VarId
)
type GatherState dom =
( GatherSet dom
, NodeId
)
type GatherMonad dom a = ReaderT GatherEnv (State (GatherState dom)) a
runGather :: GatherSet dom -> GatherMonad dom a -> (GatherSet dom, a)
runGather s gather = (gm,a)
where
(a,(gm,n')) = runState (runReaderT gather ([0], Set.empty)) (s,1)
getInnerLimit :: GatherMonad dom NodeId
getInnerLimit = liftM (head . fst) ask
getScope :: GatherMonad dom (Set.Set VarId)
getScope = liftM snd ask
addInnerLimit :: NodeId -> GatherMonad dom a -> GatherMonad dom a
addInnerLimit n = local (\(ns,vs) -> (n:ns,vs))
addScopeVar :: VarId -> GatherMonad dom a -> GatherMonad dom a
addScopeVar v = local (\(ns,vs) -> (ns, Set.insert v vs ))
gather :: forall dom a
. ( ConstrainedBy dom Typeable
, AlphaEq dom dom (NodeDomain dom) [(VarId,VarId)]
, Equality dom
)
=> (forall c. ASTF dom c -> Bool)
-> PrjDict dom
-> ASTF dom a
-> (GatherSet dom, ASTF (NodeDomain dom) a)
gather hoistOver pd a@(Sym s) | Dict <- exprDict a = (emptyGS, Sym (C' (InjR s)))
gather hoistOver pd a | Dict <- exprDict a
= runGather emptyGS (gatherRec (hoistOver a) a)
where
gather' :: Bool -> ASTF dom b -> GatherMonad dom (ASTF (NodeDomain dom) b)
gather' h a | Dict <- exprDict a = do
(a',n) <-
mfix (\(~(a',n)) -> do
a' <- addInnerLimitIf (not h) n $ gatherRec (hoistOver a) a
n <- recordExpr a'
return (a',n)
)
return $ Sym $ C' $ InjL $ Node n
addInnerLimitIf True n m = addInnerLimit n m
addInnerLimitIf _ n m = m
gatherRec
:: (Sat dom (DenResult b))
=> Bool
-> AST dom b
-> GatherMonad dom (AST (NodeDomain dom) b)
gatherRec h (Sym lam :$ b) | Just v <- prjLambda pd lam = do
b' <- addScopeVar v $ gather' h b
return ((Sym (C' (InjR lam))) :$ b')
gatherRec h (Sym s) = return $ Sym $ C' $ InjR s
gatherRec h (s :$ b) = do
b' <- gather' h b
s' <- gatherRec h s
return (s' :$ b')
recordExpr :: ASTF (NodeDomain dom) b -> GatherMonad dom NodeId
recordExpr a | Dict <- exprDict a = do
(s,n) <- get
innerLimit <- getInnerLimit
scope <- getScope
case lookupGS s a of
Just ge -> do
let ge' = ge { geInfo = updateInfo scope (geInfo ge) innerLimit }
put (updateGS s ge', n)
return (geNodeId ge)
Nothing -> do
let ge = Gathered { geExpr = ASTB a , geNodeId = n , geInfo = [(innerLimit, GatherInfo { giCount = 1 , giScopes = scope })] }
put (updateGS s ge, n+1)
return n
updateInfo :: Set.Set VarId -> [(NodeId, GatherInfo)] -> NodeId -> [(NodeId, GatherInfo)]
updateInfo scope [] n = [(n, GatherInfo { giCount = 1 , giScopes = scope })]
updateInfo scope ((n,gi):xs) n' | n == n' = (n, gi') : xs
where
gi' = gi { giCount = giCount gi + 1 , giScopes = Set.union (giScopes gi) scope }
updateInfo scope (x:xs) n' = x : updateInfo scope xs n'
reifySmart2 :: forall dom p pVar a
. ( AlphaEq dom dom (NodeDomain (FODomain dom p pVar)) [(VarId,VarId)]
, Equality dom
, Syntactic a
, Domain a ~ HODomain dom p pVar
, p :< Typeable
)
=> (forall c. ASTF (FODomain dom p pVar) c -> Bool)
-> MkInjDict (FODomain dom p pVar)
-> a
-> ASTF (FODomain dom p pVar) (Internal a)
reifySmart2 hoistOver mkId = flip evalState 0 . (codeMotion2 hoistOver prjDictFO mkId <=< reifyM . desugar)