{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE MultiWayIf #-} module Hans.Layer.Tcp.Monad where import Hans.Address.IP4 import Hans.Channel import Hans.Layer import Hans.Layer.IP4 import Hans.Layer.Tcp.Types import Hans.Layer.Tcp.Window import Hans.Message.Ip4 import Hans.Message.Tcp import Control.Applicative(Applicative(..)) import Control.Monad (MonadPlus(..),guard,when) import Data.Time.Clock.POSIX (POSIXTime) import MonadLib (get,set) import qualified Data.ByteString.Lazy as L import qualified Data.Map.Strict as Map import qualified Data.Sequence as Seq -- TCP Monad ------------------------------------------------------------------- type TcpHandle = Channel (Tcp ()) type Tcp = Layer TcpState data TcpState = TcpState { tcpSelf :: {-# UNPACK #-} !TcpHandle , tcpIP4 :: {-# UNPACK #-} !IP4Handle , tcpHost :: !Host } emptyTcpState :: TcpHandle -> IP4Handle -> POSIXTime -> TcpState emptyTcpState tcp ip4 start = TcpState { tcpSelf = tcp , tcpIP4 = ip4 , tcpHost = emptyHost start } -- | The handle to this layer. self :: Tcp TcpHandle self = tcpSelf `fmap` get -- | Get the handle to the IP4 layer. ip4Handle :: Tcp IP4Handle ip4Handle = tcpIP4 `fmap` get -- Host Operations ------------------------------------------------------------- getHost :: Tcp Host getHost = tcpHost `fmap` get getLastUpdate :: Tcp POSIXTime getLastUpdate = hostLastUpdate `fmap` getHost setHost :: Host -> Tcp () setHost host = do rw <- get host `seq` set rw { tcpHost = host } modifyHost_ :: (Host -> Host) -> Tcp () modifyHost_ f = do host <- getHost setHost (f host) modifyHost :: (Host -> (a,Host)) -> Tcp a modifyHost f = do host <- getHost let (a,host') = f host setHost host' return a -- | Reset the 2MSL timer on the socket in TimeWait. resetTimeWait2MSL :: SocketId -> Tcp () resetTimeWait2MSL sid = modifyHost_ $ \ host -> host { hostTimeWaits = Map.adjust twReset2MSL sid (hostTimeWaits host) } getTimeWait :: IP4 -> TcpHeader -> Tcp (Maybe (SocketId,TimeWaitSock)) getTimeWait remote hdr = do host <- getHost let sid = incomingSocketId remote hdr return $ do tw <- Map.lookup sid (hostTimeWaits host) return (sid,tw) removeTimeWait :: SocketId -> Tcp () removeTimeWait sid = modifyHost_ $ \ host -> host { hostTimeWaits = Map.delete sid (hostTimeWaits host) } getConnections :: Tcp Connections getConnections = hostConnections `fmap` getHost takeConnections :: Tcp Connections takeConnections = modifyHost $ \ Host { .. } -> (hostConnections, Host { hostConnections = Map.empty, .. }) setConnections :: Connections -> Tcp () setConnections cons = modifyHost_ (\host -> host { hostConnections = cons }) -- | Lookup a connection, returning @Nothing@ if the connection doesn't exist. lookupConnection :: SocketId -> Tcp (Maybe TcpSocket) lookupConnection sid = do cons <- getConnections return (Map.lookup sid cons) -- | Retrieve a connection from the host. The computation fails if the -- connection doesn't exist. getConnection :: SocketId -> Tcp TcpSocket getConnection sid = do cs <- getConnections case Map.lookup sid cs of Just tcp -> return tcp Nothing -> mzero -- | Assign a connection to a socket id. If the TcpSocket is in TimeWait, this -- will do two things: -- -- 1. Remove the corresponding key from the connections map -- 2. Add the socket to the TimeWait map, using the current value of its 2MSL -- timer (which should be set when the TimeWait state is entered) -- -- The purpose of this is to clean up the memory associated with the connection -- as soon as possible, and once it's in TimeWait, no data will flow on the -- socket. setConnection :: SocketId -> TcpSocket -> Tcp () setConnection ident con | tcpState con == TimeWait = modifyHost_ $ \ host -> host { hostTimeWaits = addTimeWait con (hostTimeWaits host) , hostConnections = Map.delete ident (hostConnections host) } | otherwise = do cons <- getConnections setConnections (Map.insert ident con cons) -- | Add a new connection to the host. addConnection :: SocketId -> TcpSocket -> Tcp () addConnection = setConnection -- | Modify an existing connection in the host. modifyConnection :: SocketId -> (TcpSocket -> TcpSocket) -> Tcp () modifyConnection sid k = do cons <- getConnections setConnections (Map.adjust k sid cons) -- | Remove a connection from the host. remConnection :: SocketId -> Tcp () remConnection sid = do cons <- getConnections setConnections (Map.delete sid cons) -- | Send out a tcp segment via the IP layer. sendSegment :: IP4 -> TcpHeader -> L.ByteString -> Tcp () sendSegment dst hdr body = do ip4 <- ip4Handle output $ withIP4Source ip4 dst $ \ src -> do let ip4Hdr = emptyIP4Header { ip4DestAddr = dst , ip4Protocol = tcpProtocol , ip4DontFragment = False } pkt = renderWithTcpChecksumIP4 src dst hdr body sendIP4Packet ip4 ip4Hdr pkt -- | Get the initial sequence number. initialSeqNum :: Tcp TcpSeqNum initialSeqNum = hostInitialSeqNum `fmap` getHost -- | Increment the initial sequence number by a value. addInitialSeqNum :: TcpSeqNum -> Tcp () addInitialSeqNum sn = modifyHost_ (\host -> host { hostInitialSeqNum = hostInitialSeqNum host + sn }) -- | Allocate a new port for use. allocatePort :: Tcp TcpPort allocatePort = do host <- getHost case takePort host of Just (p,host') -> do setHost host' return p Nothing -> mzero -- | Release a used port. closePort :: TcpPort -> Tcp () closePort port = modifyHost_ (releasePort port) -- Socket Monad ---------------------------------------------------------------- -- | Tcp operations in the context of a socket. -- -- This implementation is a bit ridiculous, and when the eventual rewrite comes -- this should be one of the first things to be reconsidered. The basic problem -- is that if you rely on the `finished` implementation for the Layer monad, you -- exit from the socket context as well, losing any changes that have been made -- locally. This gives the ability to simulate `finished`, with the benefit of -- only yielding from the Sock context, not the whole Tcp context. newtype Sock a = Sock { unSock :: forall r. TcpSocket -> Escape r -> Next a r -> Tcp r } type Escape r = TcpSocket -> Tcp r type Next a r = TcpSocket -> a -> Tcp r instance Functor Sock where fmap f m = Sock $ \s x k -> unSock m s x $ \s' a -> k s' (f a) {-# INLINE fmap #-} instance Applicative Sock where {-# INLINE pure #-} pure x = Sock $ \ s _ k -> k s x {-# INLINE (<*>) #-} f <*> a = Sock $ \ s x k -> unSock f s x $ \ s' g -> unSock a s' x $ \ s'' b -> k s'' (g b) instance Monad Sock where {-# INLINE return #-} return = pure m >>= f = Sock $ \ s x k -> unSock m s x $ \ s' a -> unSock (f a) s' x k {-# INLINE (>>=) #-} m >> n = Sock $ \ s x k -> unSock m s x $ \ s' _ -> unSock n s' x k {-# INLINE (>>) #-} inTcp :: Tcp a -> Sock a inTcp m = Sock $ \ s _ k -> do a <- m k s a {-# INLINE inTcp #-} -- | Finish early, with no result. escape :: Sock a escape = Sock $ \ s x _ -> x s runSock_ :: TcpSocket -> Sock a -> Tcp () runSock_ tcp sm = do _ <- runSock tcp sm return () runSock :: TcpSocket -> Sock a -> Tcp (Maybe a) runSock tcp sm = do (tcp',mb) <- runSock' tcp sm setConnection (tcpSocketId tcp') $! tcp' return mb seqMaybe :: Maybe a -> () seqMaybe (Just a) = a `seq` () seqMaybe Nothing = () -- | Run the socket action, and increment its internal timestamp value. Do not -- add it back to the connections map. runSock' :: TcpSocket -> Sock a -> Tcp (TcpSocket,Maybe a) runSock' tcp sm = do now <- time let steppedTcp = tcp { tcpTimestamp = let ts' = stepTimestamp now `fmap` tcpTimestamp tcp in seqMaybe ts' `seq` ts' } res@(tcp',_) <- (unSock sm $! steppedTcp) escapeK nextK tcp' `seq` return res where escapeK s = return (s,Nothing) nextK s a = return (s,Just a) -- | Iterate for each connection, rolling back to its previous state if the -- computation fails. eachConnection :: Sock () -> Tcp () eachConnection m = do cs <- takeConnections (cs',ws) <- sandbox [] [] (Map.elems cs) modifyHost_ $ \ Host { .. } -> Host { hostConnections = cs' , hostTimeWaits = Map.union ws hostTimeWaits , .. } where -- Prevent failure in the socket action from leaking out of this scope. When -- failure is detected, just return the old TCB sandbox active timeWait (tcp:rest) = do tcp' <- fmap fst (runSock' tcp m) `mplus` return tcp if | tcpState tcp' == TimeWait -> sandbox active (tcp':timeWait) rest | tcpState tcp' == Closed && tcpUserClosed tcp' -> sandbox active timeWait rest | otherwise -> sandbox (tcp':active) timeWait rest sandbox active timeWait [] = return ( Map.fromList [ (tcpSocketId tcp, tcp) | tcp <- active ] , Map.fromList [ (tcpSocketId tcp, mkTimeWait tcp) | tcp <- timeWait ]) withConnection :: IP4 -> TcpHeader -> Sock a -> Tcp () withConnection remote hdr m = withConnection' remote hdr m mzero withConnection' :: IP4 -> TcpHeader -> Sock a -> Tcp () -> Tcp () withConnection' remote hdr m noConn = do cs <- getConnections case Map.lookup estId cs `mplus` Map.lookup listenId cs of Just con -> runSock_ con m Nothing -> noConn where estId = incomingSocketId remote hdr listenId = listenSocketId (tcpDestPort hdr) listeningConnection :: SocketId -> Sock a -> Tcp (Maybe a) listeningConnection sid m = do tcp <- getConnection sid guard (tcpState tcp == Listen && isAccepting tcp) mb <- runSock tcp m return mb -- | Run a socket operation in the context of the socket identified by the -- socket id. -- -- XXX this should really be renamed, as it's not guarding on the state of the -- socket establishedConnection :: SocketId -> Sock a -> Tcp () establishedConnection sid m = do tcp <- getConnection sid runSock_ tcp m -- | Get the parent id of the current socket, and fail if it doesn't exist. getParent :: Sock (Maybe SocketId) getParent = tcpParent `fmap` getTcpSocket -- | Run an action in the context of the socket's parent. Returns `Nothing` if -- the connection has no parent. inParent :: Sock a -> Sock (Maybe a) inParent m = do mbPid <- getParent case mbPid of Just pid -> inTcp $ do p <- getConnection pid mb <- runSock p m return mb Nothing -> return Nothing withChild :: TcpSocket -> Sock a -> Sock (Maybe a) withChild tcp m = inTcp $ do mb <- runSock tcp m return mb getTcpSocket :: Sock TcpSocket getTcpSocket = Sock (\s _ k -> k s $! s) {-# INLINE getTcpSocket #-} setTcpSocket :: TcpSocket -> Sock () setTcpSocket tcp = Sock (\ _ _ k -> (k $! tcp) ()) {-# INLINE setTcpSocket #-} getTcpTimers :: Sock TcpTimers getTcpTimers = tcpTimers `fmap` getTcpSocket modifyTcpSocket :: (TcpSocket -> (a,TcpSocket)) -> Sock a modifyTcpSocket f = Sock $ \ s _ k -> let (a,s') = f s in (k $! s') a modifyTcpSocket_ :: (TcpSocket -> TcpSocket) -> Sock () modifyTcpSocket_ k = modifyTcpSocket (\tcp -> ((), k tcp)) modifyTcpTimers :: (TcpTimers -> (a,TcpTimers)) -> Sock a modifyTcpTimers k = modifyTcpSocket $ \ tcp -> let (a,t') = k (tcpTimers tcp) in (a,tcp { tcpTimers = t' }) modifyTcpTimers_ :: (TcpTimers -> TcpTimers) -> Sock () modifyTcpTimers_ k = modifyTcpTimers (\t -> ((), k t)) -- | Set the state of the current connection. setState :: ConnState -> Sock () setState state = modifyTcpSocket_ (\tcp -> tcp { tcpState = state }) -- | Get the state of the current connection. getState :: Sock ConnState getState = tcpState `fmap` getTcpSocket whenState :: ConnState -> Sock () -> Sock () whenState state body = do curState <- getState when (state == curState) body whenStates :: [ConnState] -> Sock () -> Sock () whenStates states body = do curState <- getState when (curState `elem` states) body pushAcceptor :: Acceptor -> Sock () pushAcceptor k = modifyTcpSocket_ $ \ tcp -> tcp { tcpAcceptors = tcpAcceptors tcp Seq.|> k } -- | Pop off an acceptor. popAcceptor :: Sock (Maybe Acceptor) popAcceptor = do tcp <- getTcpSocket case Seq.viewl (tcpAcceptors tcp) of a Seq.:< as -> do setTcpSocket $! tcp { tcpAcceptors = as } return (Just a) Seq.EmptyL -> return Nothing -- | Send a notification back to a waiting process that the socket has been -- established, or that it has failed. It's assumed that this will only be -- called from the context of a user socket, so when the parameter is @False@, -- the user close field will be set to true. notify :: Bool -> Sock () notify success = do mbNotify <- modifyTcpSocket $ \ tcp -> let tcp' = tcp { tcpNotify = Nothing } in if success then (tcpNotify tcp, tcp') else (tcpNotify tcp, tcp' { tcpUserClosed = True }) case mbNotify of Just f -> outputS (f success) Nothing -> return () -- | Output some IO to the Tcp layer. outputS :: IO () -> Sock () outputS = inTcp . output advanceRcvNxt :: TcpSeqNum -> Sock () advanceRcvNxt n = modifyTcpSocket_ (\tcp -> tcp { tcpIn = addRcvNxt n (tcpIn tcp) }) advanceSndNxt :: TcpSeqNum -> Sock () advanceSndNxt n = modifyTcpSocket_ (\tcp -> tcp { tcpSndNxt = tcpSndNxt tcp + n }) remoteHost :: Sock IP4 remoteHost = (sidRemoteHost . tcpSocketId) `fmap` getTcpSocket -- | Send a TCP segment in the context of a socket. tcpOutput :: TcpHeader -> L.ByteString -> Sock () tcpOutput hdr body = do dst <- remoteHost inTcp (sendSegment dst hdr body) -- | Unblock any waiting processes, in preparation to close. shutdown :: Sock () shutdown = do finalize <- modifyTcpSocket $ \ tcp -> let (wOut,bufOut) = flushWaiting (tcpOutBuffer tcp) (wIn,bufIn) = flushWaiting (tcpInBuffer tcp) in (wOut >> wIn,tcp { tcpOut = clearRetransmit (tcpOut tcp) , tcpOutBuffer = bufOut , tcpInBuffer = bufIn }) outputS finalize -- | Set the socket state to closed, and unblock any waiting processes. closeSocket :: Sock () closeSocket = do shutdown setState Closed userClose :: Sock () userClose = modifyTcpSocket_ (\tcp -> tcp { tcpUserClosed = True })