{-#LANGUAGE TupleSections #-} module Text.HPaco.Optimizer -- ( optimize -- ) where import Text.HPaco.AST.AST import Text.HPaco.AST.Expression import Text.HPaco.AST.Statement import qualified Text.HPaco.Writers.Run as Run import Data.Variant hiding (lookup, elem) import Control.Monad.State import Control.Applicative import System.IO.Unsafe (unsafePerformIO) import Text.HPaco.AST.Identifier (Identifier) import Data.Maybe import Control.Arrow ( (***) ) import qualified Control.Arrow as Arrow optimize :: AST -> AST optimize = expandDefs . optimizeASTStatements . optimizeASTDefs -- Statically expand macros -- expandDefs :: AST -> AST expandDefs ast = ast { astRootStatement = goStatement [] $ astRootStatement ast , astDefs = [ (i, goStatement [i] s) | (i, s) <- astDefs ast ] } where goStatement :: [Identifier] -> Statement -> Statement goStatement identPath stmt = fromMaybe stmt $ goRaw identPath stmt goRaw :: [Identifier] -> Statement -> Maybe Statement goRaw identPath (CallStatement ident) | ident `elem` identPath = Just $ CallStatement ident goRaw identPath (CallStatement ident) = lookup ident (astDefs ast) >>= goRaw (ident:identPath) goRaw identPath stmt = case stmt of PrintStatement _ -> Just stmt NullStatement -> Just stmt SourcePositionStatement _ _ -> Just stmt otherwise -> Nothing optimizeASTDefs :: AST -> AST optimizeASTDefs (AST { astRootStatement = rs, astDeps = deps, astDefs = defs }) = AST { astRootStatement = rs, astDeps = deps, astDefs = [ (i, optimizeStatement s) | (i, s) <- defs ] } optimizeASTStatements :: AST -> AST optimizeASTStatements ast = ast { astRootStatement = optimizeStatement . astRootStatement $ ast , astDefs = map (Arrow.second optimizeStatement) . astDefs $ ast } optimizeStatement :: Statement -> Statement -- Reduce constants optimizeStatement (PrintStatement (IntLiteral i)) = PrintStatement . StringLiteral . show $ i optimizeStatement (PrintStatement (FloatLiteral i)) = PrintStatement . StringLiteral . show $ i optimizeStatement (PrintStatement (StringLiteral [])) = NullStatement optimizeStatement (PrintStatement e) = let e' = optimizeExpression e in if e == e' then PrintStatement e else optimizeStatement $ PrintStatement e' -- Turn SourcePositionStatements into Null statements so that they are removed optimizeStatement (SourcePositionStatement _ _) = NullStatement optimizeStatement (StatementSequence xs) = let xs' = fusePrints $ filter (/= NullStatement) (map optimizeStatement xs) in case xs' of -- flatten nested statement sequences StatementSequence ss:rem -> optimizeStatement . StatementSequence $ ss ++ rem -- reduce statement sequences with only one statement in them down -- to that single statement s:[] -> s -- remove empty statement sequences [] -> NullStatement otherwise -> StatementSequence xs' optimizeStatement (IfStatement cond true false) = let true' = optimizeStatement true false' = optimizeStatement false cond' = optimizeExpression cond in case cond' of BooleanLiteral True -> true' BooleanLiteral False -> false' otherwise -> IfStatement cond' true' false' optimizeStatement (LetStatement id e stmt) = let e' = optimizeExpression e stmt' = optimizeStatement stmt in LetStatement id e' stmt' optimizeStatement (ForStatement iter id e stmt) = let e' = optimizeExpression e stmt' = optimizeStatement stmt in ForStatement iter id e' stmt' optimizeStatement (SwitchStatement e branches) = let e' = optimizeExpression e branches' = map (optimizeExpression *** optimizeStatement) branches in SwitchStatement e' branches' optimizeStatement s = s fusePrints :: [Statement] -> [Statement] fusePrints (PrintStatement (StringLiteral lhs):PrintStatement (StringLiteral rhs):rem) = fusePrints $ PrintStatement (StringLiteral $ lhs ++ rhs):fusePrints rem fusePrints (s:rem) = s:fusePrints rem fusePrints [] = [] optimizeExpression :: Expression -> Expression optimizeExpression e = if isConst e then evaluateConstExpression e else e isConst :: Expression -> Bool isConst (StringLiteral _) = True isConst (IntLiteral _) = True isConst (FloatLiteral _) = True isConst (BooleanLiteral _) = True isConst (ListExpression xs) = all isConst xs isConst (AListExpression xs) = all (\(k,v) -> isConst k && isConst v) xs isConst (EscapeExpression _ e) = isConst e isConst (UnaryExpression _ e) = isConst e isConst (BinaryExpression _ a b) = isConst a && isConst b isConst e = False evaluateConstExpression :: Expression -> Expression evaluateConstExpression e = let rs = Run.RunState { Run.rsScope = Null, Run.rsAST = defAST, Run.rsOptions = Run.defaultOptions } v = unsafePerformIO $ evalStateT (Run.runExpression e) rs in fromVariant v fromVariant :: Variant -> Expression fromVariant (String s) = StringLiteral s fromVariant (Integer i) = IntLiteral i fromVariant (Double d) = FloatLiteral d fromVariant (Bool b) = BooleanLiteral b fromVariant (List xs) = ListExpression $ map fromVariant xs fromVariant (AList xs) = AListExpression $ map (fromVariant *** fromVariant) xs