module Text.HPaco.Optimizer
            ( optimize
            )
where

import Text.HPaco.AST.AST
import Text.HPaco.AST.Expression
import Text.HPaco.AST.Statement
import qualified Text.HPaco.Writers.Run as Run
import Data.Variant
import Control.Monad.State
import System.IO.Unsafe (unsafePerformIO)

optimize :: AST -> AST
optimize ast = ast 
                { astRootStatement = optimizeStatement . astRootStatement $ ast 
                , astDefs = map (\(i, s) -> (i, optimizeStatement s)) . astDefs $ ast
                }

optimizeStatement :: Statement -> Statement

-- Reduce constants
optimizeStatement (PrintStatement (IntLiteral i)) = PrintStatement . StringLiteral . show $ i
optimizeStatement (PrintStatement (FloatLiteral i)) = PrintStatement . StringLiteral . show $ i
optimizeStatement (PrintStatement (StringLiteral [])) = NullStatement
optimizeStatement (PrintStatement e) =
    let e' = optimizeExpression e
    in if e == e'
            then PrintStatement e
            else optimizeStatement $ PrintStatement e'

-- Turn SourcePositionStatements into Null statements so that they are removed
optimizeStatement (SourcePositionStatement _ _) = NullStatement

optimizeStatement (StatementSequence xs) =
    let xs' = fusePrints $ filter (/= NullStatement) (map optimizeStatement xs)
    in case xs' of
        -- flatten nested statement sequences
        (StatementSequence ss):rem -> optimizeStatement . StatementSequence $ ss ++ rem
        -- reduce statement sequences with only one statement in them down
        -- to that single statement
        s:[] -> s
        -- remove empty statement sequences
        [] -> NullStatement
        otherwise -> StatementSequence xs'
optimizeStatement (IfStatement cond true false) =
    let true' = optimizeStatement true
        false' = optimizeStatement false
        cond' = optimizeExpression cond
    in case cond' of
            BooleanLiteral True -> true'
            BooleanLiteral False -> false'
            otherwise -> IfStatement cond' true' false'

optimizeStatement (LetStatement id e stmt) =
    let e' = optimizeExpression e
        stmt' = optimizeStatement stmt
    in LetStatement id e' stmt'

optimizeStatement (ForStatement iter id e stmt) =
    let e' = optimizeExpression e
        stmt' = optimizeStatement stmt
    in ForStatement iter id e' stmt'

optimizeStatement (SwitchStatement e branches) =
    let e' = optimizeExpression e
        branches' = map (\(k,v) -> (optimizeExpression k, optimizeStatement v)) branches
    in SwitchStatement e' branches'

optimizeStatement s = s


fusePrints :: [Statement] -> [Statement]
fusePrints ((PrintStatement (StringLiteral lhs)):(PrintStatement (StringLiteral rhs)):rem) =
    fusePrints $ PrintStatement (StringLiteral $ lhs ++ rhs):fusePrints rem
fusePrints (s:rem) = s:fusePrints rem
fusePrints [] = []

optimizeExpression :: Expression -> Expression
optimizeExpression e =
    if isConst e
        then evaluateConstExpression e
        else e

isConst :: Expression -> Bool
isConst (StringLiteral _) = True
isConst (IntLiteral _) = True
isConst (FloatLiteral _) = True
isConst (BooleanLiteral _) = True
isConst (ListExpression xs) = all isConst xs
isConst (AListExpression xs) = all (\(k,v) -> isConst k && isConst v) xs
isConst (EscapeExpression _ e) = isConst e
isConst (UnaryExpression _ e) = isConst e
isConst (BinaryExpression _ a b) = isConst a && isConst b
isConst e = False

evaluateConstExpression :: Expression -> Expression
evaluateConstExpression e =
    let rs = Run.RunState { Run.rsScope = Null, Run.rsAST = defAST, Run.rsOptions = Run.defaultOptions }
        v = unsafePerformIO $ evalStateT (Run.runExpression e) rs
    in fromVariant v

fromVariant :: Variant -> Expression
fromVariant (String s) = StringLiteral s
fromVariant (Integer i) = IntLiteral i
fromVariant (Double d) = FloatLiteral d
fromVariant (Bool b) = BooleanLiteral b
fromVariant (List xs) = ListExpression $ map fromVariant xs
fromVariant (AList xs) = AListExpression $ map (\(k,v) -> (fromVariant k, fromVariant v)) xs