module Mathista.Generator.C where
import Mathista.IL
import Data.List
import Data.List.Utils
import Data.String.Utils
import Text.Regex
vstr s =
if (startswith "$" s)
then "mt_tmp_" ++ (replace "$" "" s)
else "mt_val_" ++ s
len = show . length
generate :: IL -> String
generate (ILLAssign v indexes dims elems) =
"mt_lassign(instance, &" ++
(vstr v) ++ ", " ++
(len dims) ++ ", {" ++ (join' dims) ++ "}, " ++
(len elems) ++ ", {" ++ (join' elems) ++ "}" ++
");"
where
join' xs = intercalate ", " $ map show xs
generate (ILAssign to indexes dims from) =
"mt_assign(instance, &" ++
(vstr to) ++ ", " ++
(len indexes) ++ ", {" ++ (join' indexes) ++ "}, " ++
(len dims) ++ ", {" ++ (join' dims) ++ "}, &" ++
(vstr from) ++ ");"
where
join' xs = intercalate ", " $ map show xs
generate (ILCall func args rets) =
"mt_func_" ++ func ++ "(instance, " ++
(len args) ++ ", &" ++ (join' args) ++ ", " ++
(len rets) ++ ", &" ++ (join' rets) ++
");"
where
join' xs = intercalate ", &" $ map vstr xs
generate (ILFuncDecl name args rets) = ""
generate (ILReturn vs) = error "unimplemented yet"
generate (ILIf v) = "if (mt_cond(" ++ (vstr v) ++ ")) {"
generate (ILElseIf v) = "} else if (mt_cond(" ++ (vstr v) ++ ")) {"
generate (ILElse) = "} else {"
generate (ILWhile v) = "while (mt_cond(" ++ (vstr v) ++ ")) {"
generate (ILBreak) = "break;"
generate (ILContinue) = "continue;"
generate ILEnd = "}"
extract_vars :: String -> [String]
extract_vars s = uniq $ match_all "(mt_tmp_[0-9]+|mt_val_[a-zA-Z0-9_']+)" s
where
match_all re s = case (matchRegexAll (mkRegex re) s) of
Just (_, _, rest, xs) -> xs ++ (match_all re rest)
Nothing -> []
generate_c :: String -> [IL] -> String
generate_c name ils = header ++ main
where
_main = foldl (\s il -> s ++ generate il ++ "\n") [] ils
body = ""
vars = extract_vars (_main ++ body)
join'' pre suf xs = intercalate "\n" $
map (\x -> pre ++ x ++ suf) xs
main = "void mt_main_" ++ name ++ "(mt_instance *instance) {\n" ++
join'' "mt_initval(instance, &" ");" vars ++
"\n" ++
_main ++
"}\n"
header = "#include <mathista.h>\n" ++
join'' "static mt_value " ";" vars ++
"\n"