module Agda.TypeChecking.Injectivity where
import Prelude hiding (mapM)
import Control.Applicative
import Control.Monad.Fail
import Control.Monad.State hiding (mapM, forM)
import Control.Monad.Reader hiding (mapM, forM)
import Control.Monad.Trans.Maybe
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Maybe
import Data.Traversable hiding (for)
import Data.Semigroup ((<>))
import qualified Agda.Syntax.Abstract.Name as A
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Internal.Pattern
import Agda.TypeChecking.Irrelevance (isIrrelevantOrPropM)
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Monad.Builtin
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Reduce
import {-# SOURCE #-} Agda.TypeChecking.MetaVars
import {-# SOURCE #-} Agda.TypeChecking.Conversion
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Polarity
import Agda.TypeChecking.Warnings
import Agda.Utils.Except ( MonadError )
import Agda.Utils.Functor
import Agda.Utils.List
import Agda.Utils.Maybe
import Agda.Utils.Monad
import Agda.Utils.Permutation
import Agda.Utils.Pretty ( prettyShow )
import Agda.Utils.Impossible
headSymbol :: Term -> TCM (Maybe TermHead)
headSymbol v = do 
  
  v <- constructorForm =<< ignoreBlocking <$> reduceHead v
  case v of
    Def f _ -> do
      let yes = return $ Just $ ConsHead f
          no  = return $ Nothing
      def <- theDef <$> do ignoreAbstractMode $ getConstInfo f
        
        
        
      case def of
        Datatype{}  -> yes
        Record{}    -> yes
        DataOrRecSig{} -> yes
        Axiom{}     -> do
          reportSLn "tc.inj.axiom" 50 $ "headSymbol: " ++ prettyShow f ++ " is an Axiom."
          
          
          
          caseMaybeM (asksTC envMutualBlock) yes $ \ mb -> do
            fs <- mutualNames <$> lookupMutualBlock mb
            if Set.member f fs then no else yes
        Function{}    -> no
        Primitive{}   -> no
        GeneralizableVar{} -> __IMPOSSIBLE__
        Constructor{} -> __IMPOSSIBLE__
        AbstractDefn{}-> __IMPOSSIBLE__
    
    Con c _ _ -> ignoreAbstractMode $ Just . ConsHead <$> canonicalName (conName c)
    Sort _  -> return (Just SortHead)
    Pi _ _  -> return (Just PiHead)
    Var i [] -> return (Just $ VarHead i) 
    Lit _   -> return Nothing 
    Lam{}   -> return Nothing
    Var{}   -> return Nothing
    Level{} -> return Nothing
    MetaV{} -> return Nothing
    DontCare{} -> return Nothing
    Dummy s _ -> __IMPOSSIBLE_VERBOSE__ s
headSymbol'
  :: (MonadReduce m, MonadError TCErr m, MonadDebug m, HasBuiltins m)
  => Term -> m (Maybe TermHead)
headSymbol' v = do
  v <- traverse constructorForm =<< reduceB v
  case v of
    Blocked{} -> return Nothing
    NotBlocked _ v -> case v of
      Def g _    -> return $ Just $ ConsHead g
      Con c _ _  -> return $ Just $ ConsHead $ conName c
      Var i _    -> return $ Just (VarHead i)
      Sort _     -> return $ Just SortHead
      Pi _ _     -> return $ Just PiHead
      Lit _      -> return Nothing
      Lam{}      -> return Nothing
      Level{}    -> return Nothing
      MetaV{}    -> return Nothing
      DontCare{} -> return Nothing
      Dummy s _  -> __IMPOSSIBLE_VERBOSE__ s
topLevelArg :: Clause -> Int -> Maybe TermHead
topLevelArg Clause{ namedClausePats = ps } i =
  case [ n | (n, VarP _ (DBPatVar _ j)) <- zip [0..] $ map namedArg ps, i == j ] of
    []    -> Nothing
    [n]   -> Just (VarHead n)
    _:_:_ -> __IMPOSSIBLE__
joinHeadMaps :: [InversionMap c] -> InversionMap c
joinHeadMaps = Map.unionsWith (<>)
updateHeads :: Monad m => (TermHead -> [c] -> m TermHead) -> InversionMap c -> m (InversionMap c)
updateHeads f m = joinHeadMaps <$> mapM f' (Map.toList m)
  where f' (h, c) = (`Map.singleton` c) <$> f h c
checkInjectivity :: QName -> [Clause] -> TCM FunctionInverse
checkInjectivity f cs0
  | null $ filter properlyMatchingClause cs = do
      reportSLn "tc.inj.check.pointless" 35 $
        "Injectivity of " ++ prettyShow (A.qnameToConcrete f) ++ " would be pointless."
      return NotInjective
  | otherwise = checkInjectivity' f cs
  where
    
    cs = filter (isJust . clauseBody) cs0
    
    
    
    properlyMatchingClause =
      any (properlyMatching' False False . namedArg) . namedClausePats
checkInjectivity' :: QName -> [Clause] -> TCM FunctionInverse
checkInjectivity' f cs = fromMaybe NotInjective <.> runMaybeT $ do
  reportSLn "tc.inj.check" 40 $ "Checking injectivity of " ++ prettyShow f
  let varToArg :: Clause -> TermHead -> MaybeT TCM TermHead
      varToArg c (VarHead i) = MaybeT $ return $ topLevelArg c i
      varToArg _ h           = return h
  
  let computeHead c@Clause{ clauseBody = Just body , clauseType = Just tbody } = do
        h <- ifM (isIrrelevantOrPropM tbody) (return UnknownHead) $
          varToArg c =<< do
            lift $ fromMaybe UnknownHead <$> do
              addContext (clauseTel c) $
                headSymbol body
        return [Map.singleton h [c]]
      computeHead _ = return []
  hdMap <- joinHeadMaps . concat <$> mapM computeHead cs
  case Map.lookup UnknownHead hdMap of
    Just (_:_:_) -> empty 
    _            -> return ()
  reportSLn  "tc.inj.check" 20 $ prettyShow f ++ " is potentially injective."
  reportSDoc "tc.inj.check" 30 $ nest 2 $ vcat $
    for (Map.toList hdMap) $ \ (h, uc) ->
      text (prettyShow h) <+> "-->" <+>
      case uc of
        [c] -> prettyTCM $ map namedArg $ namedClausePats c
        _   -> "(multiple clauses)"
  return $ Inverse hdMap
checkOverapplication
  :: forall m. (HasConstInfo m)
  => Elims -> InversionMap Clause -> m (InversionMap Clause)
checkOverapplication es = updateHeads overapplied
  where
    overapplied :: TermHead -> [Clause] -> m TermHead
    overapplied h cs | all (not . isOverapplied) cs = return h
    overapplied h cs = ifM (isSuperRigid h) (return h) (return UnknownHead)
    
    
    isSuperRigid SortHead     = return True
    isSuperRigid PiHead       = return True
    isSuperRigid VarHead{}    = return True
    isSuperRigid UnknownHead  = return True 
    isSuperRigid (ConsHead q) = do
      def <- getConstInfo q
      return $ case theDef def of
        Axiom{}        -> True
        DataOrRecSig{} -> True
        AbstractDefn{} -> True
        Function{}     -> False
        Datatype{}     -> True
        Record{}       -> True
        Constructor{conSrcCon = ConHead{ conFields = fs }}
                       -> null fs   
        Primitive{}    -> False
        GeneralizableVar{} -> __IMPOSSIBLE__
    isOverapplied Clause{ namedClausePats = ps } = length es > length ps
instantiateVarHeads
  :: forall m c. (MonadReduce m, MonadError TCErr m, MonadDebug m, HasBuiltins m)
  => QName -> Elims -> InversionMap c -> m (Maybe (InversionMap c))
instantiateVarHeads f es m = runMaybeT $ updateHeads (const . instHead) m
  where
    instHead :: TermHead -> MaybeT m TermHead
    instHead h@(VarHead i)
      | Just (Apply arg) <- es !!! i = MaybeT $ headSymbol' (unArg arg)
      | otherwise = empty   
    instHead h = return h
functionInverse
  :: (MonadReduce m, MonadError TCErr m, HasBuiltins m, HasConstInfo m)
  => Term -> m InvView
functionInverse v = case v of
  Def f es -> do
    inv <- defInverse <$> getConstInfo f
    case inv of
      NotInjective -> return NoInv
      Inverse m    -> maybe NoInv (Inv f es) <$> (traverse (checkOverapplication es) =<< instantiateVarHeads f es m)
        
        
        
  _ -> return NoInv
data InvView = Inv QName [Elim] (InversionMap Clause)
             | NoInv
useInjectivity :: MonadConversion m => CompareDirection -> CompareAs -> Term -> Term -> m ()
useInjectivity dir ty blk neu = locallyTC eInjectivityDepth succ $ do
  inv <- functionInverse blk
  
  
  
  nProblems <- Set.size <$> viewTC eActiveProblems
  injDepth  <- viewTC eInjectivityDepth
  let depth = max nProblems injDepth
  maxDepth  <- maxInversionDepth
  case inv of
    NoInv            -> fallback  
    Inv f blkArgs hdMap
      | depth > maxDepth -> warning (InversionDepthReached f) >> fallback
      | otherwise -> do
      reportSDoc "tc.inj.use" 30 $ fsep $
        pwords "useInjectivity on" ++
        [ prettyTCM blk, prettyTCM cmp, prettyTCM neu, prettyTCM ty]
      let canReduceToSelf = Map.member (ConsHead f) hdMap || Map.member UnknownHead hdMap
          allUnique       = all isUnique hdMap
          isUnique [_] = True
          isUnique _   = False
      case neu of
        
        
        
        Def f' neuArgs | f == f', not canReduceToSelf -> do
          fTy <- defType <$> getConstInfo f
          reportSDoc "tc.inj.use" 20 $ vcat
            [ fsep (pwords "comparing application of injective function" ++ [prettyTCM f] ++
                  pwords "at")
            , nest 2 $ fsep $ punctuate comma $ map prettyTCM blkArgs
            , nest 2 $ fsep $ punctuate comma $ map prettyTCM neuArgs
            , nest 2 $ "and type" <+> prettyTCM fTy
            ]
          fs  <- getForcedArgs f
          pol <- getPolarity' cmp f
          reportSDoc "tc.inj.invert.success" 20 $ hsep ["Successful spine comparison of", prettyTCM f]
          app (compareElims pol fs fTy (Def f [])) blkArgs neuArgs
        
        
        
        
        _ -> headSymbol' neu >>= \ case
          Nothing -> do
            reportSDoc "tc.inj.use" 20 $ fsep $
              pwords "no head symbol found for" ++ [prettyTCM neu] ++ pwords ", so not inverting"
            fallback
          Just (ConsHead f') | f == f', canReduceToSelf -> do
            reportSDoc "tc.inj.use" 20 $ fsep $
              pwords "head symbol" ++ [prettyTCM f'] ++ pwords "can reduce to self, so not inverting"
            fallback
                                    
                                    
                                    
          Just hd -> invertFunction cmp blk inv hd fallback err success
            where err = typeError $ app (\ u v -> UnequalTerms cmp u v ty) blk neu
  where
    fallback     = addConstraint $ app (ValueCmp cmp ty) blk neu
    success blk' = app (compareAs cmp ty) blk' neu
    (cmp, app) = case dir of
      DirEq -> (CmpEq, id)
      DirLeq -> (CmpLeq, id)
      DirGeq -> (CmpLeq, flip)
invertFunction
  :: MonadConversion m
  => Comparison -> Term -> InvView -> TermHead -> m () -> m () -> (Term -> m ()) -> m ()
invertFunction _ _ NoInv _ fallback _ _ = fallback
invertFunction cmp blk (Inv f blkArgs hdMap) hd fallback err success = do
    fTy <- defType <$> getConstInfo f
    reportSDoc "tc.inj.use" 20 $ vcat
      [ "inverting injective function" <?> hsep [prettyTCM f, ":", prettyTCM fTy]
      , "for" <?> pretty hd
      , nest 2 $ "args =" <+> prettyList (map prettyTCM blkArgs)
      ]                         
    case fromMaybe [] $ Map.lookup hd hdMap <> Map.lookup UnknownHead hdMap of
      [] -> err
      _:_:_ -> fallback
      [cl@Clause{ clauseTel  = tel }] -> speculateMetas fallback $ do
          let ps   = clausePats cl
              perm = fromMaybe __IMPOSSIBLE__ $ clausePerm cl
          
          ms <- map unArg <$> newTelMeta tel
          reportSDoc "tc.inj.invert" 20 $ vcat
            [ "meta patterns" <+> prettyList (map prettyTCM ms)
            , "  perm =" <+> text (show perm)
            , "  tel  =" <+> prettyTCM tel
            , "  ps   =" <+> prettyList (map (text . show) ps)
            ]
          
          let msAux = permute (invertP __IMPOSSIBLE__ $ compactP perm) ms
          let sub   = parallelS (reverse ms)
          margs <- runReaderT (evalStateT (mapM metaElim ps) msAux) sub
          reportSDoc "tc.inj.invert" 20 $ vcat
            [ "inversion"
            , nest 2 $ vcat
              [ "lhs  =" <+> prettyTCM margs
              , "rhs  =" <+> prettyTCM blkArgs
              , "type =" <+> prettyTCM fTy
              ]
            ]
          
          
          
          pol <- purgeNonvariant <$> getPolarity' cmp f
          fs  <- getForcedArgs f
          
          
          let blkArgs' = take (length margs) blkArgs
          compareElims pol fs fTy (Def f []) margs blkArgs'
          
          r <- liftReduce $ unfoldDefinitionStep False (Def f []) f blkArgs
          case r of
            YesReduction _ blk' -> do
              reportSDoc "tc.inj.invert.success" 20 $ hsep ["Successful inversion of", prettyTCM f, "at", pretty hd]
              KeepMetas <$ success blk'
            NoReduction{}       -> do
              reportSDoc "tc.inj.invert" 30 $ vcat
                [ "aborting inversion;" <+> prettyTCM blk
                , "does not reduce"
                ]
              return RollBackMetas
  where
    nextMeta :: (MonadState [Term] m, MonadFail m) => m Term
    nextMeta = do
      m : ms <- get
      put ms
      return m
    dotP :: MonadReader Substitution m => Term -> m Term
    dotP v = do
      sub <- ask
      return $ applySubst sub v
    metaElim
      :: (MonadState [Term] m, MonadReader Substitution m, HasConstInfo m, MonadFail m)
      => Arg DeBruijnPattern -> m Elim
    metaElim (Arg _ (ProjP o p))  = Proj o <$> getOriginalProjection p
    metaElim (Arg info p)         = Apply . Arg info <$> metaPat p
    metaArgs
      :: (MonadState [Term] m, MonadReader Substitution m, MonadFail m)
      => [NamedArg DeBruijnPattern] -> m Args
    metaArgs args = mapM (traverse $ metaPat . namedThing) args
    metaPat
      :: (MonadState [Term] m, MonadReader Substitution m, MonadFail m)
      => DeBruijnPattern -> m Term
    metaPat (DotP _ v)       = dotP v
    metaPat (VarP _ _)       = nextMeta
    metaPat (IApplyP{})      = nextMeta
    metaPat (ConP c mt args) = Con c (fromConPatternInfo mt) . map Apply <$> metaArgs args
    metaPat (DefP o q args)  = Def q . map Apply <$> metaArgs args
    metaPat (LitP _ l)       = return $ Lit l
    metaPat ProjP{}          = __IMPOSSIBLE__
forcePiUsingInjectivity :: Type -> TCM Type
forcePiUsingInjectivity t = reduceB t >>= \ case
    Blocked _ blkTy -> do
      let blk = unEl blkTy
      inv <- functionInverse blk
      blkTy <$ invertFunction CmpEq blk inv PiHead fallback err success
    NotBlocked _ t -> return t
  where
    fallback  = return ()
    err       = typeError (ShouldBePi t)
    success _ = return ()