{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Etherbunny.Tcp
-- Copyright   :  (c) Nicholas Burlett 2007
-- License     :  GPL (see the file LICENSE)
--
-- Maintainer  :  nickburlett@mac.com
-- Stability   :  experimental
-- Portability :  ghc
--
-- TCP Packet access for Etherbunny.
--
-----------------------------------------------------------------------------

module Network.Etherbunny.Tcp (
    -- * Types
    TCPPkt,

    -- * Functions
    getTCPPacket,
) where

-- import Network.Etherbunny.Packet

import Data.Word
import Numeric
import Bits
-- import Network (PortNumber)
import Data.Binary.Get
import qualified Data.ByteString as B


-- |
--   Represents the Data Offset, ECN, and Control Bits
--

newtype TCPDOECNCB = TCPDOECNCB Word16
    deriving (Eq, Ord, Bits, Num, Integral, Enum, Real, Show)

-- |
--   Get the Data Offset
--
tcpDataOffset :: TCPDOECNCB -> Word16
tcpDataOffset (TCPDOECNCB v) = v `shiftR` 12

-- |
--   Get the TCP ECN
--
tcpECN :: TCPDOECNCB -> Word16
tcpECN (TCPDOECNCB v) = (v `shiftR` 6) .&. 0x0f


-- |
--   Get the TCP Control Bits
--
tcpControlBits :: TCPDOECNCB -> Word16
tcpControlBits (TCPDOECNCB v) = v .&. 0x2f


-- |
--   Compute a TCP checksum, given a starting sum value
--

checksum :: Word16 -> B.ByteString -> Word16
checksum c b = fromIntegral $ (cb .&. 0xFFFF) + (cb `shiftR` 16) where
    cb = checksum' (fromIntegral c) 0 b
    checksum' :: Word32 -> Int -> B.ByteString -> Word32
    checksum' d r e = if r >= B.length b
        then fromIntegral d
        else
            if r == (B.length e) -1
                then d + (fromIntegral (B.index e r) `shiftL` 8)
                else let h = (fromIntegral $ B.index e r) :: Word32
                         l = (fromIntegral $ B.index e $ r+1) :: Word32
                         s = (h `shiftL` 8 .|. l) + d
                     in checksum' s (r+2) b


-- |
--   Dummy TCPPayload type
--
type TCPPayload = Int

-- |
--   The IPPkt type gives an interface to Internet Protocol packets
--

data TCPPkt = TCPPkt {
         tcpSourcePort      :: !Word16, -- ^ version and header length
         tcpDestinationPort :: !Word16, -- ^ type of service
         tcpSequenceNumber  :: !Word32,     -- ^ total length of the datagram
         tcpAcknowledgement :: !Word32,     -- ^ identification for all fragments of this datagram
         tcpDOECNCB         :: !TCPDOECNCB, -- ^ Data Offset, ECN, Control Bits
         tcpWindow          :: !Word16,     -- ^ time to live
         tcpChecksum        :: !Word16,     -- ^ checksum of the header and options
         tcpChecksumCorrect :: !Bool,       -- ^ true if the checksum is correct
         tcpUrgentPointer   :: !Word16,     -- ^ pointer to the end seq of urgent data
         tcpOptions         :: ![Word8],    -- ^ tcp options
         tcpPayload         :: !(Maybe TCPPayload)      -- ^ payload
       }


instance Show TCPPkt where
  showsPrec p pkt =
            showString "\n  TCP: Source Port " . showsPrec p (tcpSourcePort pkt)
          . showString "\n       Destination Port " . showsPrec p (tcpDestinationPort pkt)
          . showString "\n       Sequence Number: " . showsPrec p (tcpSequenceNumber pkt)
          . showString "\n       Ack Number: " . showsPrec p (tcpAcknowledgement pkt)
          . showString "\n       Data Offset: " . showsPrec p (tcpDataOffset $ tcpDOECNCB pkt)
          . showString "\n       ECN: " . showsPrec p (tcpECN $ tcpDOECNCB pkt)
          . showString "\n       Control Bits: " . showsPrec p (tcpControlBits $ tcpDOECNCB pkt)
          . showString "\n       Window: " . showsPrec p (tcpWindow pkt)
          . showString "\n       Checksum: " . showHex (tcpChecksum pkt)
                              . showString " correct? " . showsPrec p (tcpChecksumCorrect pkt)
          . showString "\n       Urgent Pointer: " . showsPrec p (tcpUrgentPointer pkt)
          . showString "\n       Options: " . showsPrec p (tcpOptions pkt)
          . showString "\n       Payload: " . showsPrec p (tcpPayload pkt)


getTCPPacket :: Int -> Word32 -> Word32 -> Get TCPPkt
getTCPPacket len srcip dstip = do
    r <- remaining
    fullBytes <- lookAhead $ getByteString $ fromIntegral r
    let headersum = srcip + dstip + (fromIntegral len) + 6
    let headersum16 = (headersum .&. 0xFFFF) + (headersum `shiftR` 16)
    let headersum16' = fromIntegral $ (headersum16 .&. 0xFFFF) + (headersum16 `shiftR` 16)
    let computedChecksum = checksum headersum16' fullBytes
    sp <- getWord16be
    dp <- getWord16be
    seqnum <- getWord32be
    ack <- getWord32be
    doecncb <- getWord16be
    window <- getWord16be
    cksum <- getWord16be
    urg <- getWord16be
    let doffset = fromIntegral $ tcpDataOffset $ TCPDOECNCB doecncb
    let optlength = (doffset*4) - 20
    opt <- getByteString optlength
    skip $ len - (doffset*4)
    return $ TCPPkt
        sp
        dp
        seqnum
        ack
        (TCPDOECNCB doecncb)
        window
        cksum
        (computedChecksum == 0xffff)
        urg
        (B.unpack opt)
        (Nothing)