module Data.BTree.Cache.STM where
import Data.Hashable (Hashable)
import qualified Data.List as L
import qualified Data.Map as M
import qualified Data.Graph as G
import qualified Data.BTree.Cache.Class as C
import qualified Data.BTree.HashTable.STM as H
import qualified Data.Serialize as S
import qualified Data.BTree.KVBackend.Class as KV
import Control.Concurrent.STM
import Control.Concurrent
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Reader
import qualified Data.ByteString as B
import Data.Either
import Data.Maybe
import Data.Time.Clock
import Debug.Trace
import System.IO
import Control.Monad.Error
data State k v = Read !(Maybe v)
| Write !(Maybe k) !(Maybe v)
deriving Eq
instance Show (State k v) where
show (Read a) = "Read: " ++ mtoS a
show (Write k a) = "Write: " ++ mtoS a
mtoS a = if isJust a then "Just" else "Nothing"
type AccessTime = UTCTime
data Exist = Exist | NoExist
deriving (Eq)
data Ref k v = Ref {
refST :: TVar (Either (State k v)
(State k v, Int, State k v)),
refExt :: TVar Exist
}
data Param m k v = Param {
cacheSize :: Int
, table :: H.HashTableSTM k (Ref k (Either B.ByteString v))
, toIO :: forall a. m a -> IO a
, flushQ :: TVar [(k, Ref k v)]
, timestamp :: UTCTime
, genId :: TVar Int
, genActive :: TVar Int
}
trace = id
newtype CacheSTM m k v a = CacheSTM { runCacheSTM :: ReaderT (Param m k v) (ErrorT (IO ()) STM) a }
deriving (Monad, MonadReader (Param m k v), MonadError (IO ()), Functor)
instance Error (IO ())
stm m = lift $ lift m
evalCacheSTM :: Param m k v -> CacheSTM m k v a -> IO a
evalCacheSTM p m = do
mer <- atomically $ runErrorT $ runReaderT (runCacheSTM m) p
case mer of
Left io -> io >> evalCacheSTM p m
Right a -> return a
sizedParam :: Int -> (forall a. m a -> IO a) -> IO (Param m k v)
sizedParam s f = do
ts <- getCurrentTime
ht <- H.newSized s
mv <- atomically $ newTVar []
gn <- atomically $ newTVar 0
ga <- atomically $ newTVar 0
return $ Param s ht f mv ts gn ga
getRef k = do
p <- ask
ht <- asks table
ev <- asks toIO
mr <- stm $ H.lookup ht k
case mr of
Nothing -> fetchRef p ev ht
Just r -> return r
where
fetchRef p ev ht = throwError $ (ev $ KV.fetch k) >>= createRef p ht
createRef p ht bytes = do
atomically $ do
l <- H.lookup ht k
case l of
Just _ -> return ()
Nothing -> do
tvst <- newTVar $ Left $ Read $ Left `fmap` bytes
tvex <- newTVar Exist
let ref = Ref tvst tvex
H.insert ht k ref
newRef t k v = do
ht <- asks table
r@(Ref tv _) <- getRef k
update tv $! Write t v
where
update tv x = do
gt <- asks genId
ga <- asks genActive
stm $ do
gid <- readTVar gt
act <- readTVar ga
s <- readTVar tv
case s of
Left o | act == 0 -> writeTVar tv $! Left $! x
| otherwise -> writeTVar tv $! Right $! (x, gid, o)
Right (x', n, o)
| act == 0 -> writeTVar tv $! Left $! x
| gid /= n -> writeTVar tv $! Right $! (x, gid, x')
| otherwise -> writeTVar tv $! Right $! (x, gid, o )
maybeQueue force t x =
if force then update
else do
s <- stm $ readTVar t
case s of
Left (Write _ _) -> return ()
Right (Write _ _, _, _) -> return ()
_ -> update
where
update = do
qt <- asks flushQ
stm $ do q <- readTVar qt
writeTVar qt $! x : q
store t k v = CacheSTM $ newRef t k $! Just $! Right v
fetch k = CacheSTM $ do
r@(Ref tv _) <- getRef k
s <- stm $ readTVar tv
case s of
Left x -> return $! value x
Right (x, _, _) -> return $! value x
fetchGen n k = CacheSTM $ do
r@(Ref tv _) <- getRef k
s <- stm $ readTVar tv
return $ value $ getGen n s
remove t k = CacheSTM $ newRef t k Nothing
updateTag t k = CacheSTM $ do
ht <- asks table
x <- stm $ H.lookup ht k
case x of
Nothing -> return ()
Just (Ref tv _) -> do
s <- stm $ readTVar tv
case s of
Left (Write _ v) -> stm $ writeTVar tv $! Left $! Write t v
Right (Write _ v, n, o) -> stm $ writeTVar tv $! Right $! (Write t v, n, o)
_ -> return ()
keys = CacheSTM $ do
ht <- asks table
stm $ H.keys ht
debug a = liftIO $ do print a
hFlush stdout
getGen n (Left s) = s
getGen n (Right (news, m, olds))
| n == m = olds
| otherwise = news
flipWrite x (Left (Write _ x')) | x == x' = Left $! (Read x)
flipWrite x (Right (Write _ x', n, o)) | x == x' = Right $! (Read x', n, o)
flipWrite x (Right (c, n, Write _ x')) | x == x' = Right $! (c, n, Read x')
flipWrite _ s = s
equals a b = value a == value b
value v = case v of
Read x -> either decode id `fmap` x
Write _ x -> either decode id `fmap` x
where
decode = either error id . S.decode
withGeneration p f = do
n <- liftIO $ atomically $ do
a <- readTVar $ genActive p
n <- readTVar $ genId p
writeTVar (genActive p) $! a + 1
if a > 0 then return n
else do writeTVar (genId p) $! n + 1
return $! n + 1
x <- f n
liftIO $ atomically $ do
a <- readTVar $ genActive p
writeTVar (genActive p) $! a 1
return x
flush p = do
nowSize <- atomically $ H.size ht
gen <- atomically $ readTVar $ genId p
when (nowSize > maxSize) $ do
flush =<< (atomically $ H.toList ht)
where
ht = table p
maxSize = cacheSize p
flush ks = do
mapM_ (evalCacheSTM p . flushKey ht) ks
flushKey ht (k, r@(Ref tvst tvex)) = CacheSTM $ do
tvga <- asks genActive
tvgi <- asks genId
stm $ do
act <- readTVar tvga
gen <- readTVar tvgi
s <- readTVar tvst
case s of
Left (Read _) ->
H.delete ht k >> return Nothing
Right (Read s, n, _) | act == 0 || n /= gen ->
H.delete ht k >> return Nothing
Right (Write t s, n, _) | act == 0 || n /= gen ->
(writeTVar tvst $! Left $! Write t s) >> return Nothing
Left (Write t (Just (Right v))) ->
(writeTVar tvst $! Left $! Write t $! Just $! Left $! S.encode v)
>> return Nothing
Right (Write t (Just (Right v)), n, o) ->
(writeTVar tvst $! Right $! (Write t $! Just $! Left $! S.encode v, n, o))
>> return Nothing
_ -> return $ Just s
sync p = do
withGeneration p $ \gen -> do
ks <- atomically $ H.toList $ table p
ls <- forM ks $ \(k, r@(Ref tv _)) -> do
s <- atomically $ readTVar tv
case getGen gen s of
Write (Just t) _ -> return $! Just $! Left (t, k, r)
Write Nothing _ -> return $! Just $! Right (k, r)
_ -> return $! Nothing
let (lefts, rights) = partitionEithers $ catMaybes ls
mapM_ (evalCacheSTM p . go gen) $ sortByTag lefts
mapM_ (evalCacheSTM p . go gen) $ rights
where
sortByTag :: Ord k => [(k, k, Ref k v)] -> [(k, Ref k v)]
sortByTag ls =
let m = M.fromList $ zip (map (\(_, k, _) -> k) ls) [0..]
(g, f, _) = G.graphFromEdges
[((k, r), i, maybe [] return $ M.lookup t m)
| ((t, k, r), i) <- zip ls [0..]]
in map (\(p, _, _) -> p) $ map f $ G.topSort g
go gen (k, (r@(Ref tv tvex))) = CacheSTM $ do
s <- stm $ readTVar tv
ex <- stm $ readTVar tvex
case getGen gen s of
Write _ Nothing | ex == Exist -> throwError write
| otherwise -> stm $ update Nothing
Write _ (Just v) -> throwError write
_ -> return ()
where
write = do
s <- atomically $ readTVar tv
case getGen gen s of
Write _ Nothing -> (toIO p $ KV.remove k ) >> (atomically $ update Nothing)
Write _ (Just v) -> (toIO p $ KV.store k $ either id S.encode $ v)
>> (atomically $ update (Just v))
_ -> return ()
update v = do
s <- readTVar tv
writeTVar tv $! flipWrite v s
liftSTM = CacheSTM . stm
fail = CacheSTM . throwError
instance ( Show k, S.Serialize k, S.Serialize v, Ord k, Eq k, Eq v
, Hashable k, KV.KVBackend m k B.ByteString) =>
C.Cache (CacheSTM m k v) (Param m k v) k v where
store = store
fetch = fetch
remove = remove
sync = sync
eval = evalCacheSTM