{-# LANGUAGE CPP #-} -- #define DEBUG {-| Module : AERN2.QA.Strategy.Parallel Description : QA net parallel evaluation Copyright : (c) Michal Konecny License : BSD3 Maintainer : mikkonecny@gmail.com Stability : experimental Portability : portable QA net parallel evaluation -} module AERN2.QA.Strategy.Parallel ( QAParA , executeQAParA, executeQAParAwithLog ) where #ifdef DEBUG import Debug.Trace (trace) #define maybeTrace trace #define maybeTraceIO putStrLn #else #define maybeTrace (\ (_ :: String) t -> t) #define maybeTraceIO (\ (_ :: String)-> return ()) #endif import MixedTypesNumPrelude -- import qualified Prelude as P import Text.Printf import Control.Arrow import qualified Data.IntMap as IntMap import Control.Concurrent import Control.Concurrent.STM import Control.Monad.IO.Class import AERN2.QA.Protocol import AERN2.QA.NetLog data QANetState = QANetState { net_nextId :: ValueId , net_log :: QANetLog } initQANetState :: QANetState initQANetState = QANetState { net_nextId = ValueId 1 , net_log = [] } getValueId :: QANetState -> [ValueId] -> String -> (QANetState, ValueId) getValueId ns sources name = (ns2, vId) where ns2 = ns { net_nextId = succ vId , net_log = (net_log ns) ++ [logItem] } vId = net_nextId ns logItem = QANetLogCreate vId sources name logQuery :: QANetState -> Maybe ValueId -> ValueId -> String -> QANetState logQuery ns src valueId qS = ns { net_log = (net_log ns) ++ [logItem] } where logItem = QANetLogQuery src valueId qS logAnswer :: QANetState -> Maybe ValueId -> ValueId -> (String, String) -> QANetState logAnswer ns src valueId (aS, usedCacheS) = ns { net_log = (net_log ns) ++ [logItem] } where logItem = QANetLogAnswer src valueId usedCacheS aS type QAParA = Kleisli QAParM data QAParM a = QAParM { unQAParM :: Maybe (TVar QANetState) -> IO a } instance Functor QAParM where -- fmap f (QAParM tv2ma) = QAParM (\nsTV -> fmap f (tv2ma nsTV)) fmap f (QAParM tv2ma) = QAParM (fmap (fmap f) tv2ma) instance Applicative QAParM where pure a = QAParM (pure . pure a) (QAParM tv2f) <*> (QAParM tv2a) = QAParM (\nsTV -> (tv2f nsTV <*> tv2a nsTV)) instance Monad QAParM where (QAParM tv2ma) >>= f = QAParM $ \nsTV -> tv2ma nsTV >>= ($ nsTV) . unQAParM . f instance MonadIO QAParM where liftIO = QAParM . const instance QAArrow QAParA where type QAId QAParA = ValueId qaRegister options = Kleisli qaRegisterM where isParallel = not (QARegPreferSerial `elem` options) qaRegisterM qa@(QA__ name Nothing sourceIds (p :: p) sampleQ _) = QAParM $ \m_nsTV -> do vId <- case m_nsTV of Nothing -> pure (ValueId 0) Just nsTV -> atomically $ do ns <- readTVar nsTV let (ns2,i) = getValueId ns sourceIds name writeTVar nsTV ns2 pure i activeQsTV <- atomically $ newTVar initActiveQs cacheTV <- atomically $ newTVar $ newQACache p return $ QA__ name (Just vId) [] p sampleQ (\me_src -> Kleisli $ makeQPar vId activeQsTV cacheTV me_src) where initActiveQs = IntMap.empty :: IntMap.IntMap (Q p) nextActiveQId activeQs | IntMap.null activeQs = int 1 | otherwise = int $ 1 + (fst $ IntMap.findMax activeQs) makeQPar vId activeQsTV cacheTV (_, src) q = QAParM $ \m_nsTV -> do maybeTraceIO $ printf "[%s]: q = %s" name (show q) case m_nsTV of Nothing -> pure () Just nsTV -> atomically $ do -- log query ns <- readTVar nsTV writeTVar nsTV $ logQuery ns src vId (show q) -- consult the cache and index of active queries in an atomic transaction: (maybeAnswer, mLogMsg, maybeComputeId) <- atomically $ do cache <- readTVar cacheTV case lookupQACache p cache q of (Just a, mLogMsg) -> return (Just a, mLogMsg, Nothing) (_, mLogMsg) -> do activeQs <- readTVar activeQsTV let alreadyActive = or $ map (!>=! q) $ IntMap.elems activeQs if alreadyActive then return (Nothing, mLogMsg, Nothing) else do let computeId = nextActiveQId activeQs writeTVar activeQsTV $ IntMap.insert computeId q activeQs return (Nothing, mLogMsg, Just computeId) -- act based on the cache and actity consultation: case (maybeAnswer, maybeComputeId) of (Just a, _) -> -- got cached answer, just return it: pure $ promise mLogMsg (pure a) (_, Just computeId) -> -- no cached answer, no pending computation: do _ <- forkComputation m_nsTV computeId -- start a new computation pure $ promise mLogMsg waitForAnwer -- and wait for the answer _ -> -- no cached answer but there is a pending computation: pure $ promise mLogMsg waitForAnwer -- wait for a pending computation where promise mLogMsg answerIO = Kleisli $ const $ QAParM $ \m_nsTV -> case m_nsTV of Nothing -> answerIO Just nsTV -> do a <- answerIO atomically $ do ns <- readTVar nsTV writeTVar nsTV $ logAnswer ns src vId (show a, logMsg) pure a where logMsg = case mLogMsg of Just m -> m; _ -> "" waitForAnwer = atomically $ do cache <- readTVar cacheTV case lookupQACache p cache q of (Just a, _mLogMsg) -> return a (_, _mLogMsg) -> retry forkComputation nsTV computeId | isParallel = do { _ <- forkIO computation; return () } | otherwise = do { computation; return () } where computation = do -- compute an answer: a <- (unQAParM $ runKleisli (qaMakeQuery qa src) q) nsTV -- update the cache with this answer: atomically $ modifyTVar cacheTV (updateQACache p q a) -- remove computeId from active queries: atomically $ modifyTVar activeQsTV (IntMap.delete computeId) qaRegisterM _ = error "internal error in AERN2.QA.Strategy.Par: qaRegister called with an existing id" qaFulfilPromiseA = Kleisli qaFulfilPromiseM where qaFulfilPromiseM promiseA = runKleisli promiseA () qaMakeQueryGetPromiseA src = Kleisli qaMakeQueryGetPromiseM where qaMakeQueryGetPromiseM (qa, q) = runKleisli (qaMakeQueryGetPromise qa (me, src)) q where me = case qaId qa of Nothing -> src; me2 -> me2 executeQAParA :: (QAParA () a) -> IO a executeQAParA code = do (unQAParM $ runKleisli code ()) Nothing executeQAParAwithLog :: (QAParA () a) -> IO (QANetLog, a) executeQAParAwithLog code = do nsTV <- atomically $ newTVar initQANetState result <- (unQAParM $ runKleisli code ()) (Just nsTV) ns <- atomically $ readTVar nsTV return (net_log ns, result)