module Agda.TypeChecking.Polarity where
import Control.Monad.State
import Data.Maybe
import Agda.Syntax.Abstract.Name
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Datatypes (getNumberOfParameters)
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.SizedTypes
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Free
import Agda.TypeChecking.Positivity.Occurrence
import Agda.Utils.List
import Agda.Utils.Maybe ( whenNothingM )
import Agda.Utils.Monad
import Agda.Utils.Pretty ( prettyShow )
import Agda.Utils.Size
import Agda.Utils.Impossible
(/\) :: Polarity -> Polarity -> Polarity
Nonvariant /\ b = b
a /\ Nonvariant = a
a /\ b | a == b    = a
       | otherwise = Invariant
neg :: Polarity -> Polarity
neg Covariant     = Contravariant
neg Contravariant = Covariant
neg Invariant     = Invariant
neg Nonvariant    = Nonvariant
composePol :: Polarity -> Polarity -> Polarity
composePol Nonvariant _    = Nonvariant
composePol _ Nonvariant    = Nonvariant
composePol Invariant _     = Invariant
composePol Covariant x     = x
composePol Contravariant x = neg x
polFromOcc :: Occurrence -> Polarity
polFromOcc o = case o of
  GuardPos  -> Covariant
  StrictPos -> Covariant
  JustPos   -> Covariant
  JustNeg   -> Contravariant
  Mixed     -> Invariant
  Unused    -> Nonvariant
nextPolarity :: [Polarity] -> (Polarity, [Polarity])
nextPolarity []       = (Invariant, [])
nextPolarity (p : ps) = (p, ps)
purgeNonvariant :: [Polarity] -> [Polarity]
purgeNonvariant = map (\ p -> if p == Nonvariant then Covariant else p)
polarityFromPositivity :: QName -> TCM ()
polarityFromPositivity x = inConcreteOrAbstractMode x $ \ def -> do
  
  let npars = droppedPars def
  let pol0 = replicate npars Nonvariant ++ map polFromOcc (defArgOccurrences def)
  reportSLn "tc.polarity.set" 15 $
    "Polarity of " ++ prettyShow x ++ " from positivity: " ++ prettyShow pol0
  
  setPolarity x $ drop npars pol0
computePolarity :: [QName] -> TCM ()
computePolarity xs = do
 
 
 when (length xs >= 2) $ mapM_ polarityFromPositivity xs
 
 forM_ xs $ \ x -> inConcreteOrAbstractMode x $ \ def -> do
  reportSLn "tc.polarity.set" 25 $ "Refining polarity of " ++ prettyShow x
  
  let npars = droppedPars def
  let pol0 = replicate npars Nonvariant ++ map polFromOcc (defArgOccurrences def)
  reportSLn "tc.polarity.set" 15 $
    "Polarity of " ++ prettyShow x ++ " from positivity: " ++ prettyShow pol0
  
  pol1 <- sizePolarity x pol0
  
  let t = defType def
  
  
  
  reportSDoc "tc.polarity.set" 15 $
    "Refining polarity with type " <+> prettyTCM t
  reportSDoc "tc.polarity.set" 60 $
    "Refining polarity with type (raw): " <+> (text .show) t
  pol <- dependentPolarity t (enablePhantomTypes (theDef def) pol1) pol1
  reportSLn "tc.polarity.set" 10 $ "Polarity of " ++ prettyShow x ++ ": " ++ prettyShow pol
  
  setPolarity x $ drop npars pol 
enablePhantomTypes :: Defn -> [Polarity] -> [Polarity]
enablePhantomTypes def pol = case def of
  Datatype{ dataPars = np } -> enable np
  Record  { recPars  = np } -> enable np
  _                         -> pol
  where enable np = let (pars, rest) = splitAt np pol
                    in  purgeNonvariant pars ++ rest
dependentPolarity :: Type -> [Polarity] -> [Polarity] -> TCM [Polarity]
dependentPolarity t _      []          = return []  
dependentPolarity t []     (_ : _)     = __IMPOSSIBLE__
dependentPolarity t (q:qs) pols@(p:ps) = do
  t <- reduce $ unEl t
  reportSDoc "tc.polarity.dep" 20 $ "dependentPolarity t = " <+> prettyTCM t
  reportSDoc "tc.polarity.dep" 70 $ "dependentPolarity t = " <+> (text . show) t
  case t of
    Pi dom b -> do
      ps <- underAbstraction dom b $ \ c -> dependentPolarity c qs ps
      let fallback = ifM (isJust <$> isSizeType (unDom dom)) (return p) (return q)
      p <- case b of
        Abs{} | p /= Invariant  ->
          
          
          
          ifM (relevantInIgnoringNonvariant 0 (absBody b) ps)
             (return Invariant)
             fallback
        _ -> fallback
      return $ p : ps
    _ -> return pols
relevantInIgnoringNonvariant :: Nat -> Type -> [Polarity] -> TCM Bool
relevantInIgnoringNonvariant i t []     = return $ i `relevantInIgnoringSortAnn` t
relevantInIgnoringNonvariant i t (p:ps) = do
  t <- reduce $ unEl t
  case t of
    Pi a b -> if p /= Nonvariant && i `relevantInIgnoringSortAnn` a then return True
              else relevantInIgnoringNonvariant (i + 1) (absBody b) ps
    _ -> return $ i `relevantInIgnoringSortAnn` t
