{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternSynonyms #-}
module Agda.Compiler.Treeless.Erase (eraseTerms, computeErasedConstructorArgs) where
import Control.Arrow ((&&&), (***), first, second)
import Control.Monad
import Control.Monad.State
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Semigroup
import Agda.Syntax.Common
import Agda.Syntax.Internal as I
import Agda.Syntax.Abstract.Name (QName)
import Agda.Syntax.Position
import Agda.Syntax.Treeless
import Agda.Syntax.Literal
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Monad as I
import Agda.TypeChecking.Monad.Builtin
import Agda.TypeChecking.Telescope
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Datatypes
import Agda.TypeChecking.Pretty hiding ((<>))
import Agda.TypeChecking.Primitive
import Agda.Compiler.Treeless.Subst
import Agda.Compiler.Treeless.Pretty
import Agda.Compiler.Treeless.Unused
import Agda.Utils.Functor
import Agda.Utils.Lens
import Agda.Utils.Maybe
import Agda.Utils.Memo
import Agda.Utils.Monad
import Agda.Utils.Pretty (prettyShow)
import qualified Agda.Utils.Pretty as P
import Agda.Utils.IntSet.Infinite (IntSet)
import qualified Agda.Utils.IntSet.Infinite as IntSet
#include "undefined.h"
import Agda.Utils.Impossible
data ESt = ESt { _funMap :: Map QName FunInfo
, _typeMap :: Map QName TypeInfo }
funMap :: Lens' (Map QName FunInfo) ESt
funMap f r = f (_funMap r) <&> \ a -> r { _funMap = a }
typeMap :: Lens' (Map QName TypeInfo) ESt
typeMap f r = f (_typeMap r) <&> \ a -> r { _typeMap = a }
type E = StateT ESt TCM
runE :: E a -> TCM a
runE m = evalStateT m (ESt Map.empty Map.empty)
computeErasedConstructorArgs :: QName -> TCM ()
computeErasedConstructorArgs d = do
cs <- getConstructors d
runE $ mapM_ getFunInfo cs
eraseTerms :: QName -> EvaluationStrategy -> TTerm -> TCM TTerm
eraseTerms q eval t = usedArguments q t *> runE (eraseTop q t)
where
eraseTop q t = do
(_, h) <- getFunInfo q
case h of
Erasable -> pure TErased
Empty -> pure TErased
_ -> erase t
erase t = case tAppView t of
TCon c : vs -> do
(rs, h) <- getFunInfo c
when (length rs < length vs) __IMPOSSIBLE__
case h of
Erasable -> pure TErased
Empty -> pure TErased
_ -> tApp (TCon c) <$> zipWithM eraseRel rs vs
TDef f : vs -> do
(rs, h) <- getFunInfo f
case h of
Erasable -> pure TErased
Empty -> pure TErased
_ -> tApp (TDef f) <$> zipWithM eraseRel (rs ++ repeat NotErasable) vs
_ -> case t of
TVar{} -> pure t
TDef{} -> pure t
TPrim{} -> pure t
TLit{} -> pure t
TCon{} -> pure t
TApp f es -> tApp <$> erase f <*> mapM erase es
TLam b -> tLam <$> erase b
TLet e b -> do
e <- erase e
if isErased e
then case b of
TCase 0 _ _ _ -> tLet TErased <$> erase b
_ -> erase $ subst 0 TErased b
else tLet e <$> erase b
TCase x t d bs -> do
(d, bs) <- pruneUnreachable x (caseType t) d bs
d <- erase d
bs <- mapM eraseAlt bs
tCase x t d bs
TUnit -> pure t
TSort -> pure t
TErased -> pure t
TError{} -> pure t
TCoerce e -> TCoerce <$> erase e
tLam TErased | eval == LazyEvaluation = TErased
tLam t = TLam t
tLet e b
| freeIn 0 b = TLet e b
| otherwise = strengthen __IMPOSSIBLE__ b
tApp f [] = f
tApp TErased _ = TErased
tApp f _ | isUnreachable f = tUnreachable
tApp f es = TApp f es
tCase x t d bs
| isErased d && all (isErased . aBody) bs = pure TErased
| otherwise = case bs of
[TACon c a b] -> do
h <- snd <$> getFunInfo c
case h of
NotErasable -> noerase
Empty -> pure TErased
Erasable -> (if a == 0 then pure else erase) $ applySubst (replicate a TErased ++# idS) b
_ -> noerase
where
noerase = pure $ TCase x t d bs
isErased t = t == TErased || isUnreachable t
eraseRel r t | erasable r = pure TErased
| otherwise = erase t
eraseAlt a = case a of
TALit l b -> TALit l <$> erase b
TACon c a b -> do
rs <- map erasable . fst <$> getFunInfo c
let sub = foldr (\ e -> if e then (TErased :#) . wkS 1 else liftS 1) idS $ reverse rs
TACon c a <$> erase (applySubst sub b)
TAGuard g b -> TAGuard <$> erase g <*> erase b
pruneUnreachable :: Int -> CaseType -> TTerm -> [TAlt] -> E (TTerm, [TAlt])
pruneUnreachable _ (CTData q) d bs = do
cs <- lift $ getConstructors q
let complete =length cs == length [ b | b@TACon{} <- bs ]
let d' | complete = tUnreachable
| otherwise = d
return (d', bs)
pruneUnreachable x CTNat d bs = return $ pruneIntCase x d bs (IntSet.below 0)
pruneUnreachable x CTInt d bs = return $ pruneIntCase x d bs IntSet.empty
pruneUnreachable _ _ d bs = pure (d, bs)
pattern Below :: Range -> Int -> Integer -> TTerm
pattern Below r x n = TApp (TPrim PLt) [TVar x, TLit (LitNat r n)]
pattern Above :: Range -> Int -> Integer -> TTerm
pattern Above r x n = TApp (TPrim PGeq) [TVar x, TLit (LitNat r n)]
pruneIntCase :: Int -> TTerm -> [TAlt] -> IntSet -> (TTerm, [TAlt])
pruneIntCase x d bs cover = go bs cover
where
go [] cover
| cover == IntSet.full = (tUnreachable, [])
| otherwise = (d, [])
go (b : bs) cover =
case b of
TAGuard (Below _ y n) _ | x == y -> rec (IntSet.below n)
TAGuard (Above _ y n) _ | x == y -> rec (IntSet.above n)
TALit (LitNat _ n) _ -> rec (IntSet.singleton n)
_ -> second (b :) $ go bs cover
where
rec this = second addAlt $ go bs cover'
where
this' = IntSet.difference this cover
cover' = this' <> cover
addAlt = case IntSet.toFiniteList this' of
Just [] -> id
Just [n] -> (TALit (LitNat noRange n) (aBody b) :)
_ -> (b :)
data TypeInfo = Empty | Erasable | NotErasable
deriving (Eq, Show)
sumTypeInfo :: [TypeInfo] -> TypeInfo
sumTypeInfo is = foldr plus Empty is
where
plus Empty r = r
plus r Empty = r
plus Erasable r = r
plus r Erasable = r
plus NotErasable NotErasable = NotErasable
erasable :: TypeInfo -> Bool
erasable Erasable = True
erasable Empty = True
erasable NotErasable = False
type FunInfo = ([TypeInfo], TypeInfo)
getFunInfo :: QName -> E FunInfo
getFunInfo q = memo (funMap . key q) $ getInfo q
where
getInfo q = do
(rs, t) <- do
(tel, t) <- lift $ typeWithoutParams q
is <- mapM (getTypeInfo . snd . dget) tel
used <- lift $ (++ repeat True) <$> getCompiledArgUse q
forced <- lift $ (++ repeat NotForced) <$> getForcedArgs q
return (zipWith3 (uncurry . mkR . getModality) tel (zip forced used) is, t)
h <- if isAbsurdLambdaName q then pure Erasable else getTypeInfo t
lift $ reportSLn "treeless.opt.erase.info" 50 $ "type info for " ++ prettyShow q ++ ": " ++ show rs ++ " -> " ++ show h
lift $ setErasedConArgs q $ map erasable rs
return (rs, h)
mkR :: Modality -> IsForced -> Bool -> TypeInfo -> TypeInfo
mkR m f b i
| not (usableModality m) = Erasable
| not b = Erasable
| Forced <- f = Erasable
| otherwise = i
telListView :: Type -> TCM (ListTel, Type)
telListView t = do
TelV tel t <- telView t
return (telToList tel, t)
typeWithoutParams :: QName -> TCM (ListTel, Type)
typeWithoutParams q = do
def <- getConstInfo q
let d = case I.theDef def of
Function{ funProjection = Just Projection{ projIndex = i } } -> i - 1
Constructor{ conPars = n } -> n
_ -> 0
first (drop d) <$> telListView (defType def)
getTypeInfo :: Type -> E TypeInfo
getTypeInfo t0 = do
(tel, t) <- lift $ telListView t0
et <- case I.unEl t of
I.Def d _ -> do
oldMap <- use typeMap
dInfo <- typeInfo d
typeMap .= Map.insert d dInfo oldMap
return dInfo
Sort{} -> return Erasable
_ -> return NotErasable
is <- mapM (getTypeInfo . snd . dget) tel
let e | any (== Empty) is = Erasable
| null is = et
| et == Empty = Erasable
| otherwise = et
lift $ reportSDoc "treeless.opt.erase.type" 50 $ prettyTCM t0 <+> text ("is " ++ show e)
return e
where
typeInfo :: QName -> E TypeInfo
typeInfo q = memoRec (typeMap . key q) Erasable $ do
msizes <- lift $ mapM getBuiltinName
[builtinSize, builtinSizeLt]
def <- lift $ getConstInfo q
mcs <- return $ case I.theDef def of
I.Datatype{ dataCons = cs } -> Just cs
I.Record{ recConHead = c } -> Just [conName c]
_ -> Nothing
case mcs of
_ | Just q `elem` msizes -> return Erasable
Just [c] -> do
(ts, _) <- lift $ typeWithoutParams c
let rs = map getModality ts
is <- mapM (getTypeInfo . snd . dget) ts
let er = and [ erasable i || not (usableModality r) | (i, r) <- zip is rs ]
return $ if er then Erasable else NotErasable
Just [] -> return Empty
Just (_:_:_) -> return NotErasable
Nothing ->
case I.theDef def of
I.Function{ funClauses = cs } ->
sumTypeInfo <$> mapM (maybe (return Empty) (getTypeInfo . El __DUMMY_SORT__) . clauseBody) cs
_ -> return NotErasable