module HERMIT.Dictionary.WorkerWrapper.Fix
(
HERMIT.Dictionary.WorkerWrapper.Fix.externals
, wwFacBR
, wwSplitR
, wwSplitStaticArg
, wwGenerateFusionT
, wwFusionBR
, wwAssA
, wwAssB
, wwAssC
) where
import Control.Arrow
import Data.String (fromString)
import HERMIT.Core
import HERMIT.External
import HERMIT.GHC
import HERMIT.Kure hiding ((<$>))
import HERMIT.Lemma
import HERMIT.Monad
import HERMIT.Name
import HERMIT.ParserCore
import HERMIT.Utilities
import HERMIT.Dictionary.AlphaConversion
import HERMIT.Dictionary.Common
import HERMIT.Dictionary.FixPoint
import HERMIT.Dictionary.Function
import HERMIT.Dictionary.Local
import HERMIT.Dictionary.Navigation
import HERMIT.Dictionary.Reasoning
import HERMIT.Dictionary.Unfold
import HERMIT.Dictionary.WorkerWrapper.Common
import Prelude.Compat
externals :: [External]
externals =
[
external "ww-factorisation" ((\ wrap unwrap assC -> promoteExprBiR $ wwFac (mkWWAssC assC) wrap unwrap)
:: CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore)
[ "Worker/Wrapper Factorisation",
"For any \"f :: A -> A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments,",
"and a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
"fix A f ==> wrap (fix B (\\ b -> unwrap (f (wrap b))))"
] .+ Introduce .+ Context
, external "ww-factorisation-unsafe" ((\ wrap unwrap -> promoteExprBiR $ wwFac Nothing wrap unwrap)
:: CoreString -> CoreString -> BiRewriteH LCore)
[ "Unsafe Worker/Wrapper Factorisation",
"For any \"f :: A -> A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments, then",
"fix A f <==> wrap (fix B (\\ b -> unwrap (f (wrap b))))",
"Note: the pre-condition \"fix A (\\ a -> wrap (unwrap (f a))) == fix A f\" is expected to hold."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-split" ((\ wrap unwrap assC -> promoteDefR $ wwSplit (mkWWAssC assC) wrap unwrap)
:: CoreString -> CoreString -> RewriteH LCore -> RewriteH LCore)
[ "Worker/Wrapper Split",
"For any \"prog :: A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments,",
"and a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
"prog = expr ==> prog = let f = \\ prog -> expr",
" in let work = unwrap (f (wrap work))",
" in wrap work"
] .+ Introduce .+ Context
, external "ww-split-unsafe" ((\ wrap unwrap -> promoteDefR $ wwSplit Nothing wrap unwrap)
:: CoreString -> CoreString -> RewriteH LCore)
[ "Unsafe Worker/Wrapper Split",
"For any \"prog :: A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments, then",
"prog = expr ==> prog = let f = \\ prog -> expr",
" in let work = unwrap (f (wrap work))",
" in wrap work",
"Note: the pre-condition \"fix A (wrap . unwrap . f) == fix A f\" is expected to hold."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-split-static-arg" ((\ n is wrap unwrap assC -> promoteDefR $ wwSplitStaticArg n is (mkWWAssC assC) wrap unwrap)
:: Int -> [Int] -> CoreString -> CoreString -> RewriteH LCore -> RewriteH LCore)
[ "Worker/Wrapper Split - Static Argument Variant",
"Perform the static argument transformation on the first n arguments, then perform the worker/wrapper split,",
"applying the given wrap and unwrap functions to the specified (by index) static arguments before use."
] .+ Introduce .+ Context
, external "ww-split-static-arg-unsafe" ((\ n is wrap unwrap -> promoteDefR $ wwSplitStaticArg n is Nothing wrap unwrap)
:: Int -> [Int] -> CoreString -> CoreString -> RewriteH LCore)
[ "Unsafe Worker/Wrapper Split - Static Argument Variant",
"Perform the static argument transformation on the first n arguments, then perform the (unsafe) worker/wrapper split,",
"applying the given wrap and unwrap functions to the specified (by index) static arguments before use."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-assumption-A" ((\ wrap unwrap assA -> promoteExprBiR $ wwA (Just $ extractR assA) wrap unwrap)
:: CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore)
[ "Worker/Wrapper Assumption A",
"For a \"wrap :: B -> A\" and an \"unwrap :: A -> B\",",
"and given a proof of \"wrap (unwrap a) ==> a\", then",
"wrap (unwrap a) <==> a"
] .+ Introduce .+ Context
, external "ww-assumption-B" ((\ wrap unwrap f assB -> promoteExprBiR $ wwB (Just $ extractR assB) wrap unwrap f)
:: CoreString -> CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore)
[ "Worker/Wrapper Assumption B",
"For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\",",
"and given a proof of \"wrap (unwrap (f a)) ==> f a\", then",
"wrap (unwrap (f a)) <==> f a"
] .+ Introduce .+ Context
, external "ww-assumption-C" ((\ wrap unwrap f assC -> promoteExprBiR $ wwC (Just $ extractR assC) wrap unwrap f)
:: CoreString -> CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore)
[ "Worker/Wrapper Assumption C",
"For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\",",
"and given a proof of \"fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f\", then",
"fix A (\\ a -> wrap (unwrap (f a))) <==> fix A f"
] .+ Introduce .+ Context
, external "ww-assumption-A-unsafe" ((\ wrap unwrap -> promoteExprBiR $ wwA Nothing wrap unwrap)
:: CoreString -> CoreString -> BiRewriteH LCore)
[ "Unsafe Worker/Wrapper Assumption A",
"For a \"wrap :: B -> A\" and an \"unwrap :: A -> B\", then",
"wrap (unwrap a) <==> a",
"Note: only use this if it's true!"
] .+ Introduce .+ Context .+ PreCondition
, external "ww-assumption-B-unsafe" ((\ wrap unwrap f -> promoteExprBiR $ wwB Nothing wrap unwrap f)
:: CoreString -> CoreString -> CoreString -> BiRewriteH LCore)
[ "Unsafe Worker/Wrapper Assumption B",
"For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\", then",
"wrap (unwrap (f a)) <==> f a",
"Note: only use this if it's true!"
] .+ Introduce .+ Context .+ PreCondition
, external "ww-assumption-C-unsafe" ((\ wrap unwrap f -> promoteExprBiR $ wwC Nothing wrap unwrap f)
:: CoreString -> CoreString -> CoreString -> BiRewriteH LCore)
[ "Unsafe Worker/Wrapper Assumption C",
"For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\", then",
"fix A (\\ a -> wrap (unwrap (f a))) <==> fix A f",
"Note: only use this if it's true!"
] .+ Introduce .+ Context .+ PreCondition
, external "ww-AssA-to-AssB" (promoteExprR . wwAssAimpliesAssB . extractR :: RewriteH LCore -> RewriteH LCore)
[ "Convert a proof of worker/wrapper Assumption A into a proof of worker/wrapper Assumption B."
]
, external "ww-AssB-to-AssC" (promoteExprR . wwAssBimpliesAssC . extractR :: RewriteH LCore -> RewriteH LCore)
[ "Convert a proof of worker/wrapper Assumption B into a proof of worker/wrapper Assumption C."
]
, external "ww-AssA-to-AssC" (promoteExprR . wwAssAimpliesAssC . extractR :: RewriteH LCore -> RewriteH LCore)
[ "Convert a proof of worker/wrapper Assumption A into a proof of worker/wrapper Assumption C."
]
, external "ww-generate-fusion" (wwGenerateFusionT . mkWWAssC :: RewriteH LCore -> TransformH LCore ())
[ "Given a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
"execute this command on \"work = unwrap (f (wrap work))\" to enable the \"ww-fusion\" rule thereafter.",
"Note that this is performed automatically as part of \"ww-split\"."
] .+ Experiment .+ TODO
, external "ww-generate-fusion-unsafe" (wwGenerateFusionT Nothing :: TransformH LCore ())
[ "Execute this command on \"work = unwrap (f (wrap work))\" to enable the \"ww-fusion\" rule thereafter.",
"The precondition \"fix A (wrap . unwrap . f) == fix A f\" is expected to hold.",
"Note that this is performed automatically as part of \"ww-split\"."
] .+ Experiment .+ TODO
, external "ww-fusion" (promoteExprBiR wwFusion :: BiRewriteH LCore)
[ "Worker/Wrapper Fusion",
"unwrap (wrap work) <==> work",
"Note: you are required to have previously executed the command \"ww-generate-fusion\" on the definition",
" work = unwrap (f (wrap work))"
] .+ Introduce .+ Context .+ PreCondition .+ TODO
]
where
mkWWAssC :: RewriteH LCore -> Maybe WWAssumption
mkWWAssC r = Just (WWAssumption C (extractR r))
wwFacBR :: Maybe WWAssumption -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
wwFacBR mAss wrap unwrap = beforeBiR (wrapUnwrapTypes wrap unwrap)
(\ (tyA,tyB) -> bidirectional (wwL tyA tyB) wwR)
where
wwL :: Type -> Type -> RewriteH CoreExpr
wwL tyA tyB = prefixFailMsg "worker/wrapper factorisation failed: " $
do (tA,f) <- isFixExprT
guardMsg (eqType tyA tA) ("wrapper/unwrapper types do not match fix body type.")
whenJust (verifyWWAss wrap unwrap f) mAss
b <- constT (newIdH "x" tyB)
App wrap <$> buildFixT (Lam b (App unwrap (App f (App wrap (Var b)))))
wwR :: RewriteH CoreExpr
wwR = prefixFailMsg "(reverse) worker/wrapper factorisation failed: " $
withPatFailMsg "not an application." $
do App wrap2 fx <- idR
withPatFailMsg wrongFixBody $
do (_, Lam b (App unwrap1 (App f (App wrap1 (Var b'))))) <- isFixExprT <<< constant fx
guardMsg (b == b') wrongFixBody
guardMsg (equivalentBy exprAlphaEq [wrap, wrap1, wrap2]) "wrappers do not match."
guardMsg (exprAlphaEq unwrap unwrap1) "unwrappers do not match."
whenJust (verifyWWAss wrap unwrap f) mAss
buildFixT f
wrongFixBody :: String
wrongFixBody = "body of fix does not have the form Lam b (App unwrap (App f (App wrap (Var b))))"
wwFac :: Maybe WWAssumption -> CoreString -> CoreString -> BiRewriteH CoreExpr
wwFac mAss = parse2beforeBiR (wwFacBR mAss)
wwFusionBR :: BiRewriteH CoreExpr
wwFusionBR =
beforeBiR (prefixFailMsg "worker/wrapper fusion failed: " $
withPatFailMsg "malformed WW Fusion rule." $
do Equiv w (App unwrap (App _f (App wrap w'))) <- constT (lemmaC <$> findLemma workLabel)
guardMsg (exprSyntaxEq w w') "malformed WW Fusion rule."
return (wrap,unwrap,w)
)
(\ (wrap,unwrap,work) -> bidirectional (fusL wrap unwrap work) (fusR wrap unwrap work))
where
fusL :: CoreExpr -> CoreExpr -> CoreExpr -> RewriteH CoreExpr
fusL wrap unwrap work =
prefixFailMsg "worker/wrapper fusion failed: " $
withPatFailMsg (wrongExprForm "unwrap (wrap work)") $
do App unwrap' (App wrap' work') <- idR
guardMsg (exprAlphaEq wrap wrap') "wrapper does not match."
guardMsg (exprAlphaEq unwrap unwrap') "unwrapper does not match."
guardMsg (exprAlphaEq work work') "worker does not match."
return work
fusR :: CoreExpr -> CoreExpr -> CoreExpr -> RewriteH CoreExpr
fusR wrap unwrap work =
prefixFailMsg "(reverse) worker/wrapper fusion failed: " $
do work' <- idR
guardMsg (exprAlphaEq work work') "worker does not match."
return $ App unwrap (App wrap work)
wwFusion :: BiRewriteH CoreExpr
wwFusion = wwFusionBR
wwGenerateFusionT :: Maybe WWAssumption -> TransformH LCore ()
wwGenerateFusionT mAss =
prefixFailMsg "generate WW fusion failed: " $
withPatFailMsg wrongForm $
do Def w e@(App unwrap (App f (App wrap (Var w')))) <- projectT
guardMsg (w == w') wrongForm
whenJust (verifyWWAss wrap unwrap f) mAss
insertLemmaT workLabel $ Lemma (Equiv (varToCoreExpr w) e) Proven NotUsed
where
wrongForm = "definition does not have the form: work = unwrap (f (wrap work))"
wwSplitR :: Maybe WWAssumption -> CoreExpr -> CoreExpr -> RewriteH CoreDef
wwSplitR mAss wrap unwrap =
fixIntroRecR
>>> defAllR idR ( appAllR idR (letIntroR "f")
>>> letFloatArgR
>>> letAllR idR ( forwardT (wwFacBR mAss wrap unwrap)
>>> appAllR idR ( unfoldNameR (fromString "Data.Function.fix")
>>> alphaLetWithR ["work"]
>>> letRecAllR (\ _ -> defAllR idR (betaReduceR >>> letNonRecSubstR)
>>> (extractT (wwGenerateFusionT mAss) >> idR)
)
idR
)
>>> letFloatArgR
)
)
wwSplit :: Maybe WWAssumption -> CoreString -> CoreString -> RewriteH CoreDef
wwSplit mAss wrapS unwrapS = (parseCoreExprT wrapS &&& parseCoreExprT unwrapS) >>= uncurry (wwSplitR mAss)
wwSplitStaticArg :: Int -> [Int] -> Maybe WWAssumption -> CoreString -> CoreString -> RewriteH CoreDef
wwSplitStaticArg 0 _ = wwSplit
wwSplitStaticArg n is = \ mAss wrapS unwrapS ->
prefixFailMsg "worker/wrapper split (static argument variant) failed: " $
do guardMsg (all (< n) is) "arguments for wrap and unwrap must be chosen from the statically transformed arguments."
bs <- defT successT (arr collectBinders) (\ () -> take n . fst)
let args = varsToCoreExprs [ b | (i,b) <- zip [0..] bs, i `elem` is ]
staticArgPosR [0..(n1)] >>> defAllR idR
(let wwSplitArgsR :: RewriteH CoreDef
wwSplitArgsR = do wrap <- parseCoreExprT wrapS
unwrap <- parseCoreExprT unwrapS
wwSplitR mAss (mkCoreApps wrap args) (mkCoreApps unwrap args)
in
extractR $ do p <- considerConstructT LetExpr
localPathR p $ promoteExprR (letRecAllR (const wwSplitArgsR) idR >>> letSubstR)
)
wwAssAimpliesAssB :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssAimpliesAssB = id
wwAssBimpliesAssC :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssBimpliesAssC assB = appAllR idR (lamAllR idR assB >>> etaReduceR)
wwAssAimpliesAssC :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssAimpliesAssC = wwAssBimpliesAssC . wwAssAimpliesAssB
wwAssA :: Maybe (RewriteH CoreExpr)
-> CoreExpr
-> CoreExpr
-> BiRewriteH CoreExpr
wwAssA mr wrap unwrap = beforeBiR (do whenJust (verifyAssA wrap unwrap) mr
wrapUnwrapTypes wrap unwrap
)
(\ (tyA,_) -> bidirectional wwAL (wwAR tyA))
where
wwAL :: RewriteH CoreExpr
wwAL = withPatFailMsg (wrongExprForm "App wrap (App unwrap x)") $
do App wrap' (App unwrap' x) <- idR
guardMsg (exprAlphaEq wrap wrap') "given wrapper does not match wrapper in expression."
guardMsg (exprAlphaEq unwrap unwrap') "given unwrapper does not match unwrapper in expression."
return x
wwAR :: Type -> RewriteH CoreExpr
wwAR tyA = do x <- idR
guardMsg (exprKindOrType x `eqType` tyA) "type of expression does not match types of wrap/unwrap."
return $ App wrap (App unwrap x)
wwA :: Maybe (RewriteH CoreExpr)
-> CoreString
-> CoreString
-> BiRewriteH CoreExpr
wwA mr = parse2beforeBiR (wwAssA mr)
wwAssB :: Maybe (RewriteH CoreExpr)
-> CoreExpr
-> CoreExpr
-> CoreExpr
-> BiRewriteH CoreExpr
wwAssB mr wrap unwrap f = beforeBiR (whenJust (verifyAssB wrap unwrap f) mr)
(\ () -> bidirectional wwBL wwBR)
where
assA :: BiRewriteH CoreExpr
assA = wwAssA Nothing wrap unwrap
wwBL :: RewriteH CoreExpr
wwBL = withPatFailMsg (wrongExprForm "App wrap (App unwrap (App f a))") $
do App _ (App _ (App f' _)) <- idR
guardMsg (exprAlphaEq f f') "given body function does not match expression."
forwardT assA
wwBR :: RewriteH CoreExpr
wwBR = withPatFailMsg (wrongExprForm "App f a") $
do App f' _ <- idR
guardMsg (exprAlphaEq f f') "given body function does not match expression."
backwardT assA
wwB :: Maybe (RewriteH CoreExpr)
-> CoreString
-> CoreString
-> CoreString
-> BiRewriteH CoreExpr
wwB mr = parse3beforeBiR (wwAssB mr)
wwAssC :: Maybe (RewriteH CoreExpr)
-> CoreExpr
-> CoreExpr
-> CoreExpr
-> BiRewriteH CoreExpr
wwAssC mr wrap unwrap f = beforeBiR (do _ <- isFixExprT
whenJust (verifyAssC wrap unwrap f) mr
)
(\ () -> bidirectional wwCL wwCR)
where
assB :: BiRewriteH CoreExpr
assB = wwAssB Nothing wrap unwrap f
wwCL :: RewriteH CoreExpr
wwCL = wwAssBimpliesAssC (forwardT assB)
wwCR :: RewriteH CoreExpr
wwCR = appAllR idR (etaExpandR "a" >>> lamAllR idR (backwardT assB))
wwC :: Maybe (RewriteH CoreExpr)
-> CoreString
-> CoreString
-> CoreString
-> BiRewriteH CoreExpr
wwC mr = parse3beforeBiR (wwAssC mr)
verifyWWAss :: CoreExpr
-> CoreExpr
-> CoreExpr
-> WWAssumption
-> TransformH x ()
verifyWWAss wrap unwrap f (WWAssumption tag ass) =
case tag of
A -> verifyAssA wrap unwrap ass
B -> verifyAssB wrap unwrap f ass
C -> verifyAssC wrap unwrap f ass
verifyAssA :: CoreExpr
-> CoreExpr
-> RewriteH CoreExpr
-> TransformH x ()
verifyAssA wrap unwrap assA =
prefixFailMsg ("verification of worker/wrapper Assumption A failed: ") $
do _ <- wrapUnwrapTypes wrap unwrap
verifyRetractionT wrap unwrap assA
verifyAssB :: CoreExpr
-> CoreExpr
-> CoreExpr
-> RewriteH CoreExpr
-> TransformH x ()
verifyAssB wrap unwrap f assB =
prefixFailMsg ("verification of worker/wrapper assumption B failed: ") $
do (tyA,_) <- wrapUnwrapTypes wrap unwrap
a <- constT (newIdH "a" tyA)
let lhs = App wrap (App unwrap (App f (Var a)))
rhs = App f (Var a)
verifyEqualityLeftToRightT lhs rhs assB
verifyAssC :: CoreExpr
-> CoreExpr
-> CoreExpr
-> RewriteH CoreExpr
-> TransformH a ()
verifyAssC wrap unwrap f assC =
prefixFailMsg ("verification of worker/wrapper assumption C failed: ") $
do (tyA,_) <- wrapUnwrapTypes wrap unwrap
a <- constT (newIdH "a" tyA)
rhs <- buildFixT f
lhs <- buildFixT (Lam a (App wrap (App unwrap (App f (Var a)))))
verifyEqualityLeftToRightT lhs rhs assC
wrapUnwrapTypes :: MonadCatch m => CoreExpr -> CoreExpr -> m (Type,Type)
wrapUnwrapTypes wrap unwrap = setFailMsg "given expressions have the wrong types to form a valid wrap/unwrap pair." $
funExprsWithInverseTypes unwrap wrap