module Agda.Compiler.ToTreeless
( toTreeless
, closedTermToTreeless
) where
import Control.Arrow (first)
import Control.Monad.Reader
import Data.Maybe
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Traversable (traverse)
import Agda.Syntax.Common
import Agda.Syntax.Internal as I
import Agda.Syntax.Literal
import qualified Agda.Syntax.Treeless as C
import Agda.Syntax.Treeless (TTerm, EvaluationStrategy)
import Agda.TypeChecking.CompiledClause as CC
import qualified Agda.TypeChecking.CompiledClause.Compile as CC
import Agda.TypeChecking.EtaContract (binAppView, BinAppView(..))
import Agda.TypeChecking.Monad as TCM
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Records (getRecordConstructor)
import Agda.TypeChecking.Reduce
import Agda.TypeChecking.Substitute
import Agda.Compiler.Treeless.AsPatterns
import Agda.Compiler.Treeless.Builtin
import Agda.Compiler.Treeless.Erase
import Agda.Compiler.Treeless.Identity
import Agda.Compiler.Treeless.Simplify
import Agda.Compiler.Treeless.Uncase
import Agda.Compiler.Treeless.Unused
import Agda.Utils.Function
import Agda.Utils.Functor
import Agda.Utils.Lens
import Agda.Utils.List
import Agda.Utils.Maybe
import Agda.Utils.Monad
import Agda.Utils.Pretty (prettyShow)
import qualified Agda.Utils.Pretty as P
import qualified Agda.Utils.SmallSet as SmallSet
import Agda.Utils.Impossible
prettyPure :: P.Pretty a => a -> TCM Doc
prettyPure = return . P.pretty
getCompiledClauses :: QName -> TCM CC.CompiledClauses
getCompiledClauses q = do
def <- getConstInfo q
let cs = defClauses def
isProj | Function{ funProjection = proj } <- theDef def = isJust (projProper =<< proj)
| otherwise = False
translate | isProj = CC.DontRunRecordPatternTranslation
| otherwise = CC.RunRecordPatternTranslation
reportSDoc "treeless.convert" 40 $ "-- before clause compiler" $$ (pretty q <+> "=") <?> vcat (map pretty cs)
let mst = funSplitTree $ theDef def
reportSDoc "treeless.convert" 70 $
caseMaybe mst "-- not using split tree" $ \st ->
"-- using split tree" $$ pretty st
CC.compileClauses' translate cs mst
toTreeless :: EvaluationStrategy -> QName -> TCM (Maybe C.TTerm)
toTreeless eval q = ifM (alwaysInline q) (pure Nothing) $ Just <$> toTreeless' eval q
toTreeless' :: EvaluationStrategy -> QName -> TCM C.TTerm
toTreeless' eval q =
flip fromMaybeM (getTreeless q) $ verboseBracket "treeless.convert" 20 ("compiling " ++ prettyShow q) $ do
cc <- getCompiledClauses q
unlessM (alwaysInline q) $ setTreeless q (C.TDef q)
ccToTreeless eval q cc
cacheTreeless :: EvaluationStrategy -> QName -> TCM ()
cacheTreeless eval q = do
def <- theDef <$> getConstInfo q
case def of
Function{} -> () <$ toTreeless' eval q
_ -> return ()
ccToTreeless :: EvaluationStrategy -> QName -> CC.CompiledClauses -> TCM C.TTerm
ccToTreeless eval q cc = do
let pbody b = pbody' "" b
pbody' suf b = sep [ text (prettyShow q ++ suf) <+> "=", nest 2 $ prettyPure b ]
v <- ifM (alwaysInline q) (return 20) (return 0)
reportSDoc "treeless.convert" (30 + v) $ "-- compiled clauses of" <+> prettyTCM q $$ nest 2 (prettyPure cc)
body <- casetreeTop eval cc
reportSDoc "treeless.opt.converted" (30 + v) $ "-- converted" $$ pbody body
body <- runPipeline eval q (compilerPipeline v q) body
used <- usedArguments q body
when (any not used) $
reportSDoc "treeless.opt.unused" (30 + v) $
"-- used args:" <+> hsep [ if u then text [x] else "_" | (x, u) <- zip ['a'..] used ] $$
pbody' "[stripped]" (stripUnusedArguments used body)
reportSDoc "treeless.opt.final" (20 + v) $ pbody body
setTreeless q body
setCompiledArgUse q used
return body
data Pipeline = FixedPoint Int Pipeline
| Sequential [Pipeline]
| SinglePass CompilerPass
data CompilerPass = CompilerPass
{ passTag :: String
, passVerbosity :: Int
, passName :: String
, passCode :: EvaluationStrategy -> TTerm -> TCM TTerm
}
compilerPass :: String -> Int -> String -> (EvaluationStrategy -> TTerm -> TCM TTerm) -> Pipeline
compilerPass tag v name code = SinglePass (CompilerPass tag v name code)
compilerPipeline :: Int -> QName -> Pipeline
compilerPipeline v q =
Sequential
[ compilerPass "simpl" (35 + v) "simplification" $ const simplifyTTerm
, compilerPass "builtin" (30 + v) "builtin translation" $ const translateBuiltins
, FixedPoint 5 $ Sequential
[ compilerPass "simpl" (30 + v) "simplification" $ const simplifyTTerm
, compilerPass "erase" (30 + v) "erasure" $ eraseTerms q
, compilerPass "uncase" (30 + v) "uncase" $ const caseToSeq
, compilerPass "aspat" (30 + v) "@-pattern recovery" $ const recoverAsPatterns
]
, compilerPass "id" (30 + v) "identity function detection" $ const (detectIdentityFunctions q)
]
runPipeline :: EvaluationStrategy -> QName -> Pipeline -> TTerm -> TCM TTerm
runPipeline eval q pipeline t = case pipeline of
SinglePass p -> runCompilerPass eval q p t
Sequential ps -> foldM (flip $ runPipeline eval q) t ps
FixedPoint n p -> runFixedPoint n eval q p t
runCompilerPass :: EvaluationStrategy -> QName -> CompilerPass -> TTerm -> TCM TTerm
runCompilerPass eval q p t = do
t' <- passCode p eval t
let dbg f = reportSDoc ("treeless.opt." ++ passTag p) (passVerbosity p) $ f $ text ("-- " ++ passName p)
pbody b = sep [ text (prettyShow q) <+> "=", nest 2 $ prettyPure b ]
dbg $ if | t == t' -> (<+> "(No effect)")
| otherwise -> ($$ pbody t')
return t'
runFixedPoint :: Int -> EvaluationStrategy -> QName -> Pipeline -> TTerm -> TCM TTerm
runFixedPoint n eval q pipeline = go 1
where
go i t | i > n = do
reportSLn "treeless.opt.loop" 20 $ "++ Optimisation loop reached maximum iterations (" ++ show n ++ ")"
return t
go i t = do
reportSLn "treeless.opt.loop" 30 $ "++ Optimisation loop iteration " ++ show i
t' <- runPipeline eval q pipeline t
if | t == t' -> do
reportSLn "treeless.opt.loop" 30 $ "++ Optimisation loop terminating after " ++ show i ++ " iterations"
return t'
| otherwise -> go (i + 1) t'
closedTermToTreeless :: EvaluationStrategy -> I.Term -> TCM C.TTerm
closedTermToTreeless eval t = do
substTerm t `runReaderT` initCCEnv eval
alwaysInline :: QName -> TCM Bool
alwaysInline q = do
def <- theDef <$> getConstInfo q
pure $ case def of
Function{} -> isJust (funExtLam def) || isJust (funWith def)
_ -> False
initCCEnv :: EvaluationStrategy -> CCEnv
initCCEnv eval = CCEnv
{ ccCxt = []
, ccCatchAll = Nothing
, ccEvaluation = eval
}
data CCEnv = CCEnv
{ ccCxt :: CCContext
, ccCatchAll :: Maybe Int
, ccEvaluation :: EvaluationStrategy
}
type CCContext = [Int]
type CC = ReaderT CCEnv TCM
shift :: Int -> CCContext -> CCContext
shift n = map (+n)
lookupIndex :: Int
-> CCContext
-> Int
lookupIndex i xs = fromMaybe __IMPOSSIBLE__ $ xs !!! i
lookupLevel :: Int
-> CCContext
-> Int
lookupLevel l xs = fromMaybe __IMPOSSIBLE__ $ xs !!! (length xs - 1 - l)
casetreeTop :: EvaluationStrategy -> CC.CompiledClauses -> TCM C.TTerm
casetreeTop eval cc = flip runReaderT (initCCEnv eval) $ do
let a = commonArity cc
lift $ reportSLn "treeless.convert.arity" 40 $ "-- common arity: " ++ show a
lambdasUpTo a $ casetree cc
casetree :: CC.CompiledClauses -> CC C.TTerm
casetree cc = do
case cc of
CC.Fail -> return C.tUnreachable
CC.Done xs v -> withContextSize (length xs) $ do
v <- lift (putAllowedReductions (SmallSet.fromList [ProjectionReductions, CopatternReductions]) $ normalise v)
substTerm v
CC.Case _ (CC.Branches True _ _ _ Just{} _ _) -> __IMPOSSIBLE__
CC.Case (Arg _ n) (CC.Branches True conBrs _ _ Nothing _ _) -> lambdasUpTo n $ do
mkRecord =<< traverse casetree (CC.content <$> conBrs)
CC.Case (Arg _ n) (CC.Branches False conBrs etaBr litBrs catchAll _ lazy) -> lambdasUpTo (n + 1) $ do
let conBrs' = Map.union conBrs $ Map.fromList $ map (first conName) $ maybeToList etaBr
if Map.null conBrs' && Map.null litBrs then do
updateCatchAll catchAll fromCatchAll
else do
caseTy <- case (Map.keys conBrs', Map.keys litBrs) of
((c:_), []) -> do
c' <- lift (canonicalName c)
dtNm <- conData . theDef <$> lift (getConstInfo c')
return $ C.CTData dtNm
([], (LitChar _ _):_) -> return C.CTChar
([], (LitString _ _):_) -> return C.CTString
([], (LitFloat _ _):_) -> return C.CTFloat
([], (LitQName _ _):_) -> return C.CTQName
_ -> __IMPOSSIBLE__
updateCatchAll catchAll $ do
x <- lookupLevel n <$> asks ccCxt
def <- fromCatchAll
let caseInfo = C.CaseInfo { caseType = caseTy, caseLazy = lazy }
C.TCase x caseInfo def <$> do
br1 <- conAlts n conBrs'
br2 <- litAlts n litBrs
return (br1 ++ br2)
where
fromCatchAll :: CC C.TTerm
fromCatchAll = maybe C.tUnreachable C.TVar <$> asks ccCatchAll
commonArity :: CC.CompiledClauses -> Int
commonArity cc =
case arities 0 cc of
[] -> 0
as -> minimum as
where
arities cxt (Case (Arg _ x) (Branches False cons eta lits def _ _)) =
concatMap (wArities cxt') (Map.elems cons) ++
concatMap (wArities cxt') (map snd $ maybeToList eta) ++
concatMap (wArities cxt' . WithArity 0) (Map.elems lits) ++
concat [ arities cxt' c | Just c <- [def] ]
where cxt' = max (x + 1) cxt
arities cxt (Case _ Branches{projPatterns = True}) = [cxt]
arities cxt (Done xs _) = [max cxt (length xs)]
arities _ Fail = []
wArities cxt (WithArity k c) = map (\ x -> x - k + 1) $ arities (cxt - 1 + k) c
updateCatchAll :: Maybe CC.CompiledClauses -> (CC C.TTerm -> CC C.TTerm)
updateCatchAll Nothing cont = cont
updateCatchAll (Just cc) cont = do
def <- casetree cc
local (\e -> e { ccCatchAll = Just 0, ccCxt = shift 1 (ccCxt e) }) $ do
C.mkLet def <$> cont
withContextSize :: Int -> CC C.TTerm -> CC C.TTerm
withContextSize n cont = do
diff <- (n -) . length <$> asks ccCxt
if diff <= 0
then do
let diff' = -diff
local (\e -> e { ccCxt = shift diff . drop diff' $ ccCxt e }) $
cont <&> (`C.mkTApp` map C.TVar (downFrom diff'))
else do
local (\e -> e { ccCxt = [0..(diff - 1)] ++ shift diff (ccCxt e)}) $ do
createLambdas diff <$> do
cont
where createLambdas :: Int -> C.TTerm -> C.TTerm
createLambdas 0 cont' = cont'
createLambdas i cont' | i > 0 = C.TLam (createLambdas (i - 1) cont')
createLambdas _ _ = __IMPOSSIBLE__
lambdasUpTo :: Int -> CC C.TTerm -> CC C.TTerm
lambdasUpTo n cont = do
diff <- (n -) . length <$> asks ccCxt
if diff <= 0 then cont
else do
catchAll <- asks ccCatchAll
withContextSize n $ do
case catchAll of
Just catchAll' -> do
local (\e -> e { ccCatchAll = Just 0
, ccCxt = shift 1 (ccCxt e)}) $ do
let catchAllArgs = map C.TVar $ downFrom diff
C.mkLet (C.mkTApp (C.TVar $ catchAll' + diff) catchAllArgs)
<$> cont
Nothing -> cont
conAlts :: Int -> Map QName (CC.WithArity CC.CompiledClauses) -> CC [C.TAlt]
conAlts x br = forM (Map.toList br) $ \ (c, CC.WithArity n cc) -> do
c' <- lift $ canonicalName c
replaceVar x n $ do
branch (C.TACon c' n) cc
litAlts :: Int -> Map Literal CC.CompiledClauses -> CC [C.TAlt]
litAlts x br = forM (Map.toList br) $ \ (l, cc) ->
replaceVar x 0 $ do
branch (C.TALit l ) cc
branch :: (C.TTerm -> C.TAlt) -> CC.CompiledClauses -> CC C.TAlt
branch alt cc = alt <$> casetree cc
replaceVar :: Int -> Int -> CC a -> CC a
replaceVar x n cont = do
let upd cxt = shift n ys ++ ixs ++ shift n zs
where
i = length cxt - 1 - x
(ys, _:zs) = splitAt i cxt
ixs = [0..(n - 1)]
local (\e -> e { ccCxt = upd (ccCxt e) , ccCatchAll = (+n) <$> ccCatchAll e }) $
cont
mkRecord :: Map QName C.TTerm -> CC C.TTerm
mkRecord fs = lift $ do
let p1 = fst $ headWithDefault __IMPOSSIBLE__ $ Map.toList fs
I.ConHead c _ind xs <- conSrcCon . theDef <$> (getConstInfo =<< canonicalName . I.conName =<< recConFromProj p1)
reportSDoc "treeless.convert.mkRecord" 60 $ vcat
[ text "record constructor fields: xs = " <+> (text . show) xs
, text "to be filled with content: keys fs = " <+> (text . show) (Map.keys fs)
]
let (args :: [C.TTerm]) = for xs $ \ x -> Map.findWithDefault __IMPOSSIBLE__ (unArg x) fs
return $ C.mkTApp (C.TCon c) args
recConFromProj :: QName -> TCM I.ConHead
recConFromProj q = do
caseMaybeM (isProjection q) __IMPOSSIBLE__ $ \ proj -> do
let d = unArg $ projFromType proj
getRecordConstructor d
substTerm :: I.Term -> CC C.TTerm
substTerm term = normaliseStatic term >>= \ term ->
case I.unSpine $ etaContractErased term of
I.Var ind es -> do
ind' <- lookupIndex ind <$> asks ccCxt
let args = fromMaybe __IMPOSSIBLE__ $ I.allApplyElims es
C.mkTApp (C.TVar ind') <$> substArgs args
I.Lam _ ab ->
C.TLam <$>
local (\e -> e { ccCxt = 0 : (shift 1 $ ccCxt e) })
(substTerm $ I.unAbs ab)
I.Lit l -> return $ C.TLit l
I.Level _ -> return C.TUnit
I.Def q es -> do
let args = fromMaybe __IMPOSSIBLE__ $ I.allApplyElims es
maybeInlineDef q args
I.Con c ci es -> do
let args = fromMaybe __IMPOSSIBLE__ $ I.allApplyElims es
c' <- lift $ canonicalName $ I.conName c
C.mkTApp (C.TCon c') <$> substArgs args
I.Pi _ _ -> return C.TUnit
I.Sort _ -> return C.TSort
I.MetaV _ _ -> __IMPOSSIBLE__
I.DontCare _ -> return C.TErased
I.Dummy{} -> __IMPOSSIBLE__
etaContractErased :: I.Term -> I.Term
etaContractErased = trampoline etaErasedOnce
where
etaErasedOnce :: I.Term -> Either I.Term I.Term
etaErasedOnce t =
case t of
I.Lam _ (NoAbs _ v) ->
case binAppView v of
App u arg | not (usableModality arg) -> Right u
_ -> done
I.Lam ai (Abs _ v) | not (usableModality ai) ->
case binAppView v of
App u arg | not (usableModality arg) -> Right $ subst 0 (DontCare __DUMMY_TERM__) u
_ -> done
_ -> done
where
done = Left t
normaliseStatic :: I.Term -> CC I.Term
normaliseStatic v@(I.Def f es) = lift $ do
static <- isStaticFun . theDef <$> getConstInfo f
if static then normalise v else pure v
normaliseStatic v = pure v
maybeInlineDef :: I.QName -> I.Args -> CC C.TTerm
maybeInlineDef q vs = do
eval <- asks ccEvaluation
ifM (lift $ alwaysInline q) (doinline eval) $ do
lift $ cacheTreeless eval q
def <- lift $ getConstInfo q
case theDef def of
fun@Function{}
| fun ^. funInline -> doinline eval
| otherwise -> do
used <- lift $ getCompiledArgUse q
let substUsed False _ = pure C.TErased
substUsed True arg = substArg arg
C.mkTApp (C.TDef q) <$> sequence [ substUsed u arg | (arg, u) <- zip vs $ used ++ repeat True ]
_ -> C.mkTApp (C.TDef q) <$> substArgs vs
where
doinline eval = C.mkTApp <$> inline eval q <*> substArgs vs
inline eval q = lift $ toTreeless' eval q
substArgs :: [Arg I.Term] -> CC [C.TTerm]
substArgs = traverse substArg
substArg :: Arg I.Term -> CC C.TTerm
substArg x | usableModality x = substTerm (unArg x)
| otherwise = return C.TErased