{-# LANGUAGE CPP , BangPatterns , DataKinds , EmptyCase , ExistentialQuantification , FlexibleContexts , FlexibleInstances , GADTs , GeneralizedNewtypeDeriving , KindSignatures , MultiParamTypeClasses , OverloadedStrings , PolyKinds , ScopedTypeVariables , StandaloneDeriving , TupleSections , TypeFamilies , TypeOperators , UndecidableInstances #-} {-# OPTIONS_GHC -Wall -fwarn-tabs #-} ---------------------------------------------------------------- -- 2017.02.01 -- | -- Module : Language.Hakaru.Syntax.Hoist -- Copyright : Copyright (c) 2016 the Hakaru team -- License : BSD3 -- Maintainer : -- Stability : experimental -- Portability : GHC-only -- -- Hoist expressions to the point where their data dependencies are met. -- This pass duplicates *a lot* of work and relies on a the CSE and pruning -- passes to cleanup the junk (most of which is trivial to do, but we don't know -- what is junk until after CSE has occured). -- -- NOTE: This pass assumes globally unique variable ids, as two subterms may -- otherwise bind the same variable. Those variables would potentially shadow -- eachother if hoisted upward to a common scope. -- ---------------------------------------------------------------- module Language.Hakaru.Syntax.Hoist (hoist) where import Control.Applicative (liftA2) import Control.Monad.RWS import qualified Data.Foldable as F import qualified Data.Graph as G import qualified Data.IntMap.Strict as IM import qualified Data.List as L import Data.Maybe (mapMaybe) import Data.Number.Nat import Data.Proxy (KProxy (..)) import qualified Data.Vector as V import Language.Hakaru.Syntax.ABT import Language.Hakaru.Syntax.ANF (isValue) import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.AST.Eq (alphaEq) import Language.Hakaru.Syntax.Gensym import Language.Hakaru.Syntax.IClasses import Language.Hakaru.Types.DataKind import Language.Hakaru.Types.Sing (Sing) #if __GLASGOW_HASKELL__ < 710 import Control.Applicative #endif data Entry (abt :: Hakaru -> *) = forall (a :: Hakaru) . Entry { varDependencies :: !(VarSet (KindOf a)) , expression :: !(abt a) -- The type of the expression, to allow for easy comparison of types. -- The typeOf operator is technically O(n) in the size of the expresion -- and we may need to call it many times. , sing :: !(Sing a) , bindings :: ![Variable a] } instance Show (Entry abt) where show (Entry d _ _ b) = "Entry (" ++ show d ++ ") (" ++ show b ++ ")" type HakaruProxy = ('KProxy :: KProxy Hakaru) type LiveSet = VarSet HakaruProxy type HakaruVar = SomeVariable HakaruProxy -- The @HoistM@ monad makes use of three monadic layers to propagate information -- both downwards to the leaves and upwards to the root node of the AST. -- -- The Writer layer propagates the live expressions which may be hoisted (i.e. -- all their data dependencies are currently filled) from each subexpression to -- their parents. -- -- The Reader layer propagates the currently bound variables which will be used -- to decide when to introduce new bindings. -- -- The State layer is just to provide a counter in order to gensym new -- variables, since the process of adding new bindings is a little tricky. -- What we want is to fully duplicate bindings without altering the original -- variable identifiers. To do so, all original variable names are preserved and -- new variables are added outside the range of existing variables. newtype HoistM (abt :: [Hakaru] -> Hakaru -> *) a = HoistM { runHoistM :: RWS LiveSet (ExpressionSet abt) Nat a } deriving instance Functor (HoistM abt) deriving instance (ABT Term abt) => Applicative (HoistM abt) deriving instance (ABT Term abt) => Monad (HoistM abt) deriving instance (ABT Term abt) => MonadState Nat (HoistM abt) deriving instance (ABT Term abt) => MonadWriter (ExpressionSet abt) (HoistM abt) deriving instance (ABT Term abt) => MonadReader LiveSet (HoistM abt) newtype ExpressionSet (abt :: [Hakaru] -> Hakaru -> *) = ExpressionSet [Entry (abt '[])] mergeEntry :: (ABT Term abt) => Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[]) mergeEntry (Entry d e s1 b1) (Entry _ _ s2 b2) = case jmEq1 s1 s2 of Just Refl -> Entry d e s1 $ L.nub (b1 ++ b2) Nothing -> error "cannot union mismatched entries" entryEqual :: (ABT Term abt) => Entry (abt '[]) -> Entry (abt '[]) -> Bool entryEqual Entry{varDependencies=d1,expression=e1,sing=s1} Entry{varDependencies=d2,expression=e2,sing=s2} = case (d1 == d2, jmEq1 s1 s2) of (True , Just Refl) -> alphaEq e1 e2 _ -> False unionEntrySet :: forall abt . (ABT Term abt) => ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt unionEntrySet (ExpressionSet xs) (ExpressionSet ys) = ExpressionSet . mapMaybe uniquify $ L.groupBy entryEqual (xs ++ ys) where uniquify :: [Entry (abt '[])] -> Maybe (Entry (abt '[])) uniquify [] = Nothing uniquify zs = Just $ L.foldl1' mergeEntry zs intersectEntrySet :: forall abt . (ABT Term abt) => ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt intersectEntrySet (ExpressionSet xs) (ExpressionSet ys) = ExpressionSet merged where merged :: [Entry (abt '[])] merged = map (uncurry mergeEntry) . filter (uncurry entryEqual) $ liftA2 (,) xs ys -- The general case for generating the entry set for a term is to simply union -- the sets for all the subterms, so we choose union as our monoidal operation -- for the Writer monad. instance (ABT Term abt) => Monoid (ExpressionSet abt) where mempty = ExpressionSet [] mappend = unionEntrySet -- Given a list of entries to introduce, order them so that their data -- data dependencies are satisified. topSortEntries :: forall abt . [Entry (abt '[])] -> [Entry (abt '[])] topSortEntries entryList = map (entries V.!) $ G.topSort graph where entries :: V.Vector (Entry (abt '[])) !entries = V.fromList entryList -- The graph is represented as dependencies between entries, where an entry -- (a) depends on entry (b) if (b) introduces a variable which (a) depends -- on. getVIDs :: Entry (abt '[]) -> [Int] getVIDs Entry{bindings=b} = map (fromNat . varID) b -- Associates all variables introduced by an entry to the entry itself. -- A given entry may introduce multiple bindings, since an entry stores all -- α-equivalent variable definitions. assocBindingsTo :: IM.IntMap Int -> Int -> Entry (abt '[]) -> IM.IntMap Int assocBindingsTo m n = L.foldl' (\acc v -> IM.insert v n acc) m . getVIDs -- Mapping from variable IDs to their corresponding entries varMap :: IM.IntMap Int !varMap = V.ifoldl' assocBindingsTo IM.empty entries -- Create an edge from each dependency to the variable makeEdges :: Int -> Entry (abt '[]) -> [G.Edge] makeEdges idx Entry{varDependencies=d} = map (, idx) . mapMaybe (flip IM.lookup varMap) $ varSetKeys d -- Collect all the verticies to build the full graph vertices :: [G.Edge] !vertices = V.foldr (++) [] $ V.imap makeEdges entries -- The full graph structure to be topologically sorted graph :: G.Graph !graph = G.buildG (0, V.length entries - 1) vertices recordEntry :: (ABT Term abt) => Variable a -> abt '[] a -> HoistM abt () recordEntry v abt = tell $ ExpressionSet [Entry (freeVars abt) abt (varType v) [v]] execHoistM :: Nat -> HoistM abt a -> a execHoistM counter act = a where hoisted = runHoistM act (a, _, _) = runRWS hoisted emptyVarSet counter -- | An expression is considered "toplevel" if it can be hoisted outside all -- binders. This means that the expression has no data dependencies. toplevelEntry :: Entry abt -> Bool toplevelEntry Entry{varDependencies=d} = sizeVarSet d == 0 captureEntries :: (ABT Term abt) => HoistM abt a -> HoistM abt (a, ExpressionSet abt) captureEntries = censor (const mempty) . listen hoist :: (ABT Term abt) => abt '[] a -> abt '[] a hoist abt = execHoistM (nextFreeOrBind abt) $ captureEntries (hoist' abt) >>= uncurry (introduceToplevel emptyVarSet) partitionEntrySet :: (Entry (abt '[]) -> Bool) -> ExpressionSet abt -> (ExpressionSet abt, ExpressionSet abt) partitionEntrySet p (ExpressionSet xs) = (ExpressionSet true, ExpressionSet false) where (true, false) = L.partition p xs introduceToplevel :: (ABT Term abt) => LiveSet -> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a) introduceToplevel avail abt entries = do -- After transforming the given ast, we need to introduce all the toplevel -- bindings (i.e. bindings with no data dependencies), most of which should be -- eliminated by constant propagation. let (ExpressionSet toplevel, rest) = partitionEntrySet toplevelEntry entries intro = concatMap getBoundVars toplevel ++ fromVarSet avail -- First we wrap the now AST in the all terms which depdend on top level -- definitions wrapped <- introduceBindings intro abt rest -- Then wrap the result in the toplevel definitions wrapExpr wrapped toplevel bindVar :: (ABT Term abt) => Variable (a :: Hakaru) -> HoistM abt b -> HoistM abt b bindVar = local . insertVarSet isolateBinder :: (ABT Term abt) => Variable (a :: Hakaru) -> HoistM abt b -> HoistM abt (b, ExpressionSet abt) isolateBinder v = captureEntries . bindVar v hoist' :: forall abt xs a . (ABT Term abt) => abt xs a -> HoistM abt (abt xs a) hoist' = start where insertMany :: [HakaruVar] -> LiveSet -> LiveSet insertMany = flip $ L.foldl' (\ acc (SomeVariable v) -> insertVarSet v acc) start :: forall ys b . abt ys b -> HoistM abt (abt ys b) start = loop [] . viewABT isolateBinders :: [HakaruVar] -> HoistM abt c -> HoistM abt (c, ExpressionSet abt) isolateBinders xs = censor (const mempty) . listen . local (insertMany xs) -- @loop@ takes 2 parameters. -- -- 1. The list of variables bound so far -- 2. The current term we are recurring over -- -- We add a value to the first every time we hit a @Bind@ term, and when -- a @Syn@ term is finally reached, we introduce any hoisted values whose -- data dependencies are satisified by these new variables. loop :: forall ys b . [HakaruVar] -> View (Term abt) ys b -> HoistM abt (abt ys b) loop _ (Var v) = return (var v) -- This case is not needed, but we can avoid performing the expensive work -- of calling introduceBindings in the case were we won't be performing any -- work. loop [] (Syn s) = hoistTerm s loop xs (Syn s) = do (term, entries) <- isolateBinders xs (hoistTerm s) introduceBindings xs term entries loop xs (Bind v b) = bind v <$> loop (SomeVariable v : xs) b getBoundVars :: Entry x -> [HakaruVar] getBoundVars Entry{bindings=b} = fmap SomeVariable b wrapExpr :: forall abt b . (ABT Term abt) => abt '[] b -> [Entry (abt '[])] -> HoistM abt (abt '[] b) wrapExpr = F.foldrM wrap where mklet :: abt '[] a -> Variable a -> abt '[] b -> abt '[] b mklet e v b = case viewABT b of Var v' | Just Refl <- varEq v v' -> e _ -> syn (Let_ :$ e :* bind v b :* End) -- Binds the Entry's expression to a fresh variable and rebinds any other -- variable uses to the fresh variable. wrap :: Entry (abt '[]) -> abt '[] b -> HoistM abt (abt '[] b) wrap Entry{expression=e,bindings=[]} acc = do tmp <- varForExpr e return $ mklet e tmp acc wrap Entry{expression=e,bindings=(x:xs)} acc = do let rhs = var x body = foldr (mklet rhs) acc xs return $ mklet e x body -- This will introduce all binders which must be introduced by binding the -- @newVars@ set. As a side effect, the remaining entries are written into the -- Writer layer of the stack. introduceBindings :: forall (a :: Hakaru) abt . (ABT Term abt) => [HakaruVar] -> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a) introduceBindings newVars body (ExpressionSet entries) = do tell (ExpressionSet leftOver) wrapExpr body (topSortEntries resultEntries) where resultEntries, leftOver :: [Entry (abt '[])] (resultEntries, leftOver) = loop entries newVars introducedBy :: forall (b :: Hakaru) . Variable b -> Entry (abt '[]) -> Bool introducedBy v Entry{varDependencies=deps} = memberVarSet v deps loop :: [Entry (abt '[])] -> [HakaruVar] -> ([Entry (abt '[])], [Entry (abt '[])]) loop exprs [] = ([], exprs) loop exprs (SomeVariable v : xs) = (introduced ++ intro, acc) where ~(intro, acc) = loop rest (xs ++ vars) vars = concatMap getBoundVars introduced (introduced, rest) = L.partition (introducedBy v) exprs -- Contrary to the other binding forms, let expressions are killed by the -- hoisting pass. Their RHSs are floated upward in the AST and re-introduced -- where their data dependencies are fulfilled. Thus, the result of hoisting -- a let expression is just the hoisted body. hoistTerm :: forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *) . (ABT Term abt) => Term abt a -> HoistM abt (abt '[] a) hoistTerm (Let_ :$ rhs :* body :* End) = caseBind body $ \ v body' -> do rhs' <- hoist' rhs recordEntry v rhs' bindVar v (hoist' body') hoistTerm (Lam_ :$ body :* End) = caseBind body $ \ v body' -> do available <- fmap (insertVarSet v) ask (body'', entries) <- isolateBinder v (hoist' body') finalized <- introduceToplevel available body'' entries return $ syn (Lam_ :$ bind v finalized :* End) hoistTerm term = do result <- syn <$> traverse21 hoist' term if isValue result then return result else do fresh <- varForExpr result recordEntry fresh result return (var fresh)