module HERMIT.Dictionary.WorkerWrapper.Fix
       ( -- * The Worker/Wrapper Transformation
         -- | Note that many of these operations require 'Data.Function.fix' to be in scope.
         HERMIT.Dictionary.WorkerWrapper.Fix.externals
       , wwFacBR
       , wwSplitR
       , wwSplitStaticArg
       , wwGenerateFusionR
       , wwFusionBR
       , wwAssA
       , wwAssB
       , wwAssC
       )
where

import Control.Applicative
import Control.Arrow

import HERMIT.Core
import HERMIT.Monad
import HERMIT.Kure
import HERMIT.External
import HERMIT.GHC
import HERMIT.Utilities
import HERMIT.ParserCore

import HERMIT.Dictionary.AlphaConversion
import HERMIT.Dictionary.Common
import HERMIT.Dictionary.FixPoint
import HERMIT.Dictionary.Function
import HERMIT.Dictionary.Local
import HERMIT.Dictionary.Navigation
import HERMIT.Dictionary.Reasoning
import HERMIT.Dictionary.Unfold

import HERMIT.Dictionary.WorkerWrapper.Common

import qualified Language.Haskell.TH as TH

--------------------------------------------------------------------------------------------------

-- | Externals for manipulating fixed points, and for the worker/wrapper transformation.
externals ::  [External]
externals =
         [
           external "ww-factorisation" ((\ wrap unwrap assC -> promoteExprBiR $ wwFac (mkWWAssC assC) wrap unwrap)
                                          :: CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
                [ "Worker/Wrapper Factorisation",
                  "For any \"f :: A -> A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments,",
                  "and a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
                  "fix A f  ==>  wrap (fix B (\\ b -> unwrap (f (wrap b))))"
                ] .+ Introduce .+ Context
         , external "ww-factorisation-unsafe" ((\ wrap unwrap -> promoteExprBiR $ wwFac Nothing wrap unwrap)
                                               :: CoreString -> CoreString -> BiRewriteH Core)
                [ "Unsafe Worker/Wrapper Factorisation",
                  "For any \"f :: A -> A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments, then",
                  "fix A f  <==>  wrap (fix B (\\ b -> unwrap (f (wrap b))))",
                  "Note: the pre-condition \"fix A (\\ a -> wrap (unwrap (f a))) == fix A f\" is expected to hold."
                ] .+ Introduce .+ Context .+ PreCondition
         , external "ww-split" ((\ wrap unwrap assC -> promoteDefR $ wwSplit (mkWWAssC assC) wrap unwrap)
                                  :: CoreString -> CoreString -> RewriteH Core -> RewriteH Core)
                [ "Worker/Wrapper Split",
                  "For any \"prog :: A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments,",
                  "and a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
                  "prog = expr  ==>  prog = let f = \\ prog -> expr",
                  "                          in let work = unwrap (f (wrap work))",
                  "                              in wrap work"
                ] .+ Introduce .+ Context
         , external "ww-split-unsafe" ((\ wrap unwrap -> promoteDefR $ wwSplit Nothing wrap unwrap)
                                       :: CoreString -> CoreString -> RewriteH Core)
                [ "Unsafe Worker/Wrapper Split",
                  "For any \"prog :: A\", and given \"wrap :: B -> A\" and \"unwrap :: A -> B\" as arguments, then",
                  "prog = expr  ==>  prog = let f = \\ prog -> expr",
                  "                          in let work = unwrap (f (wrap work))",
                  "                              in wrap work",
                  "Note: the pre-condition \"fix A (wrap . unwrap . f) == fix A f\" is expected to hold."
                ] .+ Introduce .+ Context .+ PreCondition
         , external "ww-split-static-arg" ((\ n is wrap unwrap assC -> promoteDefR $ wwSplitStaticArg n is (mkWWAssC assC) wrap unwrap)
                                      :: Int -> [Int] -> CoreString -> CoreString -> RewriteH Core -> RewriteH Core)
                [ "Worker/Wrapper Split - Static Argument Variant",
                  "Perform the static argument transformation on the first n arguments, then perform the worker/wrapper split,",
                  "applying the given wrap and unwrap functions to the specified (by index) static arguments before use."
                ] .+ Introduce .+ Context
         , external "ww-split-static-arg-unsafe" ((\ n is wrap unwrap -> promoteDefR $ wwSplitStaticArg n is Nothing wrap unwrap)
                                      :: Int -> [Int] -> CoreString -> CoreString -> RewriteH Core)
                [ "Unsafe Worker/Wrapper Split - Static Argument Variant",
                  "Perform the static argument transformation on the first n arguments, then perform the (unsafe) worker/wrapper split,",
                  "applying the given wrap and unwrap functions to the specified (by index) static arguments before use."
                ] .+ Introduce .+ Context .+ PreCondition
         , external "ww-assumption-A" ((\ wrap unwrap assA -> promoteExprBiR $ wwA (Just $ extractR assA) wrap unwrap)
                                       :: CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
                [ "Worker/Wrapper Assumption A",
                  "For a \"wrap :: B -> A\" and an \"unwrap :: A -> B\",",
                  "and given a proof of \"wrap (unwrap a) ==> a\", then",
                  "wrap (unwrap a)  <==>  a"
                ] .+ Introduce .+ Context
         , external "ww-assumption-B" ((\ wrap unwrap f assB -> promoteExprBiR $ wwB (Just $ extractR assB) wrap unwrap f)
                                       :: CoreString -> CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
                [ "Worker/Wrapper Assumption B",
                  "For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\",",
                  "and given a proof of \"wrap (unwrap (f a)) ==> f a\", then",
                  "wrap (unwrap (f a))  <==>  f a"
                ] .+ Introduce .+ Context
         , external "ww-assumption-C" ((\ wrap unwrap f assC -> promoteExprBiR $ wwC (Just $ extractR assC) wrap unwrap f)
                                       :: CoreString -> CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
                [ "Worker/Wrapper Assumption C",
                  "For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\",",
                  "and given a proof of \"fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f\", then",
                  "fix A (\\ a -> wrap (unwrap (f a)))  <==>  fix A f"
                ] .+ Introduce .+ Context
         , external "ww-assumption-A-unsafe" ((\ wrap unwrap -> promoteExprBiR $ wwA Nothing wrap unwrap)
                                              :: CoreString -> CoreString -> BiRewriteH Core)
                [ "Unsafe Worker/Wrapper Assumption A",
                  "For a \"wrap :: B -> A\" and an \"unwrap :: A -> B\", then",
                  "wrap (unwrap a)  <==>  a",
                  "Note: only use this if it's true!"
                ] .+ Introduce .+ Context .+ PreCondition
         , external "ww-assumption-B-unsafe" ((\ wrap unwrap f -> promoteExprBiR $ wwB Nothing wrap unwrap f)
                                              :: CoreString -> CoreString -> CoreString -> BiRewriteH Core)
                [ "Unsafe Worker/Wrapper Assumption B",
                  "For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\", then",
                  "wrap (unwrap (f a))  <==>  f a",
                  "Note: only use this if it's true!"
                ] .+ Introduce .+ Context .+ PreCondition
         , external "ww-assumption-C-unsafe" ((\ wrap unwrap f -> promoteExprBiR $ wwC Nothing wrap unwrap f)
                                              :: CoreString -> CoreString -> CoreString -> BiRewriteH Core)
                [ "Unsafe Worker/Wrapper Assumption C",
                  "For a \"wrap :: B -> A\", an \"unwrap :: A -> B\", and an \"f :: A -> A\", then",
                  "fix A (\\ a -> wrap (unwrap (f a)))  <==>  fix A f",
                  "Note: only use this if it's true!"
                ] .+ Introduce .+ Context .+ PreCondition
         , external "ww-AssA-to-AssB" (promoteExprR . wwAssAimpliesAssB . extractR :: RewriteH Core -> RewriteH Core)
                   [ "Convert a proof of worker/wrapper Assumption A into a proof of worker/wrapper Assumption B."
                   ]
         , external "ww-AssB-to-AssC" (promoteExprR . wwAssBimpliesAssC . extractR :: RewriteH Core -> RewriteH Core)
                   [ "Convert a proof of worker/wrapper Assumption B into a proof of worker/wrapper Assumption C."
                   ]
         , external "ww-AssA-to-AssC" (promoteExprR . wwAssAimpliesAssC . extractR :: RewriteH Core -> RewriteH Core)
                   [ "Convert a proof of worker/wrapper Assumption A into a proof of worker/wrapper Assumption C."
                   ]
         , external "ww-generate-fusion" (wwGenerateFusionR . mkWWAssC :: RewriteH Core -> RewriteH Core)
                   [ "Given a proof of Assumption C (fix A (\\ a -> wrap (unwrap (f a))) ==> fix A f), then",
                     "execute this command on \"work = unwrap (f (wrap work))\" to enable the \"ww-fusion\" rule thereafter.",
                     "Note that this is performed automatically as part of \"ww-split\"."
                   ] .+ Experiment .+ TODO
         , external "ww-generate-fusion-unsafe" (wwGenerateFusionR Nothing :: RewriteH Core)
                   [ "Execute this command on \"work = unwrap (f (wrap work))\" to enable the \"ww-fusion\" rule thereafter.",
                     "The precondition \"fix A (wrap . unwrap . f) == fix A f\" is expected to hold.",
                     "Note that this is performed automatically as part of \"ww-split\"."
                   ] .+ Experiment .+ TODO
         , external "ww-fusion" (promoteExprBiR wwFusion :: BiRewriteH Core)
                [ "Worker/Wrapper Fusion",
                  "unwrap (wrap work)  <==>  work",
                  "Note: you are required to have previously executed the command \"ww-generate-fusion\" on the definition",
                  "      work = unwrap (f (wrap work))"
                ] .+ Introduce .+ Context .+ PreCondition .+ TODO
         ]
  where
    mkWWAssC :: RewriteH Core -> Maybe WWAssumption
    mkWWAssC r = Just (WWAssumption C (extractR r))

--------------------------------------------------------------------------------------------------

-- | For any @f :: A -> A@, and given @wrap :: B -> A@ and @unwrap :: A -> B@ as arguments, then
--   @fix A f@  \<==\>  @wrap (fix B (\\ b -> unwrap (f (wrap b))))@
wwFacBR :: Maybe WWAssumption -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
wwFacBR mAss wrap unwrap = beforeBiR (wrapUnwrapTypes wrap unwrap)
                                                (\ (tyA,tyB) -> bidirectional (wwL tyA tyB) wwR)
  where
    wwL :: Type -> Type -> RewriteH CoreExpr
    wwL tyA tyB = prefixFailMsg "worker/wrapper factorisation failed: " $
                  do (tA,f) <- isFixExprT
                     guardMsg (eqType tyA tA) ("wrapper/unwrapper types do not match fix body type.")
                     whenJust (verifyWWAss wrap unwrap f) mAss
                     b <- constT (newIdH "x" tyB)
                     App wrap <$> mkFixT (Lam b (App unwrap (App f (App wrap (Var b)))))

    wwR :: RewriteH CoreExpr
    wwR  =    prefixFailMsg "(reverse) worker/wrapper factorisation failed: " $
              withPatFailMsg "not an application." $
              do App wrap2 fx <- idR
                 withPatFailMsg wrongFixBody $
                   do (_, Lam b (App unwrap1 (App f (App wrap1 (Var b'))))) <- isFixExprT <<< constant fx
                      guardMsg (b == b') wrongFixBody
                      guardMsg (equivalentBy exprAlphaEq [wrap, wrap1, wrap2]) "wrappers do not match."
                      guardMsg (exprAlphaEq unwrap unwrap1) "unwrappers do not match."
                      whenJust (verifyWWAss wrap unwrap f) mAss
                      mkFixT f

    wrongFixBody :: String
    wrongFixBody = "body of fix does not have the form Lam b (App unwrap (App f (App wrap (Var b))))"

-- | For any @f :: A -> A@, and given @wrap :: B -> A@ and @unwrap :: A -> B@ as arguments, then
--   @fix A f@  \<==\>  @wrap (fix B (\\ b -> unwrap (f (wrap b))))@
wwFac :: Maybe WWAssumption -> CoreString -> CoreString -> BiRewriteH CoreExpr
wwFac mAss = parse2beforeBiR (wwFacBR mAss)

--------------------------------------------------------------------------------------------------

-- | Given @wrap :: B -> A@, @unwrap :: A -> B@ and @work :: B@ as arguments, then
--   @unwrap (wrap work)@  \<==\>  @work@
wwFusionBR :: BiRewriteH CoreExpr
wwFusionBR =
    beforeBiR (prefixFailMsg "worker/wrapper fusion failed: " $
               withPatFailMsg "malformed WW Fusion rule." $
               do Def w (App unwrap (App _f (App wrap (Var w')))) <- constT (lookupDef workLabel)
                  guardMsg (w == w') "malformed WW Fusion rule."
                  return (wrap,unwrap,Var w)
              )
              (\ (wrap,unwrap,work) -> bidirectional (fusL wrap unwrap work) (fusR wrap unwrap work))
  where
    fusL :: CoreExpr -> CoreExpr -> CoreExpr -> RewriteH CoreExpr
    fusL wrap unwrap work =
           prefixFailMsg "worker/wrapper fusion failed: " $
           withPatFailMsg (wrongExprForm "unwrap (wrap work)") $
           do App unwrap' (App wrap' work') <- idR
              guardMsg (exprAlphaEq wrap wrap') "wrapper does not match."
              guardMsg (exprAlphaEq unwrap unwrap') "unwrapper does not match."
              guardMsg (exprAlphaEq work work') "worker does not match."
              return work

    fusR :: CoreExpr -> CoreExpr -> CoreExpr -> RewriteH CoreExpr
    fusR wrap unwrap work =
           prefixFailMsg "(reverse) worker/wrapper fusion failed: " $
           do work' <- idR
              guardMsg (exprAlphaEq work work') "worker does not match."
              return $ App unwrap (App wrap work)


-- | Given @wrap :: B -> A@, @unwrap :: A -> B@ and @work :: B@ as arguments, then
--   @unwrap (wrap work)@  \<==\>  @work@
wwFusion :: BiRewriteH CoreExpr
wwFusion = wwFusionBR

--------------------------------------------------------------------------------------------------

-- | Save the recursive definition of work in the stash, so that we can later verify uses of 'wwFusionBR'.
--   Must be applied to a definition of the form: @work = unwrap (f (wrap work))@
--   Note that this is performed automatically as part of 'wwSplitR'.
wwGenerateFusionR :: Maybe WWAssumption -> RewriteH Core
wwGenerateFusionR mAss =
    prefixFailMsg "generate WW fusion failed: " $
    withPatFailMsg wrongForm $
    do Def w (App unwrap (App f (App wrap (Var w')))) <- projectT
       guardMsg (w == w') wrongForm
       whenJust (verifyWWAss wrap unwrap f) mAss
       rememberR workLabel
  where
    wrongForm = "definition does not have the form: work = unwrap (f (wrap work))"

--------------------------------------------------------------------------------------------------

-- | \\ wrap unwrap ->  (@prog = expr@  ==>  @prog = let f = \\ prog -> expr in let work = unwrap (f (wrap work)) in wrap work)@
wwSplitR :: Maybe WWAssumption -> CoreExpr -> CoreExpr -> RewriteH CoreDef
wwSplitR mAss wrap unwrap =
  let work = TH.mkName "work"
      fx   = TH.mkName "fix"
   in
      fixIntroR
      >>> defAllR idR ( appAllR idR (letIntroR "f")
                        >>> letFloatArgR
                        >>> letAllR idR ( forwardT (wwFacBR mAss wrap unwrap)
                                          >>> appAllR idR ( unfoldNameR fx
                                                            >>> alphaLetWithR [work]
                                                            >>> letRecAllR (\ _ -> defAllR idR (betaReduceR >>> letNonRecSubstR)
                                                                                   >>> extractR (wwGenerateFusionR mAss)
                                                                           )
                                                                           idR
                                                          )
                                          >>> letFloatArgR
                                        )
                      )

-- | \\ wrap unwrap ->  (@prog = expr@  ==>  @prog = let f = \\ prog -> expr in let work = unwrap (f (wrap work)) in wrap work)@
wwSplit :: Maybe WWAssumption -> CoreString -> CoreString -> RewriteH CoreDef
wwSplit mAss wrapS unwrapS = (parseCoreExprT wrapS &&& parseCoreExprT unwrapS) >>= uncurry (wwSplitR mAss)

-- | As 'wwSplit' but performs the static-argument transformation for @n@ static arguments first, and optionally provides some of those arguments (specified by index) to all calls of wrap and unwrap.
--   This is useful if, for example, the expression, and wrap and unwrap, all have a @forall@ type.
wwSplitStaticArg :: Int -> [Int] -> Maybe WWAssumption -> CoreString -> CoreString -> RewriteH CoreDef
wwSplitStaticArg 0 _  = wwSplit
wwSplitStaticArg n is = \ mAss wrapS unwrapS ->
                            prefixFailMsg "worker/wrapper split (static argument variant) failed: " $
                            do guardMsg (all (< n) is) "arguments for wrap and unwrap must be chosen from the statically transformed arguments."
                               bs <- defT successT (arr collectBinders) (\ () -> take n . fst)
                               let args = varsToCoreExprs [ b | (i,b) <- zip [0..] bs, i `elem` is ]

                               staticArgPosR [0..(n-1)] >>> defAllR idR
                                                                    (let wwSplitArgsR :: RewriteH CoreDef
                                                                         wwSplitArgsR = do wrap   <- parseCoreExprT wrapS
                                                                                           unwrap <- parseCoreExprT unwrapS
                                                                                           wwSplitR mAss (mkCoreApps wrap args) (mkCoreApps unwrap args)
                                                                      in
                                                                         extractR $ do p <- considerConstructT LetExpr
                                                                                       localPathR p $ promoteExprR (letRecAllR (const wwSplitArgsR) idR >>> letSubstR)
                                                                    )


--------------------------------------------------------------------------------------------------

-- | Convert a proof of WW Assumption A into a proof of WW Assumption B.
wwAssAimpliesAssB :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssAimpliesAssB = id

-- | Convert a proof of WW Assumption B into a proof of WW Assumption C.
wwAssBimpliesAssC :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssBimpliesAssC assB = appAllR idR (lamAllR idR assB >>> etaReduceR)

-- | Convert a proof of WW Assumption A into a proof of WW Assumption C.
wwAssAimpliesAssC :: RewriteH CoreExpr -> RewriteH CoreExpr
wwAssAimpliesAssC =  wwAssBimpliesAssC . wwAssAimpliesAssB

--------------------------------------------------------------------------------------------------

-- | @wrap (unwrap a)@  \<==\>  @a@
wwAssA :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption A
       -> CoreExpr                  -- ^ wrap
       -> CoreExpr                  -- ^ unwrap
       -> BiRewriteH CoreExpr
wwAssA mr wrap unwrap = beforeBiR (do whenJust (verifyAssA wrap unwrap) mr
                                      wrapUnwrapTypes wrap unwrap
                                  )
                                  (\ (tyA,_) -> bidirectional wwAL (wwAR tyA))
  where
    wwAL :: RewriteH CoreExpr
    wwAL = withPatFailMsg (wrongExprForm "App wrap (App unwrap x)") $
           do App wrap' (App unwrap' x) <- idR
              guardMsg (exprAlphaEq wrap wrap')     "given wrapper does not match wrapper in expression."
              guardMsg (exprAlphaEq unwrap unwrap') "given unwrapper does not match unwrapper in expression."
              return x

    wwAR :: Type -> RewriteH CoreExpr
    wwAR tyA = do x <- idR
                  guardMsg (exprKindOrType x `eqType` tyA) "type of expression does not match types of wrap/unwrap."
                  return $ App wrap (App unwrap x)

-- | @wrap (unwrap a)@  \<==\>  @a@
wwA :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption A
    -> CoreString                -- ^ wrap
    -> CoreString                -- ^ unwrap
    -> BiRewriteH CoreExpr
wwA mr = parse2beforeBiR (wwAssA mr)

-- | @wrap (unwrap (f a))@  \<==\>  @f a@
wwAssB :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption B
       -> CoreExpr                  -- ^ wrap
       -> CoreExpr                  -- ^ unwrap
       -> CoreExpr                  -- ^ f
       -> BiRewriteH CoreExpr
wwAssB mr wrap unwrap f = beforeBiR (whenJust (verifyAssB wrap unwrap f) mr)
                                    (\ () -> bidirectional wwBL wwBR)
  where
    assA :: BiRewriteH CoreExpr
    assA = wwAssA Nothing wrap unwrap

    wwBL :: RewriteH CoreExpr
    wwBL = withPatFailMsg (wrongExprForm "App wrap (App unwrap (App f a))") $
           do App _ (App _ (App f' _)) <- idR
              guardMsg (exprAlphaEq f f') "given body function does not match expression."
              forwardT assA

    wwBR :: RewriteH CoreExpr
    wwBR = withPatFailMsg (wrongExprForm "App f a") $
           do App f' _ <- idR
              guardMsg (exprAlphaEq f f') "given body function does not match expression."
              backwardT assA

-- | @wrap (unwrap (f a))@  \<==\>  @f a@
wwB :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption B
    -> CoreString                -- ^ wrap
    -> CoreString                -- ^ unwrap
    -> CoreString                -- ^ f
    -> BiRewriteH CoreExpr
wwB mr = parse3beforeBiR (wwAssB mr)

-- | @fix A (\ a -> wrap (unwrap (f a)))@  \<==\>  @fix A f@
wwAssC :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption C
       -> CoreExpr                  -- ^ wrap
       -> CoreExpr                  -- ^ unwrap
       -> CoreExpr                  -- ^ f
       -> BiRewriteH CoreExpr
wwAssC mr wrap unwrap f = beforeBiR (do _ <- isFixExprT
                                        whenJust (verifyAssC wrap unwrap f) mr
                                    )
                                    (\ () -> bidirectional wwCL wwCR)
  where
    assB :: BiRewriteH CoreExpr
    assB = wwAssB Nothing wrap unwrap f

    wwCL :: RewriteH CoreExpr
    wwCL = wwAssBimpliesAssC (forwardT assB)

    wwCR :: RewriteH CoreExpr
    wwCR = appAllR idR (etaExpandR "a" >>> lamAllR idR (backwardT assB))

-- | @fix A (\ a -> wrap (unwrap (f a)))@  \<==\>  @fix A f@
wwC :: Maybe (RewriteH CoreExpr) -- ^ WW Assumption C
    -> CoreString                -- ^ wrap
    -> CoreString                -- ^ unwrap
    -> CoreString                -- ^ f
    -> BiRewriteH CoreExpr
wwC mr = parse3beforeBiR (wwAssC mr)

--------------------------------------------------------------------------------------------------

verifyWWAss :: CoreExpr        -- ^ wrap
            -> CoreExpr        -- ^ unwrap
            -> CoreExpr        -- ^ f
            -> WWAssumption
            -> TranslateH x ()
verifyWWAss wrap unwrap f (WWAssumption tag ass) =
    case tag of
      A -> verifyAssA wrap unwrap ass
      B -> verifyAssB wrap unwrap f ass
      C -> verifyAssC wrap unwrap f ass

verifyAssA :: CoreExpr          -- ^ wrap
           -> CoreExpr          -- ^ unwrap
           -> RewriteH CoreExpr -- ^ WW Assumption A
           -> TranslateH x ()
verifyAssA wrap unwrap assA =
  prefixFailMsg ("verification of worker/wrapper Assumption A failed: ") $
  do _ <- wrapUnwrapTypes wrap unwrap -- this check is redundant, but will produce a better error message
     verifyRetractionT wrap unwrap assA

verifyAssB :: CoreExpr          -- ^ wrap
           -> CoreExpr          -- ^ unwrap
           -> CoreExpr          -- ^ f
           -> RewriteH CoreExpr -- ^ WW Assumption B
           -> TranslateH x ()
verifyAssB wrap unwrap f assB =
  prefixFailMsg ("verification of worker/wrapper assumption B failed: ") $
  do (tyA,_) <- wrapUnwrapTypes wrap unwrap
     a       <- constT (newIdH "a" tyA)
     let lhs = App wrap (App unwrap (App f (Var a)))
         rhs = App f (Var a)
     verifyEqualityLeftToRightT lhs rhs assB

verifyAssC :: CoreExpr          -- ^ wrap
           -> CoreExpr          -- ^ unwrap
           -> CoreExpr          -- ^ f
           -> RewriteH CoreExpr -- ^ WW Assumption C
           -> TranslateH a ()
verifyAssC wrap unwrap f assC =
  prefixFailMsg ("verification of worker/wrapper assumption C failed: ") $
  do (tyA,_) <- wrapUnwrapTypes wrap unwrap
     a       <- constT (newIdH "a" tyA)
     rhs     <- mkFixT f
     lhs     <- mkFixT (Lam a (App wrap (App unwrap (App f (Var a)))))
     verifyEqualityLeftToRightT lhs rhs assC

--------------------------------------------------------------------------------------------------

wrapUnwrapTypes :: MonadCatch m => CoreExpr -> CoreExpr -> m (Type,Type)
wrapUnwrapTypes wrap unwrap = setFailMsg "given expressions have the wrong types to form a valid wrap/unwrap pair." $
                              funsWithInverseTypes unwrap wrap

--------------------------------------------------------------------------------------------------