module Language.Passage.Graph2C where
import Language.Passage.Utils hiding (double,int)
import Language.Passage.Term hiding (bin)
import Language.Passage.AST
import Language.Passage.Lang.C
import Language.Passage.GraphColor(groupByColor)
import qualified Data.Map as M
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import Data.Maybe(maybeToList,fromJust)
import Data.List(sortBy, transpose)
import Data.Function(on)
import MonadLib (ReaderT, StateT, Id
, runId, runStateT, runReaderT
, get, set, asks, mapReader
, forM
, zipWithM
)
import Data.Graph(SCC(..))
import Data.Graph.SCC
cnameVar :: (NodeIdx, StoVar) -> M CExpr
cnameVar (ix,sv) =
case stoVarName sv of
InArray x is ->
do ai <- lookupArray x
let fixIndex i (from,_) = int_lit (i from)
name = arrName (x,ai)
return $ foldl arr_ix (var name)
$ zipWith fixIndex is $ arrayDimensions ai
_ -> return $ var $ simpleCName ix
simpleCName :: NodeIdx -> CIdent
simpleCName x = ident ("v_" ++ show x)
variable :: NodeIdx -> M CExpr
variable x =
do isObs <- isObserved x
case isObs of
Just v -> return (double_lit v)
Nothing ->
do samp <- isSampled x
if samp
then return $ var $ simpleCName x
else
do mbsv <- lookupVarMb x
case mbsv of
Just sv -> cnameVar (x,sv)
Nothing -> return $ var $ simpleCName x
term :: Term NodeIdx -> M CExpr
term t =
case t of
TVar x -> variable x
TArr x -> do ai <- lookupArray x
return $ var $ arrName (x,ai)
TConst x -> return (double_lit x)
TApp op ts ->
do ds@(a : bs) <- mapM term ts
let b : _ = bs
bin x = parens a <+> text x <+> parens b
case op of
TExp -> return $ call (ident "exp") ds
TLog -> return $ call (ident "log") ds
TNeg -> return $ char '-' <> parens (head ds)
TAdd -> return $ bin "+"
TMul -> return $ bin "*"
TSub -> return $ bin "-"
TDiv -> return $ bin "/"
TPow -> case ts of
[_ , TConst 2.0] -> (return $ call (ident "square") [a])
_ -> return $ call (ident "pow") ds
TLogGamma -> return $ call (ident "lgamma") ds
TCase ->
do i <- newDetVar
let name = ident ("case_fun_" ++ show i)
args <- (map simpleCName . maybeToList) `fmap` isSampling
newLocalFunDecl
(fun_decl double name [ (double,x) | x <- args ]) [
switch (cast int a)
(zip [ 0 .. ] (map (return . creturn) bs))
(callS (var (ident "crash_out_of_bounds"))
[ text "__LINE__" ])
]
return $ call (var name) (map var args)
TIx ->
case ts of
[ arr, ix ] ->
do dims <- getArrDimensions arr
case dims of
(from,_) : _ ->
do expr <- term (ix fromIntegral from)
return (arr_ix a (cast int expr))
_ -> error $ "Type error: attempt ro index non an array."
_ -> error $ "TIx: Unexpected args: " ++ show ts
getArrDimensions :: Term NodeIdx -> M [(Int,Int)]
getArrDimensions t =
case t of
TArr x -> arrayDimensions `fmap` lookupArray x
TApp TIx [ a, _ ] ->
do ds <- getArrDimensions a
case ds of
_ : ds1 -> return ds1
[] -> error $ "Type error: attempt to index a non-array."
_ -> error $ "Type error: not an array"
newtype M a = M (ReaderT R (StateT S Id) a) deriving (Functor, Monad)
data R = R { config :: SamplerConf
, sampling :: Maybe NodeIdx
}
data CModule =
CModule
{ cpp_stuff :: Doc
, var_decls :: [Doc]
, cpp_funs :: Doc
, fun_decls :: [(Doc,Doc)]
}
blankMod :: CModule
blankMod = CModule { cpp_stuff = empty
, var_decls = []
, cpp_funs = empty
, fun_decls = []
}
mergeCModules :: CModule -> CModule -> CModule
mergeCModules m1 m2 =
CModule { cpp_stuff = cpp_stuff m1 $$ cpp_stuff m2
, var_decls = var_decls m1 ++ var_decls m2
, cpp_funs = cpp_funs m1 $$ cpp_funs m2
, fun_decls = fun_decls m1 ++ fun_decls m2
}
data S = S { main_mod :: CModule
, cur_mod :: Maybe CModule
, helper_mods :: [(NodeIdx, CModule)]
, cnames :: !Int
}
noHelpers :: S -> S
noHelpers s = s { main_mod = foldr mergeCModules (main_mod s)
$ map snd (helper_mods s)
, helper_mods = []
}
getGraph :: M BayesianGraph
getGraph = M (asks (graph . config))
lookupArray :: NodeIdx -> M ArrayInfo
lookupArray x =
do mb <- lookupArrayMb x
case mb of
Just a -> return a
Nothing -> error ("Unknown array variable: " ++ show x)
lookupArrayMb :: NodeIdx -> M (Maybe ArrayInfo)
lookupArrayMb x =
do g <- getGraph
return (IM.lookup x (stoArryas g))
lookupVarMb :: NodeIdx -> M (Maybe StoVar)
lookupVarMb x =
do g <- getGraph
return (IM.lookup x (stoNodes g))
lookupVar :: NodeIdx -> M StoVar
lookupVar x =
do mb <- lookupVarMb x
case mb of
Just sv -> return sv
Nothing -> error ("Unknown stochastic variable: " ++ show x)
nowSampling :: NodeIdx -> M a -> M a
nowSampling x (M a) = M (mapReader (\i -> i { sampling = Just x }) a)
isSampling :: M (Maybe NodeIdx)
isSampling = M (asks sampling)
isSampled :: NodeIdx -> M Bool
isSampled x = (Just x ==) `fmap` isSampling
isObserved :: NodeIdx -> M (Maybe Double)
isObserved ix = IM.lookup ix `fmap` M (asks (observe . config))
isInitialized :: NodeIdx -> M (Maybe Double)
isInitialized ix = IM.lookup ix `fmap` M (asks (initialize . config))
newDetVar :: M NodeIdx
newDetVar = M $
do s <- get
let i = cnames s
set s { cnames = i + 1 }
return i
newHelper :: NodeIdx -> M a -> M a
newHelper i (M m) = M $
do s <- get
set s { cur_mod = Just blankMod }
a <- m
s1 <- get
set s1 { cur_mod = Nothing
, helper_mods = (i, fromJust (cur_mod s1)) : helper_mods s1
}
return a
updHelper :: (CModule -> CModule) -> M ()
updHelper f = M $
do s <- get
case cur_mod s of
Nothing -> error "BUG: updHelper called without a module"
Just m ->
set s { cur_mod = Just (f m) }
updMain :: (CModule -> CModule) -> M ()
updMain f = M $
do s <- get
set s { main_mod = f (main_mod s) }
newFunDecl :: CFunDecl -> [CStmt] -> M ()
newFunDecl d body =
updMain $ \m -> m { fun_decls = (d, block body) : fun_decls m }
newDecl :: CDecl -> M ()
newDecl d = updMain $ \m -> m { var_decls = d : var_decls m }
cpp :: String -> M ()
cpp t = updMain $ \m -> m { cpp_stuff = cpp_stuff m $$ text t }
newLocalFunDecl :: CFunDecl -> [CStmt] -> M ()
newLocalFunDecl d body = updHelper $ \m ->
m { fun_decls = (d, block body) : fun_decls m }
newLocalDecl :: CDecl -> M ()
newLocalDecl d = updHelper $ \m ->
m { var_decls = static d : var_decls m }
cppFun :: String -> M ()
cppFun t = cppFun' (text t)
cppFun' :: Doc -> M ()
cppFun' t = updHelper $ \m ->
m { cpp_funs = cpp_funs m $$ t }
runM :: SamplerConf -> M a -> (a, S)
runM conf (M m) = runId $ runStateT start $ runReaderT info m
where start = S { main_mod = blankMod
, cur_mod = Nothing
, helper_mods = []
, cnames = maxNode + 1
}
info = R { config = conf, sampling = Nothing }
(maxNode,_) = IM.findMax $ stoNodes $ graph conf
renderMod :: CModule -> Doc
renderMod m =
cpp_stuff m $$
char ' ' $$ text "/* Variable declarations */" $$
decls (var_decls m) $$
char ' ' $$ text "/* Function types */" $$
decls (map fst (fun_decls m)) $$
char ' ' $$ text "/* Included templates */" $$
cpp_funs m $$
char ' ' $$ text "/* Function definitions */" $$
vcat [ d $$ b | (d,b) <- fun_decls m ]
where decls = vcat . map (\d -> d <> semi)
renderState :: SamplerConf -> S -> [(FilePath, Doc)]
renderState conf s0 = ("sampler.h", hdr)
: ("sampler.c", main)
: map helper (helper_mods s)
where
s = if split_files conf then s0 else noHelpers s0
mm = main_mod s
hdr = decls (map extern (var_decls mm))
main = renderMod mm {
var_decls = concatMap extern_helper (helper_mods s) ++
var_decls mm }
helper (i,m) =
( "slice_" ++ show i ++ ".c"
, renderMod $ m { cpp_stuff = text "#include \"passage.h\"" $$
text "#include \"sampler.h\"" }
)
extern_helper (h,m)
| special_slicers conf
= [ text ("extern double SLICE(" ++ show h ++ ")(double)")
, text ("extern double SLICE_TUNE(" ++ show h ++ ")(double)")
]
| otherwise = map (extern . fst) (fun_decls m)
decls = vcat . map (\d -> d <> semi)
call_slicer :: (NodeIdx, StoVar) -> M ([CStmt], CExpr,CExpr)
call_slicer x =
do special <- M (asks (special_slicers . config))
if special then call_special_slicer x else call_generic_slicer x
call_generic_slicer :: (NodeIdx, StoVar) -> M ([CStmt], CExpr,CExpr)
call_generic_slicer (ix,sv) =
do v <- cnameVar (ix,sv)
case priSupport (stoVarPrior sv) of
Real ->
do newDecl $ var_decl double wid
let slice = var (ident "slice_real")
tune = var (ident "tune_slice_real")
return
( [ assign (var wid) (double_lit 1) ]
, call slice [ llfun, var wid, v ]
, call tune [ llfun, addr_of wid, v ]
)
PosReal ->
do newDecl $ var_decl double wid
let z = int_lit 0
slice = var (ident "slice_pos_real")
tune = var (ident "tune_slice_pos_real")
return
( [ assign (var wid) (double_lit 1) ]
, call slice [ llfun, var wid, z, v ]
, call tune [ llfun, addr_of wid, z, v ]
)
Interval lo hi ->
do e1 <- term lo
e2 <- term hi
let slice = var (ident "slice_real_left_right")
expr = call slice [ llfun, e1, e2, v ]
return ([], expr, expr)
Discrete (Just t) ->
do e <- term t
let slice = var (ident "slice_discrete_right")
expr = call slice [ llfun, e, v ]
return ([], expr, expr)
Discrete Nothing ->
do let slice = var (ident "slice_discrete")
expr = call slice [ llfun, v ]
return ([], expr, expr)
where llfun = var $ ident $ "LL_FUN(" ++ show ix ++ ")"
wid = ident $ "WIDTH(" ++ show ix ++ ")"
call_special_slicer :: (NodeIdx, StoVar) -> M ([CStmt], CExpr,CExpr)
call_special_slicer (ix,sv) =
do v <- cnameVar (ix,sv)
cppFun ("#define VAR " ++ show ix)
let fun = var $ ident $ "SLICE(" ++ show ix ++ ")"
the_tune_fun = var $ ident $ "SLICE_TUNE(" ++ show ix ++ ")"
tune_fun <- case priSupport (stoVarPrior sv) of
Real ->
do cppFun "#include \"templates/slice.c\""
return the_tune_fun
PosReal ->
do cppFun "#define LEFT 0"
cppFun "#include \"templates/slice.c\""
cppFun "#undef LEFT"
return the_tune_fun
Interval lo hi ->
do e1 <- term lo
e2 <- term hi
cppFun' (text "#define LEFT " <+> parens e1)
cppFun' (text "#define RIGHT " <+> parens e2)
cppFun "#include \"templates/slice.c\""
cppFun "#undef LEFT"
cppFun "#undef RIGHT"
return fun
Discrete (Just t) ->
do e <- term t
cppFun' (text "#define RIGHT" <+> parens e)
cppFun "#include \"templates/finiteMetropolis.c\""
cppFun "#undef RIGHT"
return fun
Discrete Nothing ->
do cppFun "#include \"templates/metropolis_posreal.c\""
return fun
cppFun "#undef VAR"
return ( []
, call fun [v]
, call tune_fun [v]
)
initOrder :: [(NodeIdx, StoVar)] -> M [(NodeIdx, StoVar)]
initOrder ns =
do bg <- getGraph
return $ map check $ stronglyConnComp [ (n,ix,uses bg v) | n@(ix,v) <- ns ]
where
uses bg = IS.toList . fvsSupport (fvsArray bg) . priSupport . stoVarPrior
check (AcyclicSCC d) = d
check (CyclicSCC _) = error "Cannot initialize: recursive support!"
init_code :: (NodeIdx, StoVar) -> M CStmt
init_code (x,sv) =
do v <- cnameVar (x,sv)
i <- case priSupport (stoVarPrior sv) of
Real -> return $ double_lit 0
Discrete _ -> return $ double_lit 0
PosReal -> return $ double_lit 1
Interval lo hi -> term (lo + (hi lo) / 2)
return (assign v i)
init_code_initialized :: (NodeIdx, Double) -> M [CStmt]
init_code_initialized (x,d) =
do sv <- lookupVar x
case stoVarName sv of
InArray {} ->
do v <- cnameVar (x,sv)
return [assign v (double_lit d)]
_ -> return []
ll_summand :: (Term NodeIdx, Term NodeIdx)
-> M ([CStmt], Term NodeIdx, IS.IntSet)
ll_summand (x,c) =
do bg <- getGraph
if isSimpleTerm c
then return ([], x * c, varsOf bg (x * c))
else do c1 <- newDetVar
let c2 = simpleCName c1
newLocalDecl $ var_decl double c2
expr <- term c
return ( [assign (var c2) expr]
, x * tvar c1
, varsOf bg c `IS.union` varsOf bg x
)
where
varsOf bg = leavesOfTerm (fvsArray bg)
data StoVarCode =
StoVarCode
{ tuneCode :: [CStmt]
, sliceCode :: [CStmt]
, locality :: (Int,[Int])
}
sto_var :: (NodeIdx, StoVar) -> M ( [CStmt]
, (NodeIdx, StoVarCode, IS.IntSet)
)
sto_var (ix,sv) = newHelper ix $
do let xParam = simpleCName ix
iname = ident ("INIT_DET_VARS(" ++ show ix ++ ")")
llname = ident $ "LL_FUN(" ++ show ix ++ ")"
loc <- case stoVarName sv of
InArray aix ixes -> return (aix,ixes)
_ -> do newDecl $ var_decl double xParam
return (1,[0])
(is,ts,vs) <- unzip3 `fmap` mapM ll_summand (M.toList (stoPostDistLL sv))
expr <- nowSampling ix (term (sum ts))
init_dets <- case concat is of
[] -> return []
have_dets ->
do newLocalFunDecl (fun_decl void iname []) have_dets
return [ callS (var iname) [] ]
newLocalFunDecl (fun_decl double llname [(double,xParam)]) [ creturn expr ]
x <- cnameVar (ix,sv)
(initW, sliceExpr,sliceTuneExpr) <- call_slicer (ix,sv)
ic <- init_code (ix,sv)
return ( initW ++ [ic]
, ( ix
, StoVarCode
{ tuneCode = init_dets ++ [ assign x sliceTuneExpr ]
, sliceCode = init_dets ++ [ assign x sliceExpr ]
, locality = loc
}
, IS.unions vs
)
)
data SamplerConf = SamplerConf
{ graph :: BayesianGraph
, sampleNum :: Int
, itsPerSample :: Int
, warmup :: Int
, thread_num :: Int
, seed :: [Int]
, monitor :: [(String, Term NodeIdx)]
, observe :: IM.IntMap Double
, initialize :: IM.IntMap Double
, special_slicers :: Bool
, split_files :: Bool
}
declareArr :: (NodeIdx, ArrayInfo) -> M ()
declareArr (ix,i) =
newDecl $ array_decl double (arrName (ix,i)) (map size (arrayDimensions i))
where size (x,y) = y x + 1
arrName :: (NodeIdx,ArrayInfo) -> CIdent
arrName (x,_) = ident ("a_" ++ show x)
genParGroups :: Int -> (StoVarCode -> [CStmt]) -> [[StoVarCode]] -> [[CStmt]]
genParGroups cpus which = map concat . transpose . map threadBlocks
where
entries_per_thread len = (len + cpus 1) `div` cpus
threadBlocks vs = map makeBlock
$ addBlanks cpus
$ chunks (entries_per_thread len)
$ sortBy (compare `on` locality) vs
where len = length vs
makeBlock vs = pragma "omp barrier" : concatMap which vs
addBlanks n [] = replicate n []
addBlanks n (x : xs) = x : seq m (addBlanks m xs)
where m = n 1
chunks :: Int -> [a] -> [[a]]
chunks n xs = case splitAt n xs of
(as,bs) -> as : case bs of
[] -> []
_ -> chunks n bs
genThread :: [(CStmt,CStmt)] -> Int -> ([CStmt], [CStmt]) -> M [CStmt]
genThread monitor_code n (tune_code,sample_code) =
do newFunDecl (fun_decl void name []) $
[ var_decl unsigned_long (ident "i") <> semi
, var_decl unsigned_long (ident "j") <> semi
]
++
ifMaster ( map fst monitor_code ++
[ nl, toStdErr [ string_lit "Tuning width parameters.\n" ] ]
)
++
[ text ("for (i = 0; i < warm_up_steps; ++i)")
$$ nest 2 (block tune_code)
, barrier
]
++
ifMaster [ toStdErr [string_lit "Sampling.\n"] ]
++
[ text ("for (i = 0; i < number_of_samples; ++i)") $$
nest 2 (block $
ifMaster [ ppProg ]
++
[ text ("for (j = 0; j < steps_per_sample; ++j)")
$$ nest 2 (block sample_code)
, barrier
]
++
ifMaster (printRowLabel : map snd monitor_code ++ [ nl ])
)
]
return [ pragma "omp section"
, callS name []
]
where
name = ident ("thread_" ++ show n)
ifMaster xs = if n == 0 then xs else []
barrier = pragma "omp barrier"
toStdErr xs = callS (var (ident "fprintf")) (var (ident "stderr") : xs)
toStdOut xs = callS (var (ident "printf")) xs
ppProg = callS (ident "progress") [ var (ident "i") ]
printRowLabel = toStdOut [ string_lit "%lu", ident "i" ]
nl = toStdOut [ string_lit "\n" ]
gen_c :: SamplerConf -> [(FilePath, Doc)]
gen_c conf = renderState conf $ snd $ runM conf $
do let bg = graph conf
cpp "#include <math.h>"
cpp "#include <stdio.h>"
cpp "#include <omp.h>"
cpp "#include \"passage.h\""
mapM_ declareArr $ IM.toList $ stoArryas bg
let observedVars = IM.keysSet (observe conf)
sampledNodes = filter (not . (`IS.member` observedVars) . fst)
$ IM.toList $ stoNodes bg
(ins,deps) <- unzip `fmap` (mapM sto_var =<< initOrder sampledNodes)
let dropObserved (x,y,zs) =
(x,y, IS.filter (not . (`IS.member` observedVars)) zs)
par_groups = map (map snd) $ groupByColor $ map dropObserved deps
obs_ins <- mapM init_code_initialized $ IM.toList $ observe conf
init_ins <- mapM init_code_initialized $ IM.toList $ initialize conf
let cpus = thread_num conf
newFunDecl (fun_decl void (ident "set_defaults") []) $
[ assign (var (ident "number_of_samples")) $ int_lit $ sampleNum conf
, assign (var (ident "steps_per_sample")) $ int_lit $ itsPerSample conf
, assign (var (ident "warm_up_steps")) $ int_lit $ warmup conf
, assign (var (ident "num_threads")) $ int_lit cpus
, assign (var (ident "have_seed")) $ int_lit $ length $ seed conf
] ++ [ assign (arr_ix (var (ident "seeds")) (int_lit n)) (int_lit v)
| (n,v) <- zip [ 0 .. ] (reverse (seed conf)) ]
newFunDecl (fun_decl void (ident "init_vars") [])
$ concat $ obs_ins ++ ins
monitor_code <- forM (monitor conf) $ \(lab,x) ->
do expr <- term x
return ( callS (ident "printf") [ string_lit ("\t" ++ lab) ]
, callS (ident "printf") [ string_lit ("\t%f"), expr ]
)
let tune_codes = genParGroups cpus tuneCode par_groups
slice_codes = genParGroups cpus sliceCode par_groups
threads <- zipWithM (genThread monitor_code)
[ 0 .. ] (zip tune_codes slice_codes)
newFunDecl (fun_decl void (ident "sampler") []) $
[ pragma "omp sections"
, block (concat threads)
]