{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveGeneric #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Network.Transport.TCP
-- Copyright   :  (c) Phil Hargett 2013
-- License     :  MIT (see LICENSE file)
-- 
-- Maintainer  :  phil@haphazardhouse.net
-- Stability   :  experimental
-- Portability :  non-portable (uses STM)
--
-- TCP transports deliver messages to other 'Network.Endpoints.Endpoint's using TCP/IP.
--
-- Each TCP transport manages both socket bindings and connections on behalf of
-- 'Endpoint's, dynamically opening / closing new sockets as needed to deliver
-- messages to other 'Endpoint's using TCP transports.
-----------------------------------------------------------------------------

module Network.Transport.TCP (
  newTCPTransport
  ) where

-- local imports

import Network.Transport

-- external imports

import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception

import qualified Data.ByteString as B
import qualified Data.Map as M
import Data.Serialize
import qualified Data.Set as S
import qualified Data.Text as T

import GHC.Generics

import Network.Socket (HostName,ServiceName,Socket,sClose,accept)
import Network.Simple.TCP hiding (accept)

import System.Log.Logger

--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

_log :: String
_log = "transport.tcp"

data TCPTransport = TCPTransport {
  tcpListeners :: TVar (M.Map ServiceName Socket),
  tcpMessengers :: TVar (M.Map Address Messenger),  
  tcpBindings :: TVar (M.Map Name Mailbox),
  tcpInbound :: Mailbox,
  tcpDispatchers :: S.Set (Async ()),
  tcpResolver :: Resolver
  }
                    
data IdentifyMessage = IdentifyMessage Address deriving (Generic)

instance Serialize IdentifyMessage

{-|
Create a new 'Transport' suitable for sending messages over TCP/IP.  There can
be multiple instances of these 'Transport's: 'Network.Endpoints.Endpoint' using
different instances will still be able to communicate, provided they use
correct TCP/IP addresses (or hostnames) for communication.
-}
newTCPTransport :: Resolver -> IO Transport
newTCPTransport resolver = do 
  listeners <- atomically $ newTVar M.empty
  messengers <- atomically $ newTVar M.empty
  bindings <- atomically $ newTVar M.empty
  inbound <- newMailbox
  dispatch <- async $ dispatcher bindings inbound
  let transport = TCPTransport {
        tcpListeners = listeners,
        tcpMessengers = messengers,
        tcpBindings = bindings,
        tcpInbound = inbound,
        tcpDispatchers = S.fromList [dispatch],
        tcpResolver = resolver
        }
  return Transport {
      scheme = tcpScheme,
      handles = tcpHandles transport,
      bind = tcpBind transport,
      sendTo = tcpSendTo transport,
      shutdown = tcpShutdown transport
      }

--------------------------------------------------------------------------------
                        
{-|
Parse a TCP 'Address' into its respective 'HostName' and 'PortNumber' components, on the
assumption the 'Address' has an identifer in the format @host:port@. If
the port number is missing from the supplied address, it will default to 0.  If the
hostname component is missing from the identifier (e.g., just @:port@), then hostname
is assumed to be @localhost@.
-}
parseTCPAddress :: Address -> (HostName,ServiceName)
parseTCPAddress address = 
  let identifer = T.pack $ address 
      parts = T.splitOn ":" identifer
  in if (length parts) > 1 then
       (host $ T.unpack $ parts !! 0, port $ T.unpack $ parts !! 1)
     else (host $ T.unpack $ parts !! 0, "0")
  where
    host h = if h == "" then
               "localhost"
             else h
    port p = p

tcpScheme :: Scheme
tcpScheme = "tcp"

tcpHandles :: TCPTransport -> Name -> IO Bool
tcpHandles transport name = do 
  resolved <- resolve (tcpResolver transport) name
  return $ isJust resolved
  where
    isJust (Just _) = True
    isJust _ = False

tcpBind :: TCPTransport -> Mailbox -> Name -> IO (Either String Binding)
tcpBind transport inc name = do  
  atomically $ modifyTVar (tcpBindings transport) $ \bindings ->
    M.insert name inc bindings
  Just address <- resolve (tcpResolver transport) name
  let (_,port) = parseTCPAddress address
  listener <- async $ do 
    infoM _log $ "Binding to address " ++ (show address)
    tcpListen address port
  return $ Right Binding {
    bindingName = name,
    unbind = tcpUnbind listener address
    }
  where
    tcpListen address port = listen HostAny port $ \(socket,_) -> do 
      tcpAccept address socket
    tcpAccept address socket = do
      (client,clientAddress) <- accept socket
      _ <- async $ tcpDispatch address client clientAddress
      tcpAccept address socket
    tcpDispatch address client socketAddress = do
      infoM _log $ "Accepted connection on " ++ (show address)
      identity <- tcpIdentify client socketAddress
      case identity of
        Nothing -> sClose client
        Just (IdentifyMessage clientAddress) -> do
          infoM _log $ "Identified " ++ (show clientAddress)
          msngr <- newMessenger client clientAddress (tcpInbound transport)
          found <- atomically $ do 
            msngrs <- readTVar $ tcpMessengers transport
            return $ M.lookup clientAddress msngrs
          case found of
            Just _ -> closeMessenger msngr
            Nothing -> do
              addMessenger transport clientAddress msngr
    tcpIdentify client clientAddress = do
      infoM _log $ "Awaiting identity from " ++ (show clientAddress)
      maybeMsg <- receiveMessage client
      case maybeMsg of
        Nothing -> return Nothing
        Just bytes -> do
          let msg = decode bytes
          case msg of
            Left _ -> return Nothing
            Right message -> return $ Just message
    tcpUnbind listener address = do 
      infoM _log $ "Unbinding from port " ++ (show address)
      cancel listener

tcpSendTo :: TCPTransport -> Name -> Message -> IO ()
tcpSendTo transport name msg = do
  Just address <- resolve (tcpResolver transport) name
  let env = encode $ Envelope {
        envelopeDestination = name,
        envelopeContents = msg
        }
  amsngr <- atomically $ do 
    msngrs <- readTVar $ tcpMessengers transport
    return $ M.lookup address msngrs
  case amsngr of
    Nothing -> do
      let (host,port) = parseTCPAddress address
      infoM _log $ "Connecting to " ++ (show address)
      (socket,_) <- connectSock host port
      infoM _log $ "Connected to " ++ (show address)
      msngr <- newMessenger socket address (tcpInbound transport)
      addMessenger transport address msngr
      -- identify all bindings
      identifyAll msngr
      deliver msngr env
      return ()
    Just msngr -> deliver msngr env
    where
      deliver msngr message = atomically $ writeTQueue (messengerOut msngr) message
      identifyAll msngr = do
        bindings <- atomically $ readTVar $ tcpBindings transport
        boundAddresses <- mapM (resolve $ tcpResolver transport) (M.keys bindings)
        let uniqueAddresses = S.toList $ S.fromList boundAddresses
        mapM_ (identify msngr) uniqueAddresses
      identify msngr maybeUniqueAddress= do
        case maybeUniqueAddress of
          Nothing -> return()
          Just uniqueAddress -> deliver msngr $ encode $ IdentifyMessage uniqueAddress

tcpShutdown :: TCPTransport -> IO ()
tcpShutdown transport = do
  infoM _log $ "Closing messengers"
  msngrs <- atomically $ readTVar $ tcpMessengers transport
  mapM_ closeMessenger $ M.elems msngrs
  infoM _log $ "Closing listeners"
  listeners <- atomically $ readTVar $ tcpListeners transport
  mapM_ sClose $ M.elems listeners
  infoM _log $ "Closing dispatcher"
  mapM_ cancel $ S.toList $ tcpDispatchers transport

data Messenger = Messenger {
  messengerOut :: Mailbox,
  messengerAddress :: Address,
  messengerSender :: Async (),
  messengerReceiver :: Async (),
  messengerSocket :: Socket
  }
                 
instance Show Messenger where
  show msngr = "Messenger(" ++ (show $ messengerAddress msngr) ++ "," ++ (show $ messengerSocket msngr) ++ ")"
                 
newMessenger :: Socket -> Address -> Mailbox -> IO Messenger                 
newMessenger socket address inc = do
  out <- newMailbox
  sndr <- async $ sender socket address out
  rcvr <- async $ receiver socket address inc
  return Messenger {
    messengerOut = out,
    messengerAddress = address,
    messengerSender = sndr,
    messengerReceiver = rcvr,
    messengerSocket = socket
    }
                 
addMessenger :: TCPTransport -> Address -> Messenger -> IO ()
addMessenger transport address msngr = do
  msngrs <- atomically $ do
        modifyTVar (tcpMessengers transport) $ \msngrs -> M.insert address msngr msngrs
        msngrs <- readTVar (tcpMessengers transport)
        return msngrs
  infoM _log $ "Added messenger to " ++ (show address) ++ "; messengers are " ++ (show msngrs)

closeMessenger :: Messenger -> IO ()                 
closeMessenger msngr = do
  cancel $ messengerSender msngr
  cancel $ messengerReceiver msngr
  sClose $ messengerSocket msngr

sender :: Socket -> Address -> Mailbox -> IO ()
sender socket address mailbox = sendMessages
  where
    sendMessages = do 
      catch (do
                infoM _log $ "Waiting to send to " ++ (show address)
                msg <- atomically $ readTQueue mailbox
                infoM _log $ "Sending message to " ++ (show address)
                send socket $ encode (B.length msg)
                infoM _log $ "Length sent"
                send socket msg
                infoM _log $ "Message sent to" ++ (show address)
            ) (\e -> do 
                  errorM _log $ "Send error " ++ (show (e :: SomeException))
                  throw e)
      sendMessages

dispatcher :: TVar (M.Map Name Mailbox) -> Mailbox -> IO ()
dispatcher bindings mbox = dispatchMessages
  where
    dispatchMessages = do
      infoM _log $ "Dispatching messages"
      env <- atomically $ readTQueue mbox
      dispatchMessage env
      dispatchMessages
    dispatchMessage env = do
      infoM _log $ "Dispatching message"
      let envelopeOrErr = decode env
      case envelopeOrErr of
        Left err -> do
          errorM _log $ "Error decoding message for dispatch: " ++ err
          return ()
        Right (Envelope destination msg) -> do 
          atomically $ do 
            dests <- readTVar bindings
            let maybeDest = M.lookup destination dests
            case maybeDest of
              Nothing -> return ()
              Just dest -> do 
                writeTQueue dest msg
                return ()

receiver :: Socket -> Address -> Mailbox -> IO ()
receiver socket address mailbox  = receiveMessages
  where
    receiveMessages = do
      infoM _log $ "Waiting to receive from " ++ (show address)
      maybeMsg <- receiveMessage socket
      infoM _log $ "Received message from " ++ (show address)
      case maybeMsg of
        Nothing -> return ()
        Just msg -> atomically $ writeTQueue mailbox msg
        
receiveMessage :: Socket -> IO (Maybe Message)    
receiveMessage socket = do
  maybeLen <- recv socket 8 -- TODO must figure out what defines length of an integer in bytes 
  case maybeLen of
    Nothing -> do
      errorM _log $ "No length received"
      return Nothing
    Just len -> do 
      maybeMsg <- recv socket $ msgLength (decode len)
      infoM _log $ "Received message"
      return maybeMsg
  where
    msgLength (Right size) = size
    msgLength (Left err) = error err