{-# LANGUAGE LambdaCase, OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
module Database.Franz.Network
( defaultPort
, Connection
, withConnection
, connect
, disconnect
, Query(..)
, ItemRef(..)
, RequestType(..)
, defQuery
, Response
, awaitResponse
, SomeIndexMap
, Contents
, fetch
, fetchSimple
, atomicallyWithin
, FranzException(..)) where
import Control.Arrow ((&&&))
import Control.Concurrent
import Control.Concurrent.STM
import Control.Concurrent.STM.Delay (newDelay, waitDelay)
import Control.Exception
import Control.Monad
import qualified Data.ByteString.Char8 as B
import Data.ConcurrentResourceMap
import qualified Data.HashMap.Strict as HM
import Data.IORef
import Data.IORef.Unboxed
import Data.Int (Int64)
import qualified Data.IntMap.Strict as IM
import Data.Serialize hiding (getInt64le)
import qualified Data.Vector as V
import qualified Data.Vector.Generic.Mutable as VGM
import Database.Franz.Internal
import Database.Franz.Protocol
import qualified Network.Socket as S
import qualified Network.Socket.ByteString as SB
newtype ConnStateMap v = ConnStateMap (IM.IntMap v)
instance ResourceMap ConnStateMap where
type Key ConnStateMap = Int
empty = ConnStateMap IM.empty
delete k (ConnStateMap m) = ConnStateMap (IM.delete k m)
insert k v (ConnStateMap m) = ConnStateMap (IM.insert k v m)
lookup k (ConnStateMap m) = IM.lookup k m
data Connection = Connection
{ connSocket :: MVar S.Socket
, connReqId :: !Counter
, connStates :: !(ConcurrentResourceMap
ConnStateMap
(TVar (ResponseStatus Contents)))
, connThread :: !ThreadId
}
data ResponseStatus a = WaitingInstant
| WaitingDelayed
| Errored !FranzException
| Available !a
| RequestFinished
deriving (Show, Functor)
withConnection :: String -> S.PortNumber -> B.ByteString -> (Connection -> IO r) -> IO r
withConnection host port dir = bracket (connect host port dir) disconnect
connect :: String -> S.PortNumber -> B.ByteString -> IO Connection
connect host port dir = do
let hints = S.defaultHints { S.addrFlags = [S.AI_NUMERICSERV], S.addrSocketType = S.Stream }
addr:_ <- S.getAddrInfo (Just hints) (Just host) (Just $ show port)
sock <- S.socket (S.addrFamily addr) S.Stream (S.addrProtocol addr)
S.setSocketOption sock S.NoDelay 1
S.connect sock $ S.addrAddress addr
SB.sendAll sock $ encode dir
readyMsg <- SB.recv sock 4096
unless (readyMsg == apiVersion) $ case decode readyMsg of
Right (ResponseError _ e) -> throwIO e
e -> throwIO $ ClientError $ "Database.Franz.Network.connect: Unexpected response: " ++ show e
connSocket <- newMVar sock
connReqId <- newCounter 0
connStates <- newResourceMap
buf <- newIORef B.empty
let
withRequest i f = withInitialisedResource connStates i (\_ -> pure ()) $ \case
Nothing ->
void $ atomically (f Nothing)
Just reqVar -> atomically $ readTVar reqVar >>= \case
RequestFinished -> void $ f Nothing
s -> f (Just s) >>= mapM_ (writeTVar reqVar)
runGetThrow :: Get a -> IO a
runGetThrow g = runGetRecv buf sock g
>>= either (throwIO . ClientError) pure
connThread <- flip forkFinally (either throwIO pure) $ forever $ runGetThrow get >>= \case
Response i -> do
resp <- runGetThrow getResponse
withRequest i . traverse $ \case
WaitingInstant -> pure (Available resp)
WaitingDelayed -> pure (Available resp)
e -> throwSTM $ ClientError $ "Unexpected state on ResponseInstant " ++ show i ++ ": " ++ show e
ResponseWait i -> withRequest i . traverse $ \case
WaitingInstant -> pure WaitingDelayed
e -> throwSTM $ ClientError $ "Unexpected state on ResponseWait " ++ show i ++ ": " ++ show e
ResponseError i e -> withRequest i $ \case
Nothing -> throwSTM e
Just{} -> pure $ Just (Errored e)
return Connection{..}
disconnect :: Connection -> IO ()
disconnect Connection{..} = do
killThread connThread
withMVar connSocket S.close
defQuery :: B.ByteString -> Query
defQuery name = Query
{ reqStream = name
, reqFrom = BySeqNum 0
, reqTo = BySeqNum 0
, reqType = AllItems
}
type SomeIndexMap = HM.HashMap IndexName Int64
type Contents = V.Vector (Int, SomeIndexMap, B.ByteString)
type Response = Either Contents (STM Contents)
awaitResponse :: STM (Either a (STM a)) -> STM a
awaitResponse = (>>=either pure id)
getResponse :: Get Contents
getResponse = do
PayloadHeader s0 s1 p0 names <- get
let df = s1 - s0
if df <= 0
then pure mempty
else do
ixs <- V.replicateM df $ (,) <$> getInt64le <*> traverse (const getInt64le) names
payload <- getByteString $ fst (V.unsafeLast ixs) - p0
pure $ V.create $ do
vres <- VGM.unsafeNew df
let go i ofs0
| i >= df = pure ()
| otherwise = do
let (ofs1, indices) = V.unsafeIndex ixs i
!m = HM.fromList $ zip names indices
!bs = B.take (ofs1 - ofs0) $ B.drop (ofs0 - p0) payload
!num = s0 + i + 1
VGM.unsafeWrite vres i (num, m, bs)
go (i + 1) ofs1
go 0 p0
return vres
fetch
:: Connection
-> Query
-> (STM Response -> IO r)
-> IO r
fetch Connection{..} req cont = do
reqId <- atomicAddCounter connReqId 1
let
cleanupRequest reqVar = do
let inFlight WaitingInstant = True
inFlight WaitingDelayed = True
inFlight _ = False
requestInFlight <- atomically $
stateTVar reqVar $ inFlight &&& const RequestFinished
when requestInFlight $ withMVar connSocket $ \sock ->
SB.sendAll sock $ encode $ RawClean reqId
withSharedResource connStates reqId
(newTVarIO WaitingInstant)
cleanupRequest $ \reqVar -> do
withMVar connSocket $ \sock -> SB.sendAll sock $ encode
$ RawRequest reqId req
let
requestFinished = ClientError "request already finished"
getDelayed = readTVar reqVar >>= \case
RequestFinished -> throwSTM requestFinished
WaitingDelayed -> retry
Available xs -> return xs
Errored e -> throwSTM e
WaitingInstant -> throwSTM $ ClientError $
"fetch/WaitingDelayed: unexpected state WaitingInstant"
cont $ readTVar reqVar >>= \case
RequestFinished -> throwSTM requestFinished
Errored e -> throwSTM e
WaitingInstant -> retry
Available xs -> pure $ Left xs
WaitingDelayed -> pure $ Right getDelayed
fetchSimple :: Connection
-> Int
-> Query
-> IO Contents
fetchSimple conn timeout req = fetch conn req (fmap (maybe mempty id) . atomicallyWithin timeout . awaitResponse)
atomicallyWithin :: Int
-> STM a
-> IO (Maybe a)
atomicallyWithin timeout m = do
d <- newDelay timeout
atomically $ fmap Just m `orElse` (Nothing <$ waitDelay d)