{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}

module HERMIT.Dictionary.FixPoint
    ( -- * Operations on the Fixed Point Operator (fix)
      HERMIT.Dictionary.FixPoint.externals
      -- ** Rewrites and BiRewrites on Fixed Points
    , fixIntroR
    , fixIntroNonRecR
    , fixIntroRecR
    , fixComputationRuleBR
    , fixRollingRuleBR
    , fixFusionRuleBR
      -- ** Utilities
    , isFixExprT
    ) where

import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class

import Data.String (fromString)

import HERMIT.Context
import HERMIT.Core
import HERMIT.Monad
import HERMIT.Kure hiding ((<$>))
import HERMIT.External
import HERMIT.GHC
import HERMIT.Name
import HERMIT.ParserCore
import HERMIT.Utilities

import HERMIT.Dictionary.Common
import HERMIT.Dictionary.Function
import HERMIT.Dictionary.Kure
import HERMIT.Dictionary.Reasoning
import HERMIT.Dictionary.Undefined
import HERMIT.Dictionary.Unfold

import Prelude.Compat

--------------------------------------------------------------------------------------------------

-- | Externals for manipulating fixed points.
externals ::  [External]
externals =
    [ external "fix-intro" (promoteCoreR fixIntroR :: RewriteH LCore)
        [ "rewrite a function binding into a non-recursive binding using fix" ] .+ Introduce .+ Context
    , external "fix-computation-rule" (promoteExprBiR fixComputationRuleBR :: BiRewriteH LCore)
        [ "Fixed-Point Computation Rule",
          "fix t f  <==>  f (fix t f)"
        ] .+ Context
    , external "fix-rolling-rule" (promoteExprBiR fixRollingRuleBR :: BiRewriteH LCore)
        [ "Rolling Rule",
          "fix tyA (\\ a -> f (g a))  <==>  f (fix tyB (\\ b -> g (f b))"
        ] .+ Context
    , external "fix-fusion-rule" ((\ f g h r1 r2 strictf -> promoteExprBiR
                                                                (fixFusionRule (Just (r1,r2)) (Just strictf) f g h))
                                                                :: CoreString -> CoreString -> CoreString
                                                                    -> RewriteH LCore -> RewriteH LCore
                                                                    -> RewriteH LCore -> BiRewriteH LCore)
        [ "Fixed-point Fusion Rule"
        , "Given f :: A -> B, g :: A -> A, h :: B -> B, and"
        , "proofs that, for some x, (f (g a) ==> x) and (h (f a) ==> x) and that f is strict, then"
        , "f (fix g) <==> fix h"
        ] .+ Context
    , external "fix-fusion-rule-unsafe" ((\ f g h r1 r2 -> promoteExprBiR (fixFusionRule (Just (r1,r2)) Nothing f g h))
                                                            :: CoreString -> CoreString -> CoreString
                                                                -> RewriteH LCore -> RewriteH LCore -> BiRewriteH LCore)
        [ "(Unsafe) Fixed-point Fusion Rule"
        , "Given f :: A -> B, g :: A -> A, h :: B -> B, and"
        , "a proof that, for some x, (f (g a) ==> x) and (h (f a) ==> x), then"
        , "f (fix g) <==> fix h"
        , "Note that the precondition that f is strict is required to hold."
        ] .+ Context .+ PreCondition
    , external "fix-fusion-rule-unsafe" ((\ f g h -> promoteExprBiR (fixFusionRule Nothing Nothing f g h))
                                                        :: CoreString -> CoreString -> CoreString -> BiRewriteH LCore)
        [ "(Very Unsafe) Fixed-point Fusion Rule"
        , "Given f :: A -> B, g :: A -> A, h :: B -> B, then"
        , "f (fix g) <==> fix h"
        , "Note that the preconditions that f (g a) == h (f a) and that f is strict are required to hold."
        ] .+ Context .+ PreCondition
    ]

