module Mathista.Generator.Matlab (generate_matlab) where
import Mathista.IL
import Data.List.Split
import Data.String.Utils


vstr :: String -> String
vstr s =
    if (startswith "$" s)
       then "__" ++ (replace "$" "" s)
       else s

genRetVars :: Int -> [String]
genRetVars n = map (("__r" ++ ) . show) [0..(n - 1)]


generate :: IL -> String
generate (ILLAssign v indexes dims elems) = vstr v ++ " = [" ++ elems' ++ "];"
    where
        elems' = if length dims >= 2
                     then join ";" $ map (join " ") $
                          chunksOf ((fromInteger (dims !! 1)) :: Int) (map show elems)
                     else join " " (map show elems)

generate (ILAssign to indexes dims from) = (vstr to) ++ " = " ++ (vstr from) ++ ";" -- TODO: support indexes and dims

generate (ILCall func args rets) = (lhs rets) ++ "mt_" ++ func ++ "(" ++ (join ", " (map vstr args)) ++ ");"
                                   where
                                     lhs [] = ""
                                     lhs rs = "[" ++ (join ", " (map vstr rs)) ++ "] = "

generate (ILFuncDecl name args rets) = "function " ++ ret rets ++ " = " ++
                                       name ++ "(" ++ (join ", " (map fst args)) ++ ")"
    where
        ret xs
          | length(xs) == 0 = ""
          | length(xs) == 1 = "__r0"
          | otherwise       = "[" ++ (join ", " (genRetVars (length xs))) ++ "]"


generate (ILReturn vs) = join "\n" $ zipWith f rs vs
                         where
                             rs = genRetVars (length vs)
                             f r v = r ++ " = " ++ (vstr v)
generate (ILIf v)      = "if " ++ (vstr v) ++ " != 0"
generate (ILElseIf v)  = "elseif " ++ (vstr v) ++ " != 0"                        
generate (ILElse)      = "else"
generate (ILWhile v)   = "while " ++ (vstr v) ++ " != 0"
generate (ILBreak)     = "break"
generate (ILContinue)  = "continue"
generate ILEnd         = "end"

generate_matlab :: String -> [IL] -> String
generate_matlab name ils = foldl (\s il -> s ++ generate il ++ "\n") [] ils