module Network.Socket.Types(
         Socket
       , newSocket
       , connectIP4
       , bindIP4
       , listen
       , acceptIP4
       , SocketStatus(..)
       , getSocketStatus
       , Direction(..)
       , directionOpen
       , withTcpSocket
       , withUdpSocket
       , socketToHandle
       , ShutdownCmd(..)
       , shutdown
       , close
       , toIP4, fromIP4
       , setNetworkStack
       )
 where

import Control.Concurrent.MVar(MVar, newMVar, withMVar,
                               readMVar, swapMVar,
                               modifyMVar, modifyMVar_)
import Control.Monad(void)
import Data.Bits(shiftL, shiftR, (.|.), (.&.))
import Data.Maybe(fromMaybe)
import Data.Typeable(Typeable)
import Data.Word(Word64, Word32)
import Hans.Addr(wildcardAddr)
import Hans.IP4(IP4, packIP4, unpackIP4)
import Hans.Lens(view)
import Hans.Socket(TcpSocket, TcpListenSocket, UdpSocket, SockPort,
                   tcpRemoteAddr, tcpRemotePort, defaultSocketConfig,
                   newUdpSocket, sConnect, sListen, sAccept, sClose)
import Hans.Socket.Handle(makeHansHandle)
import Hans.Types(NetworkStack)
import System.IO(Handle, IOMode(..))
import System.IO.Unsafe(unsafePerformIO)

data Socket = Socket (MVar SocketState) Word64

instance Eq Socket where
  (Socket _ a) == (Socket _ b) = a == b

newSocket :: Bool -> IO Socket
newSocket isTcp =
  do mv    <- newMVar (if isTcp then CreatedTcp else CreatedUdp)
     ident <- nextSocketIdent
     return (Socket mv ident)

data SocketState = CreatedTcp | CreatedUdp
                 | BoundUdp     (UdpSocket IP4)
                 | BoundTcp     (Maybe IP4)     (Maybe SockPort)
                 | ListeningTcp (TcpListenSocket IP4)
                 | ConnectedTcp (TcpSocket IP4) Direction
                 | Converted
                 | ClosedSocket

connectIP4 :: Socket -> IP4 -> SockPort -> IO ()
connectIP4 (Socket mvs _) addr port =
  modifyMVar_ mvs $ \ curstate ->
    case curstate of
      CreatedTcp ->
        do let conf = defaultSocketConfig
               src  = wildcardAddr undefined
           ns <- readMVar evilNetworkStackMVar
           sock <- sConnect ns conf Nothing src Nothing addr port
           return (ConnectedTcp sock ForBoth)
      CreatedUdp ->
        fail "Cannot connect UDP socket."
      BoundUdp _ ->
        fail "Cannot connect bound UDP socket."
      BoundTcp maddr mport ->
        do let conf = defaultSocketConfig
               src  = fromMaybe (wildcardAddr undefined) maddr
           ns <- readMVar evilNetworkStackMVar
           sock <- sConnect ns conf Nothing src mport addr port
           return (ConnectedTcp sock ForBoth)
      ListeningTcp _ ->
        fail "Canoot connect listening TCP socket."
      ConnectedTcp _ _ ->
        fail "Cannot connect connected TCP socket."
      Converted ->
        fail "Cannot connect converted socket."
      ClosedSocket ->
        fail "Cannot connect closed socket."

bindIP4 :: Socket -> Maybe IP4 -> Maybe SockPort -> IO ()
bindIP4 (Socket mvs _) maddr mport =
  modifyMVar_ mvs $ \ curstate ->
    case curstate of
      CreatedTcp ->
        return (BoundTcp maddr mport)
      CreatedUdp ->
        do let addr    = fromMaybe (wildcardAddr undefined) maddr
               conf    = defaultSocketConfig
           ns <- readMVar evilNetworkStackMVar
           sock <- newUdpSocket ns conf Nothing addr mport
           return (BoundUdp sock)
      BoundUdp _ ->
        fail "Cannot re-bind bound UDP port."
      BoundTcp _ _ ->
        fail "Cannot re-bind bound TCP port."
      ListeningTcp _ ->
        fail "Cannot bind listening TCP socket."
      ConnectedTcp _ _ ->
        fail "Cannot bind connected TCP socket."
      Converted ->
        fail "Cannot bind converted socket."
      ClosedSocket ->
        fail "Cannot connect closed socket."

listen :: Socket -> Int -> IO ()
listen (Socket mvs _) backlog =
  modifyMVar_ mvs $ \ curstate ->
    case curstate of
      BoundTcp maddr (Just port) ->
        do let addr = fromMaybe (wildcardAddr undefined) maddr
               conf = defaultSocketConfig
           ns <- readMVar evilNetworkStackMVar
           lsock <- sListen ns conf addr port backlog
           return (ListeningTcp lsock)
      BoundTcp _ Nothing ->
        fail "Cannot listen on socket with unbound port."
      _ ->
        fail "Cannot listen on unbound TCP port."

acceptIP4 :: Socket -> IO (Socket, IP4, SockPort)
acceptIP4 (Socket mvs _) =
  do curstate <- readMVar mvs
     case curstate of
       ListeningTcp lsock ->
         do sock <- sAccept lsock
            let addr = view tcpRemoteAddr sock
                port = view tcpRemotePort sock
            stateMV <- newMVar (ConnectedTcp sock ForBoth)
            ident   <- nextSocketIdent
            return (Socket stateMV ident, addr, port)
       _ ->
         fail "Illegal state for accept socket."

data ShutdownCmd = ShutdownReceive | ShutdownSend | ShutdownBoth
 deriving (Typeable, Eq)

