{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Etherbunny.Ip
-- Copyright   :  (c) Nicholas Burlett 2007
-- License     :  GPL (see the file LICENSE)
--
-- Maintainer  :  nickburlett@mac.com
-- Stability   :  experimental
-- Portability :  ghc
--
-- IP Packet access for Etherbunny.
--
-----------------------------------------------------------------------------

module Network.Etherbunny.Ip (
    -- * Types
    IPPkt,
    IPVerIHL,

    -- * Functions
    getIPPacket,
) where

-- import Network.Etherbunny.Packet (wordsToWord16, wordsToWord32)
import Network.Etherbunny.Tcp

import Data.Word
import Numeric
import Bits
import Network.Socket (HostAddress)
import Data.Binary.Get
import qualified Data.ByteString as B


-- |
--   Represents both the IP version and IP Header Length as a single Word8
--

newtype IPVerIHL = IPVerIHL Word8
    deriving (Eq, Ord, Bits, Num, Integral, Enum, Real, Show)

-- |
--   Get the IP Version
--
ipVersion :: IPVerIHL -> Word8
ipVersion (IPVerIHL v) = v `shiftR` 4

-- |
--   Get the IP Header Length
--
ipHeaderLength :: IPVerIHL -> Word8
ipHeaderLength (IPVerIHL v) = v .&. 0x0f



-- |
--   Represents both the IP flags and IP fragment version a single Word8
--

newtype IPFlagsFragment = IPFlagsFragment Word16
    deriving (Eq, Ord, Bits, Num, Integral, Enum, Real, Show)

-- |
--   Get the IP Version
--
-- ipFlags :: IPFlagsFragment -> Word16
-- ipFlags (IPFlagsFragment v) = v `shiftR` 13

-- |
--   Get the IP Header Length
--

-- ipFragmentOffset :: IPFlagsFragment -> Word16
-- ipFragmentOffset (IPFlagsFragment v) = v .&. 0x1fffffff


-- |
--   The protocol for the data in an IP packet
--

newtype IPProtocol = IPProtocol Word8
    deriving (Eq, Ord, Bits, Num, Integral, Enum, Real, Show)

-- ipProtocolFromList :: [Word8] -> Word16
-- ipProtocolFromList = wordsToWord16


-- |
--   The Type of Service
--

newtype IPTOS = IPTOS Word8
    deriving (Eq, Ord, Bits, Num, Integral, Enum, Real, Show)

-- ipTosFromList = wordsToWord32

-- |
--  The IPPayload type is used to store each of the possible payload
--  that etherbunny knows about
--

-- data IPPayload = Foo

-- |
--   The IPPkt type gives an interface to Internet Protocol packets
--

data IPPkt = IPPkt {
         ipVerIHL           :: !IPVerIHL,   -- ^ version and header length
         ipTOS              :: !IPTOS,      -- ^ type of service
         ipTotalLength      :: !Word16,     -- ^ total length of the datagram
         ipIdentification   :: !Word16,     -- ^ identification for all fragments of this datagram
         ipFlagsFragment    :: !IPFlagsFragment, -- ^ both the flags and the fragment offset
         ipTTL              :: !Word8,      -- ^ time to live
         ipProtocol         :: !IPProtocol, -- ^ protocol
         ipHeaderChecksum   :: !Word16,     -- ^ checksum of the header and options
         ipSource           :: !HostAddress, -- ^ source address
         ipDestination      :: !HostAddress, -- ^ destination address
         ipOptions          :: ![Word8],     -- ^ options
         ipPayload          :: !(Maybe TCPPkt)      -- ^ payload
       }

-- |
--  Show ip addresses in a nicer format
--
showsIP :: (Bits a) => a -> String -> String
showsIP m =
    foldr (\i a -> shows (getWord m i) . showString "." . a) (shows (getWord m 0) ) $ [3,2,1]
    where
        getWord x i = (x `shiftR` (i*8)) .&. 0xff


instance Show IPPkt where
  showsPrec p pkt =
            showString "\n  IP: Ip Version " . showsPrec p (ipVersion $ ipVerIHL pkt)
          . showString "\n      Header length " . showsPrec p (ipHeaderLength $ ipVerIHL pkt)
          . showString "\n      TOS: " . showsPrec p (ipTOS pkt)
          . showString "\n      totalLength: " . showsPrec p (ipTotalLength pkt)
          . showString "\n      Frag Ident: " . showsPrec p (ipIdentification pkt)
          . showString "\n      flags/fragment offset: " . showsPrec p (ipFlagsFragment pkt)
          . showString "\n      TTL: " . showsPrec p (ipTTL pkt)
          . showString "\n      Protocol: " . showsPrec p (ipProtocol pkt)
          . showString "\n      Header Checksum: " . showHex (ipHeaderChecksum pkt)
          . showString "\n      Source: " . showsIP (ipSource pkt)
          . showString "\n      Destination: " . showsIP (ipDestination pkt)
          . showString "\n      Options: " . showsPrec p (ipOptions pkt)
          . showString "\n      Payload: " . showsPrec p (ipPayload pkt)
          . showString "\n"


getIPPacket :: Get IPPkt
getIPPacket = do
    verihl  <- getWord8
    iptos   <- getWord8
    tlength <- getWord16be
    ident   <- getWord16be
    flgfrag <- getWord16be
    ttl     <- getWord8
    ipprot  <- getWord8
    hdrcksm <- getWord16be
    srcip   <- getWord32be
    dstip   <- getWord32be
    let hl = ipHeaderLength $ IPVerIHL verihl
    options <- getByteString $ fromIntegral $ hl - 5
    payload <- case ipprot of
        6 -> do
            let tcplen = fromIntegral $  tlength - (fromIntegral (hl*4))
            tcp <- getTCPPacket tcplen srcip dstip
            return $ Just tcp
        _ -> do
            skip $ (fromIntegral $ tlength - (fromIntegral hl)*4)
            return Nothing
    return $ IPPkt
        (IPVerIHL verihl)
        (IPTOS iptos)
        tlength
        ident
        (IPFlagsFragment flgfrag)
        ttl
        (IPProtocol ipprot)
        hdrcksm
        srcip
        dstip
        (B.unpack options)
        payload