sizePolarity :: QName -> [Polarity] -> TCM [Polarity]
sizePolarity d pol0 = do
  let exit = return pol0
  ifNotM sizedTypesOption exit $  do
    def <- getConstInfo d
    case theDef def of
      Datatype{ dataPars = np, dataCons = cons } -> do
        let TelV tel _      = telView' $ defType def
            (parTel, ixTel) = splitAt np $ telToList tel
        case ixTel of
          []                 -> exit  
          Dom{unDom = (_,a)} : _ -> ifM ((/= Just BoundedNo) <$> isSizeType a) exit $ do
            
            let pol   = take np pol0
                polCo = pol ++ [Covariant]
                polIn = pol ++ [Invariant]
            setPolarity d $ polCo
            
            let check c = do
                  t <- defType <$> getConstInfo c
                  addContext (telFromList parTel) $ do
                    let pars = map (defaultArg . var) $ downFrom np
                    TelV conTel target <- telView =<< (t `piApplyM` pars)
                    case conTel of
                      EmptyTel  -> return False  
                      ExtendTel arg  tel ->
                        ifM ((/= Just BoundedNo) <$> isSizeType (unDom arg)) (return False) $ do 
                          
                          
                          let isPos = underAbstraction arg tel $ \ tel -> do
                                pols <- zipWithM polarity [0..] $ map (snd . unDom) $ telToList tel
                                reportSDoc "tc.polarity.size" 25 $
                                  text $ "to pass size polarity check, the following polarities need all to be covariant: " ++ prettyShow pols
                                return $ all (`elem` [Nonvariant, Covariant]) pols
                          
                          
                          let sizeArg = size tel
                              isLin = addContext conTel $ checkSizeIndex d sizeArg target
                          ok <- isPos `and2M` isLin
                          reportSDoc "tc.polarity.size" 15 $
                            "constructor" <+> prettyTCM c <+>
                            text (if ok then "passes" else "fails") <+>
                            "size polarity check"
                          return ok
            ifNotM (andM $ map check cons)
                (return polIn) 
              $ do  
                
                
                
                modifyArgOccurrences d $ \ occ -> take np occ ++ [JustPos]
                return polCo
      _ -> exit
checkSizeIndex :: QName -> Nat -> Type -> TCM Bool
checkSizeIndex d i a = do
  reportSDoc "tc.polarity.size" 15 $ withShowAllArguments $ vcat
    [ "checking that constructor target type " <+> prettyTCM a
    , "  is data type " <+> prettyTCM d
    , "  and has size index (successor(s) of) " <+> prettyTCM (var i)
    ]
  case unEl a of
    Def d0 es -> do
      whenNothingM (sameDef d d0) __IMPOSSIBLE__
      np <- fromMaybe __IMPOSSIBLE__ <$> getNumberOfParameters d0
      let (pars, Apply ix : ixs) = splitAt np es
      s <- deepSizeView $ unArg ix
      case s of
        DSizeVar j _ | i == j
          -> return $ not $ freeIn i (pars ++ ixs)
        _ -> return False
    _ -> __IMPOSSIBLE__
class HasPolarity a where
  polarities :: Nat -> a -> TCM [Polarity]
polarity :: HasPolarity a => Nat -> a -> TCM Polarity
polarity i x = do
  ps <- polarities i x
  case ps of
    [] -> return Nonvariant
    ps -> return $ foldr1 (/\) ps
instance HasPolarity a => HasPolarity (Arg a) where
  polarities i = polarities i . unArg
instance HasPolarity a => HasPolarity (Dom a) where
  polarities i = polarities i . unDom
instance HasPolarity a => HasPolarity (Abs a) where
  polarities i (Abs   _ b) = polarities (i + 1) b
  polarities i (NoAbs _ v) = polarities i v
instance HasPolarity a => HasPolarity [a] where
  polarities i xs = concat <$> mapM (polarities i) xs
instance (HasPolarity a, HasPolarity b) => HasPolarity (a, b) where
  polarities i (x, y) = (++) <$> polarities i x <*> polarities i y
instance HasPolarity Type where
  polarities i (El _ v) = polarities i v
instance HasPolarity a => HasPolarity (Elim' a) where
  polarities i Proj{}    = return []
  polarities i (Apply a) = polarities i a
  polarities i (IApply x y a) = polarities i (x,(y,a))
instance HasPolarity Term where
  polarities i v = do
   v <- instantiate v
   case v of
    
    
    Var n ts  | n == i -> (Covariant :) . map (const Invariant) <$> polarities i ts
              | otherwise -> map (const Invariant) <$> polarities i ts
    Lam _ t    -> polarities i t
    Lit _      -> return []
    Level l    -> polarities i l
    Def x ts   -> do
      pols <- getPolarity x
      let compose p ps = map (composePol p) ps
      concat . zipWith compose (pols ++ repeat Invariant) <$> mapM (polarities i) ts
    Con _ _ ts -> polarities i ts 
    Pi a b     -> (++) <$> (map neg <$> polarities i a) <*> polarities i b
    Sort s     -> return [] 
    MetaV _ ts -> map (const Invariant) <$> polarities i ts
    DontCare t -> polarities i t 
    Dummy{}    -> return []
instance HasPolarity Level where
  polarities i (Max _ as) = polarities i as
instance HasPolarity PlusLevel where
  polarities i (Plus _ l) = polarities i l
instance HasPolarity LevelAtom where
  polarities i l = case l of
    MetaLevel _ vs   -> map (const Invariant) <$> polarities i vs
    BlockedLevel _ v -> polarities i v
    NeutralLevel _ v -> polarities i v
    UnreducedLevel v -> polarities i v