{-#LANGUAGE TupleSections #-}
module Text.HPaco.Optimizer
--            ( optimize
--            )
where

import Text.HPaco.AST.AST
import Text.HPaco.AST.Expression
import Text.HPaco.AST.Statement
import qualified Text.HPaco.Writers.Run as Run
import Data.Variant hiding (lookup, elem)
import Control.Monad.State
import Control.Applicative
import System.IO.Unsafe (unsafePerformIO)
import Text.HPaco.AST.Identifier (Identifier)
import Data.Maybe
import Control.Arrow ( (***) )
import qualified Control.Arrow as Arrow

optimize :: AST -> AST
optimize = expandDefs . optimizeASTStatements . optimizeASTDefs

-- Statically expand macros
--

expandDefs :: AST -> AST
expandDefs ast =
    ast { astRootStatement = goStatement [] $ astRootStatement ast
        , astDefs = [ (i, goStatement [i] s) | (i, s) <- astDefs ast ]
        }
    where
        goStatement :: [Identifier] -> Statement -> Statement
        goStatement identPath stmt = fromMaybe stmt $ goRaw identPath stmt

        goRaw :: [Identifier] -> Statement -> Maybe Statement
        goRaw identPath (CallStatement ident) | ident `elem` identPath =
            Just $ CallStatement ident
        goRaw identPath (CallStatement ident) =
            lookup ident (astDefs ast) >>= goRaw (ident:identPath)
        goRaw identPath stmt =
            case stmt of
                PrintStatement _ -> Just stmt
                NullStatement -> Just stmt
                SourcePositionStatement _ _ -> Just stmt
                otherwise -> Nothing

optimizeASTDefs :: AST -> AST
optimizeASTDefs (AST { astRootStatement = rs, astDeps = deps, astDefs = defs }) =
    AST {
        astRootStatement = rs,
        astDeps = deps,
        astDefs = [ (i, optimizeStatement s) | (i, s) <- defs ]
    }

optimizeASTStatements :: AST -> AST
optimizeASTStatements ast =
    ast { astRootStatement = optimizeStatement . astRootStatement $ ast 
        , astDefs = map (Arrow.second optimizeStatement) . astDefs $ ast
        }

optimizeStatement :: Statement -> Statement

-- Reduce constants
optimizeStatement (PrintStatement (IntLiteral i)) = PrintStatement . StringLiteral . show $ i
optimizeStatement (PrintStatement (FloatLiteral i)) = PrintStatement . StringLiteral . show $ i
optimizeStatement (PrintStatement (StringLiteral [])) = NullStatement
optimizeStatement (PrintStatement e) =
    let e' = optimizeExpression e
    in if e == e'
            then PrintStatement e
            else optimizeStatement $ PrintStatement e'

-- Turn SourcePositionStatements into Null statements so that they are removed
optimizeStatement (SourcePositionStatement _ _) = NullStatement

optimizeStatement (StatementSequence xs) =
    let xs' = fusePrints $ filter (/= NullStatement) (map optimizeStatement xs)
    in case xs' of
        -- flatten nested statement sequences
        StatementSequence ss:rem -> optimizeStatement . StatementSequence $ ss ++ rem
        -- reduce statement sequences with only one statement in them down
        -- to that single statement
        s:[] -> s
        -- remove empty statement sequences
        [] -> NullStatement
        otherwise -> StatementSequence xs'
optimizeStatement (IfStatement cond true false) =
    let true' = optimizeStatement true
        false' = optimizeStatement false
        cond' = optimizeExpression cond
    in case cond' of
            BooleanLiteral True -> true'
            BooleanLiteral False -> false'
            otherwise -> IfStatement cond' true' false'

optimizeStatement (LetStatement id e stmt) =
    let e' = optimizeExpression e
        stmt' = optimizeStatement stmt
    in LetStatement id e' stmt'

optimizeStatement (ForStatement iter id e stmt) =
    let e' = optimizeExpression e
        stmt' = optimizeStatement stmt
    in ForStatement iter id e' stmt'

optimizeStatement (SwitchStatement e branches) =
    let e' = optimizeExpression e
        branches' = map (optimizeExpression *** optimizeStatement) branches
    in SwitchStatement e' branches'

optimizeStatement s = s


fusePrints :: [Statement] -> [Statement]
fusePrints (PrintStatement (StringLiteral lhs):PrintStatement (StringLiteral rhs):rem) =
    fusePrints $ PrintStatement (StringLiteral $ lhs ++ rhs):fusePrints rem
fusePrints (s:rem) = s:fusePrints rem
fusePrints [] = []

optimizeExpression :: Expression -> Expression
optimizeExpression e =
    if isConst e
        then evaluateConstExpression e
        else e

isConst :: Expression -> Bool
isConst (StringLiteral _) = True
isConst (IntLiteral _) = True
isConst (FloatLiteral _) = True
isConst (BooleanLiteral _) = True
isConst (ListExpression xs) = all isConst xs
isConst (AListExpression xs) = all (\(k,v) -> isConst k && isConst v) xs
isConst (EscapeExpression _ e) = isConst e
isConst (UnaryExpression _ e) = isConst e
isConst (BinaryExpression _ a b) = isConst a && isConst b
isConst e = False

evaluateConstExpression :: Expression -> Expression
evaluateConstExpression e =
    let rs = Run.RunState { Run.rsScope = Null, Run.rsAST = defAST, Run.rsOptions = Run.defaultOptions }
        v = unsafePerformIO $ evalStateT (Run.runExpression e) rs
    in fromVariant v

fromVariant :: Variant -> Expression
fromVariant (String s) = StringLiteral s
fromVariant (Integer i) = IntLiteral i
fromVariant (Double d) = FloatLiteral d
fromVariant (Bool b) = BooleanLiteral b
fromVariant (List xs) = ListExpression $ map fromVariant xs
fromVariant (AList xs) = AListExpression $ map (fromVariant *** fromVariant) xs