module Language.HERMIT.Primitive.New where
import GhcPlugins as GHC hiding (varName)
import Control.Applicative
import Control.Arrow
import Data.List(intersect,transpose)
import Language.HERMIT.Context
import Language.HERMIT.Core
import Language.HERMIT.Monad
import Language.HERMIT.Kure
import Language.HERMIT.External
import Language.HERMIT.GHC
import Language.HERMIT.ParserCore
import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.GHC
import Language.HERMIT.Primitive.Local
import Language.HERMIT.Primitive.Inline
import Language.HERMIT.Primitive.Unfold
import qualified Language.Haskell.TH as TH
externals :: [External]
externals = map ((.+ Experiment) . (.+ TODO))
[ external "test" (testQuery :: RewriteH Core -> TranslateH Core String)
[ "determines if a rewrite could be successfully applied" ]
, external "push" (promoteExprR . push :: TH.Name -> RewriteH Core)
[ "push a function <f> into argument."
, "Unsafe if f is not strict." ] .+ PreCondition
, external "var" (promoteExprT . isVar :: TH.Name -> TranslateH Core ())
[ "var '<v> returns successfully for variable v, and fails otherwise.",
"Useful in combination with \"when\", as in: when (var v) r" ] .+ Predicate
, external "simplify" (simplifyR :: RewriteH Core)
[ "innermost (unfold 'id <+ unfold '$ <+ unfold '. <+ beta-reduce-plus <+ safe-let-subst <+ case-reduce <+ dead-let-elimination)" ] .+ Bash
, external "let-tuple" (promoteExprR . letTupleR :: TH.Name -> RewriteH Core)
[ "let x = e1 in (let y = e2 in e) ==> let t = (e1,e2) in (let x = fst t in (let y = snd t in e))" ]
, external "any-call" (anyCallR :: RewriteH Core -> RewriteH Core)
[ "any-call (.. unfold command ..) applies an unfold commands to all applications"
, "preference is given to applications with more arguments"
] .+ Deep
, external "static-arg" (promoteDefR staticArg :: RewriteH Core)
[ "perform the static argument transformation on a recursive function" ]
, external "unsafe-replace" (promoteExprR . unsafeReplace :: CoreString -> RewriteH Core)
[ "replace the currently focused expression with a new expression" ] .+ Unsafe
, external "unsafe-replace" (promoteExprR . unsafeReplaceStash :: String -> RewriteH Core)
[ "replace the currently focused expression with an expression from the stash"
, "DOES NOT ensure expressions have the same type, or that free variables in the replacement expression are in scope" ] .+ Unsafe
, external "inline-all" (inlineAll :: [TH.Name] -> RewriteH Core)
[ "inline all named functions in a bottom-up manner" ]
]
isVar :: TH.Name -> TranslateH CoreExpr ()
isVar nm = varT (cmpTHName2Var nm) >>= guardM
simplifyR :: RewriteH Core
simplifyR = setFailMsg "Simplify failed: nothing to simplify." $
innermostR (promoteExprR (unfoldNameR (TH.mkName "$")
<+ unfoldNameR (TH.mkName ".")
<+ unfoldNameR (TH.mkName "id")
<+ betaReducePlus
<+ safeLetSubstR
<+ caseReduce
<+ letElim))
collectLets :: CoreExpr -> ([(Var, CoreExpr)],CoreExpr)
collectLets (Let (NonRec x e1) e2) = let (bs,expr) = collectLets e2 in ((x,e1):bs, expr)
collectLets expr = ([],expr)
letTupleR :: TH.Name -> RewriteH CoreExpr
letTupleR nm = prefixFailMsg "Let-tuple failed: " $
do (bnds, body) <- arr collectLets
let numBnds = length bnds
guardMsg (numBnds > 1) "at least two non-recursive let bindings required."
let (vs, rhss) = unzip bnds
guardMsg (all isId vs) "cannot tuple type variables."
let
frees = map coreExprFreeVars (drop 1 rhss)
used = concat $ zipWith intersect (map (`take` vs) [1..]) frees
if null used
then let rhs = mkCoreTup rhss
in constT $ do wild <- newIdH (show nm) (exprType rhs)
return $ mkSmallTupleCase vs body wild rhs
else fail $ "the following bound variables are used in subsequent bindings: " ++ showVars used
staticArg :: RewriteH CoreDef
staticArg = prefixFailMsg "static-arg failed: " $ do
Def f rhs <- idR
let (bnds, body) = collectBinders rhs
guardMsg (notNull bnds) "rhs is not a function"
contextonlyT $ \ c -> do
let bodyContext = foldl (flip addLambdaBinding) c bnds
callPats <- apply (callsT (var2THName f) (callT >>> arr snd)) bodyContext (ExprCore body)
let argExprs = transpose callPats
numCalls = length callPats
(ps,dbnds) = unzip [ (i,b) | (i,b,exprs) <- zip3 [0..] bnds $ argExprs ++ repeat []
, length exprs /= numCalls || isDynamic b exprs
]
isDynamic _ [] = False
isDynamic b ((Var b'):es) | b == b' = isDynamic b es
isDynamic b ((Type (TyVarTy v)):es) | b == v = isDynamic b es
isDynamic _ _ = True
wkr <- newIdH (var2String f ++ "'") (exprType (mkCoreLams dbnds body))
let replaceCall :: RewriteH CoreExpr
replaceCall = do
(_,exprs) <- callT
return $ mkApps (Var wkr) [ e | (p,e) <- zip [0..] exprs, (p::Int) `elem` ps ]
ExprCore body' <- apply (callsR (var2THName f) replaceCall) bodyContext (ExprCore body)
return $ Def f $ mkCoreLams bnds $ Let (Rec [(wkr, mkCoreLams dbnds body')])
$ mkApps (Var wkr) (varsToCoreExprs dbnds)
testQuery :: RewriteH Core -> TranslateH Core String
testQuery r = f <$> testM r
where
f True = "Rewrite would succeed."
f False = "Rewrite would fail."
anyCallR :: RewriteH Core -> RewriteH Core
anyCallR rr = prefixFailMsg "any-call failed: " $
readerT $ \ e -> case e of
ExprCore (App {}) -> childR 1 rec >+> (rr <+ childR 0 rec)
ExprCore (Var {}) -> rr
_ -> anyR rec
where
rec :: RewriteH Core
rec = anyCallR rr
push :: TH.Name -> RewriteH CoreExpr
push nm = prefixFailMsg "push failed: " $
do e <- idR
case collectArgs e of
(Var v,args) -> do
guardMsg (nm `cmpTHName2Var` v) $ "cannot find name " ++ show nm
guardMsg (not $ null args) $ "no argument for " ++ show nm
guardMsg (all isTypeArg $ init args) $ "initial arguments are not type arguments for " ++ show nm
case last args of
Case {} -> caseFloatArg
Let {} -> letFloatArg
_ -> fail "argument is not a Case or Let."
_ -> fail "no function to match."
parseCoreExprT :: CoreString -> TranslateH a CoreExpr
parseCoreExprT = contextonlyT . parseCore
unsafeReplace :: CoreString -> RewriteH CoreExpr
unsafeReplace core =
translate $ \ c e -> do
e' <- parseCore core c
guardMsg (eqType (exprType e) (exprType e')) "expression types differ."
return e'
unsafeReplaceStash :: String -> RewriteH CoreExpr
unsafeReplaceStash label = prefixFailMsg "unsafe-replace failed: " $
contextfreeT $ \ e -> do
Def _ rhs <- lookupDef label
guardMsg (eqType (exprType e) (exprType rhs)) "expression types differ."
return rhs
inlineAll :: [TH.Name] -> RewriteH Core
inlineAll = innermostR . foldr (\nm rr -> promoteExprR (inlineName nm) <+ rr) (fail "inline-all: nothing to do")