{-#LANGUAGE ScopedTypeVariables #-}
module Text.HPaco.Writers.Run
        ( run
        , RunState (..)
        , RunOptions (..)
        , defaultOptions
        , runAST
        , runStatement
        , runExpression
        )
where

import Prelude hiding (toInteger)
import Data.Variant
import qualified Data.Variant as V
import Data.Maybe
import qualified Data.List as List
import Control.Monad.State
import Safe
import Text.HPaco.Writers.Run.Encode
import Text.HPaco.AST
import Text.HPaco.AST.AST
import Text.HPaco.AST.Statement
import Text.HPaco.AST.Expression

data RunOptions = RunOptions
                    { roTemplateName :: String
                    }

defaultOptions = RunOptions
                    { roTemplateName = "unnamed"
                    }

data RunState = RunState
                    { rsScope :: Variant
                    , rsOptions :: RunOptions
                    , rsAST :: AST
                    }

type Run a = StateT RunState IO a

run :: RunOptions -> AST -> IO ()
run opts ast = do
    let st = RunState { rsScope = AList [], rsOptions = opts , rsAST = ast }
    execStateT (runAST ast) st
    return ()

getVar :: String -> Run Variant
getVar "." = gets rsScope
getVar key = liftM (V.lookup $ String key) (gets rsScope)

runAST :: AST -> Run ()
runAST ast = do
    runStatement . astRootStatement $ ast

-- Statements

runStatement :: Statement -> Run ()
runStatement (PrintStatement e) = do
    d <- runExpression e
    liftIO . putStr . flatten $ d
runStatement (StatementSequence ss) = mapM_ runStatement ss
runStatement (IfStatement cond true false) = do
    b <- liftM toBool $ runExpression cond
    runStatement $ if b then true else false
runStatement (LetStatement ident expr stmt) =
    runExpression expr >>= \e -> withIdentifiedScope ident e (runStatement stmt)
runStatement (ForStatement ident expr stmt) = do
    es <- runExpression expr
    sequence_ $ vmap (\e -> withIdentifiedScope ident e (runStatement stmt)) es
runStatement (SwitchStatement expr branches) = do
    ev <- runExpression expr
    tests <- mapM runExpression $ map fst branches
    let f test stmt = if ev ~== test then Just stmt else Nothing
        branch = headMay $ catMaybes $ zipWith f tests (map snd branches)
    maybe (return ()) runStatement branch
runStatement NullStatement = return ()
runStatement (CallStatement identifier) = do
    ast <- gets rsAST
    let body = fromMaybe NullStatement $ List.lookup identifier $ astDefs ast
    runStatement body

-- Scope helpers

-- Run with a completely independent scope; do not inherit parent scope
withScope :: Variant -> Run a -> Run a
withScope scope inner = do
    oldScope <- gets rsScope
    modify (\s -> s { rsScope = scope })
    a <- inner
    modify (\s -> s { rsScope = oldScope })
    return a

-- Run with a merged scope, based on local scope and parent scope. The local
-- scope has precedence over the inherited one.
withInheritingScope :: Variant -> Run a -> Run a
withInheritingScope scope inner = do
    oldScope <- gets rsScope
    let newScope = V.scopeMerge scope oldScope
    withScope newScope inner

withLocalVar :: Variant -> Variant -> Run a -> Run a
withLocalVar key val inner =
    withInheritingScope (AList [(key, val)]) inner

withIdentifiedScope :: String -> Variant -> Run a -> Run a
withIdentifiedScope key val inner =
    if key == "."
        then withInheritingScope val inner
        else withLocalVar (String key) val inner

-- Expressions

runExpression :: Expression -> Run Variant
runExpression (StringLiteral str) = return $ String str
runExpression (BooleanLiteral str) = return $ Bool str
runExpression (IntLiteral str) = return $ Integer str
runExpression (FloatLiteral str) = return $ Double str
runExpression (ListExpression items) = List `liftM` mapM runExpression items
runExpression (AListExpression items) = do
                let (keys, values) = unzip items
                keys' <- mapM runExpression keys
                values' <- mapM runExpression values
                return . AList $ zip keys' values'
runExpression (EscapeExpression EscapeHTML e) = (String . htmlEncode . flatten) `liftM` runExpression e
runExpression (EscapeExpression EscapeURL e) = (String . urlEncode . flatten) `liftM` runExpression e
runExpression (BinaryExpression op left right) = do
    lhs <- runExpression left
    rhs <- runExpression right
    return $ applyBinaryOperation op lhs rhs
runExpression (UnaryExpression op arg) = do
    applyUnaryOperation op `liftM` runExpression arg
runExpression (VariableReference varname) = getVar varname
runExpression (FunctionCallExpression (VariableReference "library") (libnameExpr:_)) = do
    libname <- runExpression libnameExpr
    loadLibrary $ V.flatten libname
runExpression (FunctionCallExpression fn argExprs) = do
    func <- runExpression fn
    args <- mapM runExpression argExprs
    return $ V.call func args

loadLibrary :: String -> Run Variant
loadLibrary "std" =
    return $
        AList [ ( String "count", wrapf1 (Integer . fromIntegral . List.length . V.toAList) )
              , ( String "join", wrapf (\(sep:lst:_) -> String . List.concat . List.intersperse (V.flatten sep) . map V.flatten . V.values $ lst) )
              ]
loadLibrary other = return Null

applyBinaryOperation :: BinaryOperator -> Variant -> Variant -> Variant
applyBinaryOperation OpPlus = (+)
applyBinaryOperation OpMinus = (-)
applyBinaryOperation OpMul = (*)
applyBinaryOperation OpDiv = \l -> \r ->
    if toDouble r == 0.0
        then Null
        else Double $ toDouble l / toDouble r
applyBinaryOperation OpMod = \l -> \r ->
    if toInteger r == 0
        then Null
        else Integer $ toInteger l `mod` toInteger r
applyBinaryOperation OpEquals = \l -> \r ->
    Bool $ l == r
applyBinaryOperation OpNotEquals = \l -> \r ->
    Bool $ l /= r
applyBinaryOperation OpLooseEquals = \l -> \r ->
    Bool $ l ~== r
applyBinaryOperation OpLooseNotEquals = \l -> \r ->
    Bool $ l ~/= r
applyBinaryOperation OpLess = \l -> \r ->
    Bool $ toDouble l < toDouble r
applyBinaryOperation OpNotLess = \l -> \r ->
    Bool $ toDouble l >= toDouble r
applyBinaryOperation OpGreater = \l -> \r ->
    Bool $ toDouble l > toDouble r
applyBinaryOperation OpNotGreater = \l -> \r ->
    Bool $ toDouble l <= toDouble r
applyBinaryOperation (Flipped op) = \l -> \r ->
    applyBinaryOperation op r l

applyBinaryOperation OpMember = flip V.lookup
applyBinaryOperation OpInList = V.elem

applyBinaryOperation OpBooleanAnd = \l -> \r ->
    Bool $ toBool l && toBool r
applyBinaryOperation OpBooleanOr = \l -> \r ->
    Bool $ toBool l || toBool r
applyBinaryOperation OpBooleanXor = \l -> \r ->
    let lb = toBool l
        rb = toBool r
    in Bool $ (lb || rb) && not (lb && rb)

applyUnaryOperation OpNot arg = Bool . not . V.toBool $ arg