module Network.MiniHTTP.Connection
  ( Connection
  , new
  , forkThreads
  , close
  , connoutq
  , connsocket
  ) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import Control.Monad

import Data.Maybe (fromJust)
import qualified Data.ByteString as BS
import qualified Data.Sequence as Seq

import Network.Socket hiding (send, sendTo, recv, recvFrom)
import Network.Socket.ByteString

data Connection = Connection { connsocket :: Socket
                             , connoutq :: TVar (Seq.Seq BS.ByteString)
                             , connreaderthread :: TVar (Maybe ThreadId)
                             , connwriterthread :: TVar (Maybe ThreadId)
                             , conndeath :: IO ()
                             , conndead :: TVar Bool }

-- Threading: each connection has two threads: one pumping data out and one
-- reading it in. The failure (by throwing an exception) of either is enough
-- to close the socket and kill the connection. This could happen to both at
-- the same time, in which case they race to set conndead to True.
-- However, we also need to record their ThreadIds in the Connection record
-- so that they can kill each other. It's possible that they could try to
-- kill the connection right away - before the creating thread has recorded
-- the correct ThreadIds in the Connection. Thus, at startup, they both wait
-- for the controlling thread to set conndead to False before doing anything
-- else

new :: Socket  -- ^ the socket to make a connection from
    -> IO ()  -- ^ the action run when the connection fails
    -> STM Connection
new socket deathaction = do
  dead <- newTVar True
  outq <- newTVar Seq.empty
  p1 <- newTVar Nothing
  p2 <- newTVar Nothing

  let conn = Connection socket outq p1 p2 deathaction dead
  return conn

forkThreads :: Connection  -- ^ the connection to fork the threads for
            -> IO ()  -- ^ the action which reads from the socket
            -> IO ()
forkThreads conn readeraction = do
  reader <- forkIO $ waitForReadySignal conn $
                     connectionThreadWrapper conn connwriterthread $
                     readeraction
  writer <- forkIO $ waitForReadySignal conn $
                     connectionThreadWrapper conn connreaderthread $
                     seqToSocket (connoutq conn) (connsocket conn)
  -- update the thread ids in the Connection and set the ready flag
  atomically (writeTVar (connreaderthread conn) (Just reader) >>
              writeTVar (connwriterthread conn) (Just writer) >>
              writeTVar (conndead conn) False)
  return ()

-- | Wait for conndead to be set to False on the given connection, then run
--   the given action
waitForReadySignal :: Connection -> IO a -> IO a
waitForReadySignal conn action = do
  atomically (do dead <- readTVar (conndead conn)
                 if dead == True then retry else return ())
  action

-- | Wrap a connection thread so that, when the thread dies, it races to set
--   the dead flag. If it does so, it closes the socket and kills the other
--   thread
connectionThreadWrapper :: Connection -> (Connection -> TVar (Maybe ThreadId)) -> IO a -> IO a
connectionThreadWrapper conn otherthread action = do
  finally action
          (do isDead <- atomically (do dead <- readTVar (conndead conn)
                                       when (not dead) $ writeTVar (conndead conn) True
                                       return dead)
              when (not isDead) (do t <- atomically (readTVar $ otherthread conn)
                                    killThread $ fromJust t
                                    sClose (connsocket conn)
                                    conndeath conn))

-- | Close a connection
close :: Connection -> IO ()
close = sClose . connsocket

-- | Atomically take elements from the end of the given sequence and write them
--   to the given socket. Throw an exception when the write fails
seqToSocket :: TVar (Seq.Seq BS.ByteString)  -- ^ data is removed from the end
            -> Socket  -- ^ the socket to write to
            -> IO ()
seqToSocket q sock = do
  -- Atomically remove an element from the end of the sequence
  bs <- atomically (do q' <- readTVar q
                       (bs, rest) <-
                         case Seq.viewr q' of
                              Seq.EmptyR -> retry
                              rest Seq.:> head -> return (head, rest)
                       writeTVar q rest
                       return bs)
  -- Write the data to the socket
  writea sock bs
  seqToSocket q sock

-- | Write a given number of bytes to a socket.
writea :: Socket -> BS.ByteString -> IO ()
writea sock bytes
  | BS.null bytes = return ()
  | otherwise = do
    n <- send sock bytes
    if n == BS.length bytes
       then return ()
       else writea sock $ BS.drop n bytes