{-# LANGUAGE CPP #-}
module Agda.TypeChecking.Polarity where
import Control.Monad.State
import Data.Maybe
import Data.Traversable (traverse)
import Agda.Syntax.Abstract.Name
import Agda.Syntax.Common
import Agda.Syntax.Internal
import Agda.Syntax.Internal.Pattern
import Agda.TypeChecking.Monad
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.SizedTypes
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Free hiding (Occurrence(..))
import Agda.TypeChecking.Positivity.Occurrence
import Agda.Interaction.Options
import Agda.Utils.List
import Agda.Utils.Maybe ( whenNothingM )
import Agda.Utils.Monad
import Agda.Utils.Permutation
import Agda.Utils.Pretty ( prettyShow )
import Agda.Utils.Size
#include "undefined.h"
import Agda.Utils.Impossible
(/\) :: Polarity -> Polarity -> Polarity
Nonvariant /\ b = b
a /\ Nonvariant = a
a /\ b | a == b = a
| otherwise = Invariant
neg :: Polarity -> Polarity
neg Covariant = Contravariant
neg Contravariant = Covariant
neg Invariant = Invariant
neg Nonvariant = Nonvariant
composePol :: Polarity -> Polarity -> Polarity
composePol Nonvariant _ = Nonvariant
composePol _ Nonvariant = Nonvariant
composePol Invariant _ = Invariant
composePol Covariant x = x
composePol Contravariant x = neg x
polFromOcc :: Occurrence -> Polarity
polFromOcc o = case o of
GuardPos -> Covariant
StrictPos -> Covariant
JustPos -> Covariant
JustNeg -> Contravariant
Mixed -> Invariant
Unused -> Nonvariant
nextPolarity :: [Polarity] -> (Polarity, [Polarity])
nextPolarity [] = (Invariant, [])
nextPolarity (p : ps) = (p, ps)
purgeNonvariant :: [Polarity] -> [Polarity]
purgeNonvariant = map (\ p -> if p == Nonvariant then Covariant else p)
polarityFromPositivity :: QName -> TCM ()
polarityFromPositivity x = inConcreteOrAbstractMode x $ \ def -> do
let npars = droppedPars def
let pol0 = replicate npars Nonvariant ++ map polFromOcc (defArgOccurrences def)
reportSLn "tc.polarity.set" 15 $
"Polarity of " ++ prettyShow x ++ " from positivity: " ++ prettyShow pol0
setPolarity x $ drop npars pol0
computePolarity :: [QName] -> TCM ()
computePolarity xs = do
when (length xs >= 2) $ mapM_ polarityFromPositivity xs
forM_ xs $ \ x -> inConcreteOrAbstractMode x $ \ def -> do
reportSLn "tc.polarity.set" 25 $ "Refining polarity of " ++ prettyShow x
let npars = droppedPars def
let pol0 = replicate npars Nonvariant ++ map polFromOcc (defArgOccurrences def)
reportSLn "tc.polarity.set" 15 $
"Polarity of " ++ prettyShow x ++ " from positivity: " ++ prettyShow pol0
pol1 <- sizePolarity x pol0
let t = defType def
reportSDoc "tc.polarity.set" 15 $
text "Refining polarity with type " <+> prettyTCM t
reportSDoc "tc.polarity.set" 60 $
text "Refining polarity with type (raw): " <+> (text .show) t
pol <- dependentPolarity t (enablePhantomTypes (theDef def) pol1) pol1
reportSLn "tc.polarity.set" 10 $ "Polarity of " ++ prettyShow x ++ ": " ++ prettyShow pol
setPolarity x $ drop npars pol
enablePhantomTypes :: Defn -> [Polarity] -> [Polarity]
enablePhantomTypes def pol = case def of
Datatype{ dataPars = np } -> enable np
Record { recPars = np } -> enable np
_ -> pol
where enable np = let (pars, rest) = splitAt np pol
in purgeNonvariant pars ++ rest
dependentPolarity :: Type -> [Polarity] -> [Polarity] -> TCM [Polarity]
dependentPolarity t _ [] = return []
dependentPolarity t [] (_ : _) = __IMPOSSIBLE__
dependentPolarity t (q:qs) pols@(p:ps) = do
t <- reduce $ unEl t
reportSDoc "tc.polarity.dep" 20 $ text "dependentPolarity t = " <+> prettyTCM t
reportSDoc "tc.polarity.dep" 70 $ text "dependentPolarity t = " <+> (text . show) t
case t of
Pi dom b -> do
ps <- underAbstraction dom b $ \ c -> dependentPolarity c qs ps
let fallback = ifM (isJust <$> isSizeType (unDom dom)) (return p) (return q)
p <- case b of
Abs{} | p /= Invariant ->
ifM (relevantInIgnoringNonvariant 0 (absBody b) ps)
(return Invariant)
fallback
_ -> fallback
return $ p : ps
_ -> return pols
relevantInIgnoringNonvariant :: Nat -> Type -> [Polarity] -> TCM Bool
relevantInIgnoringNonvariant i t [] = return $ i `relevantInIgnoringSortAnn` t
relevantInIgnoringNonvariant i t (p:ps) = do
t <- reduce $ unEl t
case t of
Pi a b -> if p /= Nonvariant && i `relevantInIgnoringSortAnn` a then return True
else relevantInIgnoringNonvariant (i + 1) (absBody b) ps
_ -> return $ i `relevantInIgnoringSortAnn` t
sizePolarity :: QName -> [Polarity] -> TCM [Polarity]
sizePolarity d pol0 = do
let exit = return pol0
ifM (not . optSizedTypes <$> pragmaOptions) exit $ do
def <- getConstInfo d
case theDef def of
Datatype{ dataPars = np, dataCons = cons } -> do
let TelV tel _ = telView' $ defType def
(parTel, ixTel) = splitAt np $ telToList tel
case ixTel of
[] -> exit
Dom _ (_, a) : _ -> ifM ((/= Just BoundedNo) <$> isSizeType a) exit $ do
let pol = take np pol0
polCo = pol ++ [Covariant]
polIn = pol ++ [Invariant]
setPolarity d $ polCo
let check c = do
t <- defType <$> getConstInfo c
addContext (telFromList parTel) $ do
let pars = map (defaultArg . var) $ downFrom np
TelV conTel target <- telView =<< (t `piApplyM` pars)
case conTel of
EmptyTel -> return False
ExtendTel arg tel ->
ifM ((/= Just BoundedNo) <$> isSizeType (unDom arg)) (return False) $ do
let isPos = underAbstraction arg tel $ \ tel -> do
pols <- zipWithM polarity [0..] $ map (snd . unDom) $ telToList tel
reportSDoc "tc.polarity.size" 25 $
text $ "to pass size polarity check, the following polarities need all to be covariant: " ++ prettyShow pols
return $ all (`elem` [Nonvariant, Covariant]) pols
let sizeArg = size tel
isLin = addContext conTel $ checkSizeIndex d np sizeArg target
ok <- isPos `and2M` isLin
reportSDoc "tc.polarity.size" 15 $
text "constructor" <+> prettyTCM c <+>
text (if ok then "passes" else "fails") <+>
text "size polarity check"
return ok
ifNotM (andM $ map check cons)
(return polIn)
$ do
modifyArgOccurrences d $ \ occ -> take np occ ++ [JustPos]
return polCo
_ -> exit
checkSizeIndex :: QName -> Nat -> Nat -> Type -> TCM Bool
checkSizeIndex d np i a = do
reportSDoc "tc.polarity.size" 15 $ withShowAllArguments $ vcat
[ text "checking that constructor target type " <+> prettyTCM a
, text " is data type " <+> prettyTCM d
, text " and has size index (successor(s) of) " <+> prettyTCM (var i)
]
case unEl a of
Def d0 es -> do
whenNothingM (sameDef d d0) __IMPOSSIBLE__
s <- deepSizeView $ unArg ix
case s of
DSizeVar j _ | i == j
-> return $ not $ freeIn i (pars ++ ixs)
_ -> return False
where
(pars, Apply ix : ixs) = splitAt np es
_ -> __IMPOSSIBLE__
class HasPolarity a where
polarities :: Nat -> a -> TCM [Polarity]
polarity :: HasPolarity a => Nat -> a -> TCM Polarity
polarity i x = do
ps <- polarities i x
case ps of
[] -> return Nonvariant
ps -> return $ foldr1 (/\) ps
instance HasPolarity a => HasPolarity (Arg a) where
polarities i = polarities i . unArg
instance HasPolarity a => HasPolarity (Dom a) where
polarities i = polarities i . unDom
instance HasPolarity a => HasPolarity (Abs a) where
polarities i (Abs _ b) = polarities (i + 1) b
polarities i (NoAbs _ v) = polarities i v
instance HasPolarity a => HasPolarity [a] where
polarities i xs = concat <$> mapM (polarities i) xs
instance (HasPolarity a, HasPolarity b) => HasPolarity (a, b) where
polarities i (x, y) = (++) <$> polarities i x <*> polarities i y
instance HasPolarity Type where
polarities i (El _ v) = polarities i v
instance HasPolarity a => HasPolarity (Elim' a) where
polarities i Proj{} = return []
polarities i (Apply a) = polarities i a
instance HasPolarity Term where
polarities i v = do
v <- instantiate v
case v of
Var n ts | n == i -> (Covariant :) . map (const Invariant) <$> polarities i ts
| otherwise -> map (const Invariant) <$> polarities i ts
Lam _ t -> polarities i t
Lit _ -> return []
Level l -> polarities i l
Def x ts -> do
pols <- getPolarity x
let compose p ps = map (composePol p) ps
concat . zipWith compose (pols ++ repeat Invariant) <$> mapM (polarities i) ts
Con _ _ ts -> polarities i ts
Pi a b -> (++) <$> (map neg <$> polarities i a) <*> polarities i b
Sort s -> return []
MetaV _ ts -> map (const Invariant) <$> polarities i ts
DontCare t -> polarities i t
instance HasPolarity Level where
polarities i (Max as) = polarities i as
instance HasPolarity PlusLevel where
polarities i ClosedLevel{} = return []
polarities i (Plus _ l) = polarities i l
instance HasPolarity LevelAtom where
polarities i l = case l of
MetaLevel _ vs -> map (const Invariant) <$> polarities i vs
BlockedLevel _ v -> polarities i v
NeutralLevel _ v -> polarities i v
UnreducedLevel v -> polarities i v