module Hans.Layer.IP4.Fragmentation where import Hans.Address import Hans.Address.IP4 import Hans.Message.Ip4 import Hans.Utils (chunk) import Data.Ord (comparing) import Data.Time.Clock.POSIX (POSIXTime) import qualified Data.ByteString.Lazy as L import qualified Data.Map.Strict as Map import qualified Data.ByteString as S type FragmentationTable addr = Map.Map (Ident,addr,addr) Fragments emptyFragmentationTable :: FragmentationTable IP4 emptyFragmentationTable = Map.empty data Fragments = Fragments { startTime :: !POSIXTime , totalSize :: {-# UNPACK #-} !Int , fragments :: ![Fragment] } deriving Show data Fragment = Fragment { fragmentOffset :: {-# UNPACK #-} !Int , fragmentLength :: {-# UNPACK #-} !Int , fragmentPayload :: !L.ByteString } deriving (Eq,Show) instance Ord Fragment where compare = comparing fragmentOffset -- | The end of a fragment. fragmentEnd :: Fragment -> Int fragmentEnd f = fragmentOffset f + fragmentLength f -- | Check the ordering of two fragments. comesBefore :: Fragment -> Fragment -> Bool comesBefore f g = fragmentEnd f == fragmentOffset g -- | Check the ordering of two fragments. comesAfter :: Fragment -> Fragment -> Bool comesAfter = flip comesBefore -- | Merge two fragments. -- -- Note: This doesn't do a validity check to make sure that they're actually -- adjacent. combineFragments :: Fragment -> Fragment -> Fragment combineFragments f g = Fragment (fragmentOffset f) len pay where len = fragmentLength f + fragmentLength g pay = fragmentPayload f `L.append` fragmentPayload g -- | Given a group of fragments, a new fragment, and a possible total size, -- create a new group of fragments that incorporates the new fragment. expandGroup :: Fragments -> Fragment -> Int -> Fragments expandGroup fs newfrag x = case totalSize fs of -1 | x >= 0 -> expandGroup fs{ totalSize = x } newfrag x _ -> fs { fragments = addFragment newfrag (fragments fs) } -- | Add a fragment to a list of fragments, in a position that is relative to -- its offset and length. addFragment :: Fragment -> [Fragment] -> [Fragment] addFragment f fs = case fs of [] -> [f] g:rest | f `comesBefore` g -> addFragment (combineFragments f g) rest | f `comesAfter` g -> addFragment (combineFragments g f) rest | f < g -> f:fs | otherwise -> g:(addFragment f rest) -- | Process a packet fragment through the system, potentially returning a -- fully-processed packet if this fragment completes an existing packet or -- is itself a fully-complete packet. processFragment :: Address addr => POSIXTime -> FragmentationTable addr -> Bool -> Int -> addr -> addr -> Ident -> S.ByteString -> (FragmentationTable addr, Maybe L.ByteString) processFragment _ table False 0 _ _ _ bs = (table, Just (chunk bs)) processFragment now table areMore off src dest ident bs = case group of Fragments _ x [Fragment 0 y bs'] | x == y -> (Map.delete entry table, Just bs') _ -> (Map.insert entry group table, Nothing) where entry = (ident,src,dest) group = case Map.lookup (ident,src,dest) table of Nothing -> Fragments now newTotalLen [cur] Just g -> expandGroup g cur newTotalLen curlen = S.length bs cur = Fragment off curlen (chunk bs) newTotalLen | areMore = -1 | otherwise = off + curlen processIP4Packet :: POSIXTime -> FragmentationTable IP4 -> IP4Header -> S.ByteString -> (FragmentationTable IP4, Maybe L.ByteString) processIP4Packet now table hdr bs = processFragment now table areMore off src dest ident bs where off = fromIntegral (ip4FragmentOffset hdr) ident = fromIntegral (ip4Ident hdr) areMore = ip4MoreFragments hdr src = ip4SourceAddr hdr dest = ip4DestAddr hdr