{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Hans.Message.Tcp where import Hans.Address.IP4 (IP4) import Hans.Message.Ip4 (mkIP4PseudoHeader,IP4Protocol(..)) import Hans.Utils (chunk) import Hans.Utils.Checksum import Control.Monad (unless,ap,replicateM_,replicateM) import Data.Bits ((.&.),setBit,testBit,shiftL,shiftR) import Data.List (find) import Data.Monoid (Monoid(..)) import Data.Serialize (Get,Put,Putter,getWord16be,putWord16be,getWord32be,putWord32be,getWord8 ,putWord8,putByteString,getBytes,remaining,label,isolate,skip,runGet,runPut ,putLazyByteString) import Data.Word (Word8,Word16,Word32) import System.IO.Unsafe (unsafePerformIO) import qualified Data.ByteString.Lazy as L import qualified Data.ByteString as S import qualified Data.Foldable as F -- Tcp Support Types ----------------------------------------------------------- tcpProtocol :: IP4Protocol tcpProtocol = IP4Protocol 0x6 newtype TcpPort = TcpPort { getPort :: Word16 } deriving (Eq,Ord,Read,Show,Num,Enum,Bounded) putTcpPort :: Putter TcpPort putTcpPort (TcpPort w16) = putWord16be w16 getTcpPort :: Get TcpPort getTcpPort = TcpPort `fmap` getWord16be newtype TcpSeqNum = TcpSeqNum { getSeqNum :: Word32 } deriving (Eq,Ord,Show,Num,Bounded,Enum,Real,Integral) instance Monoid TcpSeqNum where mempty = 0 mappend = (+) putTcpSeqNum :: Putter TcpSeqNum putTcpSeqNum (TcpSeqNum w32) = putWord32be w32 getTcpSeqNum :: Get TcpSeqNum getTcpSeqNum = TcpSeqNum `fmap` getWord32be -- | An alias to TcpSeqNum, as these two are used in the same role. type TcpAckNum = TcpSeqNum putTcpAckNum :: Putter TcpAckNum putTcpAckNum = putTcpSeqNum getTcpAckNum :: Get TcpAckNum getTcpAckNum = getTcpSeqNum -- Tcp Header ------------------------------------------------------------------ -- 0 1 2 3 -- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -- | Source Port | Destination Port | -- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -- | Sequence Number | -- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -- | Acknowledgment Number | -- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -- | Data | |C|E|U|A|P|R|S|F| | -- | Offset| Res. |W|C|R|C|S|S|Y|I| Window | -- | | |R|E|G|K|H|T|N|N| | -- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -- | Checksum | Urgent Pointer | -- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -- | Options | Padding | -- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -- | data | -- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ data TcpHeader = TcpHeader { tcpSourcePort :: !TcpPort , tcpDestPort :: !TcpPort , tcpSeqNum :: !TcpSeqNum , tcpAckNum :: !TcpAckNum , tcpCwr :: !Bool , tcpEce :: !Bool , tcpUrg :: !Bool , tcpAck :: !Bool , tcpPsh :: !Bool , tcpRst :: !Bool , tcpSyn :: !Bool , tcpFin :: !Bool , tcpWindow :: !Word16 , tcpChecksum :: !Word16 , tcpUrgentPointer :: !Word16 , tcpOptions :: [TcpOption] } deriving (Eq,Show) instance HasTcpOptions TcpHeader where findTcpOption tag hdr = findTcpOption tag (tcpOptions hdr) setTcpOption opt hdr = hdr { tcpOptions = setTcpOption opt (tcpOptions hdr) } emptyTcpHeader :: TcpHeader emptyTcpHeader = TcpHeader { tcpSourcePort = TcpPort 0 , tcpDestPort = TcpPort 0 , tcpSeqNum = 0 , tcpAckNum = 0 , tcpCwr = False , tcpEce = False , tcpUrg = False , tcpAck = False , tcpPsh = False , tcpRst = False , tcpSyn = False , tcpFin = False , tcpWindow = 0 , tcpChecksum = 0 , tcpUrgentPointer = 0 , tcpOptions = [] } -- | The length of the fixed part of the TcpHeader, in 4-byte octets. tcpFixedHeaderLength :: Int tcpFixedHeaderLength = 5 -- | Render a TcpHeader. The checksum value is never rendered, as it is -- expected to be calculated and poked in afterwords. putTcpHeader :: Putter TcpHeader putTcpHeader hdr = do putTcpPort (tcpSourcePort hdr) putTcpPort (tcpDestPort hdr) putTcpSeqNum (tcpSeqNum hdr) putTcpAckNum (tcpAckNum hdr) let (optLen,padding) = tcpOptionsLength (tcpOptions hdr) putWord8 (fromIntegral ((tcpFixedHeaderLength + optLen) `shiftL` 4)) putTcpControl hdr putWord16be (tcpWindow hdr) putWord16be 0 putWord16be (tcpUrgentPointer hdr) mapM_ putTcpOption (tcpOptions hdr) replicateM_ padding (putTcpOptionTag OptTagEndOfOptions) -- | Parse out a TcpHeader, and its length. The resulting length is in bytes, -- and is derived from the data offset. getTcpHeader :: Get (TcpHeader,Int) getTcpHeader = label "TcpHeader" $ do src <- getTcpPort dst <- getTcpPort seqNum <- getTcpSeqNum ackNum <- getTcpAckNum b <- getWord8 let len = fromIntegral ((b `shiftR` 4) .&. 0xf) cont <- getWord8 win <- getWord16be cs <- getWord16be urgent <- getWord16be let optsLen = len - tcpFixedHeaderLength opts <- label "options" (isolate (optsLen `shiftL` 2) getTcpOptions) let hdr = setTcpControl cont emptyTcpHeader { tcpSourcePort = src , tcpDestPort = dst , tcpSeqNum = seqNum , tcpAckNum = ackNum , tcpWindow = win , tcpChecksum = cs , tcpUrgentPointer = urgent , tcpOptions = filter (/= OptEndOfOptions) opts } return (hdr,len * 4) -- | Render out the @Word8@ that contains the Control field of the TcpHeader. putTcpControl :: Putter TcpHeader putTcpControl c = putWord8 $ putBit 7 tcpCwr $ putBit 6 tcpEce $ putBit 5 tcpUrg $ putBit 4 tcpAck $ putBit 3 tcpPsh $ putBit 2 tcpRst $ putBit 1 tcpSyn $ putBit 0 tcpFin 0 where putBit n prj w | prj c = setBit w n | otherwise = w -- | Parse out the control flags from the octet that contains them. setTcpControl :: Word8 -> TcpHeader -> TcpHeader setTcpControl w hdr = hdr { tcpCwr = testBit w 7 , tcpEce = testBit w 6 , tcpUrg = testBit w 5 , tcpAck = testBit w 4 , tcpPsh = testBit w 3 , tcpRst = testBit w 2 , tcpSyn = testBit w 1 , tcpFin = testBit w 0 } -- Tcp Options ----------------------------------------------------------------- class HasTcpOptions a where findTcpOption :: TcpOptionTag -> a -> Maybe TcpOption setTcpOption :: TcpOption -> a -> a setTcpOptions :: HasTcpOptions a => [TcpOption] -> a -> a setTcpOptions opts a = foldr setTcpOption a opts data TcpOptionTag = OptTagEndOfOptions | OptTagNoOption | OptTagMaxSegmentSize | OptTagWindowScaling | OptTagSackPermitted | OptTagSack | OptTagTimestamp | OptTagUnknown !Word8 deriving (Eq,Show) getTcpOptionTag :: Get TcpOptionTag getTcpOptionTag = do ty <- getWord8 return $! case ty of 0 -> OptTagEndOfOptions 1 -> OptTagNoOption 2 -> OptTagMaxSegmentSize 3 -> OptTagWindowScaling 4 -> OptTagSackPermitted 5 -> OptTagSack 8 -> OptTagTimestamp _ -> OptTagUnknown ty putTcpOptionTag :: Putter TcpOptionTag putTcpOptionTag tag = putWord8 $ case tag of OptTagEndOfOptions -> 0 OptTagNoOption -> 1 OptTagMaxSegmentSize -> 2 OptTagWindowScaling -> 3 OptTagSackPermitted -> 4 OptTagSack -> 5 OptTagTimestamp -> 8 OptTagUnknown ty -> ty instance HasTcpOptions [TcpOption] where findTcpOption tag = find p where p opt = tag == tcpOptionTag opt setTcpOption opt = loop where tag = tcpOptionTag opt loop [] = [opt] loop (o:opts) | tcpOptionTag o == tag = opt : opts | otherwise = o : loop opts data TcpOption = OptEndOfOptions | OptNoOption | OptMaxSegmentSize !Word16 | OptWindowScaling !Word8 | OptSackPermitted | OptSack [SackBlock] | OptTimestamp !Word32 !Word32 | OptUnknown !Word8 !Word8 !S.ByteString deriving (Show,Eq) data SackBlock = SackBlock { sbLeft :: !TcpSeqNum , sbRight :: !TcpSeqNum } deriving (Show,Eq) tcpOptionTag :: TcpOption -> TcpOptionTag tcpOptionTag opt = case opt of OptEndOfOptions{} -> OptTagEndOfOptions OptNoOption{} -> OptTagNoOption OptMaxSegmentSize{} -> OptTagMaxSegmentSize OptSackPermitted{} -> OptTagSackPermitted OptSack{} -> OptTagSack OptWindowScaling{} -> OptTagWindowScaling OptTimestamp{} -> OptTagTimestamp OptUnknown ty _ _ -> OptTagUnknown ty -- | Get the rendered length of a list of TcpOptions, in 4-byte words, and the -- number of padding bytes required. This rounds up to the nearest 4-byte word. tcpOptionsLength :: [TcpOption] -> (Int,Int) tcpOptionsLength opts | left == 0 = (len,0) | otherwise = (len + 1,4 - left) where (len,left) = F.sum (fmap tcpOptionLength opts) `quotRem` 4 tcpOptionLength :: TcpOption -> Int tcpOptionLength opt = case opt of OptEndOfOptions{} -> 1 OptNoOption{} -> 1 OptMaxSegmentSize{} -> 4 OptWindowScaling{} -> 3 OptSackPermitted{} -> 2 OptSack bs -> sackLength bs OptTimestamp{} -> 10 OptUnknown _ len _ -> fromIntegral len putTcpOption :: Putter TcpOption putTcpOption opt = do putTcpOptionTag (tcpOptionTag opt) case opt of OptEndOfOptions -> return () OptNoOption -> return () OptMaxSegmentSize mss -> putMaxSegmentSize mss OptWindowScaling w -> putWindowScaling w OptSackPermitted -> putSackPermitted OptSack bs -> putSack bs OptTimestamp v r -> putTimestamp v r OptUnknown _ len bs -> putUnknown len bs -- | Parse in known tcp options. getTcpOptions :: Get [TcpOption] getTcpOptions = label "Tcp Options" loop where loop = do left <- remaining if left > 0 then body else return [] body = do opt <- getTcpOption case opt of OptEndOfOptions -> do skip =<< remaining return [] _ -> do rest <- loop return (opt:rest) getTcpOption :: Get TcpOption getTcpOption = do tag <- getTcpOptionTag case tag of OptTagEndOfOptions -> return OptEndOfOptions OptTagNoOption -> return OptNoOption OptTagMaxSegmentSize -> getMaxSegmentSize OptTagWindowScaling -> getWindowScaling OptTagSackPermitted -> getSackPermitted OptTagSack -> getSack OptTagTimestamp -> getTimestamp OptTagUnknown ty -> getUnknown ty getMaxSegmentSize :: Get TcpOption getMaxSegmentSize = label "Max Segment Size" $ isolate 3 $ do len <- getWord8 unless (len == 4) (fail ("Unexpected length: " ++ show len)) OptMaxSegmentSize `fmap` getWord16be putMaxSegmentSize :: Putter Word16 putMaxSegmentSize w16 = do putWord8 4 putWord16be w16 getSackPermitted :: Get TcpOption getSackPermitted = label "Sack Permitted" $ isolate 1 $ do len <- getWord8 unless (len == 2) (fail ("Unexpected length: " ++ show len)) return OptSackPermitted putSackPermitted :: Put putSackPermitted = do putWord8 2 getSack :: Get TcpOption getSack = label "Sack" $ do len <- getWord8 let edgeLen = fromIntegral len - 2 OptSack `fmap` isolate edgeLen (replicateM (edgeLen `shiftR` 3) getSackBlock) putSack :: Putter [SackBlock] putSack bs = do putWord8 (fromIntegral (sackLength bs)) mapM_ putSackBlock bs getSackBlock :: Get SackBlock getSackBlock = do l <- getTcpSeqNum r <- getTcpSeqNum return $! SackBlock { sbLeft = l , sbRight = r } putSackBlock :: Putter SackBlock putSackBlock sb = do putTcpSeqNum (sbLeft sb) putTcpSeqNum (sbRight sb) sackLength :: [SackBlock] -> Int sackLength bs = length bs * 8 + 2 getWindowScaling :: Get TcpOption getWindowScaling = label "Window Scaling" $ isolate 2 $ do len <- getWord8 unless (len == 3) (fail ("Unexpected length: " ++ show len)) OptWindowScaling `fmap` getWord8 putWindowScaling :: Putter Word8 putWindowScaling w = do putWord8 3 putWord8 w getTimestamp :: Get TcpOption getTimestamp = label "Timestamp" $ isolate 9 $ do len <- getWord8 unless (len == 10) (fail ("Unexpected length: " ++ show len)) OptTimestamp `fmap` getWord32be `ap` getWord32be putTimestamp :: Word32 -> Word32 -> Put putTimestamp v r = do putWord8 10 putWord32be v putWord32be r getUnknown :: Word8 -> Get TcpOption getUnknown ty = do len <- getWord8 body <- isolate (fromIntegral len - 2) (getBytes =<< remaining) return (OptUnknown ty len body) putUnknown :: Word8 -> S.ByteString -> Put putUnknown len body = do putWord8 len putByteString body -- Tcp Packet ------------------------------------------------------------------ {-# INLINE parseTcpPacket #-} parseTcpPacket :: S.ByteString -> Either String (TcpHeader,S.ByteString) parseTcpPacket bytes = runGet getTcpPacket bytes -- | Parse a TcpPacket. getTcpPacket :: Get (TcpHeader,S.ByteString) getTcpPacket = do pktLen <- remaining (hdr,hdrLen) <- getTcpHeader body <- getBytes (pktLen - hdrLen) return (hdr,body) -- | Render out a TcpPacket, without calculating its checksum. putTcpPacket :: TcpHeader -> L.ByteString -> Put putTcpPacket hdr body = do putTcpHeader hdr putLazyByteString body -- | Calculate the checksum of a TcpHeader, and its body. renderWithTcpChecksumIP4 :: IP4 -> IP4 -> TcpHeader -> L.ByteString -> L.ByteString renderWithTcpChecksumIP4 src dst hdr body = chunk hdrbs `L.append` body where (hdrbs,_) = computeTcpChecksumIP4 src dst hdr body -- | Calculate the checksum of a tcp packet, and return its rendered header. computeTcpChecksumIP4 :: IP4 -> IP4 -> TcpHeader -> L.ByteString -> (S.ByteString,Word16) computeTcpChecksumIP4 src dst hdr body = -- this is safe, as the header bytestring that gets modified is modified at -- its creation time. (cs `seq` unsafePerformIO (pokeChecksum cs hdrbs 16), cs) where phcs = computePartialChecksum emptyPartialChecksum $ mkIP4PseudoHeader src dst tcpProtocol $ S.length hdrbs + fromIntegral (L.length body) hdrbs = runPut (putTcpHeader hdr { tcpChecksum = 0 }) hdrcs = computePartialChecksum phcs hdrbs cs = finalizeChecksum (computePartialChecksumLazy hdrcs body) -- | Re-create the checksum, minimizing duplication of the original, rendered -- TCP packet. validateTcpChecksumIP4 :: IP4 -> IP4 -> S.ByteString -> Bool validateTcpChecksumIP4 src dst bytes = finalizeChecksum (computePartialChecksum phcs bytes) == 0 where phcs = computePartialChecksum emptyPartialChecksum $ mkIP4PseudoHeader src dst tcpProtocol $ S.length bytes