module Language.HERMIT.Primitive.Unfold
( externals
, cleanupUnfoldR
, rememberR
, showStashT
, unfoldR
, unfoldPredR
, unfoldNameR
, unfoldAnyR
, unfoldSaturatedR
, unfoldStashR
, specializeR
) where
import GhcPlugins hiding (empty)
import Control.Applicative
import Control.Arrow
import Control.Monad
import qualified Data.Map as Map
import qualified Language.Haskell.TH as TH
import Language.HERMIT.PrettyPrinter.Common (DocH, PrettyH, TranslateDocH(..))
import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.GHC hiding (externals)
import Language.HERMIT.Primitive.Inline hiding (externals)
import Language.HERMIT.Core
import Language.HERMIT.Kure
import Language.HERMIT.Monad
import Language.HERMIT.External
import Language.HERMIT.GHC
import Prelude hiding (exp)
import qualified Text.PrettyPrint.MarkedHughesPJ as PP
externals :: [External]
externals =
[ external "cleanup-unfold" (promoteExprR cleanupUnfoldR :: RewriteH Core)
[ "Clean up immediately nested fully-applied lambdas, from the bottom up" ] .+ Deep
, external "remember" rememberR
[ "Remember the current binding, allowing it to be folded/unfolded in the future." ] .+ Context
, external "unfold" (promoteExprR . unfoldStashR)
[ "Unfold a remembered definition." ] .+ Deep .+ Context
, external "unfold" (promoteExprR unfoldR :: RewriteH Core)
[ "In application f x y z, unfold f." ] .+ Deep .+ Context
, external "unfold" (promoteExprR . unfoldNameR :: TH.Name -> RewriteH Core)
[ "Inline a definition, and apply the arguments; traditional unfold" ] .+ Deep .+ Context
, external "unfold-saturated" (promoteExprR unfoldSaturatedR :: RewriteH Core)
[ "Unfold a definition only if the function is fulled applied." ] .+ Deep .+ Context
, external "specialize" (promoteExprR specializeR :: RewriteH Core)
[ "Specialize an application to its type and coercion arguments." ] .+ Deep .+ Context
, external "unfold-rule" ((\ nm -> promoteExprR (rule nm >>> cleanupUnfoldR)) :: String -> RewriteH Core)
[ "Apply a named GHC rule" ] .+ Deep .+ Context
, external "show-remembered" (TranslateDocH showStashT :: TranslateDocH Core)
[ "Display all remembered definitions." ]
]
cleanupUnfoldR :: RewriteH CoreExpr
cleanupUnfoldR = do
(f, args) <- callT <+ (idR >>> arr (,[]))
let (vs, body) = collectBinders f
lenargs = length args
lenvs = length vs
comp = compare lenargs lenvs
body' = case comp of
LT -> mkCoreLams (drop lenargs vs) body
_ -> body
bnds = zipWith NonRec vs args
body'' <- contextonlyT $ \ c -> do
apply (andR $ replicate (length bnds) letSubstR) c $ mkCoreLets bnds body'
return $ case comp of
GT -> mkCoreApps body'' $ drop lenvs args
_ -> body''
unfoldR :: RewriteH CoreExpr
unfoldR = go >>> cleanupUnfoldR
where go :: RewriteH CoreExpr
go = inline <+ appAllR go idR
unfoldPredR :: (Id -> [CoreExpr] -> Bool) -> RewriteH CoreExpr
unfoldPredR p = callPredT p >>= \ _ -> unfoldR
unfoldNameR :: TH.Name -> RewriteH CoreExpr
unfoldNameR nm = callNameT nm >>= \ _ -> unfoldR
unfoldAnyR :: [TH.Name] -> RewriteH CoreExpr
unfoldAnyR = orR . map unfoldNameR
unfoldSaturatedR :: RewriteH CoreExpr
unfoldSaturatedR = callSaturatedT >>= \ _ -> unfoldR
specializeR :: RewriteH CoreExpr
specializeR = unfoldPredR (const (all isTyCoArg))
rememberR :: Label -> RewriteH Core
rememberR label = sideEffectR $ \ _ core ->
case core of
DefCore def -> saveDef label def
BindCore (NonRec i e) -> saveDef label (Def i e)
_ -> fail "remember: not a binding"
unfoldStashR :: String -> RewriteH CoreExpr
unfoldStashR label = setFailMsg "Inlining stashed definition failed: " $
withPatFailMsg (wrongExprForm "Var v") $
do (c, Var v) <- exposeT
constT $ do Def i rhs <- lookupDef label
if idName i == idName v
then ifM (all (inScope c) <$> apply freeVarsT c rhs)
(return rhs)
(fail "some free variables in stashed definition are no longer in scope.")
else fail $ "stashed definition applies to " ++ var2String i ++ " not " ++ var2String v
showStashT :: Injection CoreDef a => PrettyH a -> TranslateH a DocH
showStashT pp = do
stash <- constT getStash
docs <- contextonlyT $ \ c ->
mapM (\ (l,d) -> do dfn <- apply (extractT pp) c d
return $ PP.text ("[ " ++ l ++ " ]") PP.$+$ dfn PP.$+$ PP.space)
(Map.toList stash)
return $ PP.vcat docs