{-| Module : IRTS.LangOpts Description : Transformations to apply to Idris' IR. License : BSD3 Maintainer : The Idris Community. -} {-# LANGUAGE DeriveFunctor, PatternGuards #-} module IRTS.LangOpts(inlineAll) where import Idris.Core.CaseTree import Idris.Core.TT import IRTS.Lang import Control.Monad.State hiding (lift) import Data.List import Debug.Trace inlineAll :: [(Name, LDecl)] -> [(Name, LDecl)] inlineAll lds = let defs = addAlist lds emptyContext in map (\ (n, def) -> (n, doInline defs def)) lds nextN :: State Int Name nextN = do i <- get put (i + 1) return $ sMN i "in" -- | Inline inside a declaration. -- -- Variables are still Name at this stage. Need to preserve -- uniqueness of variable names in the resulting definition, so invent -- a new name for every variable we encounter doInline :: LDefs -> LDecl -> LDecl doInline defs d@(LConstructor _ _ _) = d doInline defs (LFun opts topn args exp) = let inl = evalState (eval [] initEnv [topn] defs exp) (length args) -- do some case floating, which might arise as a result -- then, eta contract res = eta $ caseFloats 10 inl in case res of LLam args' body -> LFun opts topn (map snd initNames ++ args') body _ -> LFun opts topn (map snd initNames) res where caseFloats 0 tm = tm caseFloats n tm = let res = caseFloat tm in if res == tm then res else caseFloats (n-1) res initNames = zipWith (\n i -> (n, newn n i)) args [0..] initEnv = map (\(n, n') -> (n, LV n')) initNames newn (UN n) i = MN i n newn _ i = sMN i "arg" unload :: [LExp] -> LExp -> LExp unload [] e = e unload stk (LApp tc e args) = LApp tc e (args ++ stk) unload stk e = LApp False e stk takeStk :: [(Name, LExp)] -> [Name] -> [LExp] -> ([(Name, LExp)], [Name], [LExp]) takeStk env (a : args) (v : stk) = takeStk ((a, v) : env) args stk takeStk env args stk = (env, args, stk) eval :: [LExp] -> [(Name, LExp)] -> [Name] -> LDefs -> LExp -> State Int LExp eval stk env rec defs (LLazyApp n es) = unload stk <$> LLazyApp n <$> (mapM (eval [] env rec defs) es) eval stk env rec defs (LForce e) = do e' <- eval [] env rec defs e case e' of LLazyExp forced -> return $ unload stk forced LLazyApp n es -> return $ unload stk (LApp False (LV n) es) _ -> return (unload stk (LForce e')) eval stk env rec defs (LLazyExp e) = unload stk <$> LLazyExp <$> eval [] env rec defs e -- Special case for io_bind, because it needs to keep executing the first -- action, and is worth inlining to avoid the thunk eval [] env rec defs (LApp t (LV n) [_, _, _, act, (LLam [arg] k)]) | n == sUN "io_bind" = do w <- nextN let env' = (w, LV w) : env act' <- eval [] env' rec defs (LApp False act [LV w]) argn <- nextN k' <- eval [] ((arg, LV argn) : env') rec defs (LApp False k [LV w]) return $ LLam [w] (LLet argn act' k') eval (world : stk) env rec defs (LApp t (LV n) [_, _, _, act, (LLam [arg] k)]) | n == sUN "io_bind" = do act' <- eval [] env rec defs (LApp False act [world]) argn <- nextN k' <- eval stk ((arg, LV argn) : env) rec defs (LApp False k [world]) return $ LLet argn act' k' eval stk env rec defs (LApp t f es) = do es' <- mapM (eval [] env rec defs) es eval (es' ++ stk) env rec defs f eval stk env rec defs (LLet n val sc) = do n' <- nextN LLet n' <$> eval [] env rec defs val <*> eval stk ((n, LV n') : env) rec defs sc eval stk env rec defs (LProj exp i) = unload stk <$> (LProj <$> eval [] env rec defs exp <*> return i) eval stk env rec defs (LCon loc i n es) = unload stk <$> (LCon loc i n <$> mapM (eval [] env rec defs) es) eval stk env rec defs (LCase ty e alts) = do alts' <- mapM (evalAlt stk env rec defs) alts e' <- eval [] env rec defs e -- If they're all lambdas, bind the lambda at the top let prefix = getLams (map getRHS alts') case prefix of [] -> return $ conOpt $ LCase ty e' (replaceInAlts e' alts') args -> do alts_red <- mapM (dropArgs args) alts' return $ LLam args (conOpt (LCase ty e' (replaceInAlts e' alts_red))) eval stk env rec defs (LOp f es) = unload stk <$> LOp f <$> mapM (eval [] env rec defs) es eval stk env rec defs (LForeign t s args) = unload stk <$> LForeign t s <$> mapM (\(t, e) -> do e' <- eval [] env rec defs e return (t, e')) args -- save the interesting cases for the end: -- lambdas, and names to reduce eval stk env rec defs (LLam args sc) | (env', args', stk') <- takeStk env args stk = case args' of [] -> eval stk' env' rec defs sc as -> do ns' <- mapM (\n -> do n' <- nextN return (n, n')) args' unload stk' <$> LLam (map snd ns') <$> eval [] (map (\ (n, n') -> (n, LV n')) ns' ++ env') rec defs sc eval stk env rec defs var@(LV n) = case lookup n env of Just t | t /= LV n && n `notElem` rec -> eval stk env (n : rec) defs t | otherwise -> return (unload stk t) Nothing | n `notElem` rec, Just (LFun opts _ args body) <- lookupCtxtExact n defs, Inline `elem` opts -> apply stk env (n : rec) defs var args body | Just (LConstructor n t a) <- lookupCtxtExact n defs -> return (LCon Nothing t n stk) | otherwise -> return (unload stk var) eval stk env rec defs t = return (unload stk t) evalAlt stk env rec defs (LConCase i n es rhs) = do ns' <- mapM (\n -> do n' <- nextN return (n, n')) es LConCase i n (map snd ns') <$> eval stk (map (\ (n, n') -> (n, LV n')) ns' ++ env) rec defs rhs evalAlt stk env rec defs (LConstCase c e) = LConstCase c <$> eval stk env rec defs e evalAlt stk env rec defs (LDefaultCase e) = LDefaultCase <$> eval stk env rec defs e apply :: [LExp] -> [(Name, LExp)] -> [Name] -> LDefs -> LExp -> [Name] -> LExp -> State Int LExp apply stk env rec defs var args body = eval stk env rec defs (LLam args body) dropArgs :: [Name] -> LAlt -> State Int LAlt dropArgs as (LConCase i n es (LLam args rhs)) = do let old = take (length as) args rhs' <- eval [] (zipWith (\ o n -> (o, LV n)) old as) [] emptyContext rhs return (LConCase i n es rhs') dropArgs as (LConstCase c (LLam args rhs)) = do let old = take (length as) args rhs' <- eval [] (zipWith (\ o n -> (o, LV n)) old as) [] emptyContext rhs return (LConstCase c rhs') dropArgs as (LDefaultCase (LLam args rhs)) = do let old = take (length as) args rhs' <- eval [] (zipWith (\ o n -> (o, LV n)) old as) [] emptyContext rhs return (LDefaultCase rhs') caseFloat :: LExp -> LExp caseFloat (LApp tc e es) = LApp tc (caseFloat e) (map caseFloat es) caseFloat (LLazyExp e) = LLazyExp (caseFloat e) caseFloat (LForce e) = LForce (caseFloat e) caseFloat (LCon up i n es) = LCon up i n (map caseFloat es) caseFloat (LOp f es) = LOp f (map caseFloat es) caseFloat (LLam ns sc) = LLam ns (caseFloat sc) caseFloat (LLet v val sc) = LLet v (caseFloat val) (caseFloat sc) caseFloat (LCase _ (LCase ct exp alts) alts') | all conRHS alts || length alts == 1 = conOpt $ replaceInCase (LCase ct (caseFloat exp) (map (updateWith alts') alts)) where conRHS (LConCase _ _ _ (LCon _ _ _ _)) = True conRHS (LConstCase _ (LCon _ _ _ _)) = True conRHS (LDefaultCase (LCon _ _ _ _)) = True conRHS _ = False updateWith alts (LConCase i n es rhs) = LConCase i n es (caseFloat (conOpt (LCase Shared (caseFloat rhs) alts))) updateWith alts (LConstCase c rhs) = LConstCase c (caseFloat (conOpt (LCase Shared (caseFloat rhs) alts))) updateWith alts (LDefaultCase rhs) = LDefaultCase (caseFloat (conOpt (LCase Shared (caseFloat rhs) alts))) caseFloat (LCase ct exp alts') = conOpt $ replaceInCase (LCase ct (caseFloat exp) (map cfAlt alts')) where cfAlt (LConCase i n es rhs) = LConCase i n es (caseFloat rhs) cfAlt (LConstCase c rhs) = LConstCase c (caseFloat rhs) cfAlt (LDefaultCase rhs) = LDefaultCase (caseFloat rhs) caseFloat exp = exp -- Case of constructor conOpt :: LExp -> LExp conOpt (LCase ct (LCon _ t n args) alts) = pickAlt n args alts where pickAlt n args (LConCase i n' es rhs : as) | n == n' = substAll (zip es args) rhs pickAlt _ _ (LDefaultCase rhs : as) = rhs pickAlt n args (_ : as) = pickAlt n args as pickAlt n args [] = error "Can't happen pickAlt - impossible case found" substAll [] rhs = rhs substAll ((n, tm) : ss) rhs = lsubst n tm (substAll ss rhs) conOpt tm = tm replaceInCase :: LExp -> LExp replaceInCase (LCase ty e alts) = LCase ty e (replaceInAlts e alts) replaceInCase exp = exp replaceInAlts :: LExp -> [LAlt] -> [LAlt] replaceInAlts exp alts = dropDups $ concatMap (replaceInAlt exp) alts -- Drop overlapping case (arising from case merging of overlapping -- patterns) dropDups (alt@(LConCase _ i n ns) : alts) = alt : dropDups (filter (notTag i) alts) where notTag i (LConCase _ j n ns) = i /= j notTag _ _ = True dropDups (c : alts) = c : dropDups alts dropDups [] = [] replaceInAlt :: LExp -> LAlt -> [LAlt] -- In an alternative, if the case appears on the right hand side, replace -- it with the given expression, to preserve sharing replaceInAlt exp@(LV _) (LConCase i con args rhs) = [LConCase i con args $ replaceExp (LCon Nothing i con (map LV args)) exp rhs] -- if a default case inspects the same variable as the case it's in, -- remove the inspection and replace with the alternatives -- (i.e. merge the inner case block) replaceInAlt exp@(LV var) (LDefaultCase (LCase ty (LV var') alts)) | var == var' = alts replaceInAlt exp a = [a] replaceExp :: LExp -> LExp -> LExp -> LExp replaceExp (LCon _ t n args) new (LCon _ t' n' args') | n == n' && args == args' = new replaceExp (LCon _ t n args) new (LApp _ (LV n') args') | n == n' && args == args' = new replaceExp old new tm = tm -- dropArgs as (LConstCase c rhs) = LConstCase c (dropRHS as rhs) -- dropArgs as (LDefaultCase rhs) = LDefaultCase (dropRHS as rhs) getRHS (LConCase i n es rhs) = rhs getRHS (LConstCase _ rhs) = rhs getRHS (LDefaultCase rhs) = rhs getLams [] = [] getLams (LLam args tm : cs) = getLamPrefix args cs getLams _ = [] getLamPrefix as [] = as getLamPrefix as (LLam args tm : cs) | length args < length as = getLamPrefix args cs | otherwise = getLamPrefix as cs getLamPrefix as (_ : cs) = [] -- eta contract ('\x -> f x' can just be compiled as 'f' when f is local) eta :: LExp -> LExp eta (LApp tc a es) = LApp tc (eta a) (map eta es) eta (LLazyApp n es) = LLazyApp n (map eta es) eta (LLazyExp e) = LLazyExp (eta e) eta (LForce e) = LForce (eta e) eta (LLet n val sc) = LLet n (eta val) (eta sc) eta (LLam args (LApp tc f args')) | args' == map LV args = eta f eta (LLam args e) = LLam args (eta e) eta (LProj e i) = LProj (eta e) i eta (LCon a t n es) = LCon a t n (map eta es) eta (LCase ct e alts) = LCase ct (eta e) (map (fmap eta) alts) eta (LOp f es) = LOp f (map eta es) eta tm = tm