module Control.TimeWarp.Rpc.Transfer
(
Transfer (..)
, TransferException (..)
, ConnectionPool
, runTransfer
, runTransferS
, runTransferRaw
, getConnPool
, FailsInRow
, Settings (..)
) where
import qualified Control.Concurrent as C
import Control.Concurrent.STM (STM, atomically, check)
import qualified Control.Concurrent.STM.TBMChan as TBM
import qualified Control.Concurrent.STM.TChan as TC
import qualified Control.Concurrent.STM.TVar as TV
import Control.Lens (at, at, each, makeLenses, use, view,
(.=), (?=), (^..))
import Control.Monad (forM_, forever, guard, unless, when)
import Control.Monad.Base (MonadBase)
import Control.Monad.Catch (Exception, MonadCatch,
MonadMask (mask), MonadThrow (..),
bracket, bracketOnError, catchAll,
finally, handleAll, onException,
throwM)
import Control.Monad.Morph (hoist)
import Control.Monad.Reader (MonadReader (ask), ReaderT (..))
import Control.Monad.State (StateT (..))
import Control.Monad.Trans (MonadIO (..), lift)
import Control.Monad.Trans.Control (MonadBaseControl (..))
import Control.Monad.Extra (whenM)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Conduit (Sink, Source, ($$))
import Data.Conduit.Binary (sinkLbs, sourceLbs)
import Data.Conduit.Network (sinkSocket, sourceSocket)
import Data.Conduit.TMChan (sinkTBMChan, sourceTBMChan)
import Data.Default (Default (..))
import Data.HashMap.Strict (HashMap)
import qualified Data.IORef as IR
import Data.List (intersperse)
import Data.Streaming.Network (acceptSafe, bindPortTCP,
getSocketFamilyTCP)
import Data.Text (Text)
import Data.Text.Buildable (Buildable (build), build)
import Data.Text.Encoding (decodeUtf8)
import Data.Typeable (Typeable)
import Formatting (bprint, builder, int, sformat, shown,
stext, string, (%))
import qualified Network.Socket as NS
import Serokell.Util.Base (inCurrentContext)
import Serokell.Util.Concurrent (modifyTVarS)
import System.Wlog (CanLog, HasLoggerName, LoggerNameBox,
Severity (..), WithLogger, logDebug,
logInfo, logMessage, logWarning)
import Control.TimeWarp.Manager (InterruptType (..), JobCurator (..),
addManagerAsJob, addSafeThreadJob,
addThreadJob, interruptAllJobs,
isInterrupted, jcIsClosed,
mkJobCurator, stopAllJobs,
unlessInterrupted)
import Control.TimeWarp.Rpc.MonadTransfer (Binding (..), MonadTransfer (..),
NetworkAddress, Port,
ResponseContext (..), ResponseT,
commLog, runResponseT, runResponseT,
sendRaw)
import Control.TimeWarp.Timed (Microsecond, MonadTimed, ThreadId,
TimedIO, for, fork, fork_, interval,
killThread, sec, wait)
logSeverityUnlessClosed :: (WithLogger m, MonadIO m)
=> Severity -> JobCurator -> Text -> m ()
logSeverityUnlessClosed severityIfNotClosed jm msg = do
closed <- isInterrupted jm
let severity = if closed then severityIfNotClosed else Debug
logMessage severity msg
data TransferException = AlreadyListeningOutbound Text
deriving (Show, Typeable)
instance Exception TransferException
instance Buildable TransferException where
build (AlreadyListeningOutbound addr) =
bprint ("Already listening at outbound connection to "%stext) addr
data PeerClosedConnection = PeerClosedConnection
deriving (Show, Typeable)
instance Exception PeerClosedConnection
instance Buildable PeerClosedConnection where
build _ = "Peer closed connection"
type PeerAddr = Text
data OutputConnection s = OutputConnection
{ outConnSend :: forall m . (MonadIO m, MonadMask m, WithLogger m)
=> Source m BS.ByteString -> m ()
, outConnRec :: forall m . (MonadIO m, MonadMask m, MonadTimed m,
MonadBaseControl IO m, WithLogger m)
=> Sink BS.ByteString (ResponseT s m) () -> m ()
, outConnJobCurator :: JobCurator
, outConnAddr :: PeerAddr
, outConnUserState :: s
}
type FailsInRow = Int
data Settings = Settings
{ queueSize :: Int
, reconnectPolicy :: forall m . (HasLoggerName m, MonadIO m)
=> FailsInRow -> m (Maybe Microsecond)
}
instance Default Settings where
def = Settings
{ queueSize = 100
, reconnectPolicy =
\failsInRow -> return $ guard (failsInRow < 3) >> Just (interval 3 sec)
}
newtype ConnectionPool s = ConnectionPool
{ _outputConn :: HashMap NetworkAddress (OutputConnection s)
}
makeLenses ''ConnectionPool
initConnectionPool :: ConnectionPool s
initConnectionPool =
ConnectionPool
{ _outputConn = mempty
}
data SocketFrame s = SocketFrame
{ sfPeerAddr :: PeerAddr
, sfInBusy :: TV.TVar Bool
, sfInChan :: TBM.TBMChan BS.ByteString
, sfOutChan :: TBM.TBMChan (BL.ByteString, IO ())
, sfJobCurator :: JobCurator
, sfUserState :: s
}
mkSocketFrame :: MonadIO m
=> Settings -> IO s -> PeerAddr -> m (SocketFrame s)
mkSocketFrame settings mkUserState sfPeerAddr = liftIO $ do
sfInBusy <- TV.newTVarIO False
sfInChan <- TBM.newTBMChanIO (queueSize settings)
sfOutChan <- TBM.newTBMChanIO (queueSize settings)
sfJobCurator <- mkJobCurator
sfUserState <- mkUserState
return SocketFrame{..}
sfSend :: (MonadIO m, WithLogger m)
=> SocketFrame s -> Source m BS.ByteString -> m ()
sfSend SocketFrame{..} src = do
lbs <- src $$ sinkLbs
logQueueState
(notifier, awaiter) <- mkMonitor
liftIO . atomically . TBM.writeTBMChan sfOutChan $ (lbs, atomically notifier)
liftIO . atomically $ do
let jm = getJobCurator sfJobCurator
closed <- view jcIsClosed <$> TV.readTVar jm
unless closed awaiter
where
mkMonitor = do
t <- liftIO $ TV.newTVarIO False
return ( TV.writeTVar t True
, check =<< TV.readTVar t
)
logQueueState = do
whenM (liftIO . atomically $ TBM.isFullTBMChan sfOutChan) $
commLog . logWarning $
sformat ("Send channel for "%shown%" is full") sfPeerAddr
whenM (liftIO . atomically $ TBM.isClosedTBMChan sfOutChan) $
commLog . logWarning $
sformat ("Send channel for "%shown%" is closed, message wouldn't be sent")
sfPeerAddr
sfReceive :: (MonadIO m, MonadMask m, MonadTimed m, WithLogger m,
MonadBaseControl IO m)
=> SocketFrame s -> Sink BS.ByteString (ResponseT s m) () -> m ()
sfReceive sf@SocketFrame{..} sink = do
busy <- liftIO . atomically $ TV.swapTVar sfInBusy True
when busy $ throwM $ AlreadyListeningOutbound sfPeerAddr
liManager <- mkJobCurator
onTimeout <- inCurrentContext logOnInterruptTimeout
let interruptType = WithTimeout (interval 3 sec) onTimeout
mask $ \unmask -> do
addManagerAsJob sfJobCurator interruptType liManager
addThreadJob liManager $ unmask $ logOnErr $ do
(sourceTBMChan sfInChan $$ sink) `runResponseT` sfMkResponseCtx sf
logListeningHappilyStopped
where
logOnErr = handleAll $ \e ->
unlessInterrupted sfJobCurator $ do
commLog . logWarning $ sformat ("Server error: "%shown) e
interruptAllJobs sfJobCurator Plain
logOnInterruptTimeout = commLog . logDebug $
sformat ("While closing socket to "%stext%" listener "%
"worked for too long, closing with no regard to it") sfPeerAddr
logListeningHappilyStopped =
commLog . logDebug $
sformat ("Listening on socket to "%stext%" happily stopped") sfPeerAddr
sfClose :: SocketFrame s -> IO ()
sfClose SocketFrame{..} = do
interruptAllJobs sfJobCurator Plain
atomically $ do
TBM.closeTBMChan sfInChan
TBM.closeTBMChan sfOutChan
clearInChan
where
clearInChan = TBM.tryReadTBMChan sfInChan >>= maybe (return ()) (const clearInChan)
sfMkOutputConn :: SocketFrame s -> OutputConnection s
sfMkOutputConn sf =
OutputConnection
{ outConnSend = sfSend sf
, outConnRec = sfReceive sf
, outConnJobCurator = sfJobCurator sf
, outConnAddr = sfPeerAddr sf
, outConnUserState = sfUserState sf
}
sfMkResponseCtx :: SocketFrame s -> ResponseContext s
sfMkResponseCtx sf =
ResponseContext
{ respSend = sfSend sf
, respClose = sfClose sf
, respPeerAddr = sfPeerAddr sf
, respUserState = sfUserState sf
}
sfProcessSocket :: (MonadIO m, MonadMask m, MonadTimed m, WithLogger m)
=> SocketFrame s -> NS.Socket -> m ()
sfProcessSocket SocketFrame{..} sock = do
eventChan <- liftIO TC.newTChanIO
stid <- fork $ reportErrors eventChan foreverSend $
sformat ("foreverSend on "%stext) sfPeerAddr
rtid <- fork $ reportErrors eventChan foreverRec $
sformat ("foreverRec on "%stext) sfPeerAddr
commLog . logDebug $ sformat ("Start processing of socket to "%stext) sfPeerAddr
ctid <- fork $ do
let jm = getJobCurator sfJobCurator
liftIO . atomically $ check . view jcIsClosed =<< TV.readTVar jm
liftIO . atomically $
TC.writeTChan eventChan $ Right ()
mapM_ killThread [stid, rtid]
let onError e = do
mapM_ killThread [stid, rtid, ctid]
throwM e
event <- liftIO . atomically $ TC.readTChan eventChan
commLog . logDebug $ sformat ("Stop processing socket to "%stext) sfPeerAddr
either onError return event
where
foreverSend =
mask $ \unmask -> do
datm <- liftIO . atomically $ TBM.readTBMChan sfOutChan
forM_ datm $
\dat@(bs, notif) -> do
let pushback = liftIO . atomically $ TBM.unGetTBMChan sfOutChan dat
unmask (sourceLbs bs $$ sinkSocket sock) `onException` pushback
liftIO notif
unmask foreverSend
foreverRec = do
hoist liftIO (sourceSocket sock) $$ sinkTBMChan sfInChan False
unlessInterrupted sfJobCurator $
throwM PeerClosedConnection
reportErrors eventChan action desc =
action `catchAll` \e -> do
commLog . logDebug $ sformat ("Caught error on "%stext%": " % shown) desc e
liftIO . atomically . TC.writeTChan eventChan . Left $ e
newtype Transfer s a = Transfer
{ getTransfer :: ReaderT Settings
(ReaderT (TV.TVar (ConnectionPool s))
(ReaderT (IO s)
(LoggerNameBox
TimedIO
)
)
) a
} deriving (Functor, Applicative, Monad, MonadIO, MonadBase IO,
MonadThrow, MonadCatch, MonadMask, MonadTimed, CanLog, HasLoggerName)
type instance ThreadId (Transfer s) = C.ThreadId
runTransferRaw
:: Settings
-> TV.TVar (ConnectionPool s)
-> IO s
-> Transfer s a
-> LoggerNameBox TimedIO a
runTransferRaw s m us t =
flip runReaderT us $ flip runReaderT m $ flip runReaderT s $
getTransfer t
runTransferS :: Settings -> IO s -> Transfer s a -> LoggerNameBox TimedIO a
runTransferS s us t = do
m <- liftIO (TV.newTVarIO initConnectionPool)
runTransferRaw s m us t
runTransfer :: IO s -> Transfer s a -> LoggerNameBox TimedIO a
runTransfer = runTransferS def
modifyManager :: StateT (ConnectionPool s) STM a -> Transfer s a
modifyManager how = Transfer . lift $
ask >>= liftIO . atomically . flip modifyTVarS how
getConnPool :: Transfer s (TV.TVar (ConnectionPool s))
getConnPool = Transfer $ lift ask
buildSockAddr :: NS.SockAddr -> PeerAddr
buildSockAddr (NS.SockAddrInet port host) =
let buildHost = mconcat . intersperse "."
. map build . (^.. each) . NS.hostAddressToTuple
in sformat (builder%":"%int) (buildHost host) port
buildSockAddr (NS.SockAddrInet6 port _ host _) =
let buildHost6 = mconcat . intersperse "."
. map build . (^.. each) . NS.hostAddress6ToTuple
in sformat (builder%":"%int) (buildHost6 host) port
buildSockAddr (NS.SockAddrUnix addr) = sformat string addr
buildSockAddr (NS.SockAddrCan addr) = sformat ("can:"%int) addr
buildNetworkAddress :: NetworkAddress -> PeerAddr
buildNetworkAddress (host, port) = sformat (stext%":"%int) (decodeUtf8 host) port
listenInbound :: Port
-> Sink BS.ByteString (ResponseT s (Transfer s)) ()
-> Transfer s (Transfer s ())
listenInbound (fromIntegral -> port) sink = do
serverJobCurator <- mkJobCurator
bracketOnError (liftIO $ bindPortTCP port "*") (liftIO . NS.close) $
\lsocket -> mask $
\unmask -> addThreadJob serverJobCurator $
flip finally (liftIO $ NS.close lsocket) . unmask $
handleAll (logOnServerError serverJobCurator) $
serve lsocket serverJobCurator
inCurrentContext $ do
commLog . logDebug $ sformat ("Stopping server at "%int) port
stopAllJobs serverJobCurator
commLog . logDebug $ sformat ("Server at "%int%" fully stopped") port
where
serve lsocket serverJobCurator = forever $
bracketOnError (liftIO $ acceptSafe lsocket) (liftIO . NS.close . fst) $
\(sock, addr) -> mask $
\unmask -> fork_ $ do
settings <- Transfer ask
us <- Transfer . lift . lift $ ask
sf@SocketFrame{..} <- mkSocketFrame settings us $ buildSockAddr addr
addManagerAsJob serverJobCurator Plain sfJobCurator
logNewInputConnection sfPeerAddr
unmask (processSocket sock sf serverJobCurator)
`finally` liftIO (NS.close sock)
processSocket sock sf@SocketFrame{..} jc = do
liftIO $ NS.setSocketOption sock NS.ReuseAddr 1
sfReceive sf sink
unlessInterrupted jc $
handleAll (logErrorOnServerSocketProcessing jc sfPeerAddr) $ do
sfProcessSocket sf sock
logInputConnHappilyClosed sfPeerAddr
logNewInputConnection addr =
commLog . logDebug $
sformat ("New input connection: "%int%" <- "%stext)
port addr
logErrorOnServerSocketProcessing jm addr e =
logSeverityUnlessClosed Warning jm $
sformat ("Error in server socket "%int%" connected with "%stext%": "%shown)
port addr e
logOnServerError jm e =
logSeverityUnlessClosed Error jm $
sformat ("Server at port "%int%" stopped with error "%shown) port e
logInputConnHappilyClosed addr =
commLog . logInfo $
sformat ("Happily closing input connection "%int%" <- "%stext)
port addr
listenOutbound :: NetworkAddress
-> Sink BS.ByteString (ResponseT s (Transfer s)) ()
-> Transfer s (Transfer s ())
listenOutbound addr sink = do
conn <- getOutConnOrOpen addr
outConnRec conn sink
return $ stopAllJobs $ outConnJobCurator conn
getOutConnOrOpen :: NetworkAddress -> Transfer s (OutputConnection s)
getOutConnOrOpen addr@(host, fromIntegral -> port) =
mask $
\unmask -> do
(conn, sfm) <- ensureConnExist
forM_ sfm $
\sf -> addSafeThreadJob (sfJobCurator sf) $
unmask (startWorker sf) `finally` releaseConn sf
return conn
where
addrName = buildNetworkAddress addr
ensureConnExist = do
settings <- Transfer ask
let getOr m act = maybe act (return . (, Nothing)) m
mconn <- modifyManager $ use $ outputConn . at addr
getOr mconn $ do
us <- Transfer . lift . lift $ ask
sf <- mkSocketFrame settings us addrName
let conn = sfMkOutputConn sf
modifyManager $ do
mres <- use $ outputConn . at addr
getOr mres $ do
outputConn . at addr ?= conn
return (conn, Just sf)
startWorker sf = do
failsInRow <- liftIO $ IR.newIORef 0
commLog . logDebug $ sformat ("Lively socket to "%stext%" created, processing")
(sfPeerAddr sf)
withRecovery sf failsInRow $
bracket (liftIO $ fst <$> getSocketFamilyTCP host port NS.AF_UNSPEC)
(liftIO . NS.close) $
\sock -> do
liftIO $ IR.writeIORef failsInRow 0
commLog . logDebug $
sformat ("Established connection to "%stext) (sfPeerAddr sf)
sfProcessSocket sf sock
withRecovery sf failsInRow action = catchAll action $ \e -> do
closed <- isInterrupted (sfJobCurator sf)
unless closed $ do
commLog . logWarning $
sformat ("Error while working with socket to "%stext%": "%shown)
addrName e
reconnect <- reconnectPolicy <$> Transfer ask
fails <- liftIO $ succ <$> IR.readIORef failsInRow
liftIO $ IR.writeIORef failsInRow fails
maybeReconnect <- reconnect fails
case maybeReconnect of
Nothing ->
commLog . logWarning $
sformat ("Can't connect to "%shown%", closing connection") addr
Just delay -> do
commLog . logWarning $
sformat ("Reconnect in "%shown) delay
wait (for delay)
withRecovery sf failsInRow action
releaseConn sf = do
interruptAllJobs (sfJobCurator sf) Plain
modifyManager $ outputConn . at addr .= Nothing
commLog . logDebug $
sformat ("Socket to "%stext%" closed") addrName
instance MonadTransfer s (Transfer s) where
sendRaw addr src = do
conn <- getOutConnOrOpen addr
outConnSend conn src
listenRaw (AtPort port) = listenInbound port
listenRaw (AtConnTo addr) = listenOutbound addr
close addr = do
maybeConn <- modifyManager . use $ outputConn . at addr
forM_ maybeConn $
\conn -> interruptAllJobs (outConnJobCurator conn) Plain
userState addr =
outConnUserState <$> getOutConnOrOpen addr
instance MonadBaseControl IO (Transfer s) where
type StM (Transfer s) a = StM (LoggerNameBox TimedIO) a
liftBaseWith io =
Transfer $ liftBaseWith $ \runInBase -> io $ runInBase . getTransfer
restoreM = Transfer . restoreM