module AERN2.QA.Strategy.Parallel
(
QAParA
, executeQAParA, executeQAParAwithLog
)
where
#ifdef DEBUG
import Debug.Trace (trace)
#define maybeTrace trace
#define maybeTraceIO putStrLn
#else
#define maybeTrace (\ (_ :: String) t -> t)
#define maybeTraceIO (\ (_ :: String)-> return ())
#endif
import MixedTypesNumPrelude
import Text.Printf
import Control.Arrow
import qualified Data.IntMap as IntMap
import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad.IO.Class
import AERN2.QA.Protocol
import AERN2.QA.NetLog
data QANetState =
QANetState
{
net_nextId :: ValueId
, net_log :: QANetLog
}
initQANetState :: QANetState
initQANetState =
QANetState
{
net_nextId = ValueId 1
, net_log = []
}
getValueId :: QANetState -> [ValueId] -> String -> (QANetState, ValueId)
getValueId ns sources name =
(ns2, vId)
where
ns2 =
ns
{
net_nextId = succ vId
, net_log = (net_log ns) ++ [logItem]
}
vId = net_nextId ns
logItem = QANetLogCreate vId sources name
logQuery ::
QANetState -> Maybe ValueId -> ValueId -> String -> QANetState
logQuery ns src valueId qS =
ns { net_log = (net_log ns) ++ [logItem] }
where
logItem = QANetLogQuery src valueId qS
logAnswer ::
QANetState -> Maybe ValueId -> ValueId -> (String, String) -> QANetState
logAnswer ns src valueId (aS, usedCacheS) =
ns
{
net_log = (net_log ns) ++ [logItem]
}
where
logItem = QANetLogAnswer src valueId usedCacheS aS
type QAParA = Kleisli QAParM
data QAParM a = QAParM { unQAParM :: Maybe (TVar QANetState) -> IO a }
instance Functor QAParM where
fmap f (QAParM tv2ma) = QAParM (fmap (fmap f) tv2ma)
instance Applicative QAParM where
pure a = QAParM (pure . pure a)
(QAParM tv2f) <*> (QAParM tv2a) =
QAParM (\nsTV -> (tv2f nsTV <*> tv2a nsTV))
instance Monad QAParM where
(QAParM tv2ma) >>= f =
QAParM $ \nsTV -> tv2ma nsTV >>= ($ nsTV) . unQAParM . f
instance MonadIO QAParM where
liftIO = QAParM . const
instance QAArrow QAParA where
type QAId QAParA = ValueId
qaRegister options = Kleisli qaRegisterM
where
isParallel = not (QARegPreferSerial `elem` options)
qaRegisterM qa@(QA__ name Nothing sourceIds (p :: p) sampleQ _) =
QAParM $ \m_nsTV ->
do
vId <- case m_nsTV of
Nothing -> pure (ValueId 0)
Just nsTV ->
atomically $
do
ns <- readTVar nsTV
let (ns2,i) = getValueId ns sourceIds name
writeTVar nsTV ns2
pure i
activeQsTV <- atomically $ newTVar initActiveQs
cacheTV <- atomically $ newTVar $ newQACache p
return $ QA__ name (Just vId) [] p sampleQ (\me_src -> Kleisli $ makeQPar vId activeQsTV cacheTV me_src)
where
initActiveQs = IntMap.empty :: IntMap.IntMap (Q p)
nextActiveQId activeQs
| IntMap.null activeQs = int 1
| otherwise =
int $ 1 + (fst $ IntMap.findMax activeQs)
makeQPar vId activeQsTV cacheTV (_, src) q =
QAParM $ \m_nsTV ->
do
maybeTraceIO $ printf "[%s]: q = %s" name (show q)
case m_nsTV of
Nothing -> pure ()
Just nsTV ->
atomically $ do
ns <- readTVar nsTV
writeTVar nsTV $ logQuery ns src vId (show q)
(maybeAnswer, mLogMsg, maybeComputeId) <- atomically $
do
cache <- readTVar cacheTV
case lookupQACache p cache q of
(Just a, mLogMsg) -> return (Just a, mLogMsg, Nothing)
(_, mLogMsg) ->
do
activeQs <- readTVar activeQsTV
let alreadyActive = or $ map (!>=! q) $ IntMap.elems activeQs
if alreadyActive then return (Nothing, mLogMsg, Nothing) else
do
let computeId = nextActiveQId activeQs
writeTVar activeQsTV $ IntMap.insert computeId q activeQs
return (Nothing, mLogMsg, Just computeId)
case (maybeAnswer, maybeComputeId) of
(Just a, _) ->
pure $ promise mLogMsg (pure a)
(_, Just computeId) ->
do
_ <- forkComputation m_nsTV computeId
pure $ promise mLogMsg waitForAnwer
_ ->
pure $ promise mLogMsg waitForAnwer
where
promise mLogMsg answerIO =
Kleisli $ const $ QAParM $ \m_nsTV ->
case m_nsTV of
Nothing -> answerIO
Just nsTV -> do
a <- answerIO
atomically $ do
ns <- readTVar nsTV
writeTVar nsTV $ logAnswer ns src vId (show a, logMsg)
pure a
where
logMsg = case mLogMsg of Just m -> m; _ -> ""
waitForAnwer = atomically $
do
cache <- readTVar cacheTV
case lookupQACache p cache q of
(Just a, _mLogMsg) -> return a
(_, _mLogMsg) -> retry
forkComputation nsTV computeId
| isParallel = do { _ <- forkIO computation; return () }
| otherwise = do { computation; return () }
where
computation =
do
a <- (unQAParM $ runKleisli (qaMakeQuery qa src) q) nsTV
atomically $ modifyTVar cacheTV (updateQACache p q a)
atomically $ modifyTVar activeQsTV (IntMap.delete computeId)
qaRegisterM _ =
error "internal error in AERN2.QA.Strategy.Par: qaRegister called with an existing id"
qaFulfilPromiseA = Kleisli qaFulfilPromiseM
where
qaFulfilPromiseM promiseA =
runKleisli promiseA ()
qaMakeQueryGetPromiseA src = Kleisli qaMakeQueryGetPromiseM
where
qaMakeQueryGetPromiseM (qa, q) =
runKleisli (qaMakeQueryGetPromise qa (me, src)) q
where
me = case qaId qa of Nothing -> src; me2 -> me2
executeQAParA :: (QAParA () a) -> IO a
executeQAParA code =
do
(unQAParM $ runKleisli code ()) Nothing
executeQAParAwithLog :: (QAParA () a) -> IO (QANetLog, a)
executeQAParAwithLog code =
do
nsTV <- atomically $ newTVar initQANetState
result <- (unQAParM $ runKleisli code ()) (Just nsTV)
ns <- atomically $ readTVar nsTV
return (net_log ns, result)