module Network.CoAP.Messaging
( createMessagingState
, MessagingState
, startMessaging
, stopMessaging
, recvRequest
, sendResponse
, sendRequest
, recvResponse
) where
import Network.CoAP.Types
import Network.CoAP.MessageCodec
import Data.List (deleteBy, find, partition, filter, minimumBy, delete, sortBy)
import Control.Monad
import Control.Monad.State.Lazy
import Data.Ord
import Data.ByteString (empty)
import Control.Concurrent.STM
import Control.Concurrent
import Control.Exception
import Control.Concurrent.Async
import Data.Maybe
import Data.Time
import System.Random

type TimeStamp = UTCTime

data MessageState = MessageState { messageContext  :: MessageContext
                                 , insertionTime   :: TimeStamp
                                 , replyTimeout    :: Double 
                                 , retransmitCount :: Integer } deriving (Show)


instance Eq MessageState where
  (==) a b = messageId (message (messageContext a)) == messageId (message (messageContext b)) &&
             srcEndpoint (messageContext a) == srcEndpoint (messageContext b)

-- A message store contains an inbound and outbound list of messages that needs to be ACKed
type MessageList = [MessageState]
data MessagingStore = MessagingStore { incomingMessages    :: TVar MessageList
                                     , outgoingMessages    :: TVar MessageList
                                     , unackedMessages     :: TVar MessageList 
                                     , unconfirmedMessages :: TVar MessageList }
data MessagingState = MessagingState Transport MessagingStore

createMessagingState :: Transport -> IO MessagingState
createMessagingState transport = do
  incoming <- newTVarIO []
  outgoing <- newTVarIO []
  unacked <- newTVarIO []
  unconfirmed <- newTVarIO []
  return (MessagingState transport (MessagingStore incoming outgoing unacked unconfirmed))

queueMessages :: [MessageState] -> TVar MessageList -> STM ()
queueMessages messages msgListVar = do
  msgList <- readTVar msgListVar
  writeTVar msgListVar (messages ++ msgList)

queueMessage :: MessageState -> TVar MessageList -> STM ()
queueMessage message = queueMessages [message]

takeMessageRetry :: TVar MessageList -> STM MessageState
takeMessageRetry msgListVar = do
  msgList <- readTVar msgListVar
  if null msgList
  then retry
  else do
    let message = minimumBy (comparing insertionTime) msgList
    let newMsgList = delete message msgList
    writeTVar msgListVar newMsgList
    return message

takeMessagesRetryMatching :: (MessageState -> Bool) -> TVar MessageList -> STM [MessageState]
takeMessagesRetryMatching msgFilter msgListVar = do
  msgList <- readTVar msgListVar
  if null msgList
  then retry
  else do
    let (matchedMessages, remainingMessages) = partition msgFilter msgList
    writeTVar msgListVar remainingMessages
    return matchedMessages

takeMessagesOlderThan :: TimeStamp -> TVar MessageList -> STM [MessageState]
takeMessagesOlderThan timeStamp = takeMessagesRetryMatching (\x -> timeStamp > insertionTime x)

checkRetransmit :: TimeStamp -> MessageState -> Bool
checkRetransmit now msgState =
  let timeout = replyTimeout msgState
      startTime = insertionTime msgState
      endTime = addUTCTime (realToFrac timeout) startTime
   in now > endTime

takeMessagesToRetransmit :: TimeStamp -> TVar MessageList -> STM [MessageState]
takeMessagesToRetransmit now = takeMessagesRetryMatching (checkRetransmit now)

takeMessageRetryMatching :: (MessageState -> Bool) -> TVar MessageList -> STM MessageState
takeMessageRetryMatching matchFilter msgListVar = do
  msgList <- readTVar msgListVar
  if null msgList
  then retry
  else do
    let sortedMsgList = sortBy (comparing insertionTime) msgList
    let msg = find matchFilter sortedMsgList
    case msg of
      Nothing -> retry
      Just m -> do
        let newMsgList = delete m msgList
        writeTVar msgListVar newMsgList
        return m

takeMessageMatching :: (MessageState -> Bool) -> TVar MessageList -> STM (Maybe MessageState)
takeMessageMatching matchFilter msgListVar = do
  msgList <- readTVar msgListVar
  if null msgList
  then return Nothing
  else do
    let sortedMsgList = sortBy (comparing insertionTime) msgList
    let (foundMessages, remainingMessages) = partition matchFilter sortedMsgList
    writeTVar msgListVar remainingMessages
    return (listToMaybe foundMessages)

