module HERMIT.Dictionary.FixPoint
(
HERMIT.Dictionary.FixPoint.externals
, fixIntroR
, fixIntroNonRecR
, fixIntroRecR
, fixComputationRuleBR
, fixRollingRuleBR
, fixFusionRuleBR
, isFixExprT
) where
import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class
import Data.String (fromString)
import HERMIT.Context
import HERMIT.Core
import HERMIT.Monad
import HERMIT.Kure hiding ((<$>))
import HERMIT.External
import HERMIT.GHC
import HERMIT.Name
import HERMIT.ParserCore
import HERMIT.Utilities
import HERMIT.Dictionary.Common
import HERMIT.Dictionary.Function
import HERMIT.Dictionary.Kure
import HERMIT.Dictionary.Reasoning
import HERMIT.Dictionary.Undefined
import HERMIT.Dictionary.Unfold
import Prelude.Compat
externals :: [External]
externals =
[ external "fix-intro" (promoteCoreR fixIntroR :: RewriteH LCore)
[ "rewrite a function binding into a non-recursive binding using fix" ] .+ Introduce .+ Context
, external "fix-computation-rule" (promoteExprBiR fixComputationRuleBR :: BiRewriteH LCore)
[ "Fixed-Point Computation Rule",
"fix t f <==> f (fix t f)"
] .+ Context
, external "fix-rolling-rule" (promoteExprBiR fixRollingRuleBR :: BiRewriteH LCore)
[ "Rolling Rule",
"fix tyA (\\ a -> f (g a)) <==> f (fix tyB (\\ b -> g (f b))"
] .+ Context
, external "fix-fusion-rule" ((\ f g h r1 r2 strictf -> promoteExprBiR
(fixFusionRule (Just (r1,r2)) (Just strictf) f g h))
:: CoreString -> CoreString -> CoreString
-> RewriteH LCore -> RewriteH LCore
-> RewriteH LCore -> BiRewriteH LCore)
[ "Fixed-point Fusion Rule"
, "Given f :: A -> B, g :: A -> A, h :: B -> B, and"
, "proofs that, for some x, (f (g a) ==> x) and (h (f a) ==> x) and that f is strict, then"
, "f (fix g) <==> fix h"
] .+ Context
, external "fix-fusion-rule-unsafe" ((\ f g h r1 r2 -> promoteExprBiR (fixFusionRule (Just (r1,r2)) Nothing f g h))
:: CoreString -> CoreString -> CoreString
-> RewriteH LCore -> RewriteH LCore -> BiRewriteH LCore)
[ "(Unsafe) Fixed-point Fusion Rule"
, "Given f :: A -> B, g :: A -> A, h :: B -> B, and"
, "a proof that, for some x, (f (g a) ==> x) and (h (f a) ==> x), then"
, "f (fix g) <==> fix h"
, "Note that the precondition that f is strict is required to hold."
] .+ Context .+ PreCondition
, external "fix-fusion-rule-unsafe" ((\ f g h -> promoteExprBiR (fixFusionRule Nothing Nothing f g h))
:: CoreString -> CoreString -> CoreString -> BiRewriteH LCore)
[ "(Very Unsafe) Fixed-point Fusion Rule"
, "Given f :: A -> B, g :: A -> A, h :: B -> B, then"
, "f (fix g) <==> fix h"
, "Note that the preconditions that f (g a) == h (f a) and that f is strict are required to hold."
] .+ Context .+ PreCondition
]
fixIntroR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Rewrite c m Core
fixIntroR = promoteR fixIntroRecR <+ promoteR fixIntroNonRecR
fixIntroNonRecR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Rewrite c m CoreBind
fixIntroNonRecR = prefixFailMsg "fix introduction failed: " $ do
NonRec f rhs <- idR
rhs' <- polyFixT f <<< return rhs
return $ NonRec f rhs'
fixIntroRecR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Rewrite c m CoreDef
fixIntroRecR = prefixFailMsg "fix introduction failed: " $ do
Def f rhs <- idR
rhs' <- polyFixT f <<< return rhs
return $ Def f rhs'
polyFixT :: forall c m.
( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Id -> Rewrite c m CoreExpr
polyFixT f = do
(tvs, body) <- arr collectTyBinders
f' <- constT $ newIdH (unqualifiedName f) (exprType body)
body' <- contextonlyT $ \ c -> do
let constLam = mkCoreLams tvs $ varToCoreExpr f'
c' = addBindingGroup (NonRec f constLam)
$ addBindingGroup (NonRec f' body) c
applyT (tryR (extractR (anyCallR (promoteR (unfoldPredR (const . (==f))) :: Rewrite c m Core)))) c' body
liftM (mkCoreLams tvs) $ buildFixT $ Lam f' body'
fixComputationRuleBR :: BiRewriteH CoreExpr
fixComputationRuleBR = bidirectional computationL computationR
where
computationL :: RewriteH CoreExpr
computationL = prefixFailMsg "fix computation rule failed: " $
do (_,f) <- isFixExprT
fixf <- idR
return (App f fixf)
computationR :: RewriteH CoreExpr
computationR = prefixFailMsg "(backwards) fix computation rule failed: " $
do App f fixf <- idR
(_,f') <- isFixExprT <<< constant fixf
guardMsg (exprAlphaEq f f') "external function does not match internal expression"
return fixf
fixRollingRuleBR :: BiRewriteH CoreExpr
fixRollingRuleBR = bidirectional rollingRuleL rollingRuleR
where
rollingRuleL :: RewriteH CoreExpr
rollingRuleL = prefixFailMsg "rolling rule failed: " $
withPatFailMsg wrongFixBody $
do (tyA, Lam a (App f (App g (Var a')))) <- isFixExprT
guardMsg (a == a') wrongFixBody
(tyA',tyB) <- funExprsWithInverseTypes g f
guardMsg (eqType tyA tyA') "Type mismatch: this shouldn't have happened, report this as a bug."
res <- rollingRuleResult tyB g f
return (App f res)
rollingRuleR :: RewriteH CoreExpr
rollingRuleR = prefixFailMsg "(reversed) rolling rule failed: " $
withPatFailMsg "not an application." $
do App f fx <- idR
withPatFailMsg wrongFixBody $
do (tyB, Lam b (App g (App f' (Var b')))) <- isFixExprT <<< constant fx
guardMsg (b == b') wrongFixBody
guardMsg (exprAlphaEq f f') "external function does not match internal expression"
(tyA,tyB') <- funExprsWithInverseTypes g f
guardMsg (eqType tyB tyB') "Type mismatch: this shouldn't have happened, report this as a bug."
rollingRuleResult tyA f g
rollingRuleResult :: Type -> CoreExpr -> CoreExpr -> TransformH z CoreExpr
rollingRuleResult ty f g = do x <- constT (newIdH "x" ty)
buildFixT (Lam x (App f (App g (Var x))))
wrongFixBody :: String
wrongFixBody = "body of fix does not have the form: Lam v (App f (App g (Var v)))"
fixFusionRuleBR :: Maybe (EqualityProof HermitC HermitM) -> Maybe (RewriteH CoreExpr) -> CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
fixFusionRuleBR meq mfstrict f g h = beforeBiR
(prefixFailMsg "fixed-point fusion failed: " $
do (_,tyA,tyB) <- funExprArgResTypesM f
(_,tyA') <- endoFunExprTypeM g
(_,tyB') <- endoFunExprTypeM h
guardMsg (typeAlphaEq tyA tyA' && typeAlphaEq tyB tyB') "given functions do not have compatible types."
whenJust (verifyStrictT f) mfstrict
whenJust (\ eq ->
do a <- constT (newGlobalIdH "a" tyA)
let lhs = App f (App g (Var a))
rhs = App h (App f (Var a))
verifyEqualityCommonTargetT lhs rhs eq
)
meq
)
(\ () -> bidirectional fixFusionL fixFusionR)
where
fixFusionL :: RewriteH CoreExpr
fixFusionL = prefixFailMsg "fixed-point fusion failed: " $
withPatFailMsg (wrongExprForm "App f (fix g)") $
do App f' fixg <- idR
guardMsg (exprAlphaEq f f') "first argument function does not match."
(_,g') <- isFixExprT <<< return fixg
guardMsg (exprAlphaEq g g') "second argument function does not match."
buildFixT h
fixFusionR :: RewriteH CoreExpr
fixFusionR = prefixFailMsg "(reversed) fixed-point fusion failed: " $
do (_,h') <- isFixExprT
guardMsg (exprAlphaEq h h') "third argument function does not match."
App f <$> buildFixT g
fixFusionRule :: Maybe (RewriteH LCore, RewriteH LCore) -> Maybe (RewriteH LCore) -> CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr
fixFusionRule meq mfstrict = parse3beforeBiR $ fixFusionRuleBR ((extractR *** extractR) <$> meq) (extractR <$> mfstrict)
isFixExprT :: TransformH CoreExpr (Type,CoreExpr)
isFixExprT = withPatFailMsg (wrongExprForm "fix t f") $
do (Var fixId, [Type ty, f]) <- callT
fixId' <- findIdT fixLocation
guardMsg (fixId == fixId') (unqualifiedName fixId ++ " does not match " ++ show fixLocation)
return (ty,f)
fixLocation :: HermitName
fixLocation = fromString "Data.Function.fix"