module Mathista.Generator.C where import Mathista.IL import Data.List import Data.List.Utils import Data.String.Utils import Text.Regex vstr s = if (startswith "$" s) then "mt_tmp_" ++ (replace "$" "" s) else "mt_val_" ++ s len = show . length generate :: IL -> String generate (ILLAssign v indexes dims elems) = "mt_lassign(instance, &" ++ (vstr v) ++ ", " ++ (len dims) ++ ", {" ++ (join' dims) ++ "}, " ++ (len elems) ++ ", {" ++ (join' elems) ++ "}" ++ ");" where join' xs = intercalate ", " $ map show xs generate (ILAssign to indexes dims from) = "mt_assign(instance, &" ++ (vstr to) ++ ", " ++ (len indexes) ++ ", {" ++ (join' indexes) ++ "}, " ++ (len dims) ++ ", {" ++ (join' dims) ++ "}, &" ++ (vstr from) ++ ");" where join' xs = intercalate ", " $ map show xs generate (ILCall func args rets) = "mt_func_" ++ func ++ "(instance, " ++ (len args) ++ ", &" ++ (join' args) ++ ", " ++ -- FIXME: it won't works with no arg. (len rets) ++ ", &" ++ (join' rets) ++ ");" where join' xs = intercalate ", &" $ map vstr xs generate (ILFuncDecl name args rets) = "" generate (ILReturn vs) = error "unimplemented yet" generate (ILIf v) = "if (mt_cond(" ++ (vstr v) ++ ")) {" generate (ILElseIf v) = "} else if (mt_cond(" ++ (vstr v) ++ ")) {" generate (ILElse) = "} else {" generate (ILWhile v) = "while (mt_cond(" ++ (vstr v) ++ ")) {" generate (ILBreak) = "break;" generate (ILContinue) = "continue;" generate ILEnd = "}" extract_vars :: String -> [String] extract_vars s = uniq $ match_all "(mt_tmp_[0-9]+|mt_val_[a-zA-Z0-9_']+)" s where match_all re s = case (matchRegexAll (mkRegex re) s) of Just (_, _, rest, xs) -> xs ++ (match_all re rest) Nothing -> [] generate_c :: String -> [IL] -> String generate_c name ils = header ++ main where _main = foldl (\s il -> s ++ generate il ++ "\n") [] ils body = "" vars = extract_vars (_main ++ body) join'' pre suf xs = intercalate "\n" $ map (\x -> pre ++ x ++ suf) xs main = "void mt_main_" ++ name ++ "(mt_instance *instance) {\n" ++ join'' "mt_initval(instance, &" ");" vars ++ "\n" ++ _main ++ "}\n" header = "#include \n" ++ join'' "static mt_value " ";" vars ++ "\n"