shutdown :: Socket -> ShutdownCmd -> IO ()
shutdown (Socket mvs _) cmd =
  modifyMVar_ mvs $ \ curstate ->
    case curstate of
      ConnectedTcp sock _ | cmd == ShutdownBoth ->
        sClose sock >> return ClosedSocket
      ConnectedTcp sock ForRead | cmd == ShutdownReceive ->
        sClose sock >> return ClosedSocket
      ConnectedTcp sock ForWrite | cmd == ShutdownSend ->
        sClose sock >> return ClosedSocket
      ConnectedTcp sock _ | cmd == ShutdownReceive ->
        return (ConnectedTcp sock ForWrite)
      ConnectedTcp sock _ | cmd == ShutdownSend ->
        return (ConnectedTcp sock ForRead)
      ConnectedTcp _ _ ->
        fail "Internal consistency error in shutdown."
      _ ->
        fail "Shutdown called on un-connected socket."

close :: Socket -> IO ()
close (Socket mvs _) =
  modifyMVar_ mvs (\ _ -> return ClosedSocket)

socketToHandle :: Socket -> IOMode -> IO Handle
socketToHandle (Socket mvs _) mode =
  modifyMVar mvs $ \ curstate ->
    case curstate of
      ConnectedTcp sock ForBoth ->
        do hndl <- makeHansHandle sock mode
           return (Converted, hndl)
      ConnectedTcp sock ForRead | mode == ReadMode ->
        do hndl <- makeHansHandle sock mode
           return (Converted, hndl)
      ConnectedTcp sock ForWrite | mode `elem` [AppendMode, WriteMode] ->
        do hndl <- makeHansHandle sock mode
           return (Converted, hndl)
      ConnectedTcp _ allowed ->
        fail ("Access error converted socket to handle. Socket is in " ++
              show allowed ++ " mode, but IOMode was " ++ show mode)
      _ ->
        fail ("Cannot convert unconnected socket to a handle.")

data SocketStatus = NotConnected
                  | Bound
                  | Listening
                  | Connected
                  | ConvertedToHandle
                  | Closed
 deriving (Eq, Show)

getSocketStatus :: Socket -> IO SocketStatus
getSocketStatus (Socket mvs _) =
  withMVar mvs $ \ curstate ->
    case curstate of
      CreatedTcp       -> return NotConnected
      CreatedUdp       -> return NotConnected
      BoundUdp     _   -> return Bound
      BoundTcp     _ _ -> return Bound
      ListeningTcp _   -> return Listening
      ConnectedTcp _ _ -> return Connected
      Converted        -> return ConvertedToHandle
      ClosedSocket     -> return Closed

data Direction = ForWrite | ForRead | ForBoth | ForNeither
 deriving (Show)

directionOpen :: Socket -> Direction -> IO Bool
directionOpen (Socket mvs _) req =
  withMVar mvs $ \ curstate ->
    case curstate of
      ConnectedTcp _ dir -> return (modesMatch req dir)
      _                  -> return False

withTcpSocket :: Socket -> Direction -> (TcpSocket IP4 -> IO a) -> IO a
withTcpSocket (Socket mvs _) dir action =
  withMVar mvs $ \ curstate ->
    case curstate of
      ConnectedTcp sock dir' | modesMatch dir dir' ->
        action sock
      ConnectedTcp _ dir' ->
        fail ("Mismatch between requested direction (" ++ show dir ++
              ") and allowed (" ++ show dir' ++ ")")
      _ ->
        fail ("TCP operation on non-TCP socket.")

withUdpSocket :: Socket -> (UdpSocket IP4 -> IO a) -> IO a
withUdpSocket (Socket mvs _) action =
  withMVar mvs $ \ curstate ->
    case curstate of
      BoundUdp udps ->
        action udps
      _ ->
        fail "UDP operation on non-UDP socket."

toIP4 :: Word32 -> IP4
toIP4 w32 = packIP4 a b c d
 where
  a = fromIntegral ((w32 `shiftR` 24) .&. 0xFF)
  b = fromIntegral ((w32 `shiftR` 16) .&. 0xFF)
  c = fromIntegral ((w32 `shiftR`  8) .&. 0xFF)
  d = fromIntegral ((w32 `shiftR`  0) .&. 0xFF)

fromIP4 :: IP4 -> Word32
fromIP4 ipaddr = w32
 where
  (a, b, c, d) = unpackIP4 ipaddr
  --
  w32 = a' .|. b' .|. c' .|. d'
  a'  = fromIntegral a `shiftL` 24
  b'  = fromIntegral b `shiftL` 16
  c'  = fromIntegral c `shiftL`  8
  d'  = fromIntegral d `shiftL`  0
 
modesMatch :: Direction -> Direction -> Bool
modesMatch ForBoth    ForBoth  = True
modesMatch ForBoth    _        = False
modesMatch ForRead    ForBoth  = True
modesMatch ForRead    ForRead  = True
modesMatch ForRead    _        = False
modesMatch ForWrite   ForBoth  = True
modesMatch ForWrite   ForWrite = True
modesMatch ForWrite   _        = False
modesMatch ForNeither _        = True

{-# NOINLINE evilNetworkStackMVar #-}
evilNetworkStackMVar :: MVar NetworkStack
evilNetworkStackMVar =
  unsafePerformIO (newMVar (error "Access before network stack set!"))

{-# NOINLINE evilSocketIDMVar #-}
evilSocketIDMVar :: MVar Word64
evilSocketIDMVar =
  unsafePerformIO (newMVar 1)

setNetworkStack :: NetworkStack -> IO ()
setNetworkStack = void . swapMVar evilNetworkStackMVar

nextSocketIdent :: IO Word64
nextSocketIdent = modifyMVar evilSocketIDMVar (\ x -> return (x + 1, x))