{-| Module : IRTS.Defunctionalise Description : Defunctionalise Idris' IR. License : BSD3 Maintainer : The Idris Community. To defunctionalise: 1. Create a data constructor for each function 2. Create a data constructor for each underapplication of a function 3. Convert underapplications to their corresponding constructors 4. Create an EVAL function which calls the appropriate function for data constructors created as part of step 1 5. Create an APPLY function which adds an argument to each underapplication (or calls APPLY again for an exact application) 6. Wrap overapplications in chains of APPLY 7. Wrap unknown applications (i.e. applications of local variables) in chains of APPLY 8. Add explicit EVAL to case, primitives, and foreign calls -} {-# LANGUAGE FlexibleContexts, PatternGuards #-} module IRTS.Defunctionalise(module IRTS.Defunctionalise , module IRTS.Lang ) where import Idris.Core.CaseTree import Idris.Core.TT import IRTS.Lang import Control.Monad import Control.Monad.State import Data.List import Data.Maybe data DExp = DV Name | DApp Bool Name [DExp] -- True = tail call | DLet Name DExp DExp -- name just for pretty printing | DUpdate Name DExp -- eval expression, then update var with it | DProj DExp Int | DC (Maybe Name) Int Name [DExp] | DCase CaseType DExp [DAlt] | DChkCase DExp [DAlt] -- a case where the type is unknown (for EVAL/APPLY) | DConst Const | DForeign FDesc FDesc [(FDesc, DExp)] | DOp PrimFn [DExp] | DNothing -- erased value, can be compiled to anything since it'll never -- be inspected | DError String deriving Eq data DAlt = DConCase Int Name [Name] DExp | DConstCase Const DExp | DDefaultCase DExp deriving (Show, Eq) data DDecl = DFun Name [Name] DExp -- name, arg names, definition | DConstructor Name Int Int -- constructor name, tag, arity deriving (Show, Eq) type DDefs = Ctxt DDecl defunctionalise :: Int -> LDefs -> DDefs defunctionalise nexttag defs = let all = toAlist defs -- sort newcons so that EVAL and APPLY cons get sequential tags (allD, (enames, anames)) = runState (mapM (addApps defs) all) ([], []) anames' = sort (nub anames) enames' = nub enames newecons = sortBy conord $ concatMap (toCons enames') (getFn all) newacons = sortBy conord $ concatMap (toConsA anames') (getFn all) eval = mkEval newecons app = mkApply newacons app2 = mkApply2 newacons condecls = declare nexttag (newecons ++ newacons) in addAlist (eval : app : app2 : condecls ++ allD) emptyContext where conord (n, _, _) (n', _, _) = compare n n' getFn :: [(Name, LDecl)] -> [(Name, Int)] getFn xs = mapMaybe fnData xs where fnData (n, LFun _ _ args _) = Just (n, length args) fnData _ = Nothing addApps :: LDefs -> (Name, LDecl) -> State ([Name], [(Name, Int)]) (Name, DDecl) addApps defs o@(n, LConstructor _ t a) = return (n, DConstructor n t a) addApps defs (n, LFun _ _ args e) = do e' <- aa args e return (n, DFun n args e') where aa :: [Name] -> LExp -> State ([Name], [(Name, Int)]) DExp aa env (LV n) | n `elem` env = return $ DV n | otherwise = aa env (LApp False (LV n) []) aa env (LApp tc (LV n) args) = do args' <- mapM (aa env) args case lookupCtxtExact n defs of Just (LConstructor _ i ar) -> return $ DApp tc n args' Just (LFun _ _ as _) -> let arity = length as in fixApply tc n args' arity Nothing -> return $ chainAPPLY (DV n) args' aa env (LLazyApp n args) = do args' <- mapM (aa env) args case lookupCtxtExact n defs of Just (LConstructor _ i ar) -> return $ DApp False n args' Just (LFun _ _ as _) -> let arity = length as in fixLazyApply n args' arity Nothing -> return $ chainAPPLY (DV n) args' aa env (LForce (LLazyApp n args)) = aa env (LApp False (LV n) args) aa env (LForce e) = liftM eEVAL (aa env e) aa env (LLet n v sc) = liftM2 (DLet n) (aa env v) (aa (n : env) sc) aa env (LCon loc i n args) = liftM (DC loc i n) (mapM (aa env) args) aa env (LProj t@(LV n) i) | n `elem` env = do t' <- aa env t return $ DProj (DUpdate n t') i aa env (LProj t i) = do t' <- aa env t return $ DProj t' i aa env (LCase up e alts) = do e' <- aa env e alts' <- mapM (aaAlt env) alts return $ DCase up e' alts' aa env (LConst c) = return $ DConst c aa env (LForeign t n args) = do args' <- mapM (aaF env) args return $ DForeign t n args' aa env (LOp LFork args) = liftM (DOp LFork) (mapM (aa env) args) aa env (LOp f args) = do args' <- mapM (aa env) args return $ DOp f args' aa env LNothing = return DNothing aa env (LError e) = return $ DError e aaF env (t, e) = do e' <- aa env e return (t, e') aaAlt env (LConCase i n args e) = liftM (DConCase i n args) (aa (args ++ env) e) aaAlt env (LConstCase c e) = liftM (DConstCase c) (aa env e) aaAlt env (LDefaultCase e) = liftM DDefaultCase (aa env e) fixApply tc n args ar | length args == ar = return $ DApp tc n args | length args < ar = do (ens, ans) <- get let alln = map (\x -> (n, x)) [length args .. ar] put (ens, alln ++ ans) return $ DApp tc (mkUnderCon n (ar - length args)) args | length args > ar = return $ chainAPPLY (DApp tc n (take ar args)) (drop ar args) fixLazyApply n args ar | length args == ar = do (ens, ans) <- get put (n : ens, ans) return $ DApp False (mkFnCon n) args | length args < ar = do (ens, ans) <- get let alln = map (\x -> (n, x)) [length args .. ar] put (ens, alln ++ ans) return $ DApp False (mkUnderCon n (ar - length args)) args | length args > ar = return $ chainAPPLY (DApp False n (take ar args)) (drop ar args) chainAPPLY f [] = f -- chainAPPLY f (a : b : as) -- = chainAPPLY (DApp False (sMN 0 "APPLY2") [f, a, b]) as chainAPPLY f (a : as) = chainAPPLY (DApp False (sMN 0 "APPLY") [f, a]) as eEVAL x = DApp False (sMN 0 "EVAL") [x] data EvalApply a = EvalCase (Name -> a) | ApplyCase a | Apply2Case a -- For a function name, generate a list of -- data constuctors, and whether to handle them in EVAL or APPLY toCons :: [Name] -> (Name, Int) -> [(Name, Int, EvalApply DAlt)] toCons ns (n, i) | n `elem` ns = (mkFnCon n, i, EvalCase (\tlarg -> (DConCase (-1) (mkFnCon n) (take i (genArgs 0)) (dupdate tlarg (DApp False n (map DV (take i (genArgs 0)))))))) : [] -- mkApplyCase n 0 i | otherwise = [] where dupdate tlarg x = DUpdate tlarg x toConsA :: [(Name, Int)] -> (Name, Int) -> [(Name, Int, EvalApply DAlt)] toConsA ns (n, i) | Just ar <- lookup n ns -- = (mkFnCon n, i, -- EvalCase (\tlarg -> -- (DConCase (-1) (mkFnCon n) (take i (genArgs 0)) -- (dupdate tlarg -- (DApp False n (map DV (take i (genArgs 0)))))))) = mkApplyCase n ar i | otherwise = [] mkApplyCase fname n ar | n == ar = [] mkApplyCase fname n ar = let nm = mkUnderCon fname (ar - n) in (nm, n, ApplyCase (DConCase (-1) nm (take n (genArgs 0)) (DApp False (mkUnderCon fname (ar - (n + 1))) (map DV (take n (genArgs 0) ++ [sMN 0 "arg"]))))) : if (ar - (n + 2) >=0 ) then (nm, n, Apply2Case (DConCase (-1) nm (take n (genArgs 0)) (DApp False (mkUnderCon fname (ar - (n + 2))) (map DV (take n (genArgs 0) ++ [sMN 0 "arg0", sMN 0 "arg1"]))))) : mkApplyCase fname (n + 1) ar else mkApplyCase fname (n + 1) ar mkEval :: [(Name, Int, EvalApply DAlt)] -> (Name, DDecl) mkEval xs = (sMN 0 "EVAL", DFun (sMN 0 "EVAL") [sMN 0 "arg"] (mkBigCase (sMN 0 "EVAL") 256 (DV (sMN 0 "arg")) (mapMaybe evalCase xs ++ [DDefaultCase (DV (sMN 0 "arg"))]))) where evalCase (n, t, EvalCase x) = Just (x (sMN 0 "arg")) evalCase _ = Nothing mkApply :: [(Name, Int, EvalApply DAlt)] -> (Name, DDecl) mkApply xs = (sMN 0 "APPLY", DFun (sMN 0 "APPLY") [sMN 0 "fn", sMN 0 "arg"] (case mapMaybe applyCase xs of [] -> DNothing cases -> mkBigCase (sMN 0 "APPLY") 256 (DV (sMN 0 "fn")) (cases ++ [DDefaultCase DNothing]))) where applyCase (n, t, ApplyCase x) = Just x applyCase _ = Nothing mkApply2 :: [(Name, Int, EvalApply DAlt)] -> (Name, DDecl) mkApply2 xs = (sMN 0 "APPLY2", DFun (sMN 0 "APPLY2") [sMN 0 "fn", sMN 0 "arg0", sMN 0 "arg1"] (case mapMaybe applyCase xs of [] -> DNothing cases -> mkBigCase (sMN 0 "APPLY") 256 (DV (sMN 0 "fn")) (cases ++ [DDefaultCase (DApp False (sMN 0 "APPLY") [DApp False (sMN 0 "APPLY") [DV (sMN 0 "fn"), DV (sMN 0 "arg0")], DV (sMN 0 "arg1")]) ]))) where applyCase (n, t, Apply2Case x) = Just x applyCase _ = Nothing declare :: Int -> [(Name, Int, EvalApply DAlt)] -> [(Name, DDecl)] declare t xs = dec' t xs [] where dec' t [] acc = reverse acc dec' t ((n, ar, _) : xs) acc = dec' (t + 1) xs ((n, DConstructor n t ar) : acc) genArgs i = sMN i "P_c" : genArgs (i + 1) mkFnCon n = sMN 0 ("P_" ++ show n) mkUnderCon n 0 = n mkUnderCon n missing = sMN missing ("U_" ++ show n) instance Show DExp where show e = show' [] e where show' env (DV n) = show n show' env (DApp _ e args) = show e ++ "(" ++ showSep ", " (map (show' env) args) ++")" show' env (DLet n v e) = "let " ++ show n ++ " = " ++ show' env v ++ " in " ++ show' (env ++ [show n]) e show' env (DUpdate n e) = "!update " ++ show n ++ "(" ++ show' env e ++ ")" show' env (DC loc i n args) = atloc loc ++ "CON " ++ show n ++ "(" ++ showSep ", " (map (show' env) args) ++ ")" where atloc Nothing = "" atloc (Just l) = "@" ++ show (LV l) ++ ":" show' env (DProj t i) = show t ++ "!" ++ show i show' env (DCase up e alts) = "case" ++ update ++ show' env e ++ " of {\n\t" ++ showSep "\n\t| " (map (showAlt env) alts) where update = case up of Shared -> " " Updatable -> "! " show' env (DChkCase e alts) = "case' " ++ show' env e ++ " of {\n\t" ++ showSep "\n\t| " (map (showAlt env) alts) show' env (DConst c) = show c show' env (DForeign ty n args) = "foreign " ++ show n ++ "(" ++ showSep ", " (map (show' env) (map snd args)) ++ ")" show' env (DOp f args) = show f ++ "(" ++ showSep ", " (map (show' env) args) ++ ")" show' env (DError str) = "error " ++ show str show' env DNothing = "____" showAlt env (DConCase _ n args e) = show n ++ "(" ++ showSep ", " (map show args) ++ ") => " ++ show' env e showAlt env (DConstCase c e) = show c ++ " => " ++ show' env e showAlt env (DDefaultCase e) = "_ => " ++ show' env e -- | Divide up a large case expression so that each has a maximum of -- 'max' branches mkBigCase cn max arg branches | length branches <= max = DChkCase arg branches | otherwise = DChkCase arg branches groupsOf :: Int -> [DAlt] -> [[DAlt]] groupsOf x [] = [] groupsOf x xs = let (batch, rest) = span (tagLT (x + tagHead xs)) xs in batch : groupsOf x rest where tagHead (DConstCase (I i) _ : _) = i tagHead (DConCase t _ _ _ : _) = t tagHead (DDefaultCase _ : _) = -1 -- must be the end tagLT i (DConstCase (I j) _) = i < j tagLT i (DConCase j _ _ _) = i < j tagLT i (DDefaultCase _) = False dumpDefuns :: DDefs -> String dumpDefuns ds = showSep "\n" $ map showDef (toAlist ds) where showDef (x, DFun fn args exp) = show fn ++ "(" ++ showSep ", " (map show args) ++ ") = \n\t" ++ show exp ++ "\n" showDef (x, DConstructor n t a) = "Constructor " ++ show n ++ " " ++ show t