{-# LANGUAGE TypeFamilies, FlexibleInstances, FlexibleContexts, ExistentialQuantification, GADTs, CPP #-} module BayesStack.Core.Gibbs ( UpdateUnit(..) , WrappedUpdateUnit(..) , gibbsUpdate ) where import Control.Monad (replicateM_, when, forever) import Control.Concurrent import Control.Concurrent.STM import GHC.Conc.Sync (labelThread) import Data.IORef import Control.DeepSeq import Data.Random import Data.Random.Lift import System.Random.MWC (withSystemRandom) import Control.Monad.State hiding (lift) class UpdateUnit uu where type ModelState uu type Setting uu fetchSetting :: uu -> ModelState uu -> Setting uu evolveSetting :: ModelState uu -> uu -> RVar (Setting uu) updateSetting :: uu -> Setting uu -> Setting uu -> ModelState uu -> ModelState uu data WrappedUpdateUnit ms = forall uu. (UpdateUnit uu, ModelState uu ~ ms, NFData (Setting uu), Eq (Setting uu)) => WrappedUU uu updateUnit :: WrappedUpdateUnit ms -> IORef ms -> TBQueue (ms -> ms) -> RVarT IO () updateUnit (WrappedUU unit) stateRef diffQueue = do modelState <- lift $ readIORef stateRef let s = fetchSetting unit modelState s' <- lift $ evolveSetting modelState unit (s,s') `deepseq` return () when (s /= s') $ lift $ atomically $ writeTBQueue diffQueue (updateSetting unit s s') updateWorker :: TQueue (WrappedUpdateUnit ms) -> IORef ms -> TBQueue (ms -> ms) -> RVarT IO () updateWorker unitsQueue stateRef diffQueue = do unit <- lift $ atomically $ tryReadTQueue unitsQueue case unit of Just unit' -> do updateUnit unit' stateRef diffQueue updateWorker unitsQueue stateRef diffQueue Nothing -> return () #if __GLASGOW_HASKELL__ < 706 atomicModifyIORef' = atomicModifyIORef #endif diffWorker :: IORef ms -> TBQueue (ms -> ms) -> Int -> IO () diffWorker stateRef diffQueue updateBlock = forever $ do diff <- execStateT (replicateM_ updateBlock $ do diff <- lift $ atomically $ readTBQueue diffQueue modify (. diff) ) id atomicModifyIORef' stateRef $ \a->(diff a, ()) labelMyThread :: String -> IO () labelMyThread label = myThreadId >>= \id->labelThread id label gibbsUpdate :: Int -> ms -> [WrappedUpdateUnit ms] -> IO ms gibbsUpdate updateBlock modelState units = do n <- getNumCapabilities unitsQueue <- atomically $ do q <- newTQueue mapM_ (writeTQueue q) units return q diffQueue <- atomically $ newTBQueue $ 2*updateBlock -- FIXME stateRef <- newIORef modelState diffThread <- forkIO $ do labelMyThread "diff worker" diffWorker stateRef diffQueue updateBlock runningWorkers <- atomically $ newTVar (0 :: Int) done <- atomically $ newEmptyTMVar :: IO (TMVar ()) replicateM_ n $ forkIO $ withSystemRandom $ \mwc->do labelMyThread "update worker" atomically $ modifyTVar' runningWorkers (+1) runRVarT (updateWorker unitsQueue stateRef diffQueue) mwc atomically $ do modifyTVar' runningWorkers (+(-1)) running <- readTVar runningWorkers when (running == 0) $ putTMVar done () atomically $ takeTMVar done readIORef stateRef