module Language.HERMIT.Primitive.New where
import GhcPlugins as GHC hiding (varName)
import Control.Applicative
import Control.Arrow
import Control.Monad
import Data.List(intercalate,intersect)
import Language.HERMIT.Context
import Language.HERMIT.Monad
import Language.HERMIT.Kure
import Language.HERMIT.External
import Language.HERMIT.GHC
import Language.HERMIT.Primitive.GHC
import Language.HERMIT.Primitive.Utils
import Language.HERMIT.Primitive.Local
import Language.HERMIT.Primitive.Local.Case
import Language.HERMIT.Primitive.Local.Let
import Language.HERMIT.Primitive.Inline
import qualified Language.Haskell.TH as TH
import MonadUtils (MonadIO)
externals :: [External]
externals = map ((.+ Experiment) . (.+ TODO))
[ external "info" (info :: TranslateH Core String)
[ "tell me what you know about this expression or binding" ] .+ Unimplemented
, external "expr-type" (promoteExprT exprTypeT :: TranslateH Core String)
[ "display the type of this expression"]
, external "test" (testQuery :: RewriteH Core -> TranslateH Core String)
[ "determines if a rewrite could be successfully applied" ]
, external "fix-intro" (promoteDefR fixIntro :: RewriteH Core)
[ "rewrite a recursive binding into a non-recursive binding using fix" ]
, external "fix-spec" (promoteExprR fixSpecialization :: RewriteH Core)
[ "specialize a fix with a given argument"] .+ Shallow
, external "cleanup-unfold" (promoteExprR cleanupUnfold :: RewriteH Core)
[ "clean up immeduate nested fully-applied lambdas, from the bottom up"]
, external "unfold" (promoteExprR . unfold :: TH.Name -> RewriteH Core)
[ "inline a definition, and apply the arguments; tranditional unfold"]
, external "push" (promoteExprR . push :: TH.Name -> RewriteH Core)
[ "push a function <f> into argument."
, "Unsafe if f is not strict." ] .+ PreCondition
, external "unfold-rule" ((\ nm -> promoteExprR (rules nm >>> cleanupUnfold)) :: String -> RewriteH Core)
[ "apply a named GHC rule" ]
, external "var" (promoteExprT . isVar :: TH.Name -> TranslateH Core ())
[ "var '<v> returns successfully for variable v, and fails otherwise.",
"Useful in combination with \"when\", as in: when (var v) r" ] .+ Predicate
, external "simplify" (simplifyR :: RewriteH Core)
[ "innermost (unfold '. <+ beta-reduce-plus <+ safe-let-subst <+ case-reduce <+ dead-code-elimination)" ]
, external "let-tuple" (promoteExprR . letTupleR :: TH.Name -> RewriteH Core)
[ "let x = e1 in (let y = e2 in e) ==> let t = (e1,e2) in (let x = fst t in (let y = snd t in e))" ]
, external "any-call" (withUnfold :: RewriteH Core -> RewriteH Core)
[ "any-call (.. unfold command ..) applies an unfold commands to all applications"
, "preference is given to applications with more arguments"
] .+ Deep
, external "abstract" (promoteExprR . abstract :: TH.Name -> RewriteH Core)
[ "Abstract over a variable using a lambda.",
"e ==> (\\ x -> e) x"
] .+ Shallow .+ Introduce .+ Context
]
isVar :: TH.Name -> TranslateH CoreExpr ()
isVar nm = varT (cmpTHName2Id nm) >>= guardM
simplifyR :: RewriteH Core
simplifyR = innermostR (promoteExprR (unfold (TH.mkName ".") <+ betaReducePlus <+ safeLetSubstR <+ caseReduce <+ dce))
letPairR :: TH.Name -> RewriteH CoreExpr
letPairR nm = do
Let (NonRec x e1) (Let (NonRec y e2) e) <- idR
ifM (letT (nonRecT (pure ()) const)
(letT (nonRecT freeVarsT (flip const)) (pure ()) const)
elem)
(fail "'x' is used in 'e2'")
(translate $ \ c _ -> do
tupleConId <- findId c "(,)"
fstId <- findId c "Data.Tuple.fst"
sndId <- findId c "Data.Tuple.snd"
let e1TyE = Type (exprType e1)
e2TyE = Type (exprType e2)
rhs = mkCoreApps (Var tupleConId) [e1TyE, e2TyE, e1, e2]
letId <- newVarH (show nm) (exprType rhs)
let fstE = mkCoreApps (Var fstId) [e1TyE, e2TyE, Var letId]
sndE = mkCoreApps (Var sndId) [e1TyE, e2TyE, Var letId]
return $ Let (NonRec letId rhs)
$ Let (NonRec x fstE)
$ Let (NonRec y sndE) e)
letTupleR :: TH.Name -> RewriteH CoreExpr
letTupleR nm = translate $ \ c e -> do
let collectLets :: CoreExpr -> ([(Id, CoreExpr)],CoreExpr)
collectLets (Let (NonRec x e1) e2) = let (bs,expr) = collectLets e2
in ((x,e1):bs, expr)
collectLets expr = ([],expr)
(bnds, body) = collectLets e
guardMsg (length bnds > 1) "cannot tuple: need at least two nonrec lets"
if length bnds == 2
then apply (letPairR nm) c e
else do
let (ids, rhss) = unzip bnds
frees <- mapM (apply freeVarsT c) (drop 1 rhss)
let used = concat $ zipWith intersect (map (flip take ids) [1..]) frees
if null used
then do
tupleConId <- findId c $ "(" ++ replicate (length bnds 1) ',' ++ ")"
let rhs = mkCoreApps (Var tupleConId) $ map (Type . exprType) rhss ++ rhss
varList = concat $ iterate (zipWith (flip (++)) $ repeat "0") $ map (:[]) ['a'..'z']
dc <- maybe (fail "cannot find tuple datacon") return $ isDataConId_maybe tupleConId
vs <- zipWithM newVarH varList $ dataConInstOrigArgTys dc $ map exprType rhss
letId <- newVarH (show nm) (exprType rhs)
return $ Let (NonRec letId rhs)
$ foldr (\ (i,(v,oe)) b -> Let (NonRec v (Case (Var letId) letId (exprType oe) [(DataAlt dc, vs, Var $ vs !! i)])) b)
body $ zip [0..] bnds
else fail "cannot tuple: some bindings are used in the rhs of others"
info :: TranslateH Core String
info = translate $ \ c core -> do
dynFlags <- getDynFlags
let pa = "Path: " ++ show (contextPath c)
node = "Node: " ++ coreNode core
con = "Constructor: " ++ coreConstructor core
bds = "Bindings in Scope: " ++ (show $ map unqualifiedIdName $ listBindings c)
expExtra = case core of
ExprCore e -> ["Type: " ++ showExprType dynFlags e] ++
["Free Variables: " ++ showVars dynFlags (coreExprFreeVars e)] ++
case e of
Var v -> ["Identifier Info: " ++ showIdInfo dynFlags v]
_ -> []
_ -> []
return (intercalate "\n" $ [pa,node,con,bds] ++ expExtra)
exprTypeT :: TranslateH CoreExpr String
exprTypeT = contextfreeT $ \ e -> do
dynFlags <- getDynFlags
return $ showExprType dynFlags e
showExprType :: DynFlags -> CoreExpr -> String
showExprType dynFlags = showPpr dynFlags . exprType
showIdInfo :: DynFlags -> Id -> String
showIdInfo dynFlags v = showSDoc dynFlags $ ppIdInfo v $ idInfo v
coreNode :: Core -> String
coreNode (ModGutsCore _) = "Module"
coreNode (ProgramCore _) = "Program"
coreNode (BindCore _) = "Binding Group"
coreNode (DefCore _) = "Recursive Definition"
coreNode (ExprCore _) = "Expression"
coreNode (AltCore _) = "Case Alternative"
coreConstructor :: Core -> String
coreConstructor (ModGutsCore _) = "ModGuts"
coreConstructor (ProgramCore prog) = case prog of
[] -> "[]"
(_:_) -> "(:)"
coreConstructor (BindCore bnd) = case bnd of
Rec _ -> "Rec"
NonRec _ _ -> "NonRec"
coreConstructor (DefCore _) = "Def"
coreConstructor (AltCore _) = "(,,)"
coreConstructor (ExprCore expr) = case expr of
Var _ -> "Var"
Type _ -> "Type"
Lit _ -> "Lit"
App _ _ -> "App"
Lam _ _ -> "Lam"
Let _ _ -> "Let"
Case _ _ _ _ -> "Case"
Cast _ _ -> "Cast"
Tick _ _ -> "Tick"
Coercion _ -> "Coercion"
testQuery :: RewriteH Core -> TranslateH Core String
testQuery r = f <$> testM r
where
f True = "Rewrite would succeed."
f False = "Rewrite would fail."
findId :: (MonadUnique m, MonadIO m, MonadThings m, HasDynFlags m) => Context -> String -> m Id
findId c = findIdMG (hermitModGuts c)
findIdMG :: (MonadUnique m, MonadIO m, MonadThings m, HasDynFlags m) => ModGuts -> String -> m Id
findIdMG modguts nm =
case filter isValName $ findNameFromTH (mg_rdr_env modguts) $ TH.mkName nm of
[] -> fail $ "cannot find " ++ nm
[n] -> lookupId n
ns -> do dynFlags <- getDynFlags
fail $ "too many " ++ nm ++ " found:\n" ++ intercalate ", " (map (showPpr dynFlags) ns)
fixIntro :: RewriteH CoreDef
fixIntro = prefixFailMsg "Fix introduction failed: " $
do (c, Def f e) <- exposeT
constT $ do fixId <- findId c "Data.Function.fix"
f' <- cloneIdH id f
let coreFix = App (App (Var fixId) (Type (idType f)))
emptySub = mkEmptySubst (mkInScopeSet (exprFreeVars e))
sub = extendSubst emptySub f (Var f')
return $ Def f (coreFix (Lam f' (substExpr (text "fixIntro") sub e)))
fixSpecialization :: RewriteH CoreExpr
fixSpecialization = do
fixId <- translate $ \ c _ -> findId c "Data.Function.fix"
App (App (App (Var fx) (Type _)) _) _ <- idR
guardMsg (fx == fixId) "fixSpecialization only works on fix"
let rr :: RewriteH CoreExpr
rr = multiEtaExpand [TH.mkName "f",TH.mkName "a"]
sub :: RewriteH Core
sub = pathR [0,1] (promoteR rr)
extractR sub >>> fixSpecialization'
fixSpecialization' :: RewriteH CoreExpr
fixSpecialization' = do
App (App (App (Var fx) (Type t))
(Lam _ (Lam v2 (App (App e _) _a2)))
)
a <- idR
let t' = case a of
Type t2 -> applyTy t t2
(Var x) | isTyVar x -> applyTy t (mkTyVarTy x)
v3 <- constT $ newVarH "f" t'
v4 <- constT $ newTypeVarH "a" (tyVarKind v2)
let f' = Lam v4 (Cast (Var v3)
(mkUnsafeCo t' (applyTy t (mkTyVarTy v4))))
let e' = Lam v3 (App (App e f') a)
return $ App (App (Var fx) (Type t')) e'
cleanupUnfold :: RewriteH CoreExpr
cleanupUnfold = betaReducePlus >>> safeLetSubstPlusR
unfold :: TH.Name -> RewriteH CoreExpr
unfold nm = translate $ \ env e0 -> do
let n = appCount e0
let sub :: RewriteH Core
sub = pathR (replicate n 0) (promoteR $ inlineName nm)
sub2 :: RewriteH CoreExpr
sub2 = extractR sub
e1 <- apply sub2 env e0
if n > 0 then apply cleanupUnfold env e1
else return e1
withUnfold :: RewriteH Core -> RewriteH Core
withUnfold rr = prefixFailMsg "any-call failed: " $
readerT $ \ e -> case e of
ExprCore (App {}) -> childR 1 rec >+> (rr <+ childR 0 rec)
ExprCore (Var {}) -> rr
_ -> anyR rec
where
rec :: RewriteH Core
rec = withUnfold rr
push :: TH.Name -> RewriteH CoreExpr
push nm = prefixFailMsg "push failed: " $
do e <- idR
case collectArgs e of
(Var v,args) -> do
guardMsg (nm `cmpTHName2Id` v) $ "could not find name " ++ show nm
guardMsg (not $ null args) $ "no argument for " ++ show nm
guardMsg (all isTypeArg $ init args) $ "initial arguments are not type arguments for " ++ show nm
case last args of
Case {} -> caseFloatArg
Let {} -> letFloatArg
_ -> fail "argument is not a Case or Let."
_ -> fail "no function to match."
abstract :: TH.Name -> RewriteH CoreExpr
abstract nm = prefixFailMsg "abstraction failed: " $
do (c,e) <- exposeT
let name = TH.nameBase nm
case filter (cmpTHName2Id nm) (listBindings c) of
[] -> fail $ name ++ " is not in scope."
[v] -> return (App (Lam v e) (Var v))
_ : _ : _ -> fail $ "multiple variables named " ++ name ++ " in scope."