{-# LANGUAGE NondecreasingIndentation #-}
module Agda.TypeChecking.With where
import Control.Monad
import Control.Monad.Writer (WriterT, runWriterT, tell)
import Data.Either
import qualified Data.List as List
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NonEmpty
import Data.Maybe
import Data.Foldable ( foldrM )
import Data.Traversable ( traverse )
import Agda.Syntax.Common
import Agda.Syntax.Internal as I
import Agda.Syntax.Internal.Pattern
import qualified Agda.Syntax.Abstract as A
import Agda.Syntax.Abstract.Pattern as A
import Agda.Syntax.Abstract.Views
import Agda.Syntax.Info
import Agda.Syntax.Position
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Datatypes
import Agda.TypeChecking.EtaContract
import Agda.TypeChecking.Free
import Agda.TypeChecking.Patterns.Abstract
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Records
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope
import Agda.TypeChecking.Abstract
import Agda.TypeChecking.Rules.LHS.Implicit
import Agda.TypeChecking.Rules.LHS.Problem (ProblemEq(..))
import Agda.Utils.Functor
import Agda.Utils.List
import Agda.Utils.Maybe
import Agda.Utils.Monad
import Agda.Utils.Null (empty)
import Agda.Utils.Permutation
import Agda.Utils.Pretty (prettyShow)
import qualified Agda.Utils.Pretty as P
import Agda.Utils.Size
import Agda.Utils.Impossible
splitTelForWith
  
  :: Telescope                         
  -> Type                              
  -> [WithHiding (Term, EqualityView)] 
  
  -> ( Telescope                         
     , Telescope                         
     , Permutation                       
     , Type                              
     , [WithHiding (Term, EqualityView)] 
     )              
splitTelForWith delta t vtys = let
    
    
    fv = allFreeVars vtys
    SplitTel delta1 delta2 perm = splitTelescope fv delta
    
    pi = renaming __IMPOSSIBLE__ (reverseP perm)
    
    rho = strengthenS __IMPOSSIBLE__ $ size delta2
    
    rhopi = composeS rho pi
    
    t' = applySubst pi t
    
    vtys' = applySubst rhopi vtys
  in (delta1, delta2, perm, t', vtys')
withFunctionType
  :: Telescope                          
  -> [WithHiding (Term, EqualityView)]  
  -> Telescope                          
  -> Type                               
  -> TCM (Type, Nat)
    
    
    
    
    
withFunctionType delta1 vtys delta2 b = addContext delta1 $ do
  reportSLn "tc.with.abstract" 20 $ "preparing for with-abstraction"
  
  
  let dbg n s x = reportSDoc "tc.with.abstract" n $ nest 2 $ text (s ++ " =") <+> prettyTCM x
  let d2b = telePi_ delta2 b
  dbg 30 "Δ₂ → B" d2b
  d2b  <- normalise d2b
  dbg 30 "normal Δ₂ → B" d2b
  d2b  <- etaContract d2b
  dbg 30 "eta-contracted Δ₂ → B" d2b
  vtys <- etaContract =<< normalise vtys
  
  wd2b <- foldrM piAbstract d2b vtys
  dbg 30 "wΓ → Δ₂ → B" wd2b
  return (telePi_ delta1 wd2b, countWithArgs (map (snd . whThing) vtys))
countWithArgs :: [EqualityView] -> Nat
countWithArgs = sum . map countArgs
  where
    countArgs OtherType{}    = 1
    countArgs EqualityType{} = 2
withArguments :: [WithHiding (Term, EqualityView)] -> [WithHiding Term]
withArguments vtys = flip concatMap vtys $ traverse $ \case
  (v, OtherType a) -> [v]
  (prf, eqt@(EqualityType s _eq _pars _t v _v')) -> [unArg v, prf]
buildWithFunction
  :: [Name]               
  -> QName                
  -> QName                
  -> Type                 
  -> Telescope            
  -> [NamedArg DeBruijnPattern] 
  -> Nat                  
  -> Substitution         
  -> Permutation          
  -> Nat                  
  -> Nat                  
  -> [A.SpineClause]      
  -> TCM [A.SpineClause]  
buildWithFunction cxtNames f aux t delta qs npars withSub perm n1 n cs = mapM buildWithClause cs
  where
    
    buildWithClause (A.Clause (A.SpineLHS i _ allPs) inheritedPats rhs wh catchall) = do
      let (ps, wps)    = splitOffTrailingWithPatterns allPs
          (wps0, wps1) = splitAt n wps
          ps0          = map (updateNamedArg fromWithP) wps0
            where
            fromWithP (A.WithP _ p) = p
            fromWithP _ = __IMPOSSIBLE__
      reportSDoc "tc.with" 50 $ "inheritedPats:" <+> vcat [ prettyA p <+> "=" <+> prettyTCM v <+> ":" <+> prettyTCM a
                                                               | A.ProblemEq p v a <- inheritedPats ]
      (strippedPats, ps') <- stripWithClausePatterns cxtNames f aux t delta qs npars perm ps
      reportSDoc "tc.with" 50 $ hang "strippedPats:" 2 $
                                  vcat [ prettyA p <+> "==" <+> prettyTCM v <+> (":" <+> prettyTCM t)
                                       | A.ProblemEq p v t <- strippedPats ]
      rhs <- buildRHS strippedPats rhs
      let (ps1, ps2) = splitAt n1 ps'
      let result = A.Clause (A.SpineLHS i aux $ ps1 ++ ps0 ++ ps2 ++ wps1)
                     (inheritedPats ++ strippedPats)
                     rhs wh catchall
      reportSDoc "tc.with" 20 $ vcat
        [ "buildWithClause returns" <+> prettyA result
        ]
      return result
    buildRHS _ rhs@A.RHS{}                 = return rhs
    buildRHS _ rhs@A.AbsurdRHS             = return rhs
    buildRHS _ (A.WithRHS q es cs)         = A.WithRHS q es <$>
      mapM ((A.spineToLhs . permuteNamedDots) <.> buildWithClause . A.lhsToSpine) cs
    buildRHS strippedPats1 (A.RewriteRHS qes strippedPats2 rhs wh) =
      flip (A.RewriteRHS qes (applySubst withSub $ strippedPats1 ++ strippedPats2)) wh <$> buildRHS [] rhs
    
    
    
    
    
    
    
    permuteNamedDots :: A.SpineClause -> A.SpineClause
    permuteNamedDots (A.Clause lhs strippedPats rhs wh catchall) =
      A.Clause lhs (applySubst withSub strippedPats) rhs wh catchall
stripWithClausePatterns
  :: [Name]                   
  -> QName                    
  -> QName                    
  -> Type                     
  -> Telescope                
  -> [NamedArg DeBruijnPattern] 
  -> Nat                      
  -> Permutation              
  -> [NamedArg A.Pattern]     
  -> TCM ([A.ProblemEq], [NamedArg A.Pattern]) 
stripWithClausePatterns cxtNames parent f t delta qs npars perm ps = do
  
  ps <- expandPatternSynonyms ps
  
  
  
  
  let paramPat i _ = A.VarP $ A.mkBindName $ indexWithDefault __IMPOSSIBLE__ cxtNames i
      ps' = zipWith (fmap . fmap . paramPat) [0..] (take npars qs) ++ ps
  psi <- insertImplicitPatternsT ExpandLast ps' t
  reportSDoc "tc.with.strip" 10 $ vcat
    [ "stripping patterns"
    , nest 2 $ "t   = " <+> prettyTCM t
    , nest 2 $ "ps  = " <+> fsep (punctuate comma $ map prettyA ps)
    , nest 2 $ "ps' = " <+> fsep (punctuate comma $ map prettyA ps')
    , nest 2 $ "psi = " <+> fsep (punctuate comma $ map prettyA psi)
    , nest 2 $ "qs  = " <+> fsep (punctuate comma $ map (prettyTCM . namedArg) qs)
    , nest 2 $ "perm= " <+> text (show perm)
    ]
  
  (ps', strippedPats) <- runWriterT $ strip (Def parent []) t psi qs
  reportSDoc "tc.with.strip" 50 $ nest 2 $
    "strippedPats:" <+> vcat [ prettyA p <+> "=" <+> prettyTCM v <+> ":" <+> prettyTCM a | A.ProblemEq p v a <- strippedPats ]
  let psp = permute perm ps'
  reportSDoc "tc.with.strip" 10 $ vcat
    [ nest 2 $ "ps' = " <+> fsep (punctuate comma $ map prettyA ps')
    , nest 2 $ "psp = " <+> fsep (punctuate comma $ map prettyA $ psp)
    ]
  return (strippedPats, psp)
  where
    
    
    
    varArgInfo = \ x -> let n = dbPatVarIndex x in
                        if n < length infos then infos !! n else __IMPOSSIBLE__
      where infos = reverse $ map getArgInfo $ telToList delta
    setVarArgInfo x p = setOrigin (getOrigin p) $ setArgInfo (varArgInfo x) p
    strip
      :: Term                         
      -> Type                         
      -> [NamedArg A.Pattern]       
      -> [NamedArg DeBruijnPattern] 
      -> WriterT [ProblemEq] TCM [NamedArg A.Pattern]
            
            
            
    
    strip self t [] qs@(_ : _) = do
      reportSDoc "tc.with.strip" 15 $ vcat
        [ "strip (out of A.Patterns)"
        , nest 2 $ "qs  =" <+> fsep (punctuate comma $ map (prettyTCM . namedArg) qs)
        , nest 2 $ "self=" <+> prettyTCM self
        , nest 2 $ "t   =" <+> prettyTCM t
        ]
      
      
      
      ps <- liftTCM $ insertImplicitPatternsT ExpandLast [] t
      if null ps then
        typeError $ GenericError $ "Too few arguments given in with-clause"
       else strip self t ps qs
    
    
    
    strip _ _ ps      []      = do
      let implicit (A.WildP{})     = True
          implicit (A.ConP ci _ _) = conPatOrigin ci == ConOSystem
          implicit _               = False
      unless (all (implicit . namedArg) ps) $
        typeError $ GenericError $ "Too many arguments given in with-clause"
      return []
    
    
    strip self t (p0 : ps) qs@(q : _)
      | A.AsP _ x p <- namedArg p0 = do
        (a, _) <- mustBePi t
        let v = patternToTerm (namedArg q)
        tell [ProblemEq (A.VarP x) v a]
        strip self t (fmap (p <$) p0 : ps) qs
    strip self t ps0@(p0 : ps) qs0@(q : qs) = do
      p <- liftTCM $ (traverse . traverse) expandLitPattern p0
      reportSDoc "tc.with.strip" 15 $ vcat
        [ "strip"
        , nest 2 $ "ps0 =" <+> fsep (punctuate comma $ map prettyA ps0)
        , nest 2 $ "exp =" <+> prettyA p
        , nest 2 $ "qs0 =" <+> fsep (punctuate comma $ map (prettyTCM . namedArg) qs0)
        , nest 2 $ "self=" <+> prettyTCM self
        , nest 2 $ "t   =" <+> prettyTCM t
        ]
      case namedArg q of
        ProjP o d -> case A.isProjP p of
          Just (o', AmbQ ds) -> do
            
            if o /= o' then liftTCM $ mismatchOrigin o o' else do
            
            
            
            d  <- liftTCM $ getOriginalProjection d
            found <- anyM ds $ \ d' -> liftTCM $ (Just d ==) . fmap projOrig <$> isProjection d'
            if not found then mismatch else do
              (self1, t1, ps) <- liftTCM $ do
                t <- reduce t
                (_, self1, t1) <- fromMaybe __IMPOSSIBLE__ <$> projectTyped self t o d
                
                
                
                ps <- insertImplicitPatternsT ExpandLast ps t1
                return (self1, t1, ps)
              strip self1 t1 ps qs
          Nothing -> mismatch
        
        VarP _ x | A.DotP _ u <- namedArg p
                 , A.Var y <- unScope u -> do
          (setVarArgInfo x (setNamedArg p $ A.VarP $ A.mkBindName y) :) <$>
            recurse (var (dbPatVarIndex x))
        VarP _ x  ->
          (setVarArgInfo x p :) <$> recurse (var (dbPatVarIndex x))
        IApplyP{}  -> typeError $ GenericError $ "with clauses not supported in the presence of Path patterns" 
        DefP{}  -> typeError $ GenericError $ "with clauses not supported in the presence of hcomp patterns" 
        DotP o v  -> do
          (a, _) <- mustBePi t
          tell [ProblemEq (namedArg p) v a]
          (makeImplicitP p :) <$> recurse v
        q'@(ConP c ci qs') -> do
         reportSDoc "tc.with.strip" 60 $
           "parent pattern is constructor " <+> prettyTCM c
         (a, b) <- mustBePi t
         
         Def d es <- liftTCM $ normalise (unEl $ unDom a)
         let us = fromMaybe __IMPOSSIBLE__ $ allApplyElims es
         
         c <- either __IMPOSSIBLE__ (`withRangeOf` c) <$> do liftTCM $ getConForm $ conName c
         case namedArg p of
          
          
          
          
          A.DotP r e -> do
            tell [ProblemEq (A.DotP r e) (patternToTerm q') a]
            ps' <-
              case appView e of
                
                
                Application (A.Con (A.AmbQ cs')) es -> do
                  cs' <- liftTCM $ snd . partitionEithers <$> mapM getConForm (NonEmpty.toList cs')
                  unless (elem c cs') mismatch
                  return $ (map . fmap . fmap) (A.DotP r) es
                _  -> return $ map (unnamed (A.WildP empty) <$) qs'
            stripConP d us b c ConOCon qs' ps'
          
          
          A.WildP{} -> do
            
            
            let ps' = map (unnamed (A.WildP empty) <$) qs'
            stripConP d us b c ConOCon qs' ps'
          
          
          
          
          
          A.VarP x -> do
            tell [ProblemEq (A.VarP x) (patternToTerm q') a]
            let ps' = map (unnamed (A.WildP empty) <$) qs'
            stripConP d us b c ConOCon qs' ps'
          A.ConP _ (A.AmbQ cs') ps' -> do
            
            
            
            cs' <- liftTCM $ snd . partitionEithers <$> mapM getConForm (NonEmpty.toList cs')
            unless (elem c cs') mismatch
            
            stripConP d us b c ConOCon qs' ps'
          A.RecP _ fs -> caseMaybeM (liftTCM $ isRecord d) mismatch $ \ def -> do
            ps' <- liftTCM $ insertMissingFields d (const $ A.WildP empty) fs
                                                 (map argFromDom $ recordFieldNames def)
            stripConP d us b c ConORec qs' ps'
          p@(A.PatternSynP pi' c' ps') -> do
             reportSDoc "impossible" 10 $
               "stripWithClausePatterns: encountered pattern synonym " <+> prettyA p
             __IMPOSSIBLE__
          p -> do
           reportSDoc "tc.with.strip" 60 $
             text $ "with clause pattern is  " ++ show p
           mismatch
        LitP _ lit -> case namedArg p of
          A.LitP lit' | lit == lit' -> recurse $ Lit lit
          A.WildP{}                 -> recurse $ Lit lit
          p@(A.PatternSynP pi' c' [ps']) -> do
             reportSDoc "impossible" 10 $
               "stripWithClausePatterns: encountered pattern synonym " <+> prettyA p
             __IMPOSSIBLE__
          _ -> mismatch
      where
        recurse v = do
          caseMaybeM (liftTCM $ isPath t) (return ()) $ \ _ ->
            typeError $ GenericError $
              "With-clauses currently not supported under Path abstraction."
          t' <- piApplyM t v
          strip (self `apply1` v) t' ps qs
        mismatch = addContext delta $ typeError $
          WithClausePatternMismatch (namedArg p0) q
        mismatchOrigin o o' = addContext delta . typeError . GenericDocError =<< fsep
          [ "With clause pattern"
          , prettyA p0
          , "is not an instance of its parent pattern"
          , P.fsep <$> prettyTCMPatterns [q]
          , text $ "since the parent pattern is " ++ prettyProjOrigin o ++
                   " and the with clause pattern is " ++ prettyProjOrigin o'
          ]
        prettyProjOrigin ProjPrefix  = "a prefix projection"
        prettyProjOrigin ProjPostfix = "a postfix projection"
        prettyProjOrigin ProjSystem  = __IMPOSSIBLE__
        
        makeImplicitP :: NamedArg A.Pattern -> NamedArg A.Pattern
        makeImplicitP = updateNamedArg $ const $ A.WildP patNoRange
        
        stripConP
          :: QName
             
          -> [Arg Term]
             
          -> Abs Type
             
          -> ConHead
             
          -> ConInfo
             
          -> [NamedArg DeBruijnPattern]
             
          -> [NamedArg A.Pattern]
             
          -> WriterT [ProblemEq] TCM [NamedArg A.Pattern]
             
        stripConP d us b c ci qs' ps' = do
          
          Defn {defType = ct, theDef = Constructor{conPars = np}}  <- getConInfo c
          
          let ct' = ct `piApply` take np us
          TelV tel' _ <- liftTCM $ telViewPath ct'
          
          reportSDoc "tc.with.strip" 20 $
            vcat [ "ct  = " <+> prettyTCM ct
                 , "ct' = " <+> prettyTCM ct'
                 , "np  = " <+> text (show np)
                 , "us  = " <+> prettyList (map prettyTCM us)
                 , "us' = " <+> prettyList (map prettyTCM $ take np us)
                 ]
          
          
          let v  = Con c ci [ Apply $ Arg info (var i) | (i, Arg info _) <- zip (downFrom $ size qs') qs' ]
              t' = tel' `abstract` absApp (raise (size tel') b) v
              self' = tel' `abstract` apply1 (raise (size tel') self) v  
          reportSDoc "tc.with.strip" 15 $ sep
            [ "inserting implicit"
            , nest 2 $ prettyList $ map prettyA (ps' ++ ps)
            , nest 2 $ ":" <+> prettyTCM t'
            ]
          
          psi' <- liftTCM $ insertImplicitPatterns ExpandLast ps' tel'
          unless (size psi' == size tel') $ typeError $
            WrongNumberOfConstructorArguments (conName c) (size tel') (size psi')
          
          
          
          psi <- liftTCM $ insertImplicitPatternsT ExpandLast (psi' ++ ps) t'
          
          strip self' t' psi (qs' ++ qs)
withDisplayForm
  :: QName
       
  -> QName
       
  -> Telescope
       
  -> Telescope
       
  -> Nat
       
  -> [NamedArg DeBruijnPattern]
      
  -> Permutation
      
  -> Permutation
      
  -> TCM DisplayForm
withDisplayForm f aux delta1 delta2 n qs perm@(Perm m _) lhsPerm = do
  
  let arity0 = n + size delta1 + size delta2
  
  topArgs <- raise arity0 <$> getContextArgs
  let top    = length topArgs
      arity  = arity0 + top
  
  wild <- freshNoName_ <&> \ x -> Def (qualify_ x) []
  let 
      tqs0       = patsToElims qs
      
      
      (ys0, ys1) = splitAt (size delta1) $ permute perm $ downFrom m
      ys         = reverse (map Just ys0 ++ replicate n Nothing ++ map Just ys1)
                   ++ map (Just . (m +)) [0..top-1]
      rho        = sub top ys wild
      tqs        = applySubst rho tqs0
      
      es         = map (Apply . fmap DTerm) topArgs ++ tqs
      withArgs   = map var $ take n $ downFrom $ size delta2 + n
      dt         = DWithApp (DDef f es) (map DTerm withArgs) []
  
  
  let display = Display arity (replicate arity $ Apply $ defaultArg $ var 0) dt
  
  let addFullCtx = addContext delta1
                 . flip (foldr addContext) (for [1..n] $ \ i -> "w" ++ show i)
                 . addContext delta2
  reportSDoc "tc.with.display" 20 $ vcat
    [ "withDisplayForm"
    , nest 2 $ vcat
      [ "f      =" <+> text (prettyShow f)
      , "aux    =" <+> text (prettyShow aux)
      , "delta1 =" <+> prettyTCM delta1
      , "delta2 =" <+> do addContext delta1 $ prettyTCM delta2
      , "n      =" <+> text (show n)
      , "perm   =" <+> text (show perm)
      , "top    =" <+> do addFullCtx $ prettyTCM topArgs
      , "qs     =" <+> prettyList (map pretty qs)
      , "qsToTm =" <+> prettyTCM tqs0 
      , "ys     =" <+> text (show ys)
      , "rho    =" <+> text (prettyShow rho)
      , "qs[rho]=" <+> do addFullCtx $ prettyTCM tqs
      , "dt     =" <+> do addFullCtx $ prettyTCM dt
      ]
    ]
  reportSDoc "tc.with.display" 70 $ nest 2 $ vcat
      [ "raw    =" <+> text (show display)
      ]
  return display
  where
    
    
    
    sub top ys wild = parallelS $ map term [0 .. m + top - 1]
      where
        term i = maybe wild var $ List.findIndex (Just i ==) ys
patsToElims :: [NamedArg DeBruijnPattern] -> [I.Elim' DisplayTerm]
patsToElims = map $ toElim . fmap namedThing
  where
    toElim :: Arg DeBruijnPattern -> I.Elim' DisplayTerm
    toElim (Arg ai p) = case p of
      ProjP o d -> I.Proj o d
      p         -> I.Apply $ Arg ai $ toTerm p
    toTerms :: [NamedArg DeBruijnPattern] -> [Arg DisplayTerm]
    toTerms = map $ fmap $ toTerm . namedThing
    toTerm :: DeBruijnPattern -> DisplayTerm
    toTerm p = case p of
      IApplyP _ _ _ x -> DTerm $ var $ dbPatVarIndex x 
      ProjP _ d   -> DDef d [] 
      VarP i x -> case patOrigin i of
        PatODot -> DDot  $ var $ dbPatVarIndex x
        _       -> DTerm  $ var $ dbPatVarIndex x
      DotP i t -> case patOrigin i of
        PatOVar{} | Var i [] <- t -> DTerm t
        _                         -> DDot   $ t
      ConP c cpi ps -> DCon c (fromConPatternInfo cpi) $ toTerms ps
      LitP _ l    -> DTerm  $ Lit l
      DefP _ q ps -> DDef q $ map Apply $ toTerms ps