{-# LANGUAGE CPP #-}
module Agda.Termination.RecCheck
    ( recursive
    , anyDefs
    )
 where
import Control.Monad (forM, forM_)
import Data.Graph
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as Set
import Agda.Syntax.Internal
import Agda.Syntax.Internal.Defs
import Agda.TypeChecking.Monad
import Agda.Utils.Functor ((<.>))
import Agda.Utils.List    (hasElem)
import Agda.Utils.Pretty  (prettyShow)
import Agda.Utils.Impossible
type NamesPerClause = IntMap (Set QName)
recursive :: [QName] -> TCM [[QName]]
recursive names = do
  
  (perClauses, nss) <- unzip <$> mapM (recDef (names `hasElem`)) names
  
  
  let graph  = zipWith (\ x ns -> (x, x, Set.toList ns)) names nss
  let sccs   = stronglyConnComp graph
  let nonRec = mapMaybe (\case{ AcyclicSCC x -> Just x ; _ -> Nothing}) sccs
  let recs   = mapMaybe (\case{ CyclicSCC xs -> Just xs; _ -> Nothing}) sccs
  reportSLn "rec.graph" 20 $ show graph
  
  mapM_ markNonRecursive nonRec
  
  
  
  let clMap = Map.fromList $ zip names perClauses
  
  forM_ recs $ \ scc -> do
    
    let overlap s = any (`Set.member` s) scc
    
    forM_ scc $ \ x -> do
      
      let perClause  = Map.findWithDefault __IMPOSSIBLE__ x clMap
      
      let recClause i = overlap $ IntMap.findWithDefault __IMPOSSIBLE__ i perClause
      markRecursive recClause x
  
  return recs
markNonRecursive :: QName -> TCM ()
markNonRecursive q = modifySignature $ updateDefinition q $ updateTheDef $ \case
  def@Function{} -> def
   { funTerminates = Just True
   , funClauses    = map (\ cl -> cl { clauseRecursive = Just False }) $ funClauses def
   }
  def -> def
markRecursive
  :: (Int -> Bool)  
  -> QName -> TCM ()
markRecursive f q = modifySignature $ updateDefinition q $ updateTheDef $ \case
  def@Function{} -> def
   { funClauses    = zipWith (\ i cl -> cl { clauseRecursive = Just (f i) }) [0..] $ funClauses def
   }
  def -> def
recDef :: (QName -> Bool) -> QName -> TCM (NamesPerClause, Set QName)
recDef include name = do
  
  def <- getConstInfo name
  
  ns1 <- anyDefs include (defType def)
  
  (perClause, ns2) <- case theDef def of
    Function{ funClauses = cls } -> do
      perClause <- do
        forM (zip [0..] cls) $ \ (i, cl) ->
          (i,) <$> anyDefs include cl
      return (IntMap.fromList perClause, mconcat $ map snd perClause)
    _ -> return (mempty, mempty)
  reportS "rec.graph" 20
    [ "recDef " ++ prettyShow name
    , "  names in the type: " ++ show ns1
    , "  names in the def:  " ++ show ns2
    ]
  return (perClause, ns1 `mappend` ns2)
anyDefs :: GetDefs a => (QName -> Bool) -> a -> TCM (Set QName)
anyDefs include a = do
  
  st <- getMetaStore
  let lookup (MetaId x) = (mvInstantiation <$> IntMap.lookup x st) >>= \case
        InstV _ v                      -> Just v    
        Open                           -> Nothing
        OpenInstance                   -> Nothing
        BlockedConst{}                 -> Nothing
        PostponedTypeCheckingProblem{} -> Nothing
      
      emb d = if include d then Set.singleton d else Set.empty
  
  return $ getDefs' lookup emb a