module HERMIT.Dictionary.WorkerWrapper.Fix
(
HERMIT.Dictionary.WorkerWrapper.Fix.externals
, wwFacBR
, wwSplitR
, wwSplitStaticArg
, wwGenerateFusionR
, wwFusionBR
, wwAssA
, wwAssB
, wwAssC
)
where
import Control.Applicative
import Control.Arrow
import HERMIT.Core
import HERMIT.Monad
import HERMIT.Kure
import HERMIT.External
import HERMIT.GHC
import HERMIT.Utilities
import HERMIT.ParserCore
import HERMIT.Dictionary.AlphaConversion
import HERMIT.Dictionary.Common
import HERMIT.Dictionary.FixPoint
import HERMIT.Dictionary.Function
import HERMIT.Dictionary.Local
import HERMIT.Dictionary.Navigation
import HERMIT.Dictionary.Reasoning
import HERMIT.Dictionary.Unfold
import HERMIT.Dictionary.WorkerWrapper.Common
import qualified Language.Haskell.TH as TH
externals :: [External]
externals =
[
external "ww-factorisation" ((\ wrap unwrap assC -> promoteExprBiR $ wwFac (mkWWAssC assC) wrap unwrap)
:: CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
[ "Worker/Wrapper Factorisation",
"For any \"f :: A -> A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments,",
"and a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
"fix A f ==> wrap (fix B (\\ b -> unwrap (f (wrap b))))"
] .+ Introduce .+ Context
, external "ww-factorisation-unsafe" ((\ wrap unwrap -> promoteExprBiR $ wwFac Nothing wrap unwrap)
:: CoreString -> CoreString -> BiRewriteH Core)
[ "Unsafe Worker/Wrapper Factorisation",
"For any \"f :: A -> A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments, then",
"fix A f <==> wrap (fix B (\\ b -> unwrap (f (wrap b))))",
"Note: the pre-condition \"fix A (\\ a -> wrap (unwrap (f a))) == fix A f\" is expected to hold."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-split" ((\ wrap unwrap assC -> promoteDefR $ wwSplit (mkWWAssC assC) wrap unwrap)
:: CoreString -> CoreString -> RewriteH Core -> RewriteH Core)
[ "Worker/Wrapper Split",
"For any \"prog :: A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments,",
"and a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
"prog = expr ==> prog = let f = \\ prog -> expr",
" in let work = unwrap (f (wrap work))",
" in wrap work"
] .+ Introduce .+ Context
, external "ww-split-unsafe" ((\ wrap unwrap -> promoteDefR $ wwSplit Nothing wrap unwrap)
:: CoreString -> CoreString -> RewriteH Core)
[ "Unsafe Worker/Wrapper Split",
"For any \"prog :: A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments, then",
"prog = expr ==> prog = let f = \\ prog -> expr",
" in let work = unwrap (f (wrap work))",
" in wrap work",
"Note: the pre-condition \"fix A (wrap . unwrap . f) == fix A f\" is expected to hold."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-split-static-arg" ((\ n is wrap unwrap assC -> promoteDefR $ wwSplitStaticArg n is (mkWWAssC assC) wrap unwrap)
:: Int -> [Int] -> CoreString -> CoreString -> RewriteH Core -> RewriteH Core)
[ "Worker/Wrapper Split - Static Argument Variant",
"Perform the static argument transformation on the first n arguments, then perform the worker/wrapper split,",
"applying the given wrap and unwrap functions to the specified (by index) static arguments before use."
] .+ Introduce .+ Context
, external "ww-split-static-arg-unsafe" ((\ n is wrap unwrap -> promoteDefR $ wwSplitStaticArg n is Nothing wrap unwrap)
:: Int -> [Int] -> CoreString -> CoreString -> RewriteH Core)
[ "Unsafe Worker/Wrapper Split - Static Argument Variant",
"Perform the static argument transformation on the first n arguments, then perform the (unsafe) worker/wrapper split,",
"applying the given wrap and unwrap functions to the specified (by index) static arguments before use."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-assumption-A" ((\ wrap unwrap assA -> promoteExprBiR $ wwA (Just $ extractR assA) wrap unwrap)
:: CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
[ "Worker/Wrapper Assumption A",
"For a \"wrap :: B -> A\" and an \"unwrap :: A -> B\",",
"and given a proof of \"wrap (unwrap a) ==> a\", then",
"wrap (unwrap a) <==> a"
] .+ Introduce .+ Context
, external "ww-assumption-B" ((\ wrap unwrap f assB -> promoteExprBiR $ wwB (Just $ extractR assB) wrap unwrap f)
:: CoreString -> CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
[ "Worker/Wrapper Assumption B",
"For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\",",
"and given a proof of \"wrap (unwrap (f a)) ==> f a\", then",
"wrap (unwrap (f a)) <==> f a"
] .+ Introduce .+ Context
, external "ww-assumption-C" ((\ wrap unwrap f assC -> promoteExprBiR $ wwC (Just $ extractR assC) wrap unwrap f)
:: CoreString -> CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
[ "Worker/Wrapper Assumption C",
"For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\",",
"and given a proof of \"fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f\", then",
"fix A (\\ a -> wrap (unwrap (f a))) <==> fix A f"
] .+ Introduce .+ Context
, external "ww-assumption-A-unsafe" ((\ wrap unwrap -> promoteExprBiR $ wwA Nothing wrap unwrap)
:: CoreString -> CoreString -> BiRewriteH Core)
[ "Unsafe Worker/Wrapper Assumption A",
"For a \"wrap :: B -> A\" and an \"unwrap :: A -> B\", then",
"wrap (unwrap a) <==> a",
"Note: only use this if it's true!"
] .+ Introduce .+ Context .+ PreCondition
, external "ww-assumption-B-unsafe" ((\ wrap unwrap f -> promoteExprBiR $ wwB Nothing wrap unwrap f)
:: CoreString -> CoreString -> CoreString -> BiRewriteH Core)
[ "Unsafe Worker/Wrapper Assumption B",
"For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\", then",
"wrap (unwrap (f a)) <==> f a",
"Note: only use this if it's true!"
] .+ Introduce .+ Context .+ PreCondition
, external "ww-assumption-C-unsafe" ((\ wrap unwrap f -> promoteExprBiR $ wwC Nothing wrap unwrap f)
:: CoreString -> CoreString -> CoreString -> BiRewriteH Core)
[ "Unsafe Worker/Wrapper Assumption C",
"For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\", then",
"fix A (\\ a -> wrap (unwrap (f a))) <==> fix A f",
"Note: only use this if it's true!"
] .+ Introduce .+ Context .+ PreCondition
, external "ww-AssA-to-AssB" (promoteExprR . wwAssAimpliesAssB . extractR :: RewriteH Core -> RewriteH Core)
[ "Convert a proof of worker/wrapper Assumption A into a proof of worker/wrapper Assumption B."
]
, external "ww-AssB-to-AssC" (promoteExprR . wwAssBimpliesAssC . extractR :: RewriteH Core -> RewriteH Core)
[ "Convert a proof of worker/wrapper Assumption B into a proof of worker/wrapper Assumption C."
]
, external "ww-AssA-to-AssC" (promoteExprR . wwAssAimpliesAssC . extractR :: RewriteH Core -> RewriteH Core)
[ "Convert a proof of worker/wrapper Assumption A into a proof of worker/wrapper Assumption C."
]
, external "ww-generate-fusion" (wwGenerateFusionR . mkWWAssC :: RewriteH Core -> RewriteH Core)
[ "Given a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
"execute this command on \"work = unwrap (f (wrap work))\" to enable the \"ww-fusion\" rule thereafter.",
"Note that this is performed automatically as part of \"ww-split\"."
] .+ Experiment .+ TODO
, external "ww-generate-fusion-unsafe" (wwGenerateFusionR Nothing :: RewriteH Core)
[ "Execute this command on \"work = unwrap (f (wrap work))\" to enable the \"ww-fusion\" rule thereafter.",
"The precondition \"fix A (wrap . unwrap . f) == fix A f\" is expected to hold.",
"Note that this is performed automatically as part of \"ww-split\"."
] .+ Experiment .+ TODO
, external "ww-fusion" (promoteExprBiR wwFusion :: BiRewriteH Core)
[ "Worker/Wrapper Fusion",
"unwrap (wrap work) <==> work",
"Note: you are required to have previously executed the command \"ww-generate-fusion\" on the definition",
" work = unwrap (f (wrap work))"
] .+ Introduce .+ Context .+ PreCondition .+ TODO
]
where
mkWWAssC :: RewriteH Core -> Maybe WWAssumption
mkWWAssC r = Just (WWAssumption C (extractR r))
wwFacBR :: Maybe WWAssumption -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
wwFacBR mAss wrap unwrap = beforeBiR (wrapUnwrapTypes wrap unwrap)
(\ (tyA,tyB) -> bidirectional (wwL tyA tyB) wwR)
where
wwL :: Type -> Type -> RewriteH CoreExpr
wwL tyA tyB = prefixFailMsg "worker/wrapper factorisation failed: " $
do (tA,f) <- isFixExprT
guardMsg (eqType tyA tA) ("wrapper/unwrapper types do not match fix body type.")
whenJust (verifyWWAss wrap unwrap f) mAss
b <- constT (newIdH "x" tyB)
App wrap <$> mkFixT (Lam b (App unwrap (App f (App wrap (Var b)))))
wwR :: RewriteH CoreExpr
wwR = prefixFailMsg "(reverse) worker/wrapper factorisation failed: " $
withPatFailMsg "not an application." $
do App wrap2 fx <- idR
withPatFailMsg wrongFixBody $
do (_, Lam b (App unwrap1 (App f (App wrap1 (Var b'))))) <- isFixExprT <<< constant fx
guardMsg (b == b') wrongFixBody
guardMsg (equivalentBy exprAlphaEq [wrap, wrap1, wrap2]) "wrappers do not match."
guardMsg (exprAlphaEq unwrap unwrap1) "unwrappers do not match."
whenJust (verifyWWAss wrap unwrap f) mAss
mkFixT f
wrongFixBody :: String
wrongFixBody = "body of fix does not have the form Lam b (App unwrap (App f (App wrap (Var b))))"
wwFac :: Maybe WWAssumption -> CoreString -> CoreString -> BiRewriteH CoreExpr
wwFac mAss = parse2beforeBiR (wwFacBR mAss)
wwFusionBR :: BiRewriteH CoreExpr
wwFusionBR =
beforeBiR (prefixFailMsg "worker/wrapper fusion failed: " $
withPatFailMsg "malformed WW Fusion rule." $
do Def w (App unwrap (App _f (App wrap (Var w')))) <- constT (lookupDef workLabel)
guardMsg (w == w') "malformed WW Fusion rule."
return (wrap,unwrap,Var w)
)
(\ (wrap,unwrap,work) -> bidirectional (fusL wrap unwrap work) (fusR wrap unwrap work))
where
fusL :: CoreExpr -> CoreExpr -> CoreExpr -> RewriteH CoreExpr
fusL wrap unwrap work =
prefixFailMsg "worker/wrapper fusion failed: " $
withPatFailMsg (wrongExprForm "unwrap (wrap work)") $
do App unwrap' (App wrap' work') <- idR
guardMsg (exprAlphaEq wrap wrap') "wrapper does not match."
guardMsg (exprAlphaEq unwrap unwrap') "unwrapper does not match."
guardMsg (exprAlphaEq work work') "worker does not match."
return work
fusR :: CoreExpr -> CoreExpr -> CoreExpr -> RewriteH CoreExpr
fusR wrap unwrap work =
prefixFailMsg "(reverse) worker/wrapper fusion failed: " $
do work' <- idR
guardMsg (exprAlphaEq work work') "worker does not match."
return $ App unwrap (App wrap work)
wwFusion :: BiRewriteH CoreExpr
wwFusion = wwFusionBR
wwGenerateFusionR :: Maybe WWAssumption -> RewriteH Core
wwGenerateFusionR mAss =
prefixFailMsg "generate WW fusion failed: " $
withPatFailMsg wrongForm $
do Def w (App unwrap (App f (App wrap (Var w')))) <- projectT
guardMsg (w == w') wrongForm
whenJust (verifyWWAss wrap unwrap f) mAss
rememberR workLabel
where
wrongForm = "definition does not have the form: work = unwrap (f (wrap work))"
wwSplitR :: Maybe WWAssumption -> CoreExpr -> CoreExpr -> RewriteH CoreDef
wwSplitR mAss wrap unwrap =
let work = TH.mkName "work"
fx = TH.mkName "fix"
in
fixIntroR
>>> defAllR idR ( appAllR idR (letIntroR "f")
>>> letFloatArgR
>>> letAllR idR ( forwardT (wwFacBR mAss wrap unwrap)
>>> appAllR idR ( unfoldNameR fx
>>> alphaLetWithR [work]
>>> letRecAllR (\ _ -> defAllR idR (betaReduceR >>> letNonRecSubstR)
>>> extractR (wwGenerateFusionR mAss)
)
idR
)
>>> letFloatArgR
)
)
wwSplit :: Maybe WWAssumption -> CoreString -> CoreString -> RewriteH CoreDef
wwSplit mAss wrapS unwrapS = (parseCoreExprT wrapS &&& parseCoreExprT unwrapS) >>= uncurry (wwSplitR mAss)
wwSplitStaticArg :: Int -> [Int] -> Maybe WWAssumption -> CoreString -> CoreString -> RewriteH CoreDef
wwSplitStaticArg 0 _ = wwSplit
wwSplitStaticArg n is = \ mAss wrapS unwrapS ->
prefixFailMsg "worker/wrapper split (static argument variant) failed: " $
do guardMsg (all (< n) is) "arguments for wrap and unwrap must be chosen from the statically transformed arguments."
bs <- defT successT (arr collectBinders) (\ () -> take n . fst)
let args = varsToCoreExprs [ b | (i,b) <- zip [0..] bs, i `elem` is ]
staticArgPosR [0..(n1)] >>> defAllR idR
(let wwSplitArgsR :: RewriteH CoreDef
wwSplitArgsR = do wrap <- parseCoreExprT wrapS
unwrap <- parseCoreExprT unwrapS
wwSplitR mAss (mkCoreApps wrap args) (mkCoreApps unwrap args)
in
extractR $ do p <- considerConstructT LetExpr
localPathR p $ promoteExprR (letRecAllR (const wwSplitArgsR) idR >>> letSubstR)
)
wwAssAimpliesAssB :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssAimpliesAssB = id
wwAssBimpliesAssC :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssBimpliesAssC assB = appAllR idR (lamAllR idR assB >>> etaReduceR)
wwAssAimpliesAssC :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssAimpliesAssC = wwAssBimpliesAssC . wwAssAimpliesAssB
wwAssA :: Maybe (RewriteH CoreExpr)
-> CoreExpr
-> CoreExpr
-> BiRewriteH CoreExpr
wwAssA mr wrap unwrap = beforeBiR (do whenJust (verifyAssA wrap unwrap) mr
wrapUnwrapTypes wrap unwrap
)
(\ (tyA,_) -> bidirectional wwAL (wwAR tyA))
where
wwAL :: RewriteH CoreExpr
wwAL = withPatFailMsg (wrongExprForm "App wrap (App unwrap x)") $
do App wrap' (App unwrap' x) <- idR
guardMsg (exprAlphaEq wrap wrap') "given wrapper does not match wrapper in expression."
guardMsg (exprAlphaEq unwrap unwrap') "given unwrapper does not match unwrapper in expression."
return x
wwAR :: Type -> RewriteH CoreExpr
wwAR tyA = do x <- idR
guardMsg (exprKindOrType x `eqType` tyA) "type of expression does not match types of wrap/unwrap."
return $ App wrap (App unwrap x)
wwA :: Maybe (RewriteH CoreExpr)
-> CoreString
-> CoreString
-> BiRewriteH CoreExpr
wwA mr = parse2beforeBiR (wwAssA mr)
wwAssB :: Maybe (RewriteH CoreExpr)
-> CoreExpr
-> CoreExpr
-> CoreExpr
-> BiRewriteH CoreExpr
wwAssB mr wrap unwrap f = beforeBiR (whenJust (verifyAssB wrap unwrap f) mr)
(\ () -> bidirectional wwBL wwBR)
where
assA :: BiRewriteH CoreExpr
assA = wwAssA Nothing wrap unwrap
wwBL :: RewriteH CoreExpr
wwBL = withPatFailMsg (wrongExprForm "App wrap (App unwrap (App f a))") $
do App _ (App _ (App f' _)) <- idR
guardMsg (exprAlphaEq f f') "given body function does not match expression."
forwardT assA
wwBR :: RewriteH CoreExpr
wwBR = withPatFailMsg (wrongExprForm "App f a") $
do App f' _ <- idR
guardMsg (exprAlphaEq f f') "given body function does not match expression."
backwardT assA
wwB :: Maybe (RewriteH CoreExpr)
-> CoreString
-> CoreString
-> CoreString
-> BiRewriteH CoreExpr
wwB mr = parse3beforeBiR (wwAssB mr)
wwAssC :: Maybe (RewriteH CoreExpr)
-> CoreExpr
-> CoreExpr
-> CoreExpr
-> BiRewriteH CoreExpr
wwAssC mr wrap unwrap f = beforeBiR (do _ <- isFixExprT
whenJust (verifyAssC wrap unwrap f) mr
)
(\ () -> bidirectional wwCL wwCR)
where
assB :: BiRewriteH CoreExpr
assB = wwAssB Nothing wrap unwrap f
wwCL :: RewriteH CoreExpr
wwCL = wwAssBimpliesAssC (forwardT assB)
wwCR :: RewriteH CoreExpr
wwCR = appAllR idR (etaExpandR "a" >>> lamAllR idR (backwardT assB))
wwC :: Maybe (RewriteH CoreExpr)
-> CoreString
-> CoreString
-> CoreString
-> BiRewriteH CoreExpr
wwC mr = parse3beforeBiR (wwAssC mr)
verifyWWAss :: CoreExpr
-> CoreExpr
-> CoreExpr
-> WWAssumption
-> TranslateH x ()
verifyWWAss wrap unwrap f (WWAssumption tag ass) =
case tag of
A -> verifyAssA wrap unwrap ass
B -> verifyAssB wrap unwrap f ass
C -> verifyAssC wrap unwrap f ass
verifyAssA :: CoreExpr
-> CoreExpr
-> RewriteH CoreExpr
-> TranslateH x ()
verifyAssA wrap unwrap assA =
prefixFailMsg ("verification of worker/wrapper Assumption A failed: ") $
do _ <- wrapUnwrapTypes wrap unwrap
verifyRetractionT wrap unwrap assA
verifyAssB :: CoreExpr
-> CoreExpr
-> CoreExpr
-> RewriteH CoreExpr
-> TranslateH x ()
verifyAssB wrap unwrap f assB =
prefixFailMsg ("verification of worker/wrapper assumption B failed: ") $
do (tyA,_) <- wrapUnwrapTypes wrap unwrap
a <- constT (newIdH "a" tyA)
let lhs = App wrap (App unwrap (App f (Var a)))
rhs = App f (Var a)
verifyEqualityLeftToRightT lhs rhs assB
verifyAssC :: CoreExpr
-> CoreExpr
-> CoreExpr
-> RewriteH CoreExpr
-> TranslateH a ()
verifyAssC wrap unwrap f assC =
prefixFailMsg ("verification of worker/wrapper assumption C failed: ") $
do (tyA,_) <- wrapUnwrapTypes wrap unwrap
a <- constT (newIdH "a" tyA)
rhs <- mkFixT f
lhs <- mkFixT (Lam a (App wrap (App unwrap (App f (Var a)))))
verifyEqualityLeftToRightT lhs rhs assC
wrapUnwrapTypes :: MonadCatch m => CoreExpr -> CoreExpr -> m (Type,Type)
wrapUnwrapTypes wrap unwrap = setFailMsg "given expressions have the wrong types to form a valid wrap/unwrap pair." $
funsWithInverseTypes unwrap wrap