{-#LANGUAGE ScopedTypeVariables #-}
module Text.HPaco.Writers.Run
        ( run, runWith
        , RunState (..)
        , RunOptions (..)
        , defaultOptions
        , runAST
        , runStatement
        , runExpression
        )
where

import Prelude hiding (toInteger)
import Data.Variant
import Data.Variant.ToFrom
import qualified Data.Variant as V
import Data.Maybe
import Data.Monoid
import qualified Data.List as List
import qualified Data.List.Split as Split
import Control.Monad.State
import Safe
import Text.HPaco.Writers.Run.Encode
import Text.HPaco.Writers.Run.Library
import Text.HPaco.AST
import Text.HPaco.AST.AST
import Text.HPaco.AST.Statement
import Text.HPaco.AST.Expression
import System.IO

data RunOptions = RunOptions
                    { roTemplateName :: String
                    , roOutput :: Handle
                    }

defaultOptions = RunOptions
                    { roTemplateName = "unnamed"
                    , roOutput = stdout
                    }

data RunState = RunState
                    { rsScope :: Variant
                    , rsOptions :: RunOptions
                    , rsAST :: AST
                    }

type Run a = StateT RunState IO a

run :: RunOptions -> AST -> IO ()
run = runWith $ AList []

runWith :: ToVariant a => a -> RunOptions -> AST -> IO ()
runWith scope opts ast = do
    let st = RunState { rsScope = toVariant scope, rsOptions = opts , rsAST = ast }
    execStateT (runAST ast) st
    return ()

getVar :: String -> Run Variant
getVar "." = gets rsScope
getVar key = liftM (V.lookup $ String key) (gets rsScope)

runAST :: AST -> Run ()
runAST ast = do
    runStatement . astRootStatement $ ast

-- Statements

runStatement :: Statement -> Run ()
runStatement (PrintStatement e) = do
    d <- runExpression e
    h <- gets (roOutput . rsOptions)
    liftIO . hPutStr h . flatten $ d
runStatement (StatementSequence ss) = mapM_ runStatement ss
runStatement (IfStatement cond true false) = do
    b <- liftM toBool $ runExpression cond
    runStatement $ if b then true else false
runStatement (LetStatement ident expr stmt) =
    runExpression expr >>= \e -> withIdentifiedScope ident e (runStatement stmt)
runStatement (ForStatement Nothing ident expr stmt) = do
    es <- runExpression expr
    sequence_ $ vmap (\e -> withIdentifiedScope ident e (runStatement stmt)) es
runStatement (ForStatement (Just iter) ident expr stmt) = do
    es <- runExpression expr
    sequence_ $ vamap (\(k, v) -> withIdentifiedScope iter k $ withIdentifiedScope ident v (runStatement stmt)) es
runStatement (SwitchStatement expr branches) = do
    ev <- runExpression expr
    tests <- mapM runExpression $ map fst branches
    let f test stmt = if ev ~== test then Just stmt else Nothing
        branch = headMay $ catMaybes $ zipWith f tests (map snd branches)
    maybe (return ()) runStatement branch
runStatement NullStatement = return ()
runStatement (CallStatement identifier) = do
    ast <- gets rsAST
    let body = fromMaybe NullStatement $ List.lookup identifier $ astDefs ast
    runStatement body
runStatement SourcePositionStatement {} = return ()

-- Scope helpers

-- Run with a completely independent scope; do not inherit parent scope
withScope :: Variant -> Run a -> Run a
withScope scope inner = do
    oldScope <- gets rsScope
    modify (\s -> s { rsScope = scope })
    a <- inner
    modify (\s -> s { rsScope = oldScope })
    return a

-- Run with a merged scope, based on local scope and parent scope. The local
-- scope has precedence over the inherited one.
withInheritingScope :: Variant -> Run a -> Run a
withInheritingScope scope inner = do
    oldScope <- gets rsScope
    let newScope = V.scopeMerge scope oldScope
    withScope newScope inner

withLocalVar :: Variant -> Variant -> Run a -> Run a
withLocalVar key val inner =
    withInheritingScope (AList [(key, val)]) inner

withIdentifiedScope :: String -> Variant -> Run a -> Run a
withIdentifiedScope key val inner =
    if key == "."
        then withInheritingScope val inner
        else withLocalVar (String key) val inner

-- Expressions

runExpression :: Expression -> Run Variant
runExpression (StringLiteral str) = return $ String str
runExpression (BooleanLiteral str) = return $ Bool str
runExpression (IntLiteral str) = return $ Integer str
runExpression (FloatLiteral str) = return $ Double str
runExpression (ListExpression items) = List `liftM` mapM runExpression items
runExpression (AListExpression items) = do
                let (keys, values) = unzip items
                keys' <- mapM runExpression keys
                values' <- mapM runExpression values
                return . AList $ zip keys' values'
runExpression (EscapeExpression EscapeHTML e) = (String . htmlEncode . flatten) `liftM` runExpression e
runExpression (EscapeExpression EscapeURL e) = (String . urlEncode . flatten) `liftM` runExpression e
runExpression (TernaryExpression cond left right) = do
    condVal <- runExpression cond
    let expr =
            if toBool condVal
                then left
                else right
    runExpression expr
runExpression (BinaryExpression op left right) = do
    lhs <- runExpression left
    rhs <- runExpression right
    return $ applyBinaryOperation op lhs rhs
runExpression (UnaryExpression op arg) = do
    applyUnaryOperation op `liftM` runExpression arg
runExpression (VariableReference varname) = getVar varname
runExpression (FunctionCallExpression (VariableReference "library") (libnameExpr:_)) = do
    libname <- runExpression libnameExpr
    return $ loadLibrary $ V.flatten libname
runExpression (FunctionCallExpression fn argExprs) = do
    func <- runExpression fn
    args <- mapM runExpression argExprs
    return $ V.call func args

applyBinaryOperation :: BinaryOperator -> Variant -> Variant -> Variant
applyBinaryOperation OpPlus = (+)
applyBinaryOperation OpMinus = (-)
applyBinaryOperation OpMul = (*)
applyBinaryOperation OpDiv = \l -> \r ->
    if toDouble r == 0.0
        then Null
        else Double $ toDouble l / toDouble r
applyBinaryOperation OpMod = \l -> \r ->
    if toInteger r == 0
        then Null
        else Integer $ toInteger l `mod` toInteger r
applyBinaryOperation OpEquals = \l -> \r ->
    Bool $ l == r
applyBinaryOperation OpNotEquals = \l -> \r ->
    Bool $ l /= r
applyBinaryOperation OpLooseEquals = \l -> \r ->
    Bool $ l ~== r
applyBinaryOperation OpLooseNotEquals = \l -> \r ->
    Bool $ l ~/= r
applyBinaryOperation OpLess = \l -> \r ->
    Bool $ toDouble l < toDouble r
applyBinaryOperation OpNotLess = \l -> \r ->
    Bool $ toDouble l >= toDouble r
applyBinaryOperation OpGreater = \l -> \r ->
    Bool $ toDouble l > toDouble r
applyBinaryOperation OpNotGreater = \l -> \r ->
    Bool $ toDouble l <= toDouble r
applyBinaryOperation (Flipped op) = \l -> \r ->
    applyBinaryOperation op r l
applyBinaryOperation OpCoalesce = \l -> \r ->
    case l of
        Null -> r
        otherwise -> l

applyBinaryOperation OpMember = flip V.lookup
applyBinaryOperation OpInList = V.elem
applyBinaryOperation OpConcat = mappend

applyBinaryOperation OpBooleanAnd = \l -> \r ->
    Bool $ toBool l && toBool r
applyBinaryOperation OpBooleanOr = \l -> \r ->
    Bool $ toBool l || toBool r
applyBinaryOperation OpBooleanXor = \l -> \r ->
    let lb = toBool l
        rb = toBool r
    in Bool $ (lb || rb) && not (lb && rb)

applyUnaryOperation OpNot arg = Bool . not . V.toBool $ arg