{-# LANGUAGE CPP                        #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- | The monad for the termination checker.
--
--   The termination monad @TerM@ is an extension of
--   the type checking monad 'TCM' by an environment
--   with information needed by the termination checker.

module Agda.Termination.Monad where

import Prelude hiding (null)

import Control.Applicative hiding (empty)
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Monad.State

import Data.Foldable (Foldable)
import Data.Traversable (Traversable)
import Data.Semigroup (Semigroup)

import Agda.Interaction.Options

import Agda.Syntax.Abstract (IsProjP(..), AllNames)
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Literal
import Agda.Syntax.Position (noRange)

import Agda.Termination.CutOff
import Agda.Termination.Order (Order,le,unknown)
import Agda.Termination.RecCheck (anyDefs)

import Agda.TypeChecking.Monad
import Agda.TypeChecking.Monad.Benchmark
import Agda.TypeChecking.Monad.Builtin
import Agda.TypeChecking.Pretty as TCP
import Agda.TypeChecking.Records
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Substitute

import Agda.Utils.Except ( MonadError(catchError, throwError) )
import Agda.Utils.Function
import Agda.Utils.Functor
import Agda.Utils.Lens
import Agda.Utils.Maybe
import Agda.Utils.Monad
import Agda.Utils.Null
import Agda.Utils.Pretty (Pretty, prettyShow)
import qualified Agda.Utils.Pretty as P
import Agda.Utils.VarSet (VarSet)
import qualified Agda.Utils.VarSet as VarSet

#include "undefined.h"
import Agda.Utils.Impossible

-- | The mutual block we are checking.
--
--   The functions are numbered according to their order of appearance
--   in this list.

type MutualNames = [QName]

-- | The target of the function we are checking.

type Target = QName

-- | The current guardedness level.

type Guarded = Order

-- | The termination environment.

data TerEnv = TerEnv

  -- First part: options, configuration.

  { terUseDotPatterns :: Bool
    -- ^ Are we mining dot patterns to find evindence of structal descent?
  , terGuardingTypeConstructors :: Bool
    -- ^ Do we assume that record and data type constructors
    --   preserve guardedness?
  , terInlineWithFunctions :: Bool
    -- ^ Do we inline with functions to enhance termination checking of with?
  , terSizeSuc :: Maybe QName
    -- ^ The name of size successor, if any.
  , terSharp   :: Maybe QName
    -- ^ The name of the delay constructor (sharp), if any.
  , terCutOff  :: CutOff
    -- ^ Depth at which to cut off the structural order.

  -- Second part: accumulated info during descent into decls./term.

  , terCurrent :: QName
    -- ^ The name of the function we are currently checking.
  , terMutual  :: MutualNames
    -- ^ The names of the functions in the mutual block we are checking.
    --   This includes the internally generated functions
    --   (with, extendedlambda, coinduction).
  , terUserNames :: [QName]
    -- ^ The list of name actually appearing in the file (abstract syntax).
    --   Excludes the internally generated functions.
  , terTarget  :: Maybe Target
    -- ^ Target type of the function we are currently termination checking.
    --   Only the constructors of 'Target' are considered guarding.
  , terDelayed :: Delayed
    -- ^ Are we checking a delayed definition?
  , terMaskArgs :: [Bool]
    -- ^ Only consider the 'notMasked' 'False' arguments for establishing termination.
    --   See issue #1023.
  , terMaskResult :: Bool
    -- ^ Only consider guardedness if 'False' (not masked).
  , _terSizeDepth :: Int  -- lazy by intention!
    -- ^ How many @SIZELT@ relations do we have in the context
    --   (= clause telescope).  Used to approximate termination
    --   for metas in call args.
  , terPatterns :: MaskedDeBruijnPats
    -- ^ The patterns of the clause we are checking.
  , terPatternsRaise :: !Int
    -- ^ Number of additional binders we have gone under
    --   (and consequently need to raise the patterns to compare to terms).
    --   Updated during call graph extraction, hence strict.
  , terGuarded :: !Guarded
    -- ^ The current guardedness status.  Changes as we go deeper into the term.
    --   Updated during call graph extraction, hence strict.
  , terUseSizeLt :: Bool
    -- ^ When extracting usable size variables during construction of the call
    --   matrix, can we take the variable for use with SIZELT constraints from the context?
    --   Yes, if we are under an inductive constructor.
    --   No, if we are under a record constructor.
    --   (See issue #1015).
  , terUsableVars :: VarSet
    -- ^ Pattern variables that can be compared to argument variables using SIZELT.
  }

-- | An empty termination environment.
--
--   Values are set to a safe default meaning that with these
--   initial values the termination checker will not miss
--   termination errors it would have seen with better settings
--   of these values.
--
--   Values that do not have a safe default are set to
--   @IMPOSSIBLE@.

--   Note: Do not write @__IMPOSSIBLE__@ in the haddock comment above
--   since it will be expanded by the CPP, leading to a haddock parse error.

defaultTerEnv :: TerEnv
defaultTerEnv = TerEnv
  { terUseDotPatterns           = False -- must be False initially!
  , terGuardingTypeConstructors = False
  , terInlineWithFunctions      = True
  , terSizeSuc                  = Nothing
  , terSharp                    = Nothing
  , terCutOff                   = defaultCutOff
  , terUserNames                = __IMPOSSIBLE__ -- needs to be set!
  , terMutual                   = __IMPOSSIBLE__ -- needs to be set!
  , terCurrent                  = __IMPOSSIBLE__ -- needs to be set!
  , terTarget                   = Nothing
  , terDelayed                  = NotDelayed
  , terMaskArgs                 = repeat False   -- use all arguments (mask none)
  , terMaskResult               = False          -- use result (do not mask)
  , _terSizeDepth               = __IMPOSSIBLE__ -- needs to be set!
  , terPatterns                 = __IMPOSSIBLE__ -- needs to be set!
  , terPatternsRaise            = 0
  , terGuarded                  = le -- not initially guarded
  , terUseSizeLt                = False -- initially, not under data constructor
  , terUsableVars               = VarSet.empty
  }

-- | Termination monad service class.

class (Functor m, Monad m) => MonadTer m where
  terAsk   :: m TerEnv
  terLocal :: (TerEnv -> TerEnv) -> m a -> m a

  terAsks :: (TerEnv -> a) -> m a
  terAsks f = f <$> terAsk

-- | Termination monad.

newtype TerM a = TerM { terM :: ReaderT TerEnv TCM a }
  deriving (Functor, Applicative, Monad, MonadBench Phase)

instance MonadTer TerM where
  terAsk     = TerM $ ask
  terLocal f = TerM . local f . terM

-- | Generic run method for termination monad.
runTer :: TerEnv -> TerM a -> TCM a
runTer tenv (TerM m) = runReaderT m tenv

-- | Run TerM computation in default environment (created from options).

runTerDefault :: TerM a -> TCM a
runTerDefault cont = do

  -- Assemble then initial configuration of the termination environment.

  cutoff <- optTerminationDepth <$> pragmaOptions

  -- Get the name of size suc (if sized types are enabled)
  suc <- sizeSucName

  -- The name of sharp (if available).
  sharp <- fmap nameOfSharp <$> coinductionKit

  guardingTypeConstructors <-
    optGuardingTypeConstructors <$> pragmaOptions

  -- Andreas, 2014-08-28
  -- We do not inline with functions if --without-K.
  inlineWithFunctions <- not . optWithoutK <$> pragmaOptions

  let tenv = defaultTerEnv
        { terGuardingTypeConstructors = guardingTypeConstructors
        , terInlineWithFunctions      = inlineWithFunctions
        , terSizeSuc                  = suc
        , terSharp                    = sharp
        , terCutOff                   = cutoff
        }

  runTer tenv cont

-- * Termination monad is a 'MonadTCM'.

instance MonadReader TCEnv TerM where
  ask       = TerM $ lift $ ask
  local f m = TerM $ ReaderT $ local f . runReaderT (terM m)

instance MonadState TCState TerM where
  get     = TerM $ lift $ get
  put     = TerM . lift . put

instance MonadIO TerM where
  liftIO = TerM . liftIO

instance MonadTCM TerM where
  liftTCM = TerM . lift

instance MonadError TCErr TerM where
  throwError = liftTCM . throwError
  catchError m handler = TerM $ ReaderT $ \ tenv -> do
    runTer tenv m `catchError` (\ err -> runTer tenv $ handler err)

instance HasConstInfo TerM where
  getConstInfo       = liftTCM . getConstInfo
  getRewriteRulesFor = liftTCM . getRewriteRulesFor

-- * Modifiers and accessors for the termination environment in the monad.

terGetGuardingTypeConstructors :: TerM Bool
terGetGuardingTypeConstructors = terAsks terGuardingTypeConstructors

terGetInlineWithFunctions :: TerM Bool
terGetInlineWithFunctions = terAsks terInlineWithFunctions

terGetUseDotPatterns :: TerM Bool
terGetUseDotPatterns = terAsks terUseDotPatterns

terSetUseDotPatterns :: Bool -> TerM a -> TerM a
terSetUseDotPatterns b = terLocal $ \ e -> e { terUseDotPatterns = b }

terGetSizeSuc :: TerM (Maybe QName)
terGetSizeSuc = terAsks terSizeSuc

terGetCurrent :: TerM QName
terGetCurrent = terAsks terCurrent

terSetCurrent :: QName -> TerM a -> TerM a
terSetCurrent q = terLocal $ \ e -> e { terCurrent = q }

terGetSharp :: TerM (Maybe QName)
terGetSharp = terAsks terSharp

terGetCutOff :: TerM CutOff
terGetCutOff = terAsks terCutOff

terGetMutual :: TerM MutualNames
terGetMutual = terAsks terMutual

terGetUserNames :: TerM [QName]
terGetUserNames = terAsks terUserNames

terGetTarget :: TerM (Maybe Target)
terGetTarget = terAsks terTarget

terSetTarget :: Maybe Target -> TerM a -> TerM a
terSetTarget t = terLocal $ \ e -> e { terTarget = t }

terGetDelayed :: TerM Delayed
terGetDelayed = terAsks terDelayed

terSetDelayed :: Delayed -> TerM a -> TerM a
terSetDelayed b = terLocal $ \ e -> e { terDelayed = b }

terGetMaskArgs :: TerM [Bool]
terGetMaskArgs = terAsks terMaskArgs

terSetMaskArgs :: [Bool] -> TerM a -> TerM a
terSetMaskArgs b = terLocal $ \ e -> e { terMaskArgs = b }

terGetMaskResult :: TerM Bool
terGetMaskResult = terAsks terMaskResult

terSetMaskResult :: Bool -> TerM a -> TerM a
terSetMaskResult b = terLocal $ \ e -> e { terMaskResult = b }

terGetPatterns :: TerM (MaskedDeBruijnPats)
terGetPatterns = do
  n   <- terAsks terPatternsRaise
  mps <- terAsks terPatterns
  return $ if n == 0 then mps else map (fmap (fmap (n +))) mps

terSetPatterns :: MaskedDeBruijnPats -> TerM a -> TerM a
terSetPatterns ps = terLocal $ \ e -> e { terPatterns = ps }

terRaise :: TerM a -> TerM a
terRaise = terLocal $ \ e -> e { terPatternsRaise = terPatternsRaise e + 1 }

terGetGuarded :: TerM Guarded
terGetGuarded = terAsks terGuarded

terModifyGuarded :: (Order -> Order) -> TerM a -> TerM a
terModifyGuarded f = terLocal $ \ e -> e { terGuarded = f $ terGuarded e }

terSetGuarded :: Order -> TerM a -> TerM a
terSetGuarded = terModifyGuarded . const

terUnguarded :: TerM a -> TerM a
terUnguarded = terSetGuarded unknown

-- | Should the codomain part of a function type preserve guardedness?
terPiGuarded :: TerM a -> TerM a
terPiGuarded m = ifM terGetGuardingTypeConstructors m $ terUnguarded m

-- | Lens for '_terSizeDepth'.

terSizeDepth :: Lens' Int TerEnv
terSizeDepth f e = f (_terSizeDepth e) <&> \ i -> e { _terSizeDepth = i }

-- | Lens for 'terUsableVars'.

terGetUsableVars :: TerM VarSet
terGetUsableVars = terAsks terUsableVars

terModifyUsableVars :: (VarSet -> VarSet) -> TerM a -> TerM a
terModifyUsableVars f = terLocal $ \ e -> e { terUsableVars = f $ terUsableVars e }

terSetUsableVars :: VarSet -> TerM a -> TerM a
terSetUsableVars = terModifyUsableVars . const

-- | Lens for 'terUseSizeLt'.

terGetUseSizeLt :: TerM Bool
terGetUseSizeLt = terAsks terUseSizeLt

terModifyUseSizeLt :: (Bool -> Bool) -> TerM a -> TerM a
terModifyUseSizeLt f = terLocal $ \ e -> e { terUseSizeLt = f $ terUseSizeLt e }

terSetUseSizeLt :: Bool -> TerM a -> TerM a
terSetUseSizeLt = terModifyUseSizeLt . const

-- | Compute usable vars from patterns and run subcomputation.
withUsableVars :: UsableSizeVars a => a -> TerM b -> TerM b
withUsableVars pats m = do
  vars <- usableSizeVars pats
  reportSLn "term.size" 20 $ "usableSizeVars = " ++ show vars
  terSetUsableVars vars $ m

-- | Set 'terUseSizeLt' when going under constructor @c@.
conUseSizeLt :: QName -> TerM a -> TerM a
conUseSizeLt c m = do
  caseMaybeM (liftTCM $ isRecordConstructor c)
    (terSetUseSizeLt True m)
    (const $ terSetUseSizeLt False m)

-- | Set 'terUseSizeLt' for arguments following projection @q@.
--   We disregard j<i after a non-coinductive projection.
--   However, the projection need not be recursive (Issue 1470).
projUseSizeLt :: QName -> TerM a -> TerM a
projUseSizeLt q m = do
  co <- isCoinductiveProjection False q
  reportSLn "term.size" 20 $ applyUnless co ("not " ++) $
    "using SIZELT vars after projection " ++ prettyShow q
  terSetUseSizeLt co m

-- | For termination checking purposes flat should not be considered a
--   projection. That is, it flat doesn't preserve either structural order
--   or guardedness like other projections do.
--   Andreas, 2012-06-09: the same applies to projections of recursive records.
isProjectionButNotCoinductive :: MonadTCM tcm => QName -> tcm Bool
isProjectionButNotCoinductive qn = liftTCM $ do
  b <- isProjectionButNotCoinductive' qn
  reportSDoc "term.proj" 60 $ do
    text "identifier" <+> prettyTCM qn <+> do
      text $
        if b then "is an inductive projection"
          else "is either not a projection or coinductive"
  return b
  where
    isProjectionButNotCoinductive' qn = do
      flat <- fmap nameOfFlat <$> coinductionKit
      if Just qn == flat
        then return False
        else do
          mp <- isProjection qn
          case mp of
            Just Projection{ projProper = True, projFromType = t }
              -> isInductiveRecord (unArg t)
            _ -> return False

-- | Check whether a projection belongs to a coinductive record
--   and is actually recursive.
--   E.g.
--   @
--      isCoinductiveProjection (Stream.head) = return False
--
--      isCoinductiveProjection (Stream.tail) = return True
--   @
isCoinductiveProjection :: MonadTCM tcm => Bool -> QName -> tcm Bool
isCoinductiveProjection mustBeRecursive q = liftTCM $ do
  reportSLn "term.guardedness" 40 $ "checking isCoinductiveProjection " ++ prettyShow q
  flat <- fmap nameOfFlat <$> coinductionKit
  -- yes for ♭
  if Just q == flat then return True else do
    pdef <- getConstInfo q
    case isProjection_ (theDef pdef) of
      Just Projection{ projProper = True, projFromType = Arg _ r, projIndex = n } ->
        caseMaybeM (isRecord r) __IMPOSSIBLE__ $ \ rdef -> do
          -- no for inductive or non-recursive record
          if recInduction rdef /= Just CoInductive then return False else do
            reportSLn "term.guardedness" 40 $ prettyShow q ++ " is coinductive"
            if not mustBeRecursive then return True else do
              reportSLn "term.guardedness" 40 $ prettyShow q ++ " must be recursive"
              if not (recRecursive rdef) then return False else do
                reportSLn "term.guardedness" 40 $ prettyShow q ++ " has been declared recursive, doing actual check now..."
                -- TODO: the following test for recursiveness of a projection should be cached.
                -- E.g., it could be stored in the @Projection@ component.
                -- Now check if type of field mentions mutually recursive symbol.
                -- Get the type of the field by dropping record parameters and record argument.
                let TelV tel core = telView' (defType pdef)
                    tel' = drop n $ telToList tel
                -- Check if any recursive symbols appear in the record type.
                -- Q (2014-07-01): Should we normalize the type?
                reportSDoc "term.guardedness" 40 $ sep
                  [ text "looking for recursive occurrences in"
                  , prettyTCM (telFromList tel')
                  , text "and"
                  , prettyTCM core
                  ]
                names <- anyDefs (r : recMutual rdef) (map (snd . unDom) tel', core)
                reportSDoc "term.guardedness" 40 $
                  text "found" <+> sep (map prettyTCM names)
                return $ not $ null names
      _ -> do
        reportSLn "term.guardedness" 40 $ prettyShow q ++ " is not a proper projection"
        return False


-- * De Bruijn patterns.

type DeBruijnPats = [DeBruijnPat]

-- | Patterns with variables as de Bruijn indices.
type DeBruijnPat = DeBruijnPat' Int

data DeBruijnPat' a
  = VarDBP a
    -- ^ De Bruijn Index.
  | ConDBP QName [DeBruijnPat' a]
    -- ^ The name refers to either an ordinary
    --   constructor or the successor function on sized types.
  | LitDBP Literal
    -- ^ Literal.  Also abused to censor part of a pattern.
  | TermDBP Term
    -- ^ Part of dot pattern that cannot be converted into a pattern.
  | ProjDBP ProjOrigin QName
    -- ^ Projection pattern.
  deriving (Functor, Show)

instance IsProjP (DeBruijnPat' a) where
  isProjP (ProjDBP o d) = Just (o, AmbQ [d])
  isProjP _ = Nothing

instance PrettyTCM DeBruijnPat where
  prettyTCM (VarDBP i)    = prettyTCM $ var i
  prettyTCM (ConDBP c ps) = parens $ do prettyTCM c <+> hsep (map prettyTCM ps)
  prettyTCM (LitDBP l)    = prettyTCM l
  prettyTCM (TermDBP v)   = parens $ prettyTCM v
  prettyTCM (ProjDBP o d) = text "." TCP.<> prettyTCM d

-- | How long is the path to the deepest variable?
patternDepth :: DeBruijnPat' a -> Int
patternDepth p =
  case p of
    ConDBP _ ps -> succ $ maximum $ 0 : map patternDepth ps
    VarDBP{}    -> 0
    LitDBP{}    -> 0
    TermDBP{}   -> 0
    ProjDBP{}   -> 0


-- | A dummy pattern used to mask a pattern that cannot be used
--   for structural descent.

unusedVar :: DeBruijnPat
unusedVar = LitDBP (LitString noRange "term.unused.pat.var")

-- | @raiseDBP n ps@ increases each de Bruijn index in @ps@ by @n@.
--   Needed when going under a binder during analysis of a term.

raiseDBP :: Int -> DeBruijnPats -> DeBruijnPats
raiseDBP 0 = id
raiseDBP n = map $ fmap (n +)

-- | Extract variables from 'DeBruijnPat's that could witness a decrease
--   via a SIZELT constraint.
--
--   These variables must be under an inductive constructor (with no record
--   constructor in the way), or after a coinductive projection (with no
--   inductive one in the way).

class UsableSizeVars a where
  usableSizeVars :: a -> TerM VarSet

instance UsableSizeVars DeBruijnPat where
  usableSizeVars p = do
    let none = return mempty
    case p of
      VarDBP i    -> ifM terGetUseSizeLt (return $ VarSet.singleton i) {- else -} none
      ConDBP c ps -> conUseSizeLt c $ usableSizeVars ps
      LitDBP{}    -> none
      TermDBP{}   -> none
      ProjDBP{}   -> none

instance UsableSizeVars DeBruijnPats where
  usableSizeVars ps =
    case ps of
      []                 -> return mempty
      (ProjDBP _ q : ps) -> projUseSizeLt q $ usableSizeVars ps
      (p           : ps) -> mappend <$> usableSizeVars p <*> usableSizeVars ps

instance UsableSizeVars (Masked DeBruijnPat) where
  usableSizeVars (Masked m p) = do
    let none = return mempty
    case p of
      VarDBP i    -> ifM terGetUseSizeLt (return $ VarSet.singleton i) {- else -} none
      ConDBP c ps -> if m then none else conUseSizeLt c $ usableSizeVars ps
      LitDBP{}    -> none
      TermDBP{}   -> none
      ProjDBP{}   -> none

instance UsableSizeVars MaskedDeBruijnPats where
  usableSizeVars ps =
    case ps of
      []                            -> return mempty
      (Masked _ (ProjDBP _ q) : ps) -> projUseSizeLt q $ usableSizeVars ps
      (p                      : ps) -> mappend <$> usableSizeVars p <*> usableSizeVars ps

-- * Masked patterns (which are not eligible for structural descent, only for size descent)
--   See issue #1023.

type MaskedDeBruijnPats = [Masked DeBruijnPat]

data Masked a = Masked
  { getMask   :: Bool  -- ^ True if thing not eligible for structural descent.
  , getMasked :: a     -- ^ Thing.
  } deriving (Eq, Ord, Show, Functor, Foldable, Traversable)

masked :: a -> Masked a
masked = Masked True

notMasked :: a -> Masked a
notMasked = Masked False

instance Decoration Masked where
  traverseF f (Masked m a) = Masked m <$> f a

-- | Print masked things in double parentheses.
instance PrettyTCM a => PrettyTCM (Masked a) where
  prettyTCM (Masked m a) = applyWhen m (parens . parens) $ prettyTCM a

-- * Call pathes

-- | The call information is stored as free monoid
--   over 'CallInfo'.  As long as we never look at it,
--   only accumulate it, it does not matter whether we use
--   'Set', (nub) list, or 'Tree'.
--   Internally, due to lazyness, it is anyway a binary tree of
--   'mappend' nodes and singleton leafs.
--   Since we define no order on 'CallInfo' (expensive),
--   we cannot use a 'Set' or nub list.
--   Performance-wise, I could not see a difference between Set and list.

newtype CallPath = CallPath { callInfos :: [CallInfo] }
  deriving (Show, Semigroup, Monoid, AllNames)

-- | Only show intermediate nodes.  (Drop last 'CallInfo').
instance Pretty CallPath where
  pretty (CallPath cis0) = if null cis then empty else
    P.hsep (map (\ ci -> arrow P.<+> P.pretty ci) cis) P.<+> arrow
    where
      cis   = init cis0
      arrow = P.text "-->"

-- * Size depth estimation

-- | A very crude way of estimating the @SIZELT@ chains
--   @i > j > k@ in context.  Returns 3 in this case.
--   Overapproximates.

-- TODO: more precise analysis, constructing a tree
-- of relations between size variables.
terSetSizeDepth :: Telescope -> TerM a -> TerM a
terSetSizeDepth tel cont = do
  n <- liftTCM $ sum <$> do
    forM (telToList tel) $ \ dom -> do
      a <- reduce $ snd $ unDom dom
      ifM (isJust <$> isSizeType a) (return 1) {- else -} $
        case ignoreSharing $ unEl a of
          MetaV{} -> return 1
          _       -> return 0
  terLocal (set terSizeDepth n) cont