module Network.KRPC.Manager
(
MonadKRPC (..)
, Options (..)
, Manager
, newManager
, closeManager
, withManager
, isActive
, listen
, QueryFailure (..)
, query
, getQueryCount
, HandlerFailure (..)
, Handler
, handler
) where
import Control.Applicative
import Control.Concurrent
import Control.Concurrent.Lifted (fork)
import Control.Exception hiding (Handler)
import qualified Control.Exception.Lifted as E (Handler (..))
import Control.Exception.Lifted as Lifted (catches, finally)
import Control.Monad
import Control.Monad.Logger
import Control.Monad.Reader
import Control.Monad.Trans.Control
import Data.BEncode as BE
import Data.ByteString as BS
import Data.ByteString.Char8 as BC
import Data.ByteString.Lazy as BL
import Data.Default.Class
import Data.IORef
import Data.List as L
import Data.Map as M
import Data.Monoid
import Data.Text as T
import Data.Text.Encoding as T
import Data.Tuple
import Data.Typeable
import Network.KRPC.Message
import Network.KRPC.Method
import Network.Socket hiding (listen)
import Network.Socket.ByteString as BS
import System.IO.Error
import System.Timeout
data Options = Options
{
optSeedTransaction :: !Int
, optQueryTimeout :: !Int
, optMaxMsgSize :: !Int
} deriving (Show, Eq)
defaultSeedTransaction :: Int
defaultSeedTransaction = 0
defaultQueryTimeout :: Int
defaultQueryTimeout = 120
defaultMaxMsgSize :: Int
defaultMaxMsgSize = 64 * 1024
instance Default Options where
def = Options
{ optSeedTransaction = defaultSeedTransaction
, optQueryTimeout = defaultQueryTimeout
, optMaxMsgSize = defaultMaxMsgSize
}
validateOptions :: Options -> IO ()
validateOptions Options {..}
| optQueryTimeout < 1
= throwIO (userError "krpc: non-positive query timeout")
| optMaxMsgSize < 1
= throwIO (userError "krpc: non-positive buffer size")
| otherwise = return ()
type KResult = Either KError KResponse
type TransactionCounter = IORef Int
type CallId = (TransactionId, SockAddr)
type CallRes = MVar KResult
type PendingCalls = IORef (Map CallId CallRes)
type HandlerBody h = SockAddr -> BValue -> h (BE.Result BValue)
type Handler h = (MethodName, HandlerBody h)
data Manager h = Manager
{ sock :: !Socket
, options :: !Options
, listenerThread :: !(MVar ThreadId)
, transactionCounter :: !TransactionCounter
, pendingCalls :: !PendingCalls
, handlers :: [Handler h]
}
class (MonadBaseControl IO m, MonadLogger m, MonadIO m)
=> MonadKRPC h m | m -> h where
getManager :: m (Manager h)
default getManager :: MonadReader (Manager h) m => m (Manager h)
getManager = ask
liftHandler :: h a -> m a
default liftHandler :: m a -> m a
liftHandler = id
instance (MonadBaseControl IO h, MonadLogger h, MonadIO h)
=> MonadKRPC h (ReaderT (Manager h) h) where
liftHandler = lift
sockAddrFamily :: SockAddr -> Family
sockAddrFamily (SockAddrInet _ _ ) = AF_INET
sockAddrFamily (SockAddrInet6 _ _ _ _) = AF_INET6
sockAddrFamily (SockAddrUnix _ ) = AF_UNIX
newManager :: Options
-> SockAddr
-> [Handler h]
-> IO (Manager h)
newManager opts @ Options {..} servAddr handlers = do
validateOptions opts
sock <- bindServ
tref <- newEmptyMVar
tran <- newIORef optSeedTransaction
calls <- newIORef M.empty
return $ Manager sock opts tref tran calls handlers
where
bindServ = do
let family = sockAddrFamily servAddr
sock <- socket family Datagram defaultProtocol
when (family == AF_INET6) $ do
setSocketOption sock IPv6Only 0
bindSocket sock servAddr
return sock
closeManager :: Manager m -> IO ()
closeManager Manager {..} = do
maybe (return ()) killThread =<< tryTakeMVar listenerThread
close sock
isActive :: Manager m -> IO Bool
isActive Manager {..} = liftIO $ isBound sock
withManager :: Options -> SockAddr -> [Handler h]
-> (Manager h -> IO a) -> IO a
withManager opts addr hs = bracket (newManager opts addr hs) closeManager
querySignature :: MethodName -> TransactionId -> SockAddr -> Text
querySignature name transaction addr = T.concat
[ "&", T.decodeUtf8 name
, " #", T.decodeUtf8 transaction
, " @", T.pack (show addr)
]
data QueryFailure
= SendFailed
| QueryFailed ErrorCode Text
| TimeoutExpired
deriving (Show, Eq, Typeable)
instance Exception QueryFailure
sendMessage :: MonadIO m => BEncode a => Socket -> SockAddr -> a -> m ()
sendMessage sock addr a = do
liftIO $ sendManyTo sock (BL.toChunks (BE.encode a)) addr
genTransactionId :: TransactionCounter -> IO TransactionId
genTransactionId ref = do
cur <- atomicModifyIORef' ref $ \ cur -> (succ cur, cur)
return $ BC.pack (show cur)
getQueryCount :: MonadKRPC h m => m Int
getQueryCount = do
Manager {..} <- getManager
curTrans <- liftIO $ readIORef transactionCounter
return $ curTrans optSeedTransaction options
registerQuery :: CallId -> PendingCalls -> IO CallRes
registerQuery cid ref = do
ares <- newEmptyMVar
atomicModifyIORef' ref $ \ m -> (M.insert cid ares m, ())
return ares
unregisterQuery :: CallId -> PendingCalls -> IO (Maybe CallRes)
unregisterQuery cid ref = do
atomicModifyIORef' ref $ swap .
M.updateLookupWithKey (const (const Nothing)) cid
queryResponse :: BEncode a => CallRes -> IO a
queryResponse ares = do
res <- readMVar ares
case res of
Left (KError c m _) -> throwIO $ QueryFailed c (T.decodeUtf8 m)
Right (KResponse {..}) ->
case fromBEncode respVals of
Right r -> pure r
Left e -> throwIO $ QueryFailed ProtocolError (T.pack e)
sendQuery :: BEncode a => Socket -> SockAddr -> a -> IO ()
sendQuery sock addr q = handle sockError $ sendMessage sock addr q
where
sockError :: IOError -> IO ()
sockError _ = throwIO SendFailed
query :: forall h m a b. (MonadKRPC h m, KRPC a b) => SockAddr -> a -> m b
query addr params = do
Manager {..} <- getManager
tid <- liftIO $ genTransactionId transactionCounter
let queryMethod = method :: Method a b
let signature = querySignature (methodName queryMethod) tid addr
$(logDebugS) "query.sending" signature
mres <- liftIO $ do
ares <- registerQuery (tid, addr) pendingCalls
let q = KQuery (toBEncode params) (methodName queryMethod) tid
sendQuery sock addr q
`onException` unregisterQuery (tid, addr) pendingCalls
timeout (optQueryTimeout options * 10 ^ (6 :: Int)) $ do
queryResponse ares
case mres of
Just res -> do
$(logDebugS) "query.responded" $ signature
return res
Nothing -> do
_ <- liftIO $ unregisterQuery (tid, addr) pendingCalls
$(logWarnS) "query.not_responding" $ signature <> " for " <>
T.pack (show (optQueryTimeout options)) <> " seconds"
throw $ TimeoutExpired
data HandlerFailure
= BadAddress
| InvalidParameter Text
deriving (Show, Eq, Typeable)
instance Exception HandlerFailure
prettyHF :: HandlerFailure -> BS.ByteString
prettyHF BadAddress = T.encodeUtf8 "bad address"
prettyHF (InvalidParameter reason) = T.encodeUtf8 $
"invalid parameter: " <> reason
prettyQF :: QueryFailure -> BS.ByteString
prettyQF e = T.encodeUtf8 $ "handler fail while performing query: "
<> T.pack (show e)
handler :: forall h a b. (KRPC a b, Monad h)
=> (SockAddr -> a -> h b) -> Handler h
handler body = (name, wrapper)
where
Method name = method :: Method a b
wrapper addr args =
case fromBEncode args of
Left e -> return $ Left e
Right a -> do
r <- body addr a
return $ Right $ toBEncode r
runHandler :: MonadKRPC h m
=> HandlerBody h -> SockAddr -> KQuery -> m KResult
runHandler h addr KQuery {..} = Lifted.catches wrapper failbacks
where
signature = querySignature queryMethod queryId addr
wrapper = do
$(logDebugS) "handler.quered" signature
result <- liftHandler (h addr queryArgs)
case result of
Left msg -> do
$(logDebugS) "handler.bad_query" $ signature <> " !" <> T.pack msg
return $ Left $ KError ProtocolError (BC.pack msg) queryId
Right a -> do
$(logDebugS) "handler.success" signature
return $ Right $ KResponse a queryId
failbacks =
[ E.Handler $ \ (e :: HandlerFailure) -> do
$(logDebugS) "handler.failed" signature
return $ Left $ KError ProtocolError (prettyHF e) queryId
, E.Handler $ \ (e :: QueryFailure) -> do
return $ Left $ KError ServerError (prettyQF e) queryId
, E.Handler $ \ (e :: SomeException) -> do
return $ Left $ KError GenericError (BC.pack (show e)) queryId
]
dispatchHandler :: MonadKRPC h m => KQuery -> SockAddr -> m KResult
dispatchHandler q @ KQuery {..} addr = do
Manager {..} <- getManager
case L.lookup queryMethod handlers of
Nothing -> return $ Left $ KError MethodUnknown queryMethod queryId
Just h -> runHandler h addr q
handleQuery :: MonadKRPC h m => KQuery -> SockAddr -> m ()
handleQuery q addr = void $ fork $ do
Manager {..} <- getManager
res <- dispatchHandler q addr
sendMessage sock addr $ either toBEncode toBEncode res
handleResponse :: MonadKRPC h m => KResult -> SockAddr -> m ()
handleResponse result addr = do
Manager {..} <- getManager
liftIO $ do
let resultId = either errorId respId result
mcall <- unregisterQuery (resultId, addr) pendingCalls
case mcall of
Nothing -> return ()
Just ares -> putMVar ares result
handleMessage :: MonadKRPC h m => KMessage -> SockAddr -> m ()
handleMessage (Q q) = handleQuery q
handleMessage (R r) = handleResponse (Right r)
handleMessage (E e) = handleResponse (Left e)
listener :: MonadKRPC h m => m ()
listener = do
Manager {..} <- getManager
forever $ do
(bs, addr) <- liftIO $ do
handle exceptions $ BS.recvFrom sock (optMaxMsgSize options)
case BE.decode bs of
Left e -> liftIO $ sendMessage sock addr $ unknownMessage e
Right m -> handleMessage m addr
where
exceptions :: IOError -> IO (BS.ByteString, SockAddr)
exceptions e
| isEOFError e = return ("", SockAddrInet 0 0)
| otherwise = throwIO e
listen :: MonadKRPC h m => m ()
listen = do
Manager {..} <- getManager
tid <- fork $ do
listener `Lifted.finally`
liftIO (takeMVar listenerThread)
liftIO $ putMVar listenerThread tid