module Language.HERMIT.Primitive.FixPoint where
import GhcPlugins as GHC hiding (varName)
import Control.Arrow
import Language.HERMIT.Core
import Language.HERMIT.Monad
import Language.HERMIT.Kure
import Language.HERMIT.External
import Language.HERMIT.GHC
import Language.HERMIT.Primitive.GHC
import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.Local
import Language.HERMIT.Primitive.AlphaConversion
import Language.HERMIT.Primitive.New
import qualified Language.Haskell.TH as TH
externals :: [External]
externals = map ((.+ Experiment) . (.+ TODO))
[ external "fix-intro" (promoteDefR fixIntro :: RewriteH Core)
[ "rewrite a recursive binding into a non-recursive binding using fix" ]
, external "fix-spec" (promoteExprR fixSpecialization :: RewriteH Core)
[ "specialize a fix with a given argument"] .+ Shallow
, external "ww-fac-test" ((\ wrap unwrap -> promoteExprR $ workerWrapperFacTest wrap unwrap) :: TH.Name -> TH.Name -> RewriteH Core)
[ "Under construction "
] .+ Introduce .+ Context .+ Experiment .+ PreCondition
, external "ww-split-test" ((\ wrap unwrap -> promoteDefR $ workerWrapperSplitTest wrap unwrap) :: TH.Name -> TH.Name -> RewriteH Core)
[ "Under construction "
] .+ Introduce .+ Context .+ Experiment .+ PreCondition
]
fixLocation :: String
fixLocation = "Data.Function.fix"
findFixId :: TranslateH a Id
findFixId = findIdT (TH.mkName fixLocation)
guardIsFixId :: Id -> TranslateH a ()
guardIsFixId v = do fixId <- findFixId
guardMsg (v == fixId) (var2String v ++ " does not match " ++ fixLocation)
fixIntro :: RewriteH CoreDef
fixIntro = prefixFailMsg "Fix introduction failed: " $
do Def f e <- idR
fixId <- findFixId
constT $ do f' <- cloneVarH id f
let coreFix = App (App (Var fixId) (Type (idType f)))
emptySub = mkEmptySubst (mkInScopeSet (exprFreeVars e))
sub = extendSubst emptySub f (Var f')
return $ Def f (coreFix (Lam f' (substExpr (text "fixIntro") sub e)))
fixSpecialization :: RewriteH CoreExpr
fixSpecialization = do
App (App (App (Var fixId) (Type _)) _) _ <- idR
guardIsFixId fixId
let rr :: RewriteH CoreExpr
rr = multiEtaExpand [TH.mkName "f",TH.mkName "a"]
sub :: RewriteH Core
sub = pathR [0,1] (promoteR rr)
extractR sub >>> fixSpecialization'
fixSpecialization' :: RewriteH CoreExpr
fixSpecialization' =
do
App (App (App (Var fx) fixTyE)
(Lam _ (Lam v2 (App (App e _) _a2)))
)
a <- idR
t <- case typeExprToType fixTyE of
Nothing -> fail "first argument to fix is not a type, this shouldn't have happened."
Just ty -> return ty
t' <- case typeExprToType a of
Just t2 -> return (applyTy t t2)
Nothing -> fail "Not a type variable."
v3 <- constT $ newIdH "f" t'
v4 <- constT $ newTyVarH "a" (tyVarKind v2)
let f' = Lam v4 (Cast (Var v3)
(mkUnsafeCo t' (applyTy t (mkTyVarTy v4))))
let e' = Lam v3 (App (App e f') a)
return $ App (App (Var fx) (Type t')) e'
workerWrapperFacTest :: TH.Name -> TH.Name -> RewriteH CoreExpr
workerWrapperFacTest wrapNm unwrapNm = do wrapId <- findBoundVarT wrapNm
unwrapId <- findBoundVarT unwrapNm
monomorphicWorkerWrapperFac (Var wrapId) (Var unwrapId)
workerWrapperSplitTest :: TH.Name -> TH.Name -> RewriteH CoreDef
workerWrapperSplitTest wrapNm unwrapNm = do wrapId <- findBoundVarT wrapNm
unwrapId <- findBoundVarT unwrapNm
monomorphicWorkerWrapperSplit (Var wrapId) (Var unwrapId)
monomorphicWorkerWrapperFac :: CoreExpr -> CoreExpr -> RewriteH CoreExpr
monomorphicWorkerWrapperFac wrapE unwrapE =
prefixFailMsg "Worker/wrapper Factorisation failed: " $
withPatFailMsg (wrongExprForm "fix type fun") $
do App (App (Var fixId) fixTyE) f <- idR
guardIsFixId fixId
case typeExprToType fixTyE of
Nothing -> fail "first argument to fix is not a type, this shouldn't have happened."
Just tyA -> case splitFunTy_maybe (exprType wrapE) of
Nothing -> fail "type of wrapper is not a function."
Just (tyB,wrapTyA) -> case splitFunTy_maybe (exprType unwrapE) of
Nothing -> fail "type of unwrapper is not a function."
Just (unwrapTyA,unwrapTyB) -> do guardMsg (eqType wrapTyA unwrapTyA) ("argument type of unwrapper does not match result type of wrapper.")
guardMsg (eqType unwrapTyB tyB) ("argument type of wrapper does not match result type of unwrapper.")
guardMsg (eqType wrapTyA tyA) ("wrapper/unwrapper types do not match expression type.")
x <- constT (newIdH "x" tyB)
return $ App wrapE
(App (App (Var fixId) (Type tyB))
(Lam x (App unwrapE
(App f
(App wrapE
(Var x)
)
)
)
)
)
monomorphicWorkerWrapperSplit :: CoreExpr -> CoreExpr -> RewriteH CoreDef
monomorphicWorkerWrapperSplit wrap unwrap =
let f = TH.mkName "f"
w = TH.mkName "w"
work = TH.mkName "work"
fx = TH.mkName "fix"
in
fixIntro >>> defR ( appAllR idR (letIntro f)
>>> letFloatArg
>>> letAllR idR ( monomorphicWorkerWrapperFac wrap unwrap
>>> appAllR idR (letIntro w)
>>> letFloatArg
>>> letNonRecAllR (unfold fx >>> alphaLetOne (Just work) >>> extractR simplifyR) idR
>>> letSubstR
>>> letFloatArg
)
)