module HERMIT.Dictionary.FixPoint
(
HERMIT.Dictionary.FixPoint.externals
, fixIntroR
, fixIntroNonRecR
, fixIntroRecR
, fixComputationRuleBR
, fixRollingRuleBR
, fixFusionRuleBR
, isFixExprT
) where
import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class
import Data.String (fromString)
import HERMIT.Context
import HERMIT.Core
import HERMIT.Monad
import HERMIT.Kure
import HERMIT.External
import HERMIT.GHC
import HERMIT.Name
import HERMIT.ParserCore
import HERMIT.Utilities
import HERMIT.Dictionary.Common
import HERMIT.Dictionary.Function
import HERMIT.Dictionary.Kure
import HERMIT.Dictionary.Reasoning
import HERMIT.Dictionary.Undefined
import HERMIT.Dictionary.Unfold
externals :: [External]
externals =
[ external "fix-intro" (promoteCoreR fixIntroR :: RewriteH LCore)
[ "rewrite a function binding into a non-recursive binding using fix" ] .+ Introduce .+ Context
, external "fix-computation-rule" (promoteExprBiR fixComputationRuleBR :: BiRewriteH LCore)
[ "Fixed-Point Computation Rule",
"fix t f <==> f (fix t f)"
] .+ Context
, external "fix-rolling-rule" (promoteExprBiR fixRollingRuleBR :: BiRewriteH LCore)
[ "Rolling Rule",
"fix tyA (\\ a -> f (g a)) <==> f (fix tyB (\\ b -> g (f b))"
] .+ Context
, external "fix-fusion-rule" ((\ f g h r1 r2 strictf -> promoteExprBiR
(fixFusionRule (Just (r1,r2)) (Just strictf) f g h))
:: CoreString -> CoreString -> CoreString
-> RewriteH LCore -> RewriteH LCore
-> RewriteH LCore -> BiRewriteH LCore)
[ "Fixed-point Fusion Rule"
, "Given f :: A -> B, g :: A -> A, h :: B -> B, and"
, "proofs that, for some x, (f (g a) ==> x) and (h (f a) ==> x) and that f is strict, then"
, "f (fix g) <==> fix h"
] .+ Context
, external "fix-fusion-rule-unsafe" ((\ f g h r1 r2 -> promoteExprBiR (fixFusionRule (Just (r1,r2)) Nothing f g h))
:: CoreString -> CoreString -> CoreString
-> RewriteH LCore -> RewriteH LCore -> BiRewriteH LCore)
[ "(Unsafe) Fixed-point Fusion Rule"
, "Given f :: A -> B, g :: A -> A, h :: B -> B, and"
, "a proof that, for some x, (f (g a) ==> x) and (h (f a) ==> x), then"
, "f (fix g) <==> fix h"
, "Note that the precondition that f is strict is required to hold."
] .+ Context .+ PreCondition
, external "fix-fusion-rule-unsafe" ((\ f g h -> promoteExprBiR (fixFusionRule Nothing Nothing f g h))
:: CoreString -> CoreString -> CoreString -> BiRewriteH LCore)
[ "(Very Unsafe) Fixed-point Fusion Rule"
, "Given f :: A -> B, g :: A -> A, h :: B -> B, then"
, "f (fix g) <==> fix h"
, "Note that the preconditions that f (g a) == h (f a) and that f is strict are required to hold."
] .+ Context .+ PreCondition
]
fixIntroR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasHermitMEnv m, HasHscEnv m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Rewrite c m Core
fixIntroR = promoteR fixIntroRecR <+ promoteR fixIntroNonRecR
fixIntroNonRecR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasHermitMEnv m, HasHscEnv m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Rewrite c m CoreBind
fixIntroNonRecR = prefixFailMsg "fix introduction failed: " $ do
NonRec f rhs <- idR
rhs' <- polyFixT f <<< return rhs
return $ NonRec f rhs'
fixIntroRecR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasHermitMEnv m, HasHscEnv m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Rewrite c m CoreDef
fixIntroRecR = prefixFailMsg "fix introduction failed: " $ do
Def f rhs <- idR
rhs' <- polyFixT f <<< return rhs
return $ Def f rhs'
polyFixT :: forall c m.
( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasHermitMEnv m, HasHscEnv m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Id -> Rewrite c m CoreExpr
polyFixT f = do
(tvs, body) <- arr collectTyBinders
f' <- constT $ newIdH (unqualifiedName f) (exprType body)
body' <- contextonlyT $ \ c -> do
let constLam = mkCoreLams tvs $ varToCoreExpr f'
c' = addBindingGroup (NonRec f constLam)
$ addBindingGroup (NonRec f' body) c
applyT (tryR (extractR (anyCallR (promoteR (unfoldPredR (const . (==f))) :: Rewrite c m Core)))) c' body
liftM (mkCoreLams tvs) $ buildFixT $ Lam f' body'
fixComputationRuleBR :: BiRewriteH CoreExpr
fixComputationRuleBR = bidirectional computationL computationR
where
computationL :: RewriteH CoreExpr
computationL = prefixFailMsg "fix computation rule failed: " $
do (_,f) <- isFixExprT
fixf <- idR
return (App f fixf)
computationR :: RewriteH CoreExpr
computationR = prefixFailMsg "(backwards) fix computation rule failed: " $
do App f fixf <- idR
(_,f') <- isFixExprT <<< constant fixf
guardMsg (exprAlphaEq f f') "external function does not match internal expression"
return fixf
fixRollingRuleBR :: BiRewriteH CoreExpr
fixRollingRuleBR = bidirectional rollingRuleL rollingRuleR
where
rollingRuleL :: RewriteH CoreExpr
rollingRuleL = prefixFailMsg "rolling rule failed: " $
withPatFailMsg wrongFixBody $
do (tyA, Lam a (App f (App g (Var a')))) <- isFixExprT
guardMsg (a == a') wrongFixBody
(tyA',tyB) <- funExprsWithInverseTypes g f
guardMsg (eqType tyA tyA') "Type mismatch: this shouldn't have happened, report this as a bug."
res <- rollingRuleResult tyB g f
return (App f res)
rollingRuleR :: RewriteH CoreExpr
rollingRuleR = prefixFailMsg "(reversed) rolling rule failed: " $
withPatFailMsg "not an application." $
do App f fx <- idR
withPatFailMsg wrongFixBody $
do (tyB, Lam b (App g (App f' (Var b')))) <- isFixExprT <<< constant fx
guardMsg (b == b') wrongFixBody
guardMsg (exprAlphaEq f f') "external function does not match internal expression"
(tyA,tyB') <- funExprsWithInverseTypes g f
guardMsg (eqType tyB tyB') "Type mismatch: this shouldn't have happened, report this as a bug."
rollingRuleResult tyA f g
rollingRuleResult :: Type -> CoreExpr -> CoreExpr -> TransformH z CoreExpr
rollingRuleResult ty f g = do x <- constT (newIdH "x" ty)
buildFixT (Lam x (App f (App g (Var x))))
wrongFixBody :: String
wrongFixBody = "body of fix does not have the form: Lam v (App f (App g (Var v)))"
fixFusionRuleBR :: Maybe (EqualityProof HermitC HermitM) -> Maybe (RewriteH CoreExpr) -> CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
fixFusionRuleBR meq mfstrict f g h = beforeBiR
(prefixFailMsg "fixed-point fusion failed: " $
do (_,tyA,tyB) <- funExprArgResTypesM f
(_,tyA') <- endoFunExprTypeM g
(_,tyB') <- endoFunExprTypeM h
guardMsg (typeAlphaEq tyA tyA' && typeAlphaEq tyB tyB') "given functions do not have compatible types."
whenJust (verifyStrictT f) mfstrict
whenJust (\ eq ->
do a <- constT (newGlobalIdH "a" tyA)
let lhs = App f (App g (Var a))
rhs = App h (App f (Var a))
verifyEqualityCommonTargetT lhs rhs eq
)
meq
)
(\ () -> bidirectional fixFusionL fixFusionR)
where
fixFusionL :: RewriteH CoreExpr
fixFusionL = prefixFailMsg "fixed-point fusion failed: " $
withPatFailMsg (wrongExprForm "App f (fix g)") $
do App f' fixg <- idR
guardMsg (exprAlphaEq f f') "first argument function does not match."
(_,g') <- isFixExprT <<< return fixg
guardMsg (exprAlphaEq g g') "second argument function does not match."
buildFixT h
fixFusionR :: RewriteH CoreExpr
fixFusionR = prefixFailMsg "(reversed) fixed-point fusion failed: " $
do (_,h') <- isFixExprT
guardMsg (exprAlphaEq h h') "third argument function does not match."
App f <$> buildFixT g
fixFusionRule :: Maybe (RewriteH LCore, RewriteH LCore) -> Maybe (RewriteH LCore) -> CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr
fixFusionRule meq mfstrict = parse3beforeBiR $ fixFusionRuleBR ((extractR *** extractR) <$> meq) (extractR <$> mfstrict)
isFixExprT :: TransformH CoreExpr (Type,CoreExpr)
isFixExprT = withPatFailMsg (wrongExprForm "fix t f") $
do (Var fixId, [Type ty, f]) <- callT
fixId' <- findIdT fixLocation
guardMsg (fixId == fixId') (unqualifiedName fixId ++ " does not match " ++ show fixLocation)
return (ty,f)
fixLocation :: HermitName
fixLocation = fromString "Data.Function.fix"