--------------------------------------------------------------------------------------------------

fixIntroR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
             , HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
          => Rewrite c m Core
fixIntroR = promoteR fixIntroRecR <+ promoteR fixIntroNonRecR

fixIntroNonRecR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
                   , HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
                => Rewrite c m CoreBind
fixIntroNonRecR = prefixFailMsg "fix introduction failed: " $ do
    NonRec f rhs <- idR
    rhs' <- polyFixT f <<< return rhs
    return $ NonRec f rhs'

-- |  @f = e@   ==\>   @f = fix (\\ f -> e)@
fixIntroRecR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
                , HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
             => Rewrite c m CoreDef
fixIntroRecR = prefixFailMsg "fix introduction failed: " $ do
    Def f rhs <- idR
    rhs' <- polyFixT f <<< return rhs
    return $ Def f rhs'

-- | Helper for fixIntroNonRecR and fixIntroRecR. Argument is function name.
-- Meant to be applied to RHS of function.
polyFixT :: forall c m.
            ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
            , HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
         => Id -> Rewrite c m CoreExpr
polyFixT f = do
    (tvs, body) <- arr collectTyBinders
    f' <- constT $ newIdH (unqualifiedName f) (exprType body)
    body' <- contextonlyT $ \ c -> do
                let constLam = mkCoreLams tvs $ varToCoreExpr f'
                    c' = addBindingGroup (NonRec f constLam)   -- we want to unfold f such as to throw away TyArgs
                       $ addBindingGroup (NonRec f' body)    c -- add f' to context so its in-scope after unfolding
                applyT (tryR (extractR (anyCallR (promoteR (unfoldPredR (const . (==f))) :: Rewrite c m Core)))) c' body
    liftM (mkCoreLams tvs) $ buildFixT $ Lam f' body'

--------------------------------------------------------------------------------------------------

-- | @fix ty f@  \<==\>  @f (fix ty f)@
fixComputationRuleBR :: BiRewriteH CoreExpr
fixComputationRuleBR = bidirectional computationL computationR
  where
    computationL :: RewriteH CoreExpr
    computationL = prefixFailMsg "fix computation rule failed: " $
                   do (_,f) <- isFixExprT
                      fixf  <- idR
                      return (App f fixf)

    computationR :: RewriteH CoreExpr
    computationR = prefixFailMsg "(backwards) fix computation rule failed: " $
                   do App f fixf <- idR
                      (_,f') <- isFixExprT <<< constant fixf
                      guardMsg (exprAlphaEq f f') "external function does not match internal expression"
                      return fixf


-- | @fix tyA (\\ a -> f (g a))@  \<==\>  @f (fix tyB (\\ b -> g (f b))@
fixRollingRuleBR :: BiRewriteH CoreExpr
fixRollingRuleBR = bidirectional rollingRuleL rollingRuleR
  where
    rollingRuleL :: RewriteH CoreExpr
    rollingRuleL = prefixFailMsg "rolling rule failed: " $
                   withPatFailMsg wrongFixBody $
                   do (tyA, Lam a (App f (App g (Var a')))) <- isFixExprT
                      guardMsg (a == a') wrongFixBody
                      (tyA',tyB) <- funExprsWithInverseTypes g f
                      guardMsg (eqType tyA tyA') "Type mismatch: this shouldn't have happened, report this as a bug."
                      res <- rollingRuleResult tyB g f
                      return (App f res)

    rollingRuleR :: RewriteH CoreExpr
    rollingRuleR = prefixFailMsg "(reversed) rolling rule failed: " $
                   withPatFailMsg "not an application." $
                   do App f fx <- idR
                      withPatFailMsg wrongFixBody $
                        do (tyB, Lam b (App g (App f' (Var b')))) <- isFixExprT <<< constant fx
                           guardMsg (b == b') wrongFixBody
                           guardMsg (exprAlphaEq f f') "external function does not match internal expression"
                           (tyA,tyB') <- funExprsWithInverseTypes g f
                           guardMsg (eqType tyB tyB') "Type mismatch: this shouldn't have happened, report this as a bug."
                           rollingRuleResult tyA f g

    rollingRuleResult :: Type -> CoreExpr -> CoreExpr -> TransformH z CoreExpr
    rollingRuleResult ty f g = do x <- constT (newIdH "x" ty)
                                  buildFixT (Lam x (App f (App g (Var x))))

    wrongFixBody :: String
    wrongFixBody = "body of fix does not have the form: Lam v (App f (App g (Var v)))"

--------------------------------------------------------------------------------------------------

-- f :: A -> B
-- g :: A -> A
-- h :: B -> B

-- | If @f@ is strict, then (@f (g a)@ == @h (f a)@)  ==\>  (@f (fix g)@ == @fix h@)
fixFusionRuleBR :: Maybe (EqualityProof HermitC HermitM) -> Maybe (RewriteH CoreExpr) -> CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
fixFusionRuleBR meq mfstrict f g h = beforeBiR
  (prefixFailMsg "fixed-point fusion failed: " $
   do (_,tyA,tyB) <- funExprArgResTypesM f -- TODO: don't throw away TyVars
      (_,tyA')    <- endoFunExprTypeM g
      (_,tyB')    <- endoFunExprTypeM h
      guardMsg (typeAlphaEq tyA tyA' && typeAlphaEq tyB tyB') "given functions do not have compatible types."
      whenJust (verifyStrictT f) mfstrict
      whenJust (\ eq ->
                  do a <- constT (newGlobalIdH "a" tyA)
                     let lhs = App f (App g (Var a))
                         rhs = App h (App f (Var a))
                     verifyEqualityCommonTargetT lhs rhs eq
               )
               meq
  )
  (\ () -> bidirectional fixFusionL fixFusionR)

     where
       fixFusionL :: RewriteH CoreExpr
       fixFusionL = prefixFailMsg "fixed-point fusion failed: " $
                    withPatFailMsg (wrongExprForm "App f (fix g)") $
                    do App f' fixg <- idR
                       guardMsg (exprAlphaEq f f') "first argument function does not match."
                       (_,g') <- isFixExprT <<< return fixg
                       guardMsg (exprAlphaEq g g') "second argument function does not match."
                       buildFixT h

       fixFusionR :: RewriteH CoreExpr
       fixFusionR = prefixFailMsg "(reversed) fixed-point fusion failed: " $
                    do (_,h') <- isFixExprT
                       guardMsg (exprAlphaEq h h') "third argument function does not match."
                       App f <$> buildFixT g

-- | If @f@ is strict, then (@f (g a)@ == @h (f a)@)  ==>  (@f (fix g)@ == @fix h@)
fixFusionRule :: Maybe (RewriteH LCore, RewriteH LCore) -> Maybe (RewriteH LCore) -> CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr
fixFusionRule meq mfstrict = parse3beforeBiR $ fixFusionRuleBR ((extractR *** extractR) <$> meq) (extractR <$> mfstrict)

--------------------------------------------------------------------------------------------------

-- | Check that the expression has the form "fix t (f :: t -> t)", returning "t" and "f".
isFixExprT :: TransformH CoreExpr (Type,CoreExpr)
isFixExprT = withPatFailMsg (wrongExprForm "fix t f") $ -- fix :: forall a. (a -> a) -> a
  do (Var fixId, [Type ty, f]) <- callT
     fixId' <- findIdT fixLocation
     guardMsg (fixId == fixId') (unqualifiedName fixId ++ " does not match " ++ show fixLocation)
     return (ty,f)

--------------------------------------------------------------------------------------------------

fixLocation :: HermitName
fixLocation = fromString "Data.Function.fix"

--------------------------------------------------------------------------------------------------