module Text.HPaco.Optimizer
( optimize
)
where
import Text.HPaco.AST.AST
import Text.HPaco.AST.Expression
import Text.HPaco.AST.Statement
import qualified Text.HPaco.Writers.Run as Run
import Data.Variant
import Control.Monad.State
import System.IO.Unsafe (unsafePerformIO)
optimize :: AST -> AST
optimize ast = ast
{ astRootStatement = optimizeStatement . astRootStatement $ ast
, astDefs = map (\(i, s) -> (i, optimizeStatement s)) . astDefs $ ast
}
optimizeStatement :: Statement -> Statement
optimizeStatement (PrintStatement (IntLiteral i)) = PrintStatement . StringLiteral . show $ i
optimizeStatement (PrintStatement (FloatLiteral i)) = PrintStatement . StringLiteral . show $ i
optimizeStatement (PrintStatement (StringLiteral [])) = NullStatement
optimizeStatement (PrintStatement e) =
let e' = optimizeExpression e
in if e == e'
then PrintStatement e
else optimizeStatement $ PrintStatement e'
optimizeStatement (SourcePositionStatement _ _) = NullStatement
optimizeStatement (StatementSequence xs) =
let xs' = fusePrints $ filter (/= NullStatement) (map optimizeStatement xs)
in case xs' of
(StatementSequence ss):rem -> optimizeStatement . StatementSequence $ ss ++ rem
s:[] -> s
[] -> NullStatement
otherwise -> StatementSequence xs'
optimizeStatement (IfStatement cond true false) =
let true' = optimizeStatement true
false' = optimizeStatement false
cond' = optimizeExpression cond
in case cond' of
BooleanLiteral True -> true'
BooleanLiteral False -> false'
otherwise -> IfStatement cond' true' false'
optimizeStatement (LetStatement id e stmt) =
let e' = optimizeExpression e
stmt' = optimizeStatement stmt
in LetStatement id e' stmt'
optimizeStatement (ForStatement iter id e stmt) =
let e' = optimizeExpression e
stmt' = optimizeStatement stmt
in ForStatement iter id e' stmt'
optimizeStatement (SwitchStatement e branches) =
let e' = optimizeExpression e
branches' = map (\(k,v) -> (optimizeExpression k, optimizeStatement v)) branches
in SwitchStatement e' branches'
optimizeStatement s = s
fusePrints :: [Statement] -> [Statement]
fusePrints ((PrintStatement (StringLiteral lhs)):(PrintStatement (StringLiteral rhs)):rem) =
fusePrints $ PrintStatement (StringLiteral $ lhs ++ rhs):fusePrints rem
fusePrints (s:rem) = s:fusePrints rem
fusePrints [] = []
optimizeExpression :: Expression -> Expression
optimizeExpression e =
if isConst e
then evaluateConstExpression e
else e
isConst :: Expression -> Bool
isConst (StringLiteral _) = True
isConst (IntLiteral _) = True
isConst (FloatLiteral _) = True
isConst (BooleanLiteral _) = True
isConst (ListExpression xs) = all isConst xs
isConst (AListExpression xs) = all (\(k,v) -> isConst k && isConst v) xs
isConst (EscapeExpression _ e) = isConst e
isConst (UnaryExpression _ e) = isConst e
isConst (BinaryExpression _ a b) = isConst a && isConst b
isConst e = False
evaluateConstExpression :: Expression -> Expression
evaluateConstExpression e =
let rs = Run.RunState { Run.rsScope = Null, Run.rsAST = defAST, Run.rsOptions = Run.defaultOptions }
v = unsafePerformIO $ evalStateT (Run.runExpression e) rs
in fromVariant v
fromVariant :: Variant -> Expression
fromVariant (String s) = StringLiteral s
fromVariant (Integer i) = IntLiteral i
fromVariant (Double d) = FloatLiteral d
fromVariant (Bool b) = BooleanLiteral b
fromVariant (List xs) = ListExpression $ map fromVariant xs
fromVariant (AList xs) = AListExpression $ map (\(k,v) -> (fromVariant k, fromVariant v)) xs