module Text.HPaco.Optimizer ( optimize ) where import Text.HPaco.AST.AST import Text.HPaco.AST.Expression import Text.HPaco.AST.Statement import qualified Text.HPaco.Writers.Run as Run import Data.Variant import Control.Monad.State import System.IO.Unsafe (unsafePerformIO) optimize :: AST -> AST optimize ast = ast { astRootStatement = optimizeStatement . astRootStatement $ ast , astDefs = map (\(i, s) -> (i, optimizeStatement s)) . astDefs $ ast } optimizeStatement :: Statement -> Statement -- Reduce constants optimizeStatement (PrintStatement (IntLiteral i)) = PrintStatement . StringLiteral . show $ i optimizeStatement (PrintStatement (FloatLiteral i)) = PrintStatement . StringLiteral . show $ i optimizeStatement (PrintStatement (StringLiteral [])) = NullStatement optimizeStatement (PrintStatement e) = let e' = optimizeExpression e in if e == e' then PrintStatement e else optimizeStatement $ PrintStatement e' -- Turn SourcePositionStatements into Null statements so that they are removed optimizeStatement (SourcePositionStatement _ _) = NullStatement optimizeStatement (StatementSequence xs) = let xs' = fusePrints $ filter (/= NullStatement) (map optimizeStatement xs) in case xs' of -- flatten nested statement sequences (StatementSequence ss):rem -> optimizeStatement . StatementSequence $ ss ++ rem -- reduce statement sequences with only one statement in them down -- to that single statement s:[] -> s -- remove empty statement sequences [] -> NullStatement otherwise -> StatementSequence xs' optimizeStatement (IfStatement cond true false) = let true' = optimizeStatement true false' = optimizeStatement false cond' = optimizeExpression cond in case cond' of BooleanLiteral True -> true' BooleanLiteral False -> false' otherwise -> IfStatement cond' true' false' optimizeStatement (LetStatement id e stmt) = let e' = optimizeExpression e stmt' = optimizeStatement stmt in LetStatement id e' stmt' optimizeStatement (ForStatement iter id e stmt) = let e' = optimizeExpression e stmt' = optimizeStatement stmt in ForStatement iter id e' stmt' optimizeStatement (SwitchStatement e branches) = let e' = optimizeExpression e branches' = map (\(k,v) -> (optimizeExpression k, optimizeStatement v)) branches in SwitchStatement e' branches' optimizeStatement s = s fusePrints :: [Statement] -> [Statement] fusePrints ((PrintStatement (StringLiteral lhs)):(PrintStatement (StringLiteral rhs)):rem) = fusePrints $ PrintStatement (StringLiteral $ lhs ++ rhs):fusePrints rem fusePrints (s:rem) = s:fusePrints rem fusePrints [] = [] optimizeExpression :: Expression -> Expression optimizeExpression e = if isConst e then evaluateConstExpression e else e isConst :: Expression -> Bool isConst (StringLiteral _) = True isConst (IntLiteral _) = True isConst (FloatLiteral _) = True isConst (BooleanLiteral _) = True isConst (ListExpression xs) = all isConst xs isConst (AListExpression xs) = all (\(k,v) -> isConst k && isConst v) xs isConst (EscapeExpression _ e) = isConst e isConst (UnaryExpression _ e) = isConst e isConst (BinaryExpression _ a b) = isConst a && isConst b isConst e = False evaluateConstExpression :: Expression -> Expression evaluateConstExpression e = let rs = Run.RunState { Run.rsScope = Null, Run.rsAST = defAST, Run.rsOptions = Run.defaultOptions } v = unsafePerformIO $ evalStateT (Run.runExpression e) rs in fromVariant v fromVariant :: Variant -> Expression fromVariant (String s) = StringLiteral s fromVariant (Integer i) = IntLiteral i fromVariant (Double d) = FloatLiteral d fromVariant (Bool b) = BooleanLiteral b fromVariant (List xs) = ListExpression $ map fromVariant xs fromVariant (AList xs) = AListExpression $ map (\(k,v) -> (fromVariant k, fromVariant v)) xs