takeMessageByToken :: Token -> TVar MessageList -> STM (Maybe MessageState)
takeMessageByToken token = takeMessageMatching (\x -> token == messageToken (message (messageContext x)))

takeMessageByIdAndOrigin :: MessageId -> Endpoint -> TVar MessageList -> STM (Maybe MessageState)
takeMessageByIdAndOrigin msgId origin = takeMessageMatching (\x -> (origin == dstEndpoint (messageContext x)) && (msgId == messageId (message (messageContext x))))

recvLoopSuccess :: MessagingState -> Endpoint -> Message -> IO ()
recvLoopSuccess state@(MessagingState transport store) srcEndpoint message = do
  dstEndpoint <- localEndpoint transport
  now <- getCurrentTime

  let messageCtx = MessageContext { message = message
                                  , srcEndpoint = srcEndpoint
                                  , dstEndpoint = dstEndpoint }
  let messageState = MessageState { messageContext = messageCtx 
                                  , insertionTime = now
                                  , replyTimeout = 0
                                  , retransmitCount = 0 }
  let msgId = messageId message
  let msgType = messageType message
  let msgCode = messageCode message
  atomically (when (msgType == ACK) (do _ <- takeMessageByIdAndOrigin msgId srcEndpoint (unconfirmedMessages store)
                                        return ()))
  atomically (when (msgCode /= CodeEmpty) (queueMessage messageState (incomingMessages store)))

recvLoopError :: String -> IO ()
recvLoopError err = putStrLn ("Error parsing message: " ++ show err ++ ", skipping")

recvLoop :: MessagingState -> IO ()
recvLoop state@(MessagingState transport _) = do
  {-putStrLn "Waiting for UDP packet"-}
  (msgData, srcEndpoint) <- recvFrom transport
  either recvLoopError (recvLoopSuccess state srcEndpoint) (decode msgData)
  recvLoop state
  
sendLoop :: MessagingState -> IO ()
sendLoop state@(MessagingState transport store) = do
  msgState <- atomically (do
    msgState <- takeMessageRetry (outgoingMessages store)
    let msgType = messageType (message (messageContext msgState))
    when (msgType == CON) (queueMessage msgState (unconfirmedMessages store))
    return msgState)
  let msgCtx = messageContext msgState
  let msgData = encode (message msgCtx)
  let origin = dstEndpoint msgCtx
  _ <- sendTo transport msgData origin
  sendLoop state


createAckMessage :: UTCTime -> MessageState -> MessageState
createAckMessage now origMessageState =
  let origCtx = messageContext origMessageState
      origMessage = message origCtx

      newMessage = Message { messageVersion = messageVersion origMessage
                           , messageType = ACK
                           , messageCode = CodeEmpty
                           , messageId = messageId origMessage
                           , messageToken = empty
                           , messageOptions = []
                           , messagePayload = Nothing }
      newCtx = MessageContext { message = newMessage
                              , srcEndpoint = dstEndpoint origCtx
                              , dstEndpoint = srcEndpoint origCtx}
   in MessageState { messageContext = newCtx
                   , replyTimeout = 0
                   , retransmitCount = 0
                   , insertionTime = now }

    

ackTimeout :: Double
ackTimeout = 2

ackLoop :: MessagingState -> IO ()
ackLoop state@(MessagingState _ store) = do
  takeStamp <- getCurrentTime
  let nomTime = realToFrac (-ackTimeout)
  let oldestTime = addUTCTime nomTime takeStamp
  oldMessages <- atomically (takeMessagesOlderThan oldestTime (unackedMessages store))
  if null oldMessages
  then do
    threadDelay 100000
    ackLoop state
  else do
    putStrLn "Timeout! Queueing ack messages"
    now <- getCurrentTime
    let ackMessages = map (createAckMessage now) oldMessages
    atomically (queueMessages ackMessages (outgoingMessages store))
    ackLoop state

ackRandomFactor :: Double
ackRandomFactor = 1.5

maxRetransmitCount :: Integer
maxRetransmitCount = 4

adjustRetransmissionState :: TimeStamp -> MessageState -> MessageState
adjustRetransmissionState now msgState =
  MessageState { messageContext = messageContext msgState
               , insertionTime = now
               , replyTimeout = replyTimeout msgState * 2
               , retransmitCount = retransmitCount msgState + 1 }

retransmitLoop :: MessagingState -> IO ()
retransmitLoop state@(MessagingState _ store) = do
  now <- getCurrentTime
  toRetransmit <- atomically (takeMessagesToRetransmit now (unconfirmedMessages store))
  if null toRetransmit
  then threadDelay 100000
  else (do
    putStrLn ("Attempting to retransmit messages " ++ show toRetransmit)
    let adjustedMessages = filter (\s -> retransmitCount s <= maxRetransmitCount) (map (adjustRetransmissionState now) toRetransmit)
    atomically (queueMessages adjustedMessages (outgoingMessages store)))
  retransmitLoop state
  

runLoop :: MessagingState -> (MessagingState -> IO ()) -> IO ()
runLoop state fn = do
  err <- try (fn state) :: IO (Either AsyncException ())
  return ()

startMessaging :: MessagingState -> IO [Async ()]
startMessaging state = mapM (async . runLoop state) [recvLoop, sendLoop, ackLoop, retransmitLoop]

stopMessaging :: MessagingState -> [Async ()] -> IO ()
stopMessaging state = mapM_ cancel 

sendMessage :: Message -> Endpoint -> MessagingState -> IO ()
sendMessage message dstEndpoint (MessagingState transport store) = do
  srcEndpoint <- localEndpoint transport
  now <- getCurrentTime
  initialTimeout <- randomRIO (ackTimeout, ackTimeout * ackRandomFactor)
  {-putStrLn ("Queueing message " ++ (show message) ++ " for sending")-}
  let ctx = MessageContext { message = message
                           , srcEndpoint = srcEndpoint
                           , dstEndpoint = dstEndpoint }
  let messageState = MessageState { messageContext = ctx
                                  , insertionTime = now
                                  , replyTimeout = initialTimeout
                                  , retransmitCount = 0}
  atomically (queueMessage messageState (outgoingMessages store))

recvMessageMatching :: (MessageState -> Bool) -> MessagingState -> IO MessageContext
recvMessageMatching matchFilter (MessagingState _ store) = do
  msgState <- atomically (takeMessageRetryMatching matchFilter (incomingMessages store))
  let msgCtx = messageContext msgState
  let msgType = messageType (message msgCtx)
  when (msgType == CON) (atomically (queueMessage msgState (unackedMessages store)))
  return msgCtx


cmpMessageCode :: MessageCode -> MessageCode -> Bool
cmpMessageCode (CodeRequest _) (CodeRequest _) = True
cmpMessageCode (CodeResponse _) (CodeResponse _) = True
cmpMessageCode CodeEmpty CodeEmpty = True
cmpMessageCode _ _ = False

recvRequest :: MessagingState -> IO MessageContext
recvRequest = recvMessageMatching (cmpMessageCode (CodeRequest GET) . messageCode . message . messageContext)

createAckResponse :: Message -> Message
createAckResponse response =
   Message { messageVersion = messageVersion response
           , messageType = ACK
           , messageCode = messageCode response
           , messageId = messageId response
           , messageToken   = messageToken response
           , messageOptions = messageOptions response
           , messagePayload = messagePayload response }

setMessageId :: MessageId -> Message -> Message
setMessageId msgId response =
   Message { messageVersion = messageVersion response
           , messageType = messageType response
           , messageCode = messageCode response
           , messageId = msgId
           , messageToken   = messageToken response
           , messageOptions = messageOptions response
           , messagePayload = messagePayload response }

allocateMessageId :: IO MessageId
allocateMessageId = randomIO
  

sendResponse :: MessageContext -> Message -> MessagingState -> IO ()
sendResponse requestCtx response state@(MessagingState _ store) = do
  let origin = srcEndpoint requestCtx
  let reqToken = messageToken (message requestCtx)
  unackedMsg <- atomically (takeMessageByToken reqToken (unackedMessages store))  

  msgId <- case unackedMsg of
             Nothing -> allocateMessageId
             _       -> return (messageId (message requestCtx))
    
  let outgoingMessage = case unackedMsg of
                          Nothing -> setMessageId msgId response
                          _       -> createAckResponse response

  sendMessage outgoingMessage origin state

sendRequest :: Message -> Endpoint -> MessagingState -> IO ()
sendRequest (Message msgVersion msgType msgCode _ tkn opts payload) msgdest state = do
  msgId <- allocateMessageId
  let msg = Message { messageVersion = msgVersion
                    , messageType = msgType
                    , messageCode = msgCode
                    , messageId = msgId
                    , messageToken = tkn
                    , messageOptions = opts
                    , messagePayload = payload }
  sendMessage msg msgdest state

recvResponse :: Message -> Endpoint -> MessagingState -> IO MessageContext
recvResponse reqMessage endpoint = recvMessageMatching matchFilter
  where matchFilter x = cmpMessageCode (CodeResponse Created) (messageCode (message (messageContext x))) && messageToken reqMessage == messageToken (message (messageContext x))