module Mathista.Compiler (compile) where
import Mathista.AST
import Mathista.IL


--
--  Helper functions
--

allocTmpId :: Int -> String
allocTmpId n = "$" ++ show n

--
--  Expressions
--
compileExpr :: Int -> Expr -> ([IL], Id)

-- Number literal
compileExpr c (Number n) =
    let val = allocTmpId c
    in ([ILLAssign val [] [1] [n]], val)

-- Matrix literal
compileExpr c (Matrix rows) =
    let
        length' :: [a] -> Integer
        length' = toInteger . length
        vec m x = case x of
                      Number n -> m ++ [n]
                      _ -> error "invalid element in a matrix literal"
        mtx m x = case x of
                      Matrix mtx -> m ++ (foldl vec [] mtx)
                      _ -> error "invalid element in a matrix literal"
        val   = allocTmpId c
        dims  = case (rows !! 0) of
                    Matrix cols -> [length' rows, length' cols] -- 2d matrix
                    Number _    -> [length' rows] -- vector
                    _ -> error "invalid element in a matrix literal"
        elems = case (rows !! 0) of
                    Matrix cols -> foldl mtx [] rows -- 2d matrix
                    Number _    -> foldl vec [] rows -- vector
                    _ -> error "invalid element in a matrix literal"
    in
        ([ILLAssign val [] [] []], val)

-- variable
compileExpr c (VarRef (name, Nothing)) = ([], name)
compileExpr c (VarRef (name, Just indexes)) = error "submatrix is not supported yet"

-- function call (with only one return value)
compileExpr c (FuncCall name argExprs) =
    let
        ret = allocTmpId c
        (ils, args) = compileExprs (c + 1) argExprs
    in
        (ils ++ [ILCall name args [ret]], ret)

-- unary operators
compileExpr c (Not   expr) = compileUnaryExpr c "not"   expr
compileExpr c (Plus  expr) = compileUnaryExpr c "plus"  expr
compileExpr c (Minus expr) = compileUnaryExpr c "minus" expr

-- binary operators
compileExpr c (Add lhs rhs) = compileBinaryExpr c "add" lhs rhs
compileExpr c (Sub lhs rhs) = compileBinaryExpr c "sub" lhs rhs
compileExpr c (Mul lhs rhs) = compileBinaryExpr c "mul" lhs rhs
compileExpr c (Div lhs rhs) = compileBinaryExpr c "div" lhs rhs
compileExpr c (Eq  lhs rhs) = compileBinaryExpr c "eq"  lhs rhs
compileExpr c (Neq lhs rhs) = compileBinaryExpr c "neq" lhs rhs
compileExpr c (Gt  lhs rhs) = compileBinaryExpr c "gt"  lhs rhs
compileExpr c (Gte lhs rhs) = compileBinaryExpr c "gte" lhs rhs
compileExpr c (Lt  lhs rhs) = compileBinaryExpr c "lt"  lhs rhs
compileExpr c (Lte lhs rhs) = compileBinaryExpr c "lte" lhs rhs
compileExpr c (And lhs rhs) = compileBinaryExpr c "and" lhs rhs
compileExpr c (Or  lhs rhs) = compileBinaryExpr c "or"  lhs rhs


compileUnaryExpr c name expr =
    let
        val = allocTmpId c
        (ils, ret) = compileExpr (c + 1) expr
    in
        (ils ++ [ILCall name [ret] [val]], val)

compileBinaryExpr c name lhsExpr rhsExpr =
    let
        val = allocTmpId c
        (lhs_ils, lhs) = compileExpr (c + 1) lhsExpr
        (rhs_ils, rhs) = compileExpr (c + 2) rhsExpr
    in
        (lhs_ils ++ rhs_ils ++ [ILCall name [lhs, rhs] [val]], val)


-- multiple expressions
compileExprs :: Int -> [Expr] -> ([IL], [Id])
compileExprs c_start exprs =
    let
        (_, ils, ids) = foldl f (c_start, [], []) exprs
        f (c, ils, ids) x =
            let (ils', ret) = compileExpr c x
            in (c + 1, ils ++ ils', ids ++ [ret])
    in
      (ils, ids)


-- function call in an assign statement
compileAssignCall :: Id -> [Expr] -> [Id] -> [IL]
compileAssignCall name argExprs rets =
    let
        c = 0
        ret = allocTmpId c
        (ils, args) = compileExprs (c + 1) argExprs
    in
        ils ++ [ILCall name args rets]


--
--  Statements
--
compileStmt :: Stmt -> [IL]
compileStmt (FuncDecl name args rets stmts) =
    [ILFuncDecl name args rets] ++ (compile stmts) ++ [ILEnd]

compileStmt (For var from to stmts) =
    error "unimplemented" -- TODO

compileStmt (While expr stmts) =
    let
        (cond_ils, cond_v) = compileExpr 0 expr
    in
        cond_ils ++ [ILWhile cond_v] ++ (compile stmts) ++ cond_ils ++ [ILEnd]

compileStmt (Assign vars expr) =
    case expr of
        FuncCall name argExprs -> compileAssignCall name argExprs (idsFromVars vars)
                                   where idsFromVars = map fst
        _ -> case vars of
            [v] -> let (ils, ret) = compileExpr 0 expr
                   in ils ++ [ILAssign (fst v) [] [] ret]
            _   -> error "rhs of a multiple assignment must be a function call"

compileStmt (If ifblocks elseblock) =
    let
        else_ils = case elseblock of
            Just stmts -> [ILElse] ++ (compile stmts)
            Nothing    -> []
        if_ils = let (ils, ret) = compileExpr 0 (fst (ifblocks !! 0))
                 in ils ++ [ILIf ret] ++ compile (snd (ifblocks !! 0))
        elseif_ils = if length(ifblocks) > 1
                         then error "elseif not supported yet"
                         else []
    in
        if_ils ++ elseif_ils ++ else_ils ++ [ILEnd]

compileStmt (ExprStmt expr) = fst $ compileExpr 0 expr

compileStmt (Continue) = [ILContinue]

compileStmt (Break) = [ILBreak]

compileStmt (Return exprs) =    
    let (ils, ids) = compileExprs 0 exprs
    in ils ++ [ILReturn ids]

compileStmt (DoNothing) = []

compile :: [Stmt] -> [IL]
compile = foldl (\ils x -> ils ++ compileStmt x) []