{-# LANGUAGE CPP
, BangPatterns
, DataKinds
, EmptyCase
, ExistentialQuantification
, FlexibleContexts
, FlexibleInstances
, GADTs
, GeneralizedNewtypeDeriving
, KindSignatures
, MultiParamTypeClasses
, OverloadedStrings
, PolyKinds
, ScopedTypeVariables
, StandaloneDeriving
, TupleSections
, TypeFamilies
, TypeOperators
, UndecidableInstances
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Syntax.Hoist (hoist) where
import Control.Applicative (liftA2)
import Control.Monad.RWS hiding ((<>))
import qualified Data.Foldable as F
import qualified Data.Graph as G
import qualified Data.IntMap.Strict as IM
import qualified Data.List as L
import Data.Maybe (mapMaybe)
import Data.Number.Nat
import Data.Proxy (KProxy (..))
import qualified Data.Vector as V
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.ANF (isValue)
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.AST.Eq (alphaEq)
import Language.Hakaru.Syntax.Gensym
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing)
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup
#endif
data Entry (abt :: Hakaru -> *)
= forall (a :: Hakaru) . Entry
{ varDependencies :: !(VarSet (KindOf a))
, expression :: !(abt a)
, sing :: !(Sing a)
, bindings :: ![Variable a]
}
instance Show (Entry abt) where
show (Entry d _ _ b) = "Entry (" ++ show d ++ ") (" ++ show b ++ ")"
type HakaruProxy = ('KProxy :: KProxy Hakaru)
type LiveSet = VarSet HakaruProxy
type HakaruVar = SomeVariable HakaruProxy
newtype HoistM (abt :: [Hakaru] -> Hakaru -> *) a
= HoistM { runHoistM :: RWS LiveSet (ExpressionSet abt) Nat a }
deriving instance Functor (HoistM abt)
deriving instance (ABT Term abt) => Applicative (HoistM abt)
deriving instance (ABT Term abt) => Monad (HoistM abt)
deriving instance (ABT Term abt) => MonadState Nat (HoistM abt)
deriving instance (ABT Term abt) => MonadWriter (ExpressionSet abt) (HoistM abt)
deriving instance (ABT Term abt) => MonadReader LiveSet (HoistM abt)
newtype ExpressionSet (abt :: [Hakaru] -> Hakaru -> *)
= ExpressionSet [Entry (abt '[])]
mergeEntry :: (ABT Term abt) => Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[])
mergeEntry (Entry d e s1 b1) (Entry _ _ s2 b2) =
case jmEq1 s1 s2 of
Just Refl -> Entry d e s1 $ L.nub (b1 ++ b2)
Nothing -> error "cannot union mismatched entries"
entryEqual :: (ABT Term abt) => Entry (abt '[]) -> Entry (abt '[]) -> Bool
entryEqual Entry{varDependencies=d1,expression=e1,sing=s1}
Entry{varDependencies=d2,expression=e2,sing=s2} =
case (d1 == d2, jmEq1 s1 s2) of
(True , Just Refl) -> alphaEq e1 e2
_ -> False
unionEntrySet
:: forall abt
. (ABT Term abt)
=> ExpressionSet abt
-> ExpressionSet abt
-> ExpressionSet abt
unionEntrySet (ExpressionSet xs) (ExpressionSet ys) =
ExpressionSet . mapMaybe uniquify $ L.groupBy entryEqual (xs ++ ys)
where
uniquify :: [Entry (abt '[])] -> Maybe (Entry (abt '[]))
uniquify [] = Nothing
uniquify zs = Just $ L.foldl1' mergeEntry zs
intersectEntrySet
:: forall abt
. (ABT Term abt)
=> ExpressionSet abt
-> ExpressionSet abt
-> ExpressionSet abt
intersectEntrySet (ExpressionSet xs) (ExpressionSet ys) = ExpressionSet merged
where
merged :: [Entry (abt '[])]
merged = map (uncurry mergeEntry) . filter (uncurry entryEqual) $ liftA2 (,) xs ys
instance (ABT Term abt) => Semigroup (ExpressionSet abt) where
(<>) = unionEntrySet
instance (ABT Term abt) => Monoid (ExpressionSet abt) where
mempty = ExpressionSet []
#if !(MIN_VERSION_base(4,11,0))
mappend = (<>)
#endif
topSortEntries
:: forall abt
. [Entry (abt '[])]
-> [Entry (abt '[])]
topSortEntries entryList = map (entries V.!) $ G.topSort graph
where
entries :: V.Vector (Entry (abt '[]))
!entries = V.fromList entryList
getVIDs :: Entry (abt '[]) -> [Int]
getVIDs Entry{bindings=b} = map (fromNat . varID) b
assocBindingsTo :: IM.IntMap Int -> Int -> Entry (abt '[]) -> IM.IntMap Int
assocBindingsTo m n = L.foldl' (\acc v -> IM.insert v n acc) m . getVIDs
varMap :: IM.IntMap Int
!varMap = V.ifoldl' assocBindingsTo IM.empty entries
makeEdges :: Int -> Entry (abt '[]) -> [G.Edge]
makeEdges idx Entry{varDependencies=d} = map (, idx)
. mapMaybe (flip IM.lookup varMap)
$ varSetKeys d
vertices :: [G.Edge]
!vertices = V.foldr (++) [] $ V.imap makeEdges entries
graph :: G.Graph
!graph = G.buildG (0, V.length entries - 1) vertices
recordEntry
:: (ABT Term abt)
=> Variable a
-> abt '[] a
-> HoistM abt ()
recordEntry v abt = tell $ ExpressionSet [Entry (freeVars abt) abt (varType v) [v]]
execHoistM :: Nat -> HoistM abt a -> a
execHoistM counter act = a
where
hoisted = runHoistM act
(a, _, _) = runRWS hoisted emptyVarSet counter
toplevelEntry
:: Entry abt
-> Bool
toplevelEntry Entry{varDependencies=d} = sizeVarSet d == 0
captureEntries
:: (ABT Term abt)
=> HoistM abt a
-> HoistM abt (a, ExpressionSet abt)
captureEntries = censor (const mempty) . listen
hoist
:: (ABT Term abt)
=> abt '[] a
-> abt '[] a
hoist abt = execHoistM (nextFreeOrBind abt) $
captureEntries (hoist' abt) >>= uncurry (introduceToplevel emptyVarSet)
partitionEntrySet
:: (Entry (abt '[]) -> Bool)
-> ExpressionSet abt
-> (ExpressionSet abt, ExpressionSet abt)
partitionEntrySet p (ExpressionSet xs) = (ExpressionSet true, ExpressionSet false)
where
(true, false) = L.partition p xs
introduceToplevel
:: (ABT Term abt)
=> LiveSet
-> abt '[] a
-> ExpressionSet abt
-> HoistM abt (abt '[] a)
introduceToplevel avail abt entries = do
let (ExpressionSet toplevel, rest) = partitionEntrySet toplevelEntry entries
intro = concatMap getBoundVars toplevel ++ fromVarSet avail
wrapped <- introduceBindings intro abt rest
wrapExpr wrapped toplevel
bindVar
:: (ABT Term abt)
=> Variable (a :: Hakaru)
-> HoistM abt b
-> HoistM abt b
bindVar = local . insertVarSet
isolateBinder
:: (ABT Term abt)
=> Variable (a :: Hakaru)
-> HoistM abt b
-> HoistM abt (b, ExpressionSet abt)
isolateBinder v = captureEntries . bindVar v
hoist'
:: forall abt xs a . (ABT Term abt)
=> abt xs a
-> HoistM abt (abt xs a)
hoist' = start
where
insertMany :: [HakaruVar] -> LiveSet -> LiveSet
insertMany = flip $ L.foldl' (\ acc (SomeVariable v) -> insertVarSet v acc)
start :: forall ys b . abt ys b -> HoistM abt (abt ys b)
start = loop [] . viewABT
isolateBinders :: [HakaruVar] -> HoistM abt c -> HoistM abt (c, ExpressionSet abt)
isolateBinders xs = censor (const mempty) . listen . local (insertMany xs)
loop :: forall ys b
. [HakaruVar]
-> View (Term abt) ys b
-> HoistM abt (abt ys b)
loop _ (Var v) = return (var v)
loop [] (Syn s) = hoistTerm s
loop xs (Syn s) = do
(term, entries) <- isolateBinders xs (hoistTerm s)
introduceBindings xs term entries
loop xs (Bind v b) = bind v <$> loop (SomeVariable v : xs) b
getBoundVars :: Entry x -> [HakaruVar]
getBoundVars Entry{bindings=b} = fmap SomeVariable b
wrapExpr
:: forall abt b . (ABT Term abt)
=> abt '[] b
-> [Entry (abt '[])]
-> HoistM abt (abt '[] b)
wrapExpr = F.foldrM wrap
where
mklet :: abt '[] a -> Variable a -> abt '[] b -> abt '[] b
mklet e v b =
case viewABT b of
Var v' | Just Refl <- varEq v v' -> e
_ -> syn (Let_ :$ e :* bind v b :* End)
wrap :: Entry (abt '[]) -> abt '[] b -> HoistM abt (abt '[] b)
wrap Entry{expression=e,bindings=[]} acc = do
tmp <- varForExpr e
return $ mklet e tmp acc
wrap Entry{expression=e,bindings=(x:xs)} acc = do
let rhs = var x
body = foldr (mklet rhs) acc xs
return $ mklet e x body
introduceBindings
:: forall (a :: Hakaru) abt
. (ABT Term abt)
=> [HakaruVar]
-> abt '[] a
-> ExpressionSet abt
-> HoistM abt (abt '[] a)
introduceBindings newVars body (ExpressionSet entries) = do
tell (ExpressionSet leftOver)
wrapExpr body (topSortEntries resultEntries)
where
resultEntries, leftOver :: [Entry (abt '[])]
(resultEntries, leftOver) = loop entries newVars
introducedBy
:: forall (b :: Hakaru)
. Variable b
-> Entry (abt '[])
-> Bool
introducedBy v Entry{varDependencies=deps} = memberVarSet v deps
loop
:: [Entry (abt '[])]
-> [HakaruVar]
-> ([Entry (abt '[])], [Entry (abt '[])])
loop exprs [] = ([], exprs)
loop exprs (SomeVariable v : xs) = (introduced ++ intro, acc)
where
~(intro, acc) = loop rest (xs ++ vars)
vars = concatMap getBoundVars introduced
(introduced, rest) = L.partition (introducedBy v) exprs
hoistTerm
:: forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *)
. (ABT Term abt)
=> Term abt a
-> HoistM abt (abt '[] a)
hoistTerm (Let_ :$ rhs :* body :* End) =
caseBind body $ \ v body' -> do
rhs' <- hoist' rhs
recordEntry v rhs'
bindVar v (hoist' body')
hoistTerm (Lam_ :$ body :* End) =
caseBind body $ \ v body' -> do
available <- fmap (insertVarSet v) ask
(body'', entries) <- isolateBinder v (hoist' body')
finalized <- introduceToplevel available body'' entries
return $ syn (Lam_ :$ bind v finalized :* End)
hoistTerm term = do
result <- syn <$> traverse21 hoist' term
if isValue result
then return result
else do fresh <- varForExpr result
recordEntry fresh result
return (var fresh)