module HERMIT.Dictionary.WorkerWrapper.Fix ( -- * The Worker/Wrapper Transformation -- | Note that many of these operations require 'Data.Function.fix' to be in scope. HERMIT.Dictionary.WorkerWrapper.Fix.externals , wwFacBR , wwSplitR , wwSplitStaticArg , wwGenerateFusionT , wwFusionBR , wwAssA , wwAssB , wwAssC ) where import Control.Arrow import Data.String (fromString) import HERMIT.Core import HERMIT.External import HERMIT.GHC import HERMIT.Kure import HERMIT.Lemma import HERMIT.Monad import HERMIT.Name import HERMIT.ParserCore import HERMIT.Utilities import HERMIT.Dictionary.AlphaConversion import HERMIT.Dictionary.Common import HERMIT.Dictionary.FixPoint import HERMIT.Dictionary.Function import HERMIT.Dictionary.Local import HERMIT.Dictionary.Navigation import HERMIT.Dictionary.Reasoning import HERMIT.Dictionary.Unfold import HERMIT.Dictionary.WorkerWrapper.Common -------------------------------------------------------------------------------------------------- -- | Externals for manipulating fixed points, and for the worker/wrapper transformation. externals :: [External] externals = [ external "ww-factorisation" ((\ wrap unwrap assC -> promoteExprBiR $ wwFac (mkWWAssC assC) wrap unwrap) :: CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore) [ "Worker/Wrapper Factorisation", "For any \"f :: A -> A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments,", "and a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then", "fix A f ==> wrap (fix B (\\ b -> unwrap (f (wrap b))))" ] .+ Introduce .+ Context , external "ww-factorisation-unsafe" ((\ wrap unwrap -> promoteExprBiR $ wwFac Nothing wrap unwrap) :: CoreString -> CoreString -> BiRewriteH LCore) [ "Unsafe Worker/Wrapper Factorisation", "For any \"f :: A -> A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments, then", "fix A f <==> wrap (fix B (\\ b -> unwrap (f (wrap b))))", "Note: the pre-condition \"fix A (\\ a -> wrap (unwrap (f a))) == fix A f\" is expected to hold." ] .+ Introduce .+ Context .+ PreCondition , external "ww-split" ((\ wrap unwrap assC -> promoteDefR $ wwSplit (mkWWAssC assC) wrap unwrap) :: CoreString -> CoreString -> RewriteH LCore -> RewriteH LCore) [ "Worker/Wrapper Split", "For any \"prog :: A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments,", "and a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then", "prog = expr ==> prog = let f = \\ prog -> expr", " in let work = unwrap (f (wrap work))", " in wrap work" ] .+ Introduce .+ Context , external "ww-split-unsafe" ((\ wrap unwrap -> promoteDefR $ wwSplit Nothing wrap unwrap) :: CoreString -> CoreString -> RewriteH LCore) [ "Unsafe Worker/Wrapper Split", "For any \"prog :: A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments, then", "prog = expr ==> prog = let f = \\ prog -> expr", " in let work = unwrap (f (wrap work))", " in wrap work", "Note: the pre-condition \"fix A (wrap . unwrap . f) == fix A f\" is expected to hold." ] .+ Introduce .+ Context .+ PreCondition , external "ww-split-static-arg" ((\ n is wrap unwrap assC -> promoteDefR $ wwSplitStaticArg n is (mkWWAssC assC) wrap unwrap) :: Int -> [Int] -> CoreString -> CoreString -> RewriteH LCore -> RewriteH LCore) [ "Worker/Wrapper Split - Static Argument Variant", "Perform the static argument transformation on the first n arguments, then perform the worker/wrapper split,", "applying the given wrap and unwrap functions to the specified (by index) static arguments before use." ] .+ Introduce .+ Context , external "ww-split-static-arg-unsafe" ((\ n is wrap unwrap -> promoteDefR $ wwSplitStaticArg n is Nothing wrap unwrap) :: Int -> [Int] -> CoreString -> CoreString -> RewriteH LCore) [ "Unsafe Worker/Wrapper Split - Static Argument Variant", "Perform the static argument transformation on the first n arguments, then perform the (unsafe) worker/wrapper split,", "applying the given wrap and unwrap functions to the specified (by index) static arguments before use." ] .+ Introduce .+ Context .+ PreCondition , external "ww-assumption-A" ((\ wrap unwrap assA -> promoteExprBiR $ wwA (Just $ extractR assA) wrap unwrap) :: CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore) [ "Worker/Wrapper Assumption A", "For a \"wrap :: B -> A\" and an \"unwrap :: A -> B\",", "and given a proof of \"wrap (unwrap a) ==> a\", then", "wrap (unwrap a) <==> a" ] .+ Introduce .+ Context , external "ww-assumption-B" ((\ wrap unwrap f assB -> promoteExprBiR $ wwB (Just $ extractR assB) wrap unwrap f) :: CoreString -> CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore) [ "Worker/Wrapper Assumption B", "For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\",", "and given a proof of \"wrap (unwrap (f a)) ==> f a\", then", "wrap (unwrap (f a)) <==> f a" ] .+ Introduce .+ Context , external "ww-assumption-C" ((\ wrap unwrap f assC -> promoteExprBiR $ wwC (Just $ extractR assC) wrap unwrap f) :: CoreString -> CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore) [ "Worker/Wrapper Assumption C", "For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\",", "and given a proof of \"fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f\", then", "fix A (\\ a -> wrap (unwrap (f a))) <==> fix A f" ] .+ Introduce .+ Context , external "ww-assumption-A-unsafe" ((\ wrap unwrap -> promoteExprBiR $ wwA Nothing wrap unwrap) :: CoreString -> CoreString -> BiRewriteH LCore) [ "Unsafe Worker/Wrapper Assumption A", "For a \"wrap :: B -> A\" and an \"unwrap :: A -> B\", then", "wrap (unwrap a) <==> a", "Note: only use this if it's true!" ] .+ Introduce .+ Context .+ PreCondition , external "ww-assumption-B-unsafe" ((\ wrap unwrap f -> promoteExprBiR $ wwB Nothing wrap unwrap f) :: CoreString -> CoreString -> CoreString -> BiRewriteH LCore) [ "Unsafe Worker/Wrapper Assumption B", "For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\", then", "wrap (unwrap (f a)) <==> f a", "Note: only use this if it's true!" ] .+ Introduce .+ Context .+ PreCondition , external "ww-assumption-C-unsafe" ((\ wrap unwrap f -> promoteExprBiR $ wwC Nothing wrap unwrap f) :: CoreString -> CoreString -> CoreString -> BiRewriteH LCore) [ "Unsafe Worker/Wrapper Assumption C", "For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\", then", "fix A (\\ a -> wrap (unwrap (f a))) <==> fix A f", "Note: only use this if it's true!" ] .+ Introduce .+ Context .+ PreCondition , external "ww-AssA-to-AssB" (promoteExprR . wwAssAimpliesAssB . extractR :: RewriteH LCore -> RewriteH LCore) [ "Convert a proof of worker/wrapper Assumption A into a proof of worker/wrapper Assumption B." ] , external "ww-AssB-to-AssC" (promoteExprR . wwAssBimpliesAssC . extractR :: RewriteH LCore -> RewriteH LCore) [ "Convert a proof of worker/wrapper Assumption B into a proof of worker/wrapper Assumption C." ] , external "ww-AssA-to-AssC" (promoteExprR . wwAssAimpliesAssC . extractR :: RewriteH LCore -> RewriteH LCore) [ "Convert a proof of worker/wrapper Assumption A into a proof of worker/wrapper Assumption C." ] , external "ww-generate-fusion" (wwGenerateFusionT . mkWWAssC :: RewriteH LCore -> TransformH LCore ()) [ "Given a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then", "execute this command on \"work = unwrap (f (wrap work))\" to enable the \"ww-fusion\" rule thereafter.", "Note that this is performed automatically as part of \"ww-split\"." ] .+ Experiment .+ TODO , external "ww-generate-fusion-unsafe" (wwGenerateFusionT Nothing :: TransformH LCore ()) [ "Execute this command on \"work = unwrap (f (wrap work))\" to enable the \"ww-fusion\" rule thereafter.", "The precondition \"fix A (wrap . unwrap . f) == fix A f\" is expected to hold.", "Note that this is performed automatically as part of \"ww-split\"." ] .+ Experiment .+ TODO , external "ww-fusion" (promoteExprBiR wwFusion :: BiRewriteH LCore) [ "Worker/Wrapper Fusion", "unwrap (wrap work) <==> work", "Note: you are required to have previously executed the command \"ww-generate-fusion\" on the definition", " work = unwrap (f (wrap work))" ] .+ Introduce .+ Context .+ PreCondition .+ TODO ] where mkWWAssC :: RewriteH LCore -> Maybe WWAssumption mkWWAssC r = Just (WWAssumption C (extractR r)) -------------------------------------------------------------------------------------------------- -- | For any @f :: A -> A@, and given @wrap :: B -> A@ and @unwrap :: A -> B@ as arguments, then -- @fix A f@ \<==\> @wrap (fix B (\\ b -> unwrap (f (wrap b))))@ wwFacBR :: Maybe WWAssumption -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr wwFacBR mAss wrap unwrap = beforeBiR (wrapUnwrapTypes wrap unwrap) (\ (tyA,tyB) -> bidirectional (wwL tyA tyB) wwR) where wwL :: Type -> Type -> RewriteH CoreExpr wwL tyA tyB = prefixFailMsg "worker/wrapper factorisation failed: " $ do (tA,f) <- isFixExprT guardMsg (eqType tyA tA) ("wrapper/unwrapper types do not match fix body type.") whenJust (verifyWWAss wrap unwrap f) mAss b <- constT (newIdH "x" tyB) App wrap <$> buildFixT (Lam b (App unwrap (App f (App wrap (Var b))))) wwR :: RewriteH CoreExpr wwR = prefixFailMsg "(reverse) worker/wrapper factorisation failed: " $ withPatFailMsg "not an application." $ do App wrap2 fx <- idR withPatFailMsg wrongFixBody $ do (_, Lam b (App unwrap1 (App f (App wrap1 (Var b'))))) <- isFixExprT <<< constant fx guardMsg (b == b') wrongFixBody guardMsg (equivalentBy exprAlphaEq [wrap, wrap1, wrap2]) "wrappers do not match." guardMsg (exprAlphaEq unwrap unwrap1) "unwrappers do not match." whenJust (verifyWWAss wrap unwrap f) mAss buildFixT f wrongFixBody :: String wrongFixBody = "body of fix does not have the form Lam b (App unwrap (App f (App wrap (Var b))))" -- | For any @f :: A -> A@, and given @wrap :: B -> A@ and @unwrap :: A -> B@ as arguments, then -- @fix A f@ \<==\> @wrap (fix B (\\ b -> unwrap (f (wrap b))))@ wwFac :: Maybe WWAssumption -> CoreString -> CoreString -> BiRewriteH CoreExpr wwFac mAss = parse2beforeBiR (wwFacBR mAss) -------------------------------------------------------------------------------------------------- -- | Given @wrap :: B -> A@, @unwrap :: A -> B@ and @work :: B@ as arguments, then -- @unwrap (wrap work)@ \<==\> @work@ wwFusionBR :: BiRewriteH CoreExpr wwFusionBR = beforeBiR (prefixFailMsg "worker/wrapper fusion failed: " $ withPatFailMsg "malformed WW Fusion rule." $ do Quantified _ (Equiv w (App unwrap (App _f (App wrap w')))) <- constT (lemmaQ <$> findLemma workLabel) guardMsg (exprSyntaxEq w w') "malformed WW Fusion rule." return (wrap,unwrap,w) ) (\ (wrap,unwrap,work) -> bidirectional (fusL wrap unwrap work) (fusR wrap unwrap work)) where fusL :: CoreExpr -> CoreExpr -> CoreExpr -> RewriteH CoreExpr fusL wrap unwrap work = prefixFailMsg "worker/wrapper fusion failed: " $ withPatFailMsg (wrongExprForm "unwrap (wrap work)") $ do App unwrap' (App wrap' work') <- idR guardMsg (exprAlphaEq wrap wrap') "wrapper does not match." guardMsg (exprAlphaEq unwrap unwrap') "unwrapper does not match." guardMsg (exprAlphaEq work work') "worker does not match." return work fusR :: CoreExpr -> CoreExpr -> CoreExpr -> RewriteH CoreExpr fusR wrap unwrap work = prefixFailMsg "(reverse) worker/wrapper fusion failed: " $ do work' <- idR guardMsg (exprAlphaEq work work') "worker does not match." return $ App unwrap (App wrap work) -- | Given @wrap :: B -> A@, @unwrap :: A -> B@ and @work :: B@ as arguments, then -- @unwrap (wrap work)@ \<==\> @work@ wwFusion :: BiRewriteH CoreExpr wwFusion = wwFusionBR -------------------------------------------------------------------------------------------------- -- | Save the recursive definition of work in the stash, so that we can later verify uses of 'wwFusionBR'. -- Must be applied to a definition of the form: @work = unwrap (f (wrap work))@ -- Note that this is performed automatically as part of 'wwSplitR'. wwGenerateFusionT :: Maybe WWAssumption -> TransformH LCore () wwGenerateFusionT mAss = prefixFailMsg "generate WW fusion failed: " $ withPatFailMsg wrongForm $ do Def w e@(App unwrap (App f (App wrap (Var w')))) <- projectT guardMsg (w == w') wrongForm whenJust (verifyWWAss wrap unwrap f) mAss insertLemmaT workLabel $ Lemma (Quantified [] (Equiv (varToCoreExpr w) e)) Proven NotUsed False where wrongForm = "definition does not have the form: work = unwrap (f (wrap work))" -------------------------------------------------------------------------------------------------- -- | \\ wrap unwrap -> (@prog = expr@ ==> @prog = let f = \\ prog -> expr in let work = unwrap (f (wrap work)) in wrap work)@ wwSplitR :: Maybe WWAssumption -> CoreExpr -> CoreExpr -> RewriteH CoreDef wwSplitR mAss wrap unwrap = fixIntroRecR >>> defAllR idR ( appAllR idR (letIntroR "f") >>> letFloatArgR >>> letAllR idR ( forwardT (wwFacBR mAss wrap unwrap) >>> appAllR idR ( unfoldNameR (fromString "Data.Function.fix") >>> alphaLetWithR ["work"] >>> letRecAllR (\ _ -> defAllR idR (betaReduceR >>> letNonRecSubstR) >>> (extractT (wwGenerateFusionT mAss) >> idR) ) idR ) >>> letFloatArgR ) ) -- | \\ wrap unwrap -> (@prog = expr@ ==> @prog = let f = \\ prog -> expr in let work = unwrap (f (wrap work)) in wrap work)@ wwSplit :: Maybe WWAssumption -> CoreString -> CoreString -> RewriteH CoreDef wwSplit mAss wrapS unwrapS = (parseCoreExprT wrapS &&& parseCoreExprT unwrapS) >>= uncurry (wwSplitR mAss) -- | As 'wwSplit' but performs the static-argument transformation for @n@ static arguments first, and optionally provides some of those arguments (specified by index) to all calls of wrap and unwrap. -- This is useful if, for example, the expression, and wrap and unwrap, all have a @forall@ type. wwSplitStaticArg :: Int -> [Int] -> Maybe WWAssumption -> CoreString -> CoreString -> RewriteH CoreDef wwSplitStaticArg 0 _ = wwSplit wwSplitStaticArg n is = \ mAss wrapS unwrapS -> prefixFailMsg "worker/wrapper split (static argument variant) failed: " $ do guardMsg (all (< n) is) "arguments for wrap and unwrap must be chosen from the statically transformed arguments." bs <- defT successT (arr collectBinders) (\ () -> take n . fst) let args = varsToCoreExprs [ b | (i,b) <- zip [0..] bs, i `elem` is ] staticArgPosR [0..(n-1)] >>> defAllR idR (let wwSplitArgsR :: RewriteH CoreDef wwSplitArgsR = do wrap <- parseCoreExprT wrapS unwrap <- parseCoreExprT unwrapS wwSplitR mAss (mkCoreApps wrap args) (mkCoreApps unwrap args) in extractR $ do p <- considerConstructT LetExpr localPathR p $ promoteExprR (letRecAllR (const wwSplitArgsR) idR >>> letSubstR) ) -------------------------------------------------------------------------------------------------- -- | Convert a proof of WW Assumption A into a proof of WW Assumption B. wwAssAimpliesAssB :: RewriteH CoreExpr -> RewriteH CoreExpr wwAssAimpliesAssB = id -- | Convert a proof of WW Assumption B into a proof of WW Assumption C. wwAssBimpliesAssC :: RewriteH CoreExpr -> RewriteH CoreExpr wwAssBimpliesAssC assB = appAllR idR (lamAllR idR assB >>> etaReduceR) -- | Convert a proof of WW Assumption A into a proof of WW Assumption C. wwAssAimpliesAssC :: RewriteH CoreExpr -> RewriteH CoreExpr wwAssAimpliesAssC = wwAssBimpliesAssC . wwAssAimpliesAssB -------------------------------------------------------------------------------------------------- -- | @wrap (unwrap a)@ \<==\> @a@ wwAssA :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption A -> CoreExpr -- ^ wrap -> CoreExpr -- ^ unwrap -> BiRewriteH CoreExpr wwAssA mr wrap unwrap = beforeBiR (do whenJust (verifyAssA wrap unwrap) mr wrapUnwrapTypes wrap unwrap ) (\ (tyA,_) -> bidirectional wwAL (wwAR tyA)) where wwAL :: RewriteH CoreExpr wwAL = withPatFailMsg (wrongExprForm "App wrap (App unwrap x)") $ do App wrap' (App unwrap' x) <- idR guardMsg (exprAlphaEq wrap wrap') "given wrapper does not match wrapper in expression." guardMsg (exprAlphaEq unwrap unwrap') "given unwrapper does not match unwrapper in expression." return x wwAR :: Type -> RewriteH CoreExpr wwAR tyA = do x <- idR guardMsg (exprKindOrType x `eqType` tyA) "type of expression does not match types of wrap/unwrap." return $ App wrap (App unwrap x) -- | @wrap (unwrap a)@ \<==\> @a@ wwA :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption A -> CoreString -- ^ wrap -> CoreString -- ^ unwrap -> BiRewriteH CoreExpr wwA mr = parse2beforeBiR (wwAssA mr) -- | @wrap (unwrap (f a))@ \<==\> @f a@ wwAssB :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption B -> CoreExpr -- ^ wrap -> CoreExpr -- ^ unwrap -> CoreExpr -- ^ f -> BiRewriteH CoreExpr wwAssB mr wrap unwrap f = beforeBiR (whenJust (verifyAssB wrap unwrap f) mr) (\ () -> bidirectional wwBL wwBR) where assA :: BiRewriteH CoreExpr assA = wwAssA Nothing wrap unwrap wwBL :: RewriteH CoreExpr wwBL = withPatFailMsg (wrongExprForm "App wrap (App unwrap (App f a))") $ do App _ (App _ (App f' _)) <- idR guardMsg (exprAlphaEq f f') "given body function does not match expression." forwardT assA wwBR :: RewriteH CoreExpr wwBR = withPatFailMsg (wrongExprForm "App f a") $ do App f' _ <- idR guardMsg (exprAlphaEq f f') "given body function does not match expression." backwardT assA -- | @wrap (unwrap (f a))@ \<==\> @f a@ wwB :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption B -> CoreString -- ^ wrap -> CoreString -- ^ unwrap -> CoreString -- ^ f -> BiRewriteH CoreExpr wwB mr = parse3beforeBiR (wwAssB mr) -- | @fix A (\ a -> wrap (unwrap (f a)))@ \<==\> @fix A f@ wwAssC :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption C -> CoreExpr -- ^ wrap -> CoreExpr -- ^ unwrap -> CoreExpr -- ^ f -> BiRewriteH CoreExpr wwAssC mr wrap unwrap f = beforeBiR (do _ <- isFixExprT whenJust (verifyAssC wrap unwrap f) mr ) (\ () -> bidirectional wwCL wwCR) where assB :: BiRewriteH CoreExpr assB = wwAssB Nothing wrap unwrap f wwCL :: RewriteH CoreExpr wwCL = wwAssBimpliesAssC (forwardT assB) wwCR :: RewriteH CoreExpr wwCR = appAllR idR (etaExpandR "a" >>> lamAllR idR (backwardT assB)) -- | @fix A (\ a -> wrap (unwrap (f a)))@ \<==\> @fix A f@ wwC :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption C -> CoreString -- ^ wrap -> CoreString -- ^ unwrap -> CoreString -- ^ f -> BiRewriteH CoreExpr wwC mr = parse3beforeBiR (wwAssC mr) -------------------------------------------------------------------------------------------------- verifyWWAss :: CoreExpr -- ^ wrap -> CoreExpr -- ^ unwrap -> CoreExpr -- ^ f -> WWAssumption -> TransformH x () verifyWWAss wrap unwrap f (WWAssumption tag ass) = case tag of A -> verifyAssA wrap unwrap ass B -> verifyAssB wrap unwrap f ass C -> verifyAssC wrap unwrap f ass verifyAssA :: CoreExpr -- ^ wrap -> CoreExpr -- ^ unwrap -> RewriteH CoreExpr -- ^ WW Assumption A -> TransformH x () verifyAssA wrap unwrap assA = prefixFailMsg ("verification of worker/wrapper Assumption A failed: ") $ do _ <- wrapUnwrapTypes wrap unwrap -- this check is redundant, but will produce a better error message verifyRetractionT wrap unwrap assA verifyAssB :: CoreExpr -- ^ wrap -> CoreExpr -- ^ unwrap -> CoreExpr -- ^ f -> RewriteH CoreExpr -- ^ WW Assumption B -> TransformH x () verifyAssB wrap unwrap f assB = prefixFailMsg ("verification of worker/wrapper assumption B failed: ") $ do (tyA,_) <- wrapUnwrapTypes wrap unwrap a <- constT (newIdH "a" tyA) let lhs = App wrap (App unwrap (App f (Var a))) rhs = App f (Var a) verifyEqualityLeftToRightT lhs rhs assB verifyAssC :: CoreExpr -- ^ wrap -> CoreExpr -- ^ unwrap -> CoreExpr -- ^ f -> RewriteH CoreExpr -- ^ WW Assumption C -> TransformH a () verifyAssC wrap unwrap f assC = prefixFailMsg ("verification of worker/wrapper assumption C failed: ") $ do (tyA,_) <- wrapUnwrapTypes wrap unwrap a <- constT (newIdH "a" tyA) rhs <- buildFixT f lhs <- buildFixT (Lam a (App wrap (App unwrap (App f (Var a))))) verifyEqualityLeftToRightT lhs rhs assC -------------------------------------------------------------------------------------------------- wrapUnwrapTypes :: MonadCatch m => CoreExpr -> CoreExpr -> m (Type,Type) wrapUnwrapTypes wrap unwrap = setFailMsg "given expressions have the wrong types to form a valid wrap/unwrap pair." $ funExprsWithInverseTypes unwrap wrap --------------------------------------------------------------------------------------------------