{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Language.Passage.Graph2C where

import Language.Passage.Utils hiding (double,int)
import Language.Passage.Term hiding (bin)
import Language.Passage.AST
import Language.Passage.Lang.C
import Language.Passage.GraphColor(groupByColor)


import qualified Data.Map as M
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import Data.Maybe(maybeToList,fromJust)
import Data.List(sortBy, transpose)
import Data.Function(on)
import MonadLib (ReaderT, StateT, Id
                , runId, runStateT, runReaderT
                , get, set, asks, mapReader
                , forM
                , zipWithM
                )

import Data.Graph(SCC(..))
import Data.Graph.SCC


--------------------------------------------------------------------------------
-- Compilation of expressions
--------------------------------------------------------------------------------

cnameVar :: (NodeIdx, StoVar) -> M CExpr
cnameVar (ix,sv) =
  case stoVarName sv of
    InArray x is ->
      do ai <- lookupArray x
         let fixIndex i (from,_) = int_lit (i - from)
             name = arrName (x,ai)
         return $ foldl arr_ix (var name)
                $ zipWith fixIndex is $ arrayDimensions ai
    _ -> return $ var $ simpleCName ix


simpleCName :: NodeIdx -> CIdent
simpleCName x = ident ("v_" ++ show x)

variable :: NodeIdx -> M CExpr
variable x =

  -- Is this an observed variable?
  do isObs <- isObserved x
     case isObs of
       Just v -> return (double_lit v)      -- Yes, just use the known value.
       Nothing ->

         {- Are we compiling within the LL_FUN for this variable?
            When we generate the LL_FUN for a stoachastic variable,
            we always use it's simple name, even if the variable is stored
            in an array in the long run.   The reason for this is that in
            the LL_FUN, the variable is passed as an argument. -}
         do samp <- isSampled x
            if samp
              then return $ var $ simpleCName x
              else

                -- OK, perhaps we have an ordinary stochastic variable?
                do mbsv <- lookupVarMb x
                   case mbsv of
                     Just sv -> cnameVar (x,sv)

                     -- Hmm, we don't know about this variable.
                     -- The variable must refer to a deterministic node
                     -- generated to factor repeated compution out
                     -- of an LL_FUN.
                     Nothing -> return $ var $ simpleCName x

term :: Term NodeIdx -> M CExpr
term t =
  case t of
    TVar x -> variable x
    TArr x -> do ai <- lookupArray x
                 return $ var $ arrName (x,ai)
    TConst x -> return (double_lit x)
    TApp op ts ->
      do ds@(a : bs) <- mapM term ts
         let b : _ = bs
             bin x = parens a <+> text x <+> parens b
         case op of
           TExp      -> return $ call (ident "exp") ds
           TLog      -> return $ call (ident "log") ds
           TNeg      -> return $ char '-' <> parens (head ds)
           TAdd      -> return $ bin "+"
           TMul      -> return $ bin "*"
           TSub      -> return $ bin "-"
           TDiv      -> return $ bin "/"
           TPow      -> case ts of
             [_ , TConst 2.0] -> (return $ call (ident "square") [a])
             _                -> return $ call (ident "pow") ds
           TLogGamma -> return $ call (ident "lgamma") ds

           TCase ->
             do i <- newDetVar -- Just used as a new name
                let name = ident ("case_fun_" ++ show i)

                -- if we are in the LL function for some variable,
                -- we have to pass the sampled variable to the "case" function.
                args <- (map simpleCName . maybeToList) `fmap` isSampling

                newLocalFunDecl
                  (fun_decl double name [ (double,x) | x <- args ]) [
                    switch (cast int a)
                      (zip [ 0 .. ] (map (return . creturn) bs))
                      (callS (var (ident "crash_out_of_bounds"))
                                                [ text "__LINE__" ])
                  ]
                return $ call (var name) (map var args)

           TIx       ->
             case ts of
               [ arr, ix ] ->
                  do dims <- getArrDimensions arr
                     case dims of
                       (from,_) : _ ->
                          do expr <- term (ix - fromIntegral from)
                             return (arr_ix a (cast int expr))
                       _ -> error $ "Type error: attempt ro index non an array."
               _ -> error $ "TIx: Unexpected args: " ++ show ts

getArrDimensions :: Term NodeIdx -> M [(Int,Int)]
getArrDimensions t =
  case t of
    TArr x -> arrayDimensions `fmap` lookupArray x
    TApp TIx [ a, _ ] ->
      do ds <- getArrDimensions a
         case ds of
           _ : ds1 -> return ds1
           [] -> error $ "Type error: attempt to index a non-array."
    _ -> error $ "Type error: not an array"

--------------------------------------------------------------------------------

newtype M a = M (ReaderT R (StateT S Id) a) deriving (Functor, Monad)

data R  = R { config   :: SamplerConf
            , sampling :: Maybe NodeIdx
            }

data CModule =
  CModule
    { cpp_stuff :: Doc          -- ^ Includes, #define, etc.
    , var_decls :: [Doc]        -- ^ Variable declarations
    , cpp_funs  :: Doc          -- ^ #included templtes
    , fun_decls :: [(Doc,Doc)]  -- ^ Function declarations: decl, body
    }

blankMod :: CModule
blankMod = CModule { cpp_stuff = empty
                   , var_decls = []
                   , cpp_funs = empty
                   , fun_decls = []
                   }

-- XXX: Watch out with the ++ing here...
mergeCModules :: CModule -> CModule -> CModule
mergeCModules m1 m2 =
  CModule { cpp_stuff = cpp_stuff m1 $$ cpp_stuff m2
          , var_decls = var_decls m1 ++ var_decls m2
          , cpp_funs  = cpp_funs m1 $$ cpp_funs m2
          , fun_decls = fun_decls m1 ++ fun_decls m2
          }



data S  = S { main_mod    :: CModule
            , cur_mod     :: Maybe CModule
            , helper_mods :: [(NodeIdx, CModule)]
            , cnames      :: !Int         -- ^ Name supply
            }

noHelpers :: S -> S
noHelpers s = s { main_mod = foldr mergeCModules (main_mod s)
                           $ map snd (helper_mods s)
                , helper_mods = []
                }

getGraph :: M BayesianGraph
getGraph = M (asks (graph . config))

lookupArray :: NodeIdx -> M ArrayInfo
lookupArray x =
  do mb <- lookupArrayMb x
     case mb of
       Just a  -> return a
       Nothing -> error ("Unknown array variable: " ++ show x)


lookupArrayMb :: NodeIdx -> M (Maybe ArrayInfo)
lookupArrayMb x =
  do g <- getGraph
     return (IM.lookup x (stoArryas g))

lookupVarMb :: NodeIdx -> M (Maybe StoVar)
lookupVarMb x =
  do g <- getGraph
     return (IM.lookup x (stoNodes g))

lookupVar :: NodeIdx -> M StoVar
lookupVar x =
  do mb <- lookupVarMb x
     case mb of
       Just sv -> return sv
       Nothing -> error ("Unknown stochastic variable: " ++ show x)

nowSampling :: NodeIdx -> M a -> M a
nowSampling x (M a) = M (mapReader (\i -> i { sampling = Just x }) a)

isSampling :: M (Maybe NodeIdx)
isSampling = M (asks sampling)

isSampled :: NodeIdx -> M Bool
isSampled x = (Just x ==) `fmap` isSampling

isObserved :: NodeIdx -> M (Maybe Double)
isObserved ix = IM.lookup ix `fmap` M (asks (observe . config))

isInitialized :: NodeIdx -> M (Maybe Double)
isInitialized ix = IM.lookup ix `fmap` M (asks (initialize . config))

newDetVar :: M NodeIdx
newDetVar = M $
  do s <- get
     let i = cnames s
     set s { cnames = i + 1 }
     return i

newHelper :: NodeIdx -> M a -> M a
newHelper i (M m) = M $
  do s <- get
     set s { cur_mod = Just blankMod }
     a <- m
     s1 <- get
     set s1 { cur_mod = Nothing
            , helper_mods = (i, fromJust (cur_mod s1)) : helper_mods s1
            }
     return a

updHelper :: (CModule -> CModule) -> M ()
updHelper f = M $
  do s <- get
     case cur_mod s of
       Nothing -> error "BUG: updHelper called without a module"
       Just m ->
         set s { cur_mod = Just (f m) }

updMain :: (CModule -> CModule) -> M ()
updMain f = M $
  do s <- get
     set s { main_mod = f (main_mod s) }


-- add a new function to the main module.
newFunDecl :: CFunDecl -> [CStmt] -> M ()
newFunDecl d body =
  updMain $ \m -> m { fun_decls = (d, block body) : fun_decls m }

-- add a new declaration to the main module.
newDecl :: CDecl -> M ()
newDecl d = updMain $ \m -> m { var_decls = d : var_decls m }


-- Add "cpp" includes to the main module
cpp :: String -> M ()
cpp t = updMain $ \m -> m { cpp_stuff = cpp_stuff m $$ text t }




-- add a new function to the current helper module
newLocalFunDecl :: CFunDecl -> [CStmt] -> M ()
newLocalFunDecl d body = updHelper $ \m ->
  m { fun_decls = (d, block body) : fun_decls m }

-- add a new static variable to the current helper module
newLocalDecl :: CDecl -> M ()
newLocalDecl d = updHelper $ \m ->
  m { var_decls = static d : var_decls m }


cppFun :: String -> M ()
cppFun t = cppFun' (text t)

-- Add "#include function" to the current helper module
cppFun' :: Doc -> M ()
cppFun' t = updHelper $ \m ->
  m { cpp_funs = cpp_funs m $$ t }



runM :: SamplerConf -> M a -> (a, S)
runM conf (M m) = runId $ runStateT start $ runReaderT info m
  where start = S { main_mod    = blankMod
                  , cur_mod     = Nothing
                  , helper_mods = []
                  , cnames      = maxNode + 1
                  }
        info  = R { config = conf, sampling = Nothing }
        (maxNode,_) = IM.findMax $ stoNodes $ graph conf


renderMod :: CModule -> Doc
renderMod m =
  cpp_stuff m $$
  char ' '    $$ text "/* Variable declarations */" $$
  decls (var_decls m) $$
  char ' '    $$ text "/* Function types */" $$
  decls (map fst (fun_decls m)) $$
  char ' '    $$ text "/* Included templates */" $$
  cpp_funs m $$
  char ' '    $$ text "/* Function definitions */" $$
  vcat [ d $$ b | (d,b) <- fun_decls m ]

  where decls = vcat . map (\d -> d <> semi)


renderState :: SamplerConf -> S -> [(FilePath, Doc)]
renderState conf s0 = ("sampler.h", hdr)
                    : ("sampler.c", main)
                    : map helper (helper_mods s)
  where
  s       = if split_files conf then s0 else noHelpers s0
  mm      = main_mod s
  hdr     = decls (map extern (var_decls mm))
  main    = renderMod mm {
              var_decls = concatMap extern_helper (helper_mods s) ++

                                                              var_decls mm }
  helper (i,m) =
    ( "slice_" ++ show i ++ ".c"
    , renderMod $ m { cpp_stuff = text "#include \"passage.h\"" $$
                                  text "#include \"sampler.h\"" }
    )

  extern_helper (h,m)
    | special_slicers conf
       = [ text ("extern double SLICE(" ++ show h ++ ")(double)")
         , text ("extern double SLICE_TUNE(" ++ show h ++ ")(double)")
         ]
    | otherwise = map (extern . fst) (fun_decls m)

  decls   = vcat . map (\d -> d <> semi)



--------------------------------------------------------------------------------

call_slicer :: (NodeIdx, StoVar) -> M ([CStmt], CExpr,CExpr)
call_slicer x =
  do special <- M (asks (special_slicers . config))
     if special then call_special_slicer x else call_generic_slicer x



call_generic_slicer :: (NodeIdx, StoVar) -> M ([CStmt], CExpr,CExpr)
call_generic_slicer (ix,sv) =
  do v <- cnameVar (ix,sv)
     case priSupport (stoVarPrior sv) of
       Real ->
         do newDecl $ var_decl double wid
            let slice = var (ident "slice_real")
                tune  = var (ident "tune_slice_real")
            return
              ( [ assign (var wid) (double_lit 1) ]
              , call slice [ llfun, var wid,     v ]
              , call tune  [ llfun, addr_of wid, v ]
              )

       PosReal ->
         do newDecl $ var_decl double wid
            let z = int_lit 0
                slice = var (ident "slice_pos_real")
                tune  = var (ident "tune_slice_pos_real")
            return
              ( [ assign (var wid) (double_lit 1) ]
              , call slice [ llfun, var wid,     z, v ]
              , call tune  [ llfun, addr_of wid, z, v ]
              )

       Interval lo hi ->
         do e1 <- term lo
            e2 <- term hi
            let slice = var (ident "slice_real_left_right")
                expr = call slice [ llfun, e1, e2, v ]
            return ([], expr, expr)

       Discrete (Just t) ->
         do e <- term t
            let slice = var (ident "slice_discrete_right")
                expr = call slice [ llfun, e, v ]
            return ([], expr, expr)

       Discrete Nothing ->
         do let slice = var (ident "slice_discrete")
                expr  = call slice [ llfun, v ]
            return ([], expr, expr)

  where llfun = var $ ident $ "LL_FUN(" ++ show ix ++ ")"
        wid   = ident $ "WIDTH(" ++ show ix ++ ")"


call_special_slicer ::  (NodeIdx, StoVar) -> M ([CStmt], CExpr,CExpr)
call_special_slicer (ix,sv) =
  do v <- cnameVar (ix,sv)
     cppFun ("#define VAR " ++ show ix)

     let fun      = var $ ident $ "SLICE(" ++ show ix ++ ")"
         the_tune_fun = var $ ident $ "SLICE_TUNE(" ++ show ix ++ ")"

     tune_fun <- case priSupport (stoVarPrior sv) of
       Real ->
         do cppFun "#include \"templates/slice.c\""
            return the_tune_fun

       PosReal ->
         do cppFun "#define LEFT 0"
            cppFun "#include \"templates/slice.c\""
            cppFun "#undef LEFT"
            return the_tune_fun

       Interval lo hi ->
         do e1 <- term lo
            e2 <- term hi
            cppFun' (text "#define LEFT " <+> parens e1)
            cppFun' (text "#define RIGHT " <+> parens e2)
            cppFun "#include \"templates/slice.c\""
            cppFun "#undef LEFT"
            cppFun "#undef RIGHT"
            return fun

       Discrete (Just t) ->
         do e <- term t
            cppFun' (text "#define RIGHT" <+> parens e)
            cppFun "#include \"templates/finiteMetropolis.c\""
            cppFun "#undef RIGHT"
            return fun

       Discrete Nothing ->
         do cppFun "#include \"templates/metropolis_posreal.c\""
            return fun

     cppFun "#undef VAR"
     return ( []
            , call fun [v]
            , call tune_fun [v]
            )



initOrder :: [(NodeIdx, StoVar)] -> M [(NodeIdx, StoVar)]
initOrder ns =
  do bg <- getGraph
     return $ map check $ stronglyConnComp [ (n,ix,uses bg v) | n@(ix,v) <- ns ]
  where
  uses bg = IS.toList . fvsSupport (fvsArray bg) . priSupport . stoVarPrior

  check (AcyclicSCC d)  = d
  check (CyclicSCC _)   = error "Cannot initialize: recursive support!"

init_code :: (NodeIdx, StoVar) -> M CStmt
init_code (x,sv) =
  do v <- cnameVar (x,sv)
     i <- case priSupport (stoVarPrior sv) of
           Real       -> return $ double_lit 0
           Discrete _ -> return $ double_lit 0
           PosReal    -> return $ double_lit 1
           Interval lo hi -> term (lo + (hi - lo) / 2)
            -- duplicates lo but, hopefully, this does not matter too much

     return (assign v i)

{- If an observed variable is stored in an array, then we need to initialize
   the corresponding entry in the array with the observed value.  The reason
   for this is that there may be expressions of the form: a[i], with "a"
   begin an observed array, and "i" which is not statically known.
-}
init_code_initialized :: (NodeIdx, Double) -> M [CStmt]
init_code_initialized (x,d) =
  do sv <- lookupVar x
     case stoVarName sv of
       InArray {} ->
         do v <- cnameVar (x,sv)
            return [assign v (double_lit d)]
       _ -> return []

--------------------------------------------------------------------------------

ll_summand :: (Term NodeIdx, Term NodeIdx)
           -> M ([CStmt], Term NodeIdx, IS.IntSet)
ll_summand (x,c) =
  do bg <- getGraph
     if isSimpleTerm c
       then return ([], x * c, varsOf bg (x * c))
       else do c1 <- newDetVar
               let c2 = simpleCName c1
               newLocalDecl $ var_decl double c2
               expr <- term c
               return ( [assign (var c2) expr]
                      , x * tvar c1
                      , varsOf bg c `IS.union` varsOf bg x
                      )
  where
  varsOf bg = leavesOfTerm (fvsArray bg)

data StoVarCode =
  StoVarCode
    { tuneCode  :: [CStmt]
    , sliceCode :: [CStmt]
    , locality  :: (Int,[Int])  -- array number, indexes.
                                -- clobals would be a 1-dim array
                                -- call "-1".  (XXX)
    }

sto_var :: (NodeIdx, StoVar) -> M ( [CStmt] -- init code
                                  , (NodeIdx, StoVarCode, IS.IntSet)
                                  )
sto_var (ix,sv) = newHelper ix $
  do let xParam = simpleCName ix
         iname  = ident ("INIT_DET_VARS(" ++ show ix ++ ")")
         llname = ident $ "LL_FUN(" ++ show ix ++ ")"

     -- Here we compute a "locality" for the variable.
     -- This is useful when we group work by thread because
     -- we prefer to put close updates together.
     loc <- case stoVarName sv of
              InArray aix ixes -> return (aix,ixes)
              _ -> do newDecl $ var_decl double xParam
                      return (-1,[0])   -- XXX: count which vars are close

     (is,ts,vs) <- unzip3 `fmap` mapM ll_summand (M.toList (stoPostDistLL sv))
     expr <- nowSampling ix (term (sum ts))
     init_dets <- case concat is of
                    [] -> return []
                    have_dets ->
                      do newLocalFunDecl (fun_decl void iname []) have_dets
                         return [ callS (var iname) [] ]
     newLocalFunDecl (fun_decl double llname [(double,xParam)]) [ creturn expr ]
     x    <- cnameVar (ix,sv)
     (initW, sliceExpr,sliceTuneExpr) <- call_slicer (ix,sv)
     ic   <- init_code (ix,sv)
     return ( initW ++ [ic]  -- initializaztion code
            , ( ix
              , StoVarCode
                  { tuneCode = init_dets ++  [ assign x sliceTuneExpr ]
                  , sliceCode = init_dets ++ [ assign x sliceExpr ]
                  , locality  = loc
                  }
              , IS.unions vs                            -- sto. var. deps.
              )
            )


--------------------------------------------------------------------------------

data SamplerConf = SamplerConf
  { graph         :: BayesianGraph
  , sampleNum     :: Int
  , itsPerSample  :: Int
  , warmup        :: Int
  , thread_num    :: Int
  , seed          :: [Int]
  , monitor       :: [(String, Term NodeIdx)]
  , observe       :: IM.IntMap Double  -- Map node indexes to observed values.
  , initialize    :: IM.IntMap Double  -- Map node indexes to initialized values.
  , special_slicers :: Bool     -- Generate a custom slicer per variable?
  , split_files     :: Bool     -- Make one file per variable?
  }

declareArr :: (NodeIdx, ArrayInfo) -> M ()
declareArr (ix,i) =
  newDecl $ array_decl double (arrName (ix,i)) (map size (arrayDimensions i))
    where size (x,y) = y - x + 1

arrName :: (NodeIdx,ArrayInfo) -> CIdent
arrName (x,_) = ident ("a_" ++ show x)

{-
genParGroups :: Int -> (StoVarCode -> [CStmt]) -> [[StoVarCode]] -> [CStmt]
genParGroups cpus which xs = concatMap makeSections xs
  where
  entries_per_thread len  = (len + cpus - 1) `div` cpus

  makeSections vs = [ pragma "omp sections"
                    , block $ concatMap makeSection
                            $ chunks (entries_per_thread len)
                            $ sortBy (compare `on` locality) vs
                    ]
    where len = length vs

  makeSection vs  = [ pragma "omp section"
                    , block (concatMap which vs)
                    ]

-}

genParGroups :: Int -> (StoVarCode -> [CStmt]) -> [[StoVarCode]] -> [[CStmt]]
genParGroups cpus which = map concat . transpose . map threadBlocks

  where
  entries_per_thread len  = (len + cpus - 1) `div` cpus

  -- Allocate a list of indipendent statements to different threads.
  -- Each blocks start with a barrier
  threadBlocks vs = map makeBlock
                  $ addBlanks cpus
                  $ chunks (entries_per_thread len)
                  $ sortBy (compare `on` locality) vs
    where len = length vs

  makeBlock vs  = pragma "omp barrier" : concatMap which vs

  -- If there is not enough work for all threads, we insert
  -- an empty list, so that we still get a barrier, otherwise
  -- things get our of sync.
  addBlanks n []        = replicate n []
  addBlanks n (x : xs)  = x : seq m (addBlanks m xs)
    where m = n - 1



-- Split a list into chunks of the given length.
-- If we run out of elements we make empty lists.
chunks :: Int -> [a] -> [[a]]
chunks n xs = case splitAt n xs of
                (as,bs) -> as : case bs of
                                  [] -> []
                                  _  -> chunks n bs



genThread :: [(CStmt,CStmt)] -> Int -> ([CStmt], [CStmt]) -> M [CStmt]
genThread monitor_code n (tune_code,sample_code) =
  do newFunDecl (fun_decl void name []) $
       [ var_decl unsigned_long (ident "i") <> semi
       , var_decl unsigned_long (ident "j") <> semi
       ]

       ++
       ifMaster ( map fst monitor_code ++
                  [ nl, toStdErr [ string_lit "Tuning width parameters.\n" ] ]
                )

       ++
       [ text ("for (i = 0; i < warm_up_steps; ++i)")
           $$ nest 2 (block tune_code)
       , barrier
       ]

       ++
       ifMaster [ toStdErr [string_lit "Sampling.\n"] ]

       ++
       [ text ("for (i = 0; i < number_of_samples; ++i)") $$
           nest 2 (block $
               ifMaster [ ppProg ]

               ++
               [ text ("for (j = 0; j < steps_per_sample; ++j)")
                   $$ nest 2 (block sample_code)
               , barrier
               ]

               ++
               ifMaster (printRowLabel : map snd monitor_code ++ [ nl ])
               )
       ]
     return [ pragma "omp section"
            , callS name []
            ]

  where
  name        = ident ("thread_" ++ show n)
  ifMaster xs = if n == 0 then xs else []
  barrier     = pragma "omp barrier"
  toStdErr xs = callS (var (ident "fprintf")) (var (ident "stderr") : xs)
  toStdOut xs = callS (var (ident "printf")) xs
  ppProg      = callS (ident "progress") [ var (ident "i") ]

  printRowLabel         = toStdOut [ string_lit "%lu", ident "i" ]
  nl                    = toStdOut [ string_lit "\n" ]







gen_c :: SamplerConf -> [(FilePath, Doc)]
gen_c conf = renderState conf $ snd $ runM conf $
  do let bg = graph conf
     cpp "#include <math.h>"
     cpp "#include <stdio.h>"
     cpp "#include <omp.h>"
     cpp "#include \"passage.h\""

     mapM_ declareArr $ IM.toList $ stoArryas bg

     -- We generate sampling code only for stochastic variables that
     -- are not observed:
     let observedVars = IM.keysSet (observe conf)
         sampledNodes = filter (not . (`IS.member` observedVars) . fst)
                      $ IM.toList $ stoNodes bg
     (ins,deps) <- unzip `fmap` (mapM sto_var =<< initOrder sampledNodes)
     let dropObserved (x,y,zs) =
            (x,y, IS.filter (not . (`IS.member` observedVars)) zs)
         par_groups = map (map snd) $ groupByColor $ map dropObserved deps

     -- Observed stochastic variables just get initialized once.
     -- Variables that are not in an array don't even need to be initialized
     -- but it is important the we initialize arrays, because of expressions
     -- a[i], where "a" is observed but "i" is not.
     obs_ins  <- mapM init_code_initialized $ IM.toList $ observe conf

     init_ins <- mapM init_code_initialized $ IM.toList $ initialize conf

     let cpus = thread_num conf

     newFunDecl (fun_decl void (ident "set_defaults") []) $
        [ assign (var (ident "number_of_samples")) $ int_lit $ sampleNum conf
        , assign (var (ident "steps_per_sample"))  $ int_lit $ itsPerSample conf
        , assign (var (ident "warm_up_steps"))     $ int_lit $ warmup conf
        , assign (var (ident "num_threads"))       $ int_lit cpus
        , assign (var (ident "have_seed"))        $ int_lit $ length $ seed conf
        ] ++ [ assign (arr_ix (var (ident "seeds")) (int_lit n)) (int_lit v)
                  | (n,v) <- zip [ 0 .. ] (reverse (seed conf)) ]

     newFunDecl (fun_decl void (ident "init_vars") [])
        $ concat $ obs_ins ++ ins

     -- Code to print the values of monitored expressions.
     monitor_code <- forM (monitor conf) $ \(lab,x) ->
        do expr <- term x
           return ( callS (ident "printf") [ string_lit ("\t" ++ lab) ]
                  , callS (ident "printf") [ string_lit ("\t%f"), expr ]
                  )

     let tune_codes  = genParGroups cpus tuneCode  par_groups
         slice_codes = genParGroups cpus sliceCode par_groups

     threads <- zipWithM (genThread monitor_code)
                               [ 0 .. ] (zip tune_codes slice_codes)


     -- The main sampling function.
     newFunDecl (fun_decl void (ident "sampler") []) $
        [ pragma "omp sections"
        , block (concat threads)
        ]