{-# LANGUAGE ScopedTypeVariables #-} module HERMIT.Dictionary.FixPoint ( -- * Operations on the Fixed Point Operator (fix) -- | Note that many of these operations require 'Data.Function.fix' to be explicitly imported, if it is not used in the source file. HERMIT.Dictionary.FixPoint.externals -- ** Rewrites and BiRewrites on Fixed Points , fixIntroR , fixComputationRuleBR , fixRollingRuleBR , fixFusionRuleBR -- ** Utilities , mkFixT , isFixExprT ) where import Control.Applicative import Control.Arrow import Data.Monoid (mempty) import HERMIT.Context import HERMIT.Core import HERMIT.Monad import HERMIT.Kure import HERMIT.External import HERMIT.GHC import HERMIT.ParserCore import HERMIT.Utilities import HERMIT.Dictionary.Common import HERMIT.Dictionary.GHC import HERMIT.Dictionary.Reasoning import HERMIT.Dictionary.Undefined import qualified Language.Haskell.TH as TH -------------------------------------------------------------------------------------------------- -- | Externals for manipulating fixed points. externals :: [External] externals = [ external "fix-intro" (promoteDefR fixIntroR :: RewriteH Core) [ "rewrite a recursive binding into a non-recursive binding using fix" ] .+ Introduce .+ Context , external "fix-computation-rule" (promoteExprBiR fixComputationRuleBR :: BiRewriteH Core) [ "Fixed-Point Computation Rule", "fix t f <==> f (fix t f)" ] .+ Context , external "fix-rolling-rule" (promoteExprBiR fixRollingRuleBR :: BiRewriteH Core) [ "Rolling Rule", "fix tyA (\\ a -> f (g a)) <==> f (fix tyB (\\ b -> g (f b))" ] .+ Context , external "fix-fusion-rule" ((\ f g h lhsR rhsR strictf -> promoteExprBiR (fixFusionRule (Just (lhsR,rhsR)) (Just strictf) f g h)) :: CoreString -> CoreString -> CoreString -> RewriteH Core -> RewriteH Core -> RewriteH Core -> BiRewriteH Core) [ "Fixed-point Fusion Rule" , "Given f :: A -> B, g :: A -> A, h :: B -> B, and" , "proofs that, for some x, (f (g a) ==> x) and (h (f a) ==> x) and that f is strict, then" , "f (fix g) <==> fix h" ] .+ Context , external "fix-fusion-rule-unsafe" ((\ f g h lhsR rhsR -> promoteExprBiR (fixFusionRule (Just (lhsR,rhsR)) Nothing f g h)) :: CoreString -> CoreString -> CoreString -> RewriteH Core -> RewriteH Core -> BiRewriteH Core) [ "(Unsafe) Fixed-point Fusion Rule" , "Given f :: A -> B, g :: A -> A, h :: B -> B, and" , "a proof that, for some x, (f (g a) ==> x) and (h (f a) ==> x), then" , "f (fix g) <==> fix h" , "Note that the precondition that f is strict is required to hold." ] .+ Context .+ PreCondition , external "fix-fusion-rule-unsafe" ((\ f g h -> promoteExprBiR (fixFusionRule Nothing Nothing f g h)) :: CoreString -> CoreString -> CoreString -> BiRewriteH Core) [ "(Very Unsafe) Fixed-point Fusion Rule" , "Given f :: A -> B, g :: A -> A, h :: B -> B, then" , "f (fix g) <==> fix h" , "Note that the preconditions that f (g a) == h (f a) and that f is strict are required to hold." ] .+ Context .+ PreCondition ] -------------------------------------------------------------------------------------------------- -- | @f = e@ ==\> @f = fix (\\ f -> e)@ fixIntroR :: RewriteH CoreDef fixIntroR = prefixFailMsg "fix introduction failed: " $ do Def f _ <- idR f' <- constT $ cloneVarH id f Def f <$> (mkFixT =<< (defT mempty (extractR $ substR f $ varToCoreExpr f') (\ () e' -> Lam f' e'))) -------------------------------------------------------------------------------------------------- -- | @fix ty f@ \<==\> @f (fix ty f)@ fixComputationRuleBR :: BiRewriteH CoreExpr fixComputationRuleBR = bidirectional computationL computationR where computationL :: RewriteH CoreExpr computationL = prefixFailMsg "fix computation rule failed: " $ do (_,f) <- isFixExprT fixf <- idR return (App f fixf) computationR :: RewriteH CoreExpr computationR = prefixFailMsg "(backwards) fix computation rule failed: " $ do App f fixf <- idR (_,f') <- isFixExprT <<< constant fixf guardMsg (exprAlphaEq f f') "external function does not match internal expression" return fixf -- | @fix tyA (\\ a -> f (g a))@ \<==\> @f (fix tyB (\\ b -> g (f b))@ fixRollingRuleBR :: BiRewriteH CoreExpr fixRollingRuleBR = bidirectional rollingRuleL rollingRuleR where rollingRuleL :: RewriteH CoreExpr rollingRuleL = prefixFailMsg "rolling rule failed: " $ withPatFailMsg wrongFixBody $ do (tyA, Lam a (App f (App g (Var a')))) <- isFixExprT guardMsg (a == a') wrongFixBody (tyA',tyB) <- funsWithInverseTypes g f guardMsg (eqType tyA tyA') "Type mismatch: this shouldn't have happened, report this as a bug." res <- rollingRuleResult tyB g f return (App f res) rollingRuleR :: RewriteH CoreExpr rollingRuleR = prefixFailMsg "(reversed) rolling rule failed: " $ withPatFailMsg "not an application." $ do App f fx <- idR withPatFailMsg wrongFixBody $ do (tyB, Lam b (App g (App f' (Var b')))) <- isFixExprT <<< constant fx guardMsg (b == b') wrongFixBody guardMsg (exprAlphaEq f f') "external function does not match internal expression" (tyA,tyB') <- funsWithInverseTypes g f guardMsg (eqType tyB tyB') "Type mismatch: this shouldn't have happened, report this as a bug." rollingRuleResult tyA f g rollingRuleResult :: Type -> CoreExpr -> CoreExpr -> TranslateH z CoreExpr rollingRuleResult ty f g = do x <- constT (newIdH "x" ty) mkFixT (Lam x (App f (App g (Var x)))) wrongFixBody :: String wrongFixBody = "body of fix does not have the form: Lam v (App f (App g (Var v)))" -------------------------------------------------------------------------------------------------- -- f :: A -> B -- g :: A -> A -- h :: B -> B -- | If @f@ is strict, then (@f (g a)@ == @h (f a)@) ==\> (@f (fix g)@ == @fix h@) fixFusionRuleBR :: Maybe (RewriteH CoreExpr, RewriteH CoreExpr) -> Maybe (RewriteH CoreExpr) -> CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr fixFusionRuleBR meq mfstrict f g h = beforeBiR (prefixFailMsg "fixed-point fusion failed: " $ do (tyA,tyB) <- funArgResTypes f tyA' <- endoFunType g tyB' <- endoFunType h guardMsg (typeAlphaEq tyA tyA' && typeAlphaEq tyB tyB') "given functions do not have compatible types." whenJust (verifyStrictT f) mfstrict whenJust (\ (lhsR,rhsR) -> do a <- constT (newGlobalIdH "a" tyA) let lhs = App f (App g (Var a)) rhs = App h (App f (Var a)) verifyEqualityCommonTargetT lhs rhs lhsR rhsR ) meq ) (\ () -> bidirectional fixFusionL fixFusionR) where fixFusionL :: RewriteH CoreExpr fixFusionL = prefixFailMsg "fixed-point fusion failed: " $ withPatFailMsg (wrongExprForm "App f (fix g)") $ do App f' fixg <- idR guardMsg (exprAlphaEq f f') "first argument function does not match." (_,g') <- isFixExprT <<< return fixg guardMsg (exprAlphaEq g g') "second argument function does not match." mkFixT h fixFusionR :: RewriteH CoreExpr fixFusionR = prefixFailMsg "(reversed) fixed-point fusion failed: " $ do (_,h') <- isFixExprT guardMsg (exprAlphaEq h h') "third argument function does not match." App f <$> mkFixT g -- | If @f@ is strict, then (@f (g a)@ == @h (f a)@) ==> (@f (fix g)@ == @fix h@) fixFusionRule :: Maybe (RewriteH Core, RewriteH Core) -> Maybe (RewriteH Core) -> CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr fixFusionRule meq mfstrict = parse3beforeBiR $ fixFusionRuleBR ((extractR *** extractR) <$> meq) (extractR <$> mfstrict) -------------------------------------------------------------------------------------------------- -- | Check that the expression has the form "fix t (f :: t -> t)", returning "t" and "f". isFixExprT :: TranslateH CoreExpr (Type,CoreExpr) isFixExprT = withPatFailMsg (wrongExprForm "fix t f") $ -- fix :: forall a. (a -> a) -> a do App (App (Var fixId) (Type ty)) f <- idR fixId' <- findFixId guardMsg (fixId == fixId') (var2String fixId ++ " does not match " ++ fixLocation) return (ty,f) -------------------------------------------------------------------------------------------------- -- | f ==> fix f mkFixT :: (BoundVars c, HasGlobalRdrEnv c, MonadCatch m, HasDynFlags m, MonadThings m) => CoreExpr -> Translate c m z CoreExpr mkFixT f = do t <- endoFunType f fixId <- findFixId return $ mkCoreApps (varToCoreExpr fixId) [Type t, f] fixLocation :: String fixLocation = "Data.Function.fix" -- TODO: will crash if 'fix' is not used (or explicitly imported) in the source file. findFixId :: (BoundVars c, HasGlobalRdrEnv c, MonadCatch m, HasDynFlags m, MonadThings m) => Translate c m a Id findFixId = findIdT (TH.mkName fixLocation) --------------------------------------------------------------------------------------------------