{-# language OverloadedStrings #-} {-# language DataKinds #-} {-# language BangPatterns #-} module OptimizeTailRecursion where import Control.Applicative ((<|>)) import Control.Lens.Cons (_last, _init) import Control.Lens.Fold ((^..), (^?), (^?!), allOf, anyOf, folded, foldrOf) import Control.Lens.Getter ((^.), to) import Control.Lens.Plated (cosmos, transform, transformOn) import Control.Lens.Prism (_Just) import Control.Lens.Review ((#)) import Control.Lens.Setter ((%~), (.~)) import Control.Lens.Tuple (_2, _3) import Data.Foldable (toList) import Data.Function ((&)) import Data.Semigroup ((<>)) import Language.Python.Optics import Language.Python.DSL import Language.Python.Syntax.Expr (Expr (..), _Exprs, argExpr, paramName) import Language.Python.Syntax.Statement (CompoundStatement (..), Statement (..), SmallStatement (..), SimpleStatement (..), _Statements) optimizeTailRecursion :: Raw Statement -> Maybe (Raw Statement) optimizeTailRecursion st = do function <- st ^? _Fundef let functionBody = function ^. body_ bodyLast <- lastStatement functionBody let functionName = function ^. fdName.identValue bodyInit = functionBody ^?! _init paramNames = function ^.. fdParameters.folded.paramName.identValue if not $ hasTC functionName bodyLast then Nothing else Just $ _Fundef # (function & body_ .~ (zipWith (\a b -> line_ (var_ (a <> "__tr") .= var_ b)) paramNames paramNames <> [ line_ ("__res__tr" .= none_) , line_ . while_ true_ . transformOn (traverse._Exprs) (renameIn paramNames "__tr") $ bodyInit <> looped functionName paramNames bodyLast , line_ $ return_ "__res__tr" ])) where lastStatement :: [Raw Line] -> Maybe (Raw Statement) lastStatement = go Nothing where go !res [] = res go !res (a:as) = go (a ^? _Statements <|> res) as isTailCall :: String -> Raw Expr -> Bool isTailCall name e | anyOf (cosmos._Call.callFunction._Ident.identValue) (== name) e = (e ^? _Call.callFunction._Ident.identValue) == Just name | otherwise = False hasTC :: String -> Raw Statement -> Bool hasTC name st = case st of CompoundStatement (If _ _ _ _ sts [] sts') -> allOf _last (hasTC name) (sts ^.. _Statements) || allOf _last (hasTC name) (sts' ^.. _Just._3._Statements) SmallStatement _ (MkSmallStatement s ss _ _ _) -> case last (s : fmap (^. _2) ss) of Return _ _ (Just e) -> isTailCall name e -- Return _ _ Nothing -> True Expr _ e -> isTailCall name e _ -> False _ -> False renameIn :: [String] -> String -> Raw Expr -> Raw Expr renameIn params suffix = transform (_Ident.identValue %~ (\a -> if a `elem` params then a <> suffix else a)) looped :: String -> [String] -> Raw Statement -> [Raw Line] looped name params st | Just ifSt <- st ^? _If , hasTC name st = let ifBodyLines = toList $ ifSt ^. body_ in case ifSt ^? to getElse._Just.body_ of Nothing -> [ line_ $ if_ (ifSt ^. ifCond) ((ifBodyLines ^?! _init) <> looped name params (ifBodyLines ^?! _last._Statements)) ] Just sts'' -> [ line_ $ if_ (ifSt ^. ifCond) ((ifSt ^?! body_.to toList._init) <> looped name params (ifBodyLines ^?! _last._Statements)) & else_ ((toList sts'' ^?! _init) <> looped name params (toList sts'' ^?! _last._Statements)) ] | otherwise = case st of CompoundStatement{} -> [line_ st] SmallStatement idnts (MkSmallStatement s ss sc cmt nl) -> let initExps = foldr (\_ _ -> init ss) [] ss lastExp = foldrOf (folded._2) (\_ _ -> last ss ^. _2) s ss newSts = case initExps of [] -> [] first : rest -> [ line_ $ SmallStatement idnts (MkSmallStatement (first ^. _2) rest sc cmt nl) ] in case lastExp of Return _ _ e -> case e ^? _Just._Call of Just call | Just name' <- call ^? callFunction._Ident.identValue , name' == name -> newSts <> fmap (\a -> line_ (var_ (a <> "__tr__old") .= var_ (a <> "__tr"))) params <> zipWith (\a b -> line_ (var_ (a <> "__tr") .= b)) params (transformOn traverse (renameIn params "__tr__old") (call ^.. callArguments.folded.folded.argExpr)) _ -> newSts <> maybe [] (\e' -> [ line_ ("__res__tr" .= e') ]) e <> [ line_ break_ ] Expr _ e | isTailCall name e -> newSts <> [line_ pass_] _ -> [line_ st]