{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Agda.Termination.Monad where
import Prelude hiding (null)
import Control.Applicative hiding (empty)
import qualified Control.Monad.Fail as Fail
import Control.Monad.Reader
import Data.Foldable (Foldable)
import Data.Traversable (Traversable)
import Data.Monoid ( Monoid(..) )
import Data.Semigroup ( Semigroup(..) )
import qualified Data.Set as Set
import Agda.Interaction.Options
import Agda.Syntax.Abstract (AllNames)
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Internal.Pattern
import Agda.Syntax.Literal
import Agda.Syntax.Position (noRange)
import Agda.Termination.CutOff
import Agda.Termination.Order (Order,le,unknown)
import Agda.Termination.RecCheck (anyDefs)
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Monad.Benchmark
import Agda.TypeChecking.Monad.Builtin
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Records
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Substitute
import Agda.Utils.Except ( MonadError )
import Agda.Utils.Function
import Agda.Utils.Functor
import Agda.Utils.Lens
import Agda.Utils.List   ( hasElem )
import Agda.Utils.Maybe
import Agda.Utils.Monad
import Agda.Utils.Monoid
import Agda.Utils.Null
import Agda.Utils.Pretty (Pretty, prettyShow)
import qualified Agda.Utils.Pretty as P
import Agda.Utils.VarSet (VarSet)
import qualified Agda.Utils.VarSet as VarSet
import Agda.Utils.Impossible
type MutualNames = [QName]
type Target = QName
type Guarded = Order
data TerEnv = TerEnv
  
  { terUseDotPatterns :: Bool
    
  , terSizeSuc :: Maybe QName
    
  , terSharp   :: Maybe QName
    
  , terCutOff  :: CutOff
    
  
  , terCurrent :: QName
    
  , terMutual  :: MutualNames
    
    
    
  , terUserNames :: [QName]
    
    
  , terHaveInlinedWith :: Bool
    
    
  , terTarget  :: Maybe Target
    
    
  , terDelayed :: Delayed
    
  , terMaskArgs :: [Bool]
    
    
  , terMaskResult :: Bool
    
  , _terSizeDepth :: Int  
    
    
    
  , terPatterns :: MaskedDeBruijnPatterns
    
  , terPatternsRaise :: !Int
    
    
    
  , terGuarded :: !Guarded
    
    
  , terUseSizeLt :: Bool
    
    
    
    
    
  , terUsableVars :: VarSet
    
  }
defaultTerEnv :: TerEnv
defaultTerEnv = TerEnv
  { terUseDotPatterns           = False 
  , terSizeSuc                  = Nothing
  , terSharp                    = Nothing
  , terCutOff                   = defaultCutOff
  , terUserNames                = __IMPOSSIBLE__ 
  , terMutual                   = __IMPOSSIBLE__ 
  , terCurrent                  = __IMPOSSIBLE__ 
  , terHaveInlinedWith          = False
  , terTarget                   = Nothing
  , terDelayed                  = NotDelayed
  , terMaskArgs                 = repeat False   
  , terMaskResult               = False          
  , _terSizeDepth               = __IMPOSSIBLE__ 
  , terPatterns                 = __IMPOSSIBLE__ 
  , terPatternsRaise            = 0
  , terGuarded                  = le 
  , terUseSizeLt                = False 
  , terUsableVars               = VarSet.empty
  }
class (Functor m, Monad m) => MonadTer m where
  terAsk   :: m TerEnv
  terLocal :: (TerEnv -> TerEnv) -> m a -> m a
  terAsks :: (TerEnv -> a) -> m a
  terAsks f = f <$> terAsk
newtype TerM a = TerM { terM :: ReaderT TerEnv TCM a }
  deriving ( Functor
           , Applicative
           , Monad
           , Fail.MonadFail
           , MonadError TCErr
           , MonadBench Phase
           , MonadStatistics
           , HasOptions
           , HasBuiltins
           , MonadDebug
           , HasConstInfo
           , MonadIO
           , MonadTCEnv
           , MonadTCState
           , MonadTCM
           , ReadTCState
           , MonadReduce
           , MonadAddContext
           )
instance MonadTer TerM where
  terAsk     = TerM $ ask
  terLocal f = TerM . local f . terM
runTer :: TerEnv -> TerM a -> TCM a
runTer tenv (TerM m) = runReaderT m tenv
runTerDefault :: TerM a -> TCM a
runTerDefault cont = do
  
  cutoff <- optTerminationDepth <$> pragmaOptions
  
  suc <- sizeSucName
  
  sharp <- fmap nameOfSharp <$> coinductionKit
  let tenv = defaultTerEnv
        { terSizeSuc                  = suc
        , terSharp                    = sharp
        , terCutOff                   = cutoff
        }
  runTer tenv cont
instance Semigroup m => Semigroup (TerM m) where
  (<>) = liftA2 (<>)
instance (Semigroup m, Monoid m) => Monoid (TerM m) where
  mempty  = pure mempty
  mappend = (<>)
  mconcat = mconcat <.> sequence
terGetUseDotPatterns :: TerM Bool
terGetUseDotPatterns = terAsks terUseDotPatterns
terSetUseDotPatterns :: Bool -> TerM a -> TerM a
terSetUseDotPatterns b = terLocal $ \ e -> e { terUseDotPatterns = b }
terGetSizeSuc :: TerM (Maybe QName)
terGetSizeSuc = terAsks terSizeSuc
terGetCurrent :: TerM QName
terGetCurrent = terAsks terCurrent
terSetCurrent :: QName -> TerM a -> TerM a
terSetCurrent q = terLocal $ \ e -> e { terCurrent = q }
terGetSharp :: TerM (Maybe QName)
terGetSharp = terAsks terSharp
terGetCutOff :: TerM CutOff
terGetCutOff = terAsks terCutOff
terGetMutual :: TerM MutualNames
terGetMutual = terAsks terMutual
terGetUserNames :: TerM [QName]
terGetUserNames = terAsks terUserNames
terGetTarget :: TerM (Maybe Target)
terGetTarget = terAsks terTarget
terSetTarget :: Maybe Target -> TerM a -> TerM a
terSetTarget t = terLocal $ \ e -> e { terTarget = t }
terGetHaveInlinedWith :: TerM Bool
terGetHaveInlinedWith = terAsks terHaveInlinedWith
terSetHaveInlinedWith :: TerM a -> TerM a
terSetHaveInlinedWith = terLocal $ \ e -> e { terHaveInlinedWith = True }
terGetDelayed :: TerM Delayed
terGetDelayed = terAsks terDelayed
terSetDelayed :: Delayed -> TerM a -> TerM a
terSetDelayed b = terLocal $ \ e -> e { terDelayed = b }
terGetMaskArgs :: TerM [Bool]
terGetMaskArgs = terAsks terMaskArgs
terSetMaskArgs :: [Bool] -> TerM a -> TerM a
terSetMaskArgs b = terLocal $ \ e -> e { terMaskArgs = b }
terGetMaskResult :: TerM Bool
terGetMaskResult = terAsks terMaskResult
terSetMaskResult :: Bool -> TerM a -> TerM a
terSetMaskResult b = terLocal $ \ e -> e { terMaskResult = b }
terGetPatterns :: TerM (MaskedDeBruijnPatterns)
terGetPatterns = do
  n   <- terAsks terPatternsRaise
  mps <- terAsks terPatterns
  return $ if n == 0 then mps else map (fmap (raise n)) mps
terSetPatterns :: MaskedDeBruijnPatterns -> TerM a -> TerM a
terSetPatterns ps = terLocal $ \ e -> e { terPatterns = ps }
terRaise :: TerM a -> TerM a
terRaise = terLocal $ \ e -> e { terPatternsRaise = terPatternsRaise e + 1 }
terGetGuarded :: TerM Guarded
terGetGuarded = terAsks terGuarded
terModifyGuarded :: (Order -> Order) -> TerM a -> TerM a
terModifyGuarded f = terLocal $ \ e -> e { terGuarded = f $ terGuarded e }
terSetGuarded :: Order -> TerM a -> TerM a
terSetGuarded = terModifyGuarded . const
terUnguarded :: TerM a -> TerM a
terUnguarded = terSetGuarded unknown
terSizeDepth :: Lens' Int TerEnv
terSizeDepth f e = f (_terSizeDepth e) <&> \ i -> e { _terSizeDepth = i }
terGetUsableVars :: TerM VarSet
terGetUsableVars = terAsks terUsableVars
terModifyUsableVars :: (VarSet -> VarSet) -> TerM a -> TerM a
terModifyUsableVars f = terLocal $ \ e -> e { terUsableVars = f $ terUsableVars e }
terSetUsableVars :: VarSet -> TerM a -> TerM a
terSetUsableVars = terModifyUsableVars . const
terGetUseSizeLt :: TerM Bool
terGetUseSizeLt = terAsks terUseSizeLt
terModifyUseSizeLt :: (Bool -> Bool) -> TerM a -> TerM a
terModifyUseSizeLt f = terLocal $ \ e -> e { terUseSizeLt = f $ terUseSizeLt e }
terSetUseSizeLt :: Bool -> TerM a -> TerM a
terSetUseSizeLt = terModifyUseSizeLt . const
withUsableVars :: UsableSizeVars a => a -> TerM b -> TerM b
withUsableVars pats m = do
  vars <- usableSizeVars pats
  reportSLn "term.size" 70 $ "usableSizeVars = " ++ show vars
  reportSDoc "term.size" 20 $ if null vars then "no usuable size vars" else
    "the size variables amoung these variables are usable: " <+>
      sep (map (prettyTCM . var) $ VarSet.toList vars)
  terSetUsableVars vars $ m
conUseSizeLt :: QName -> TerM a -> TerM a
conUseSizeLt c m = do
  ifM (liftTCM $ isEtaOrCoinductiveRecordConstructor c)  
    (terSetUseSizeLt False m)
    (terSetUseSizeLt True m)
projUseSizeLt :: QName -> TerM a -> TerM a
projUseSizeLt q m = do
  co <- isCoinductiveProjection False q
  reportSLn "term.size" 20 $ applyUnless co ("not " ++) $
    "using SIZELT vars after projection " ++ prettyShow q
  terSetUseSizeLt co m
isProjectionButNotCoinductive :: MonadTCM tcm => QName -> tcm Bool
isProjectionButNotCoinductive qn = liftTCM $ do
  b <- isProjectionButNotCoinductive' qn
  reportSDoc "term.proj" 60 $ do
    "identifier" <+> prettyTCM qn <+> do
      text $
        if b then "is an inductive projection"
          else "is either not a projection or coinductive"
  return b
  where
    isProjectionButNotCoinductive' qn = do
      flat <- fmap nameOfFlat <$> coinductionKit
      if Just qn == flat
        then return False
        else do
          mp <- isProjection qn
          case mp of
            Just Projection{ projProper = Just{}, projFromType = t }
              -> isInductiveRecord (unArg t)
            _ -> return False
isCoinductiveProjection :: MonadTCM tcm => Bool -> QName -> tcm Bool
isCoinductiveProjection mustBeRecursive q = liftTCM $ do
  reportSLn "term.guardedness" 40 $ "checking isCoinductiveProjection " ++ prettyShow q
  flat <- fmap nameOfFlat <$> coinductionKit
  
  if Just q == flat then return True else do
    pdef <- getConstInfo q
    case isProjection_ (theDef pdef) of
      Just Projection{ projProper = Just{}, projFromType = Arg _ r, projIndex = n } ->
        caseMaybeM (isRecord r) __IMPOSSIBLE__ $ \ rdef -> do
          
          if recInduction rdef /= Just CoInductive then return False else do
            reportSLn "term.guardedness" 40 $ prettyShow q ++ " is coinductive; record type is " ++ prettyShow r
            if not mustBeRecursive then return True else do
              reportSLn "term.guardedness" 40 $ prettyShow q ++ " must be recursive"
              if not (safeRecRecursive rdef) then return False else do
                reportSLn "term.guardedness" 40 $ prettyShow q ++ " has been declared recursive, doing actual check now..."
                
                
                
                
                let TelV tel core = telView' (defType pdef)
                    (pars, tel') = splitAt n $ telToList tel
                    mut = fromMaybe __IMPOSSIBLE__ $ recMutual rdef
                
                
                
                
                reportSDoc "term.guardedness" 40 $ inTopContext $ sep
                  [ "looking for recursive occurrences of"
                  , sep (map prettyTCM mut)
                  , "in"
                  , addContext pars $ prettyTCM (telFromList tel')
                  , "and"
                  , addContext tel $ prettyTCM core
                  ]
                when (null mut) __IMPOSSIBLE__
                names <- anyDefs (mut `hasElem`) =<< normalise (map (snd . unDom) tel', core)
                reportSDoc "term.guardedness" 40 $
                  "found" <+> if null names then "none" else sep (map prettyTCM $ Set.toList names)
                return $ not $ null names
      _ -> do
        reportSLn "term.guardedness" 40 $ prettyShow q ++ " is not a proper projection"
        return False
  where
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  safeRecRecursive :: Defn -> Bool
  safeRecRecursive (Record { recMutual = Just qs }) = not $ null qs
  safeRecRecursive _ = False
patternDepth :: forall a. Pattern' a -> Int
patternDepth = getMaxNat . foldrPattern depth where
  depth :: Pattern' a -> MaxNat -> MaxNat
  depth ConP{} = succ      
  depth _      = id        
unusedVar :: DeBruijnPattern
unusedVar = litP (LitString noRange "term.unused.pat.var")
class UsableSizeVars a where
  usableSizeVars :: a -> TerM VarSet
instance UsableSizeVars DeBruijnPattern where
  usableSizeVars = foldrPattern $ \case
    VarP _ x   -> const $ ifM terGetUseSizeLt (return $ VarSet.singleton $ dbPatVarIndex x) $
                    return mempty
    ConP c _ _ -> conUseSizeLt $ conName c
    LitP{}     -> none
    DotP{}     -> none
    ProjP{}    -> none
    IApplyP{}  -> none
    DefP{} -> none
    where none _ = return mempty
instance UsableSizeVars [DeBruijnPattern] where
  usableSizeVars ps =
    case ps of
      []               -> return mempty
      (ProjP _ q : ps) -> projUseSizeLt q $ usableSizeVars ps
      (p         : ps) -> mappend <$> usableSizeVars p <*> usableSizeVars ps
instance UsableSizeVars (Masked DeBruijnPattern) where
  usableSizeVars (Masked m p) = (`foldrPattern` p) $ \case
    VarP _ x   -> const $ ifM terGetUseSizeLt (return $ VarSet.singleton $ dbPatVarIndex x) $
                    return mempty
    ConP c _ _ -> if m then none else conUseSizeLt $ conName c
    LitP{}     -> none
    DotP{}     -> none
    ProjP{}    -> none
    IApplyP{}  -> none
    DefP{}     -> none
    where none _ = return mempty
instance UsableSizeVars MaskedDeBruijnPatterns where
  usableSizeVars ps =
    case ps of
      []                          -> return mempty
      (Masked _ (ProjP _ q) : ps) -> projUseSizeLt q $ usableSizeVars ps
      (p                    : ps) -> mappend <$> usableSizeVars p <*> usableSizeVars ps
type MaskedDeBruijnPatterns = [Masked DeBruijnPattern]
data Masked a = Masked
  { getMask   :: Bool  
  , getMasked :: a     
  } deriving (Eq, Ord, Show, Functor, Foldable, Traversable)
masked :: a -> Masked a
masked = Masked True
notMasked :: a -> Masked a
notMasked = Masked False
instance Decoration Masked where
  traverseF f (Masked m a) = Masked m <$> f a
instance PrettyTCM a => PrettyTCM (Masked a) where
  prettyTCM (Masked m a) = applyWhen m (parens . parens) $ prettyTCM a
newtype CallPath = CallPath { callInfos :: [CallInfo] }
  deriving (Show, Semigroup, Monoid, AllNames)
instance Pretty CallPath where
  pretty (CallPath cis0) = if null cis then empty else
    P.hsep (map (\ ci -> arrow P.<+> P.pretty ci) cis) P.<+> arrow
    where
      cis   = init cis0
      arrow = "-->"
terSetSizeDepth :: Telescope -> TerM a -> TerM a
terSetSizeDepth tel cont = do
  n <- liftTCM $ sum <$> do
    forM (telToList tel) $ \ dom -> do
      a <- reduce $ snd $ unDom dom
      ifM (isJust <$> isSizeType a) (return 1)  $
        case unEl a of
          MetaV{} -> return 1
          _       -> return 0
  terLocal (set terSizeDepth n) cont