module AERN2.QA.Strategy.Cached.NetState
(
QANetState(..), initQANetState
, AnyQAComputation(..), QAComputation(..)
, insertNode, logQuery, logAnswerUpdateCache, getAnswerPromise
)
where
import MixedTypesNumPrelude
import Control.Arrow
import Unsafe.Coerce
import qualified Data.Map as Map
import AERN2.QA.Protocol
import AERN2.QA.NetLog
data QANetState m =
QANetState
{
net_id2value :: Map.Map ValueId (AnyQAComputation m)
, net_log :: QANetLog
, net_should_cache :: Bool
}
data AnyQAComputation m =
forall p . (QAProtocolCacheable p) =>
AnyQAComputation (QAComputation m p)
data QAComputation m p =
QAComputation
p
(QACache p)
((Maybe ValueId, Maybe ValueId) -> Q p -> m (QAPromiseA (Kleisli m) (A p)))
initQANetState :: Bool -> QANetState m
initQANetState should_cache =
QANetState
{
net_id2value = Map.empty
, net_log = []
, net_should_cache = should_cache
}
insertNode ::
(QAProtocolCacheable p) =>
p ->
String ->
[ValueId] ->
((Maybe ValueId, Maybe ValueId) -> Q p -> m (QAPromiseA (Kleisli m) (A p))) ->
QANetState m ->
(ValueId, QANetState m)
insertNode p name sourceIds q2pa ns =
(i, ns { net_id2value = id2value', net_log = net_log' } )
where
id2value = net_id2value ns
lg = net_log ns
i | Map.null id2value = (ValueId 1)
| otherwise = succ $ fst (Map.findMax id2value)
id2value' = Map.insert i (AnyQAComputation (QAComputation p (newQACache p) q2pa)) id2value
net_log' = lg ++ [logItem]
logItem =
QANetLogCreate i sourceIds name
logQuery ::
QANetState m -> Maybe ValueId -> ValueId -> String -> QANetState m
logQuery ns src valueId qS =
ns { net_log = (net_log ns) ++ [logItem] }
where
logItem = QANetLogQuery src valueId qS
logAnswerUpdateCache ::
(QAProtocolCacheable p)
=>
QANetState m -> p -> Maybe ValueId -> ValueId -> (String, String, QACache p) -> QANetState m
logAnswerUpdateCache ns (p :: p) src valueId (aS, usedCacheS, cache') =
ns
{
net_id2value = id2value',
net_log = (net_log ns) ++ [logItem]
}
where
logItem = QANetLogAnswer src valueId usedCacheS aS
id2value' =
Map.insert valueId
(AnyQAComputation (QAComputation p cache' q2a))
(net_id2value ns)
id2value = net_id2value ns
qaComputation :: (QAComputation m p)
qaComputation = case Map.lookup valueId id2value of
Just (AnyQAComputation comp) -> unsafeCoerce comp
Nothing -> error $ "unknown valueId " ++ show valueId
QAComputation _ _ q2a = qaComputation
getAnswerPromise ::
(QAProtocolCacheable p, Monad m)
=>
QANetState m -> p -> Maybe ValueId -> ValueId -> Q p -> m (() -> m (A p, [Char], QACache p))
getAnswerPromise ns (p :: p) src valueId q =
do
case lookupQACache p cache q of
(Just a, mLogMsg) ->
return $ \() -> return (a, logMsg, cache)
where logMsg = "used cache" ++ case mLogMsg of Nothing -> ""; (Just m) -> " (" ++ m ++ ")"
(_, mLogMsg) ->
do
pa <- q2pa (Just valueId, src) q
a <- runKleisli pa ()
let cache' = updateQACache p q a cache
let a' = case lookupQACache p cache' q of (Just aa, _) -> aa; _ -> a
if should_cache
then return $ \() -> return (a', logMsg, cache')
else return $ \() -> return (a, logMsg, cache)
where logMsg = "not used cache" ++ case mLogMsg of Nothing -> ""; (Just m) -> " (" ++ m ++ ")"
where
id2value = net_id2value ns
should_cache = net_should_cache ns
qaComputation :: (QAComputation m p)
qaComputation = case Map.lookup valueId id2value of
Just (AnyQAComputation comp) -> unsafeCoerce comp
Nothing -> error $ "unknown valueId " ++ show valueId
QAComputation _ cache q2pa = qaComputation