module Control.CP.FD.Gecode.Translate (
generate_gecode
) where
import Maybe (fromJust,isNothing,isJust)
import List (findIndex)
import Data.Map (elems,Map,lookup)
import Control.CP.FD.Gecode.CodegenSolver
import Control.CP.FD.Gecode.Common
import Control.CP.Solver
generate_gecode = stateToProg . compile
countTypeVars :: Store -> GType -> Int -> Int
countTypeVars s t u = foldl (+) 0 $ map (\x -> 1) $ filter (\x -> (u<0 || x<u) && (t == getVarType s x)) $ varsUsed (ctree s) []
maxDepth :: StoreNode -> Int
maxDepth (StoreNode { cons=_, dis=SNLeaf}) = 1
maxDepth (StoreNode { cons=_, dis=SNIntl l r }) = 1 + (maxDepth l `max` maxDepth r)
typeList = [TypeBool, TypeInt]
varsUsed :: StoreNode -> [ Bool ] -> [ Int ]
varsUsed node path = (nvars node) ++ case (dis node,path) of
(SNLeaf,_) -> []
(SNIntl l _,[]) -> (varsUsed l [])
(SNIntl l _,False:o) -> varsUsed l o
(SNIntl _ r,True:o) -> varsUsed r o
typeToString :: GType -> String
typeToString TypeInt = "IntVar"
typeToString TypeBool = "BoolVar"
typeToDefArgs :: GType -> (String,String)
typeToDefArgs TypeInt = ("-1000000000","1000000000")
typeToDefArgs TypeBool = ("0","1")
getVarName :: Store -> String -> Int -> String
getVarName s pre i = pre ++ "bb" ++ (typeToString $ getVarType s i) ++ "[" ++ (show $ countTypeVars s (getVarType s i) i) ++ "]"
getName :: GTerm t => Store -> String -> t -> String
getName s pre v = case (getVarId v) of
Nothing -> case (getIntValue v) of
Nothing -> error "oei"
Just n -> show n
Just i -> getVarName s pre i
stateToExplList :: Store -> [ (String,String,String) ]
stateToExplList s = map (fm) $ filter (\x -> not $ isVarImplicit s x) $ [0..((vars s)1)]
where fm i = (typeToString $ getVarType s i,"v"++(show i),getVarName s "" i)
stateToConstList :: Store -> Map Int VarBound -> [ String ]
stateToConstList s b = map fm $ elems $ b
where fm (VarBound i l u) = (getVarName s "" i) ++ "(*this," ++ (if isJust l then show $ fromJust l else defl) ++ "," ++ (if isJust u then show $ fromJust u else defu) ++ ")"
where (defl,defu) = typeToDefArgs $ getVarType s i
gopToGCRel :: GOperator -> String
gopToGCRel OEqual = "IRT_EQ"
gopToGCRel ODiff = "IRT_NQ"
gopToGCRel OLess = "IRT_LE"
gopToInvGCRel :: GOperator -> String
gopToInvGCRel OEqual = "IRT_EQ"
gopToInvGCRel ODiff = "IRT_NQ"
gopToInvGCRel OLess = "IRT_GR"
stateToPostList :: Store -> [ GConstraint ] -> [ String ]
stateToPostList s c = map fm $ reverse $ c
where fm (CRel t1 o t2) = "rel(home," ++ (gn t1) ++ "," ++ (gopToGCRel o) ++ "," ++ (gn t2) ++ ")"
fm (CMult t1 t2 t3) = "mult(home," ++ (gn t1) ++ "," ++ (gn t2) ++ "," ++ (gn t3) ++ ")"
fm (CDiv t1 t2 t3) = "div(home," ++ (gn t1) ++ "," ++ (gn t2) ++ "," ++ (gn t3) ++ ")"
fm (CMod t1 t2 t3) = "mod(home," ++ (gn t1) ++ "," ++ (gn t2) ++ "," ++ (gn t3) ++ ")"
fm (CAbs t1 t2) = "abs(home," ++ (gn t1) ++ "," ++ (gn t2) ++ ")"
fm (CDom t l u) = "dom(home," ++ (gn t) ++ "," ++ (show l) ++ "," ++ (show u) ++ ")"
fm (CValue t v) = "rel(home," ++ (gn t) ++ ",IRT_EQ," ++ (show v) ++ ")"
fm (CLinear l o c) = case (c,l) of
(0,[(v1,a),(v2,b)]) | a+b==0 && a>0 -> "rel(home," ++ (gn v1) ++ "," ++ (gopToGCRel o) ++ "," ++ (gn v2) ++ ")"
(0,[(v1,a),(v2,b)]) | a+b==0 && a<0 -> "rel(home," ++ (gn v1) ++ "," ++ (gopToInvGCRel o) ++ "," ++ (gn v2) ++ ")"
(_,[(v1,a)]) | a>0 && ((c `mod` a)==0) -> "rel(home," ++ (gn v1) ++ "," ++ (gopToGCRel o) ++ "," ++ (show (c `div` a)) ++ ")"
(_,[(v1,a)]) | a<0 && ((c `mod` (a))==0) -> "rel(home," ++ (gn v1) ++ "," ++ (gopToInvGCRel o) ++ "," ++ (show (c `div` a)) ++ ")"
(_,l) | all (\(_,a) -> a==1) l -> case unzip l of
(x,a) -> "{ " ++ (bl "iva" x) ++ " linear(home,iva," ++ (gopToGCRel o) ++ "," ++ (show c) ++ "); }"
_ -> case unzip l of
(x,a) -> "{ IntArgs ia(" ++ (show $ length a) ++ (foldl (\x y -> x ++ "," ++ (show y)) "" a) ++ "); "++(bl "iva" x) ++ " linear(home,ia,iva," ++ (gopToGCRel o) ++ "," ++ (show c) ++ "); }"
fm (CAllDiff l) = "{ " ++ (bl "ia" l) ++ "; distinct(home,ia); }"
fm (CSorted l e) = "{ " ++ (bl "ia" l) ++ "; rel(home,ia,"++(if e then "IRT_LQ" else "IRT_LE")++"); }"
gn t = getName s "p->" t
bl n l = "IntVarArgs "++n++"(" ++ (show $ length l) ++ "); " ++ (foldl (++) "" (map (\i -> n++"[" ++ (show i) ++ "]=" ++ (getVarName s "p->" $ fromJust $ getVarId $ l !! i) ++ "; ") [0..(length l)1]))
stateToBranchList :: Store -> GType -> [ String ]
stateToBranchList s t = map fm $ filter ff $ [0..((vars s)1)]
where ff i = (not $ isVarImplicit s i) && (t == getVarType s i)
fm i = getVarName s "" i
stateToBranchCode s t = " " ++ tn ++ "Args b" ++ tn ++ "(" ++ (show (length vars)) ++ ");\n" ++ (
foldl (\x y -> x ++ " b" ++ tn ++ "[" ++ (show y) ++ "]=" ++ (vars !! y) ++ ";\n") "" [0..(length vars)1]) ++
" branch(*this, b"++tn++", INT_VAR_SIZE_MIN, INT_VAL_SPLIT_MIN);\n"
where tn = typeToString t
vars = stateToBranchList s t
stateToBranches s = foldl (++) [] $ map (stateToBranchCode s) typeList
nodeToProg :: Store -> Map Int VarBound -> StoreNode -> [ Bool ] -> String
nodeToProg store bnds node path =
" static void node" ++ pathS ++ "(Space &home) {\n" ++
" /* varsused" ++ (show vrs) ++ "*/\n" ++
" HaskellProg *p = (HaskellProg *)(&home);\n" ++
(foldl (\x y -> x ++ " " ++ y ++ ";\n") "" $ map (\x -> (getVarName store "p->" x)++".init(home,"++(lowest x)++","++(highest x)++")") $ nvars node) ++
(foldl (\x y -> x ++ " " ++ y ++ ";\n") "" $ stateToPostList store $ cons node) ++
(case dis node of
SNLeaf -> (foldl (++) "" $ map (\x -> " rel(home,p->i["++(show x)++"],IRT_EQ,0);\n") [(length path)..(maxDepth $ ctree store)2]) ++
(foldl (++) "" $ map (\x -> " rel(home," ++ (getVarName store "p->" x) ++ ",IRT_EQ,"++(lowest x)++");\n") $ filter (\x -> isNothing $ findIndex (== x) vrs) [0..(vars store)1])
SNIntl _ _ -> " when(home,p->i[" ++ lenS ++ "],&node"++pathS++"R,&node"++pathS++"L);\n"
) ++
" }\n" ++
(case (dis node) of
SNLeaf -> ""
SNIntl l r -> nodeToProg store bnds l (path ++ [ False ]) ++ nodeToProg store bnds r (path ++ [ True ])
)
where pathS = foldl (++) "" $ map (\x -> if x then "R" else "L") path
lenS = show $ length path
vrs = varsUsed (ctree store) path
lowest i = case Data.Map.lookup i bnds of
Nothing -> case typeToDefArgs $ getVarType store i of
(x,_) -> x
Just (VarBound _ Nothing _) -> case typeToDefArgs $ getVarType store i of
(x,_) -> x
Just (VarBound _ (Just l) _) -> show l
highest i = case Data.Map.lookup i bnds of
Nothing -> case typeToDefArgs $ getVarType store i of
(_,x) -> x
Just (VarBound _ _ Nothing) -> case typeToDefArgs $ getVarType store i of
(_,x) -> x
Just (VarBound _ _ (Just l)) -> show l
stateToProg :: Store -> String
stateToProg s =
"#include \"gecode/kernel.hh\"\n"++
"#include \"gecode/support.hh\"\n"++
"#include \"gecode/int.hh\"\n"++
"#include \"gecode/search.hh\"\n"++
"#include \"gecode/minimodel.hh\"\n"++
"\n"++
"using namespace Gecode;\n"++
"\n"++
"class HaskellProg : public Space {\n"++
"protected:\n"++
" BoolVarArray i;\n\n"++
(foldl (++) "" $ map (\x -> " " ++ (typeToString x) ++ "Array bb" ++ (typeToString x) ++ ";\n") typeList) ++
"public:\n"++
" HaskellProg() : " ++
(foldl (\x y -> (if x=="" then "" else x ++ ", ") ++ y) "" $ map (\x -> "bb" ++ typeToString x ++ "(*this," ++ (show $ countTypeVars s x (1)) ++ ")") typeList) ++
", i(*this,"++(show $ (maxDepth (ctree s)) 1)++",0,1) {\n"++
" node(*this);\n"++
" branch(*this, i, INT_VAR_SIZE_MIN, INT_VAL_MIN);\n" ++ stateToBranches s ++
" };\n"++
" virtual void print(std::ostream& os) {\n"++
(foldl (\x (vType,vName,vExpr) -> x ++ " os << \"" ++ vName ++ ": \" << " ++ vExpr ++ " << std::endl;\n") "" (stateToExplList s)) ++
" }\n"++
" HaskellProg(bool share, HaskellProg &s) : Space(share,s) {\n"++
" i.update(*this,share,s.i);\n" ++
(foldl (\x y -> x ++ " " ++ y ++ ";\n") "" $ map (\x -> "bb" ++ (typeToString x) ++ ".update(*this,share,s.bb"++(typeToString x)++")") typeList) ++
" }\n"++
" virtual Space* copy(bool share) {\n"++
" return new HaskellProg(share, *this);\n"++
" }\n"++
nodeToProg s bounds (ctree s) [] ++
"};\n"++
"\n"++
"int main(void) {\n"++
" HaskellProg *prog=new HaskellProg();\n"++
" DFS<HaskellProg> srch(prog);\n"++
" delete prog;\n"++
" do {\n"++
" HaskellProg *sol=srch.next();\n"++
" if (sol==NULL) break;\n"++
" sol->print(std::cout);\n"++
" } while(0);\n"++
" return 0;\n"++
"}\n"
where bounds = getAllBounds s
vrs = varsUsed (ctree s) []