{-# OPTIONS_GHC -fglasgow-exts #-}

{--
Copyright (c) 2006, Peng Li
              2006, Stephan A. Zdancewic
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

* Redistributions of source code must retain the above copyright
  notice, this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright
  notice, this list of conditions and the following disclaimer in the
  documentation and/or other materials provided with the distribution.

* Neither the name of the copyright owners nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--}

module Network.TCP.Type.Base where

import Foreign
import Foreign.C
import System.Time 
import System.IO.Unsafe
import Control.Exception

to_Int x    = (fromIntegral x)::Int
to_Int8  x  = (fromIntegral x)::Int8
to_Int16 x  = (fromIntegral x)::Int16
to_Int32 x  = (fromIntegral x)::Int32
to_Int64 x  = (fromIntegral x)::Int64

to_Word x   = (fromIntegral x)::Word
to_Word8 x  = (fromIntegral x)::Word8
to_Word16 x = (fromIntegral x)::Word16
to_Word32 x = (fromIntegral x)::Word32
to_Word64 x = (fromIntegral x)::Word64
 

{-# INLINE to_Int    #-}
{-# INLINE to_Int8   #-}
{-# INLINE to_Int16  #-}
{-# INLINE to_Int32  #-}
{-# INLINE to_Int64  #-}
{-# INLINE to_Word   #-}
{-# INLINE to_Word8  #-}
{-# INLINE to_Word16 #-}
{-# INLINE to_Word32 #-}
{-# INLINE to_Word64 #-}

-- Port numbers, IP addresses

type Port = Word16
newtype IPAddr   = IPAddr   Word32          deriving (Eq,Ord)
newtype TCPAddr  = TCPAddr  (IPAddr, Port)  deriving (Eq,Ord)
newtype SocketID = SocketID (Port, TCPAddr) deriving (Eq,Ord,Show)

instance Show IPAddr where
 show (IPAddr w) = (show $ w .&. 255) ++ "." ++
                   (show $ (w `shiftR` 8)  .&. 255) ++ "." ++
                   (show $ (w `shiftR` 16) .&. 255) ++ "." ++
                   (show $ (w `shiftR` 24) .&. 255)
instance Show TCPAddr where
 show (TCPAddr (ip,pt)) = (show ip) ++ ":" ++ (show pt)


get_ip :: TCPAddr -> IPAddr
get_ip (TCPAddr (i,p)) = i

get_port :: TCPAddr -> Port
get_port (TCPAddr (i,p)) = p

get_remote_addr :: SocketID -> TCPAddr
get_remote_addr (SocketID (p,a)) = a

get_local_port :: SocketID -> Port
get_local_port (SocketID (p,a)) = p

{-# INLINE get_ip #-}
{-# INLINE get_port #-}
{-# INLINE get_remote_addr #-}
{-# INLINE get_local_port #-}

-- TCP Sequence numbers

class (Eq a) => Seq32 a where
  seq_val :: a -> Word32
  seq_lt  :: a -> a -> Bool
  seq_leq :: a -> a -> Bool
  seq_gt  :: a -> a -> Bool
  seq_geq :: a -> a -> Bool
  seq_plus  :: (Integral n) => a -> n -> a
  seq_minus :: (Integral n) => a -> n -> a
  seq_diff  :: (Integral n) => a -> a -> n

instance Seq32 Word32 where
  seq_val w = w
  seq_lt  x y = (to_Int32 (x-y)) <  0
  seq_leq x y = (to_Int32 (x-y)) <= 0
  seq_gt  x y = (to_Int32 (x-y)) >  0
  seq_geq x y = (to_Int32 (x-y)) >= 0
  seq_plus  s i = assert (i>=0)  $ s + (to_Word32 i)
  seq_minus s i = assert (i>=0)  $ s - (to_Word32 i)
  seq_diff  s t = let res=fromIntegral $ to_Int32 (s-t) in assert (res>=0) res
  {-# INLINE seq_val #-}
  {-# INLINE seq_lt  #-}
  {-# INLINE seq_leq #-}
  {-# INLINE seq_gt  #-}
  {-# INLINE seq_geq #-}
  {-# INLINE seq_plus #-}
  {-# INLINE seq_minus  #-}
  {-# INLINE seq_diff #-}

newtype SeqLocal   = SeqLocal   Word32 deriving (Eq,Show,Seq32)
newtype SeqForeign = SeqForeign Word32 deriving (Eq,Show,Seq32)
newtype Timestamp  = Timestamp  Word32 deriving (Eq,Show,Seq32)

instance Ord SeqLocal where
  (<) = seq_lt
  (>) = seq_gt
  (<=) = seq_leq
  (>=) = seq_geq
  {-# INLINE (<) #-}  
  {-# INLINE (>) #-}  
  {-# INLINE (<=) #-}  
  {-# INLINE (>=) #-}  
instance Ord SeqForeign where
  (<) = seq_lt
  (>) = seq_gt
  (<=) = seq_leq
  (>=) = seq_geq
  {-# INLINE (<) #-}  
  {-# INLINE (>) #-}  
  {-# INLINE (<=) #-}  
  {-# INLINE (>=) #-}  
instance Ord Timestamp where
  (<) = seq_lt
  (>) = seq_gt
  (<=) = seq_leq
  (>=) = seq_geq
  {-# INLINE (<) #-}  
  {-# INLINE (>) #-}  
  {-# INLINE (<=) #-}  
  {-# INLINE (>=) #-}  

seq_flip_ltof (SeqLocal w) = SeqForeign w
seq_flip_ftol (SeqForeign w) = SeqLocal w

{-# INLINE seq_flip_ltof  #-}
{-# INLINE seq_flip_ftol  #-}


type Time = Int64

seconds_to_time :: Float -> Time
seconds_to_time f = round ( f * 1000*1000)

{-# INLINE seconds_to_time  #-}

-- get_current_time :: IO Time
-- get_current_time =
--  do (TOD a b) <- getClockTime
--     return $ to_Int64 (a*1000000 + (b `div` 1000000))
-- 
---------------------------------------------------------------------------

data Buffer =
  Buffer
  { buf_ptr    :: {-# UNPACK #-} !(ForeignPtr CChar)  -- a foreign pointer to data
  , buf_size   :: {-# UNPACK #-} !Int                 -- length of buffer
  , buf_offset :: {-# UNPACK #-} !Int                 -- starting address of data
  , buf_len    :: {-# UNPACK #-} !Int                 -- length of data
  }

instance Show Buffer where
  show (Buffer ptr size off len) = "Buffer:"++(show (ptr,size,off,len))

buffer_ok :: Buffer -> Bool
buffer_ok (Buffer fptr size off len) =
 (off >= 0) && (off+len <= size)

new_buffer :: Int -> IO Buffer
new_buffer 0 = 
  do (Buffer ptr size off len) <- new_buffer 1
     return $ Buffer ptr size 0 0

new_buffer size = 
  do ptr  <- mallocArray size
     fptr <- newForeignPtr finalizerFree ptr
     return $ Buffer fptr size 0 size

buffer_empty :: Buffer
buffer_empty = unsafePerformIO $ new_buffer 0


buffer_to_string :: Buffer -> IO String
buffer_to_string buf@(Buffer fptr size off len) =
  assert (buffer_ok buf) $ 
  withForeignPtr fptr
    (\ptr ->
        do arr <- peekArray len (ptr `plusPtr` off)
           return $ map castCCharToChar arr
    )
string_to_buffer :: String -> IO Buffer
string_to_buffer s =
  do let l = length s 
     c@(Buffer ptr size off len) <- new_buffer l
     withForeignPtr ptr (\ptr ->
       do pokeArray ptr (map castCharToCChar s)
      )
     return c

buffer_split :: Int -> Buffer -> (Buffer,Buffer)
buffer_split x b@(Buffer fptr size off len) = 
  let y = if x > len then len
          else if x < 0 then 0
          else x
  in
  ((Buffer fptr size off y),
  (Buffer fptr size (off+y) (len-y)))

buffer_take x b = fst $ buffer_split x b
buffer_drop x b = snd $ buffer_split x b

buffer_merge :: Buffer -> Buffer -> [Buffer]
buffer_merge b1@(Buffer ptr1 size1 offset1 len1) b2@(Buffer ptr2 size2 offset2 len2) =
   if ptr1==ptr2 && offset1+len1==offset2 then
      [Buffer ptr1 size1 offset1 (len1+len2)]
   else if len1==0 then
      [b2]
   else if len2==0 then
      [b1]
   else
      [b1,b2]
  
data BufferChain = BufferChain 
    { bufc_list   :: ![Buffer] 
    , bufc_length :: {-# UNPACK #-} !Int
    }
instance Show BufferChain where
  show (BufferChain lst len) = "BufferChain: "++(show lst)++" len="++(show len)


bufferchain_empty = BufferChain [] 0
bufferchain_singleton b = 
   if buf_len b == 0 then bufferchain_empty else BufferChain [b] (buf_len b)

bufferchain_add (Buffer _ _ _ 0) bc = bc
bufferchain_add (buf::Buffer) bc@(BufferChain lst len)  =
  assert (bufferchain_ok bc) $
  let lst2 = case lst of
             [] -> [buf]
             _ -> (buffer_merge buf (head lst)) ++ (tail lst)
  in
  BufferChain lst2 (len + (buf_len buf))

bufferchain_get bc@(BufferChain lst len) = 
  assert (bufferchain_ok bc) $ 
  b_get lst
 where
  b_get (x:xs) index =
   if index < buf_len x 
     then let (Buffer fptr size off len) =x  in
              (Buffer fptr size index 1)
     else b_get xs (index - (buf_len x))

bufferchain_append bc (Buffer _ _ _  0) = bc
bufferchain_append (BufferChain lst len) (buf::Buffer) =
  let lst2 = case lst of
             [] -> [buf]
             _ -> (take (length lst - 1) lst) ++ (buffer_merge (last lst) buf)
  in
  BufferChain lst2 (len + (buf_len buf))
  
bufferchain_concat b1 (BufferChain [] len2) = b1
bufferchain_concat (BufferChain [] len1) b2 = b2
bufferchain_concat (BufferChain lst1 len1) (BufferChain lst2 len2) =
  let lst3 = (take (length lst1 - 1) lst1) ++
             (buffer_merge (last lst1) (head lst2)) ++
             (tail lst2)
  in
  BufferChain lst3 (len1 + len2)

bufferchain_head :: BufferChain -> Buffer
bufferchain_head b = head $ bufc_list b

bufferchain_tail :: BufferChain -> BufferChain
bufferchain_tail (BufferChain lst len) = (BufferChain (tail lst) (len- (buf_len $ head lst)))


bufferchain_take :: Int -> BufferChain -> BufferChain
bufferchain_take x b@(BufferChain lst len) = 
  fst $ bufferchain_split_at x b

bufferchain_drop :: Int -> BufferChain -> BufferChain
bufferchain_drop x b@(BufferChain lst len) = 
  snd $ bufferchain_split_at x b

bufferchain_split_at :: Int -> BufferChain -> (BufferChain,BufferChain)
bufferchain_split_at z b@(BufferChain lst len) =
  let y = if z > len then len
          else if z < 0 then 0
          else z
  in
  let (lst1,lst2) = buf_split y lst in
  (BufferChain lst1 y, BufferChain lst2 (len-y))
 where
  buf_split 0 bs = ([],bs)
  buf_split x [] = error $ "buf_split, x="++(show x)
  buf_split x ((b@(Buffer ptr size off len)):bs) =
    if x < len then
        ([Buffer ptr size off x], (Buffer ptr size (off+x) (len-x)):bs)
    else if x==len then
        ([b],bs)
    else
      let (res1,res2) = buf_split (x-len) bs in
        (b:res1, res2)

bufferchain_collapse :: BufferChain -> IO Buffer
bufferchain_collapse (BufferChain [] 0) = new_buffer 0
bufferchain_collapse (BufferChain [b] _) = return b
bufferchain_collapse bc@(BufferChain lst len) = 
  do b@(Buffer fptr _ _ _) <- new_buffer len
     withForeignPtr fptr 
       (\ptr -> bufferchain_output bc ptr)
     return b

bufferchain_output bc@(BufferChain lst len) (ptr::Ptr CChar) =
  copybuf ptr lst
  where copybuf ptrDest [] = return ()
        copybuf ptrDest (x:xs) =
          withForeignPtr (buf_ptr x)
            (\ptrSrc -> do
                 copyArray ptrDest (ptrSrc `plusPtr` (buf_offset x)) (buf_len x)
                 copybuf (ptrDest `plusPtr` (buf_len x)) xs
            )

bufferchain_ok :: BufferChain -> Bool
bufferchain_ok bc@(BufferChain lst size) = 
  (size == (foldl (+) 0 (map buf_len lst)))
  && (foldl (&&) True (map buffer_ok lst))