-----------------------------------------------------------------------------
-- |
-- Module      : Network.Connection
-- Copyright   : Adam Langley
-- License     : BSD3-style (see LICENSE)
--
-- Maintainer  : Adam Langley <agl@imperialviolet.org>
-- Stability   : experimental
--
-- Helpful functions to deal with stream-like connections
-----------------------------------------------------------------------------
module Network.Connection
  ( -- * Base connections
    BaseConnection(..)
  , baseConnectionFromSocket

    -- * Connection functions
  , Connection
  , new
  , newSTM
  , forkWriterThread
  , forkInConnection
  , close
  , write
  , writeAtLowWater
  , read
  , reada
  , pushBack
  ) where

import Prelude hiding (foldl, read, catch)

import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import Control.Monad

import Data.Foldable (foldl)

import qualified Data.ByteString as B
import qualified Data.Sequence as Seq

import Network.Socket hiding (send, sendTo, recv, recvFrom)
import Network.Socket.ByteString

-- | A BaseConnection abstracts a stream like connection.
data BaseConnection = BaseConnection {
    -- | Read, at most, the given number of bytes from the connection and return
    --   a ByteString of the data. EOF is signaled by an exception and a zero
    --   length string is never a valid return value
    baseRead :: Int -> IO B.ByteString
    -- | Write the given ByteString to the connection. The write may write less
    --   than the requested number of bytes (but must always write at least one
    --   byte)
  , baseWrite :: B.ByteString -> IO Int
    -- | Close a connection
  , baseClose :: IO ()
  }

-- | Return a BaseConnection for the given socket.
baseConnectionFromSocket :: Socket -> BaseConnection
baseConnectionFromSocket sock = BaseConnection read write close where
  read = recv sock
  write = send sock
  close = sClose sock

-- | A Connection uses the functions from a BaseConnection and wraps them a
--   number of commonly needed behaviours.
--
--   Firstly, a write queue is introduced so that writes can be non-blocking.
--
--   Secondly, the Connection can manage a number of threads. Almost always
--   there will be a writer thread which is taking items from the write queue
--   and writing them to the BaseConnection. In addition, there can be zero or
--   more other threads managed by the Connection. If a thread which is managed
--   dies, by throwing an exception or otherwise, it will close the connection
--   and all other managed threads will be killed.
--
--   There is also the concept of pushing data back into the Connection. This
--   is useful in a chain of reader functions where, for efficiency reasons,
--   you would want to read large blocks at a time, but the data is
--   self-deliminating so you would otherwise end up in a situation where you
--   had read too much. See the pushBack function for details.
data Connection = Connection { connbase :: BaseConnection
                             , connoutq :: TVar (Seq.Seq B.ByteString)
                             , connthreads :: TVar [ThreadId]
                             , connpushback :: TVar (Seq.Seq B.ByteString)
                             , conndeath :: IO ()
                             , conndead :: TVar Bool }

updateTVar :: TVar a -> (a -> a) -> STM ()
updateTVar tvar f = do
  v <- readTVar tvar
  writeTVar tvar $ f v

-- | Create a new Connection from a BaseConnection object
new :: IO ()  -- ^ the action to run when the connection closes
    -> BaseConnection  -- ^ the socket-like object to make a connection from
    -> IO Connection
new deathaction baseconn = do
  conn <- atomically $ newSTM deathaction baseconn
  forkWriterThread conn
  return conn

-- | This creates most of a Connection, purely in the STM monad. The Connection
--   returned from this must be passed to forkWriterThread, otherwise nothing
--   will ever get written.
newSTM :: IO ()  -- ^ the action run when the connection closes
       -> BaseConnection  -- ^ the socket-like object to make a connection from
       -> STM Connection
newSTM deathaction baseconn = do
  dead <- newTVar False
  outq <- newTVar Seq.empty
  pushback <- newTVar Seq.empty
  threads <- newTVar []

  return $ Connection baseconn outq threads pushback deathaction dead

-- | If you created the Connection in the STM monad using newSTM, you need to
--   call this on it in order to create the thread which processes the outgoing
--   queue.
forkWriterThread :: Connection  -- ^ the connection to fork the writer thread for
                 -> IO ()
forkWriterThread conn = do
  sync <- atomically $ newTVar False
  writer <- forkIO $ waitForReadySignal sync $
                     connectionThreadWrapper conn $
                     seqToSocket (connoutq conn) $ baseWrite $ connbase conn
  -- update the thread ids in the Connection and set the ready flag
  atomically (updateTVar (connthreads conn) ((:) writer) >>
              writeTVar sync True)

-- | Run the given action, as if by forkIO, and manage the thread. If the given
--   action completes or throws an exception, the connection will be closed and
--   all other managed threads will be killed
forkInConnection :: Connection  -- ^ the connection to close on death
                 -> IO ()   -- ^ the action to run
                 -> IO ()
forkInConnection conn action = do
  sync <- atomically $ newTVar False
  thread <- forkIO $ waitForReadySignal sync $
                     connectionThreadWrapper conn action
  atomically (updateTVar (connthreads conn) ((:) thread) >>
              writeTVar sync True)

-- | Wait for the given TVar to be true and then run the given action
waitForReadySignal :: (TVar Bool) -> IO a -> IO a
waitForReadySignal sync action = do
  atomically (do go <- readTVar sync
                 if go == True then return () else retry)
  action

killThreads :: Connection -> IO ()
killThreads conn = do
  isDead <- atomically $ do
    dead <- readTVar (conndead conn)
    when (not dead) $ writeTVar (conndead conn) True
    return dead
  when (not isDead) $ do
    t <- atomically (readTVar $ connthreads conn)
    me <- myThreadId
    mapM_ killThread $ filter ((/=) me) t
    baseClose $ connbase conn
    conndeath conn

-- | Not all exceptions are safe to catch because of the way the GC works. If a
--   thread is killed because it's waiting on a TVar which is now garbage (e.g.
--   our writer thread when the Connection goes out of scope), all ForeignPtrs
--   held by the thread are also garbage, /at the same time/. Thus we can end
--   up holding invalid ForeignPtrs if we catch unsafe exceptions and try to
--   cleanup.
safeException :: Exception -> Maybe Exception
safeException (AsyncException _) = Nothing
safeException BlockedOnDeadMVar = Nothing
safeException BlockedIndefinitely = Nothing
safeException x = Just x

-- | Wrap a connection thread so that, when the thread dies, it races to set
--   the dead flag. If it does so, it closes the socket and kills the other
--   threads
connectionThreadWrapper :: Connection -> IO a -> IO a
connectionThreadWrapper conn action = do
  handleJust safeException (\e -> killThreads conn >> throwIO e) action

-- | Close a connection
close :: Connection -> IO ()
close = killThreads

-- | Enqueue a ByteString to a connection. This does not block.
write :: Connection -> B.ByteString -> STM ()
write conn bs = do
  s <- readTVar $ connoutq conn
  writeTVar (connoutq conn) (bs Seq.<| s)

-- | Block until the write queue has less than the given number of bytes in it
--   then enqueue a new ByteString.
writeAtLowWater :: Int  -- ^ the max number of bytes in the queue before we enqueue anything
                -> Connection  -- ^ the connection to write to
                -> B.ByteString  -- ^ the data to enqueue
                -> STM ()
writeAtLowWater lw conn bs = do
  q <- readTVar $ connoutq conn
  let size = foldl (\sz bs -> sz + B.length bs) 0 q
  if size > lw
     then retry
     else writeTVar (connoutq conn) $ bs Seq.<| q

-- | Read some number of bytes from a connection. The size is only a hint,
--   the returned data may be shorter. A zero length read is EOF
read :: Connection -> Int -> IO B.ByteString
read conn sz = do
  pb <- atomically $ do
          pushback <- readTVar $ connpushback conn
          case Seq.viewl pushback of
               Seq.EmptyL -> return Nothing
               head Seq.:< rest ->
                 if B.length head <= sz
                    then do
                      writeTVar (connpushback conn) rest
                      return $ Just head
                    else do
                      let (left, right) = B.splitAt sz head
                      writeTVar (connpushback conn) $ right Seq.<| rest
                      return $ Just left
  case pb of
       Nothing -> (baseRead $ connbase conn) sz
       Just bs -> return bs

-- | Read exactly a give number of bytes
reada :: Connection -> Int -> IO B.ByteString
reada conn n = do
  bytes <- read conn n
  when (B.null bytes) $ fail "EOF in reada"
  let remaining = n - B.length bytes
  if remaining == 0
     then return bytes
     else reada conn remaining >>= return . B.append bytes

-- | Unread some amount of data. It will be returned in the next call to read.
--
--   The function pushes data to the front of the queue. Thus you need to push
--   all the data base in one go, or the order of future reads will be wrong.
--
--   This might seem like an error, but consider the case of two actions:
--   the first reads 20 bytes and pushs back the last 10 of them. The second
--   reads 5 bytes and pushs back the last 4. If we appended to the push back
--   queue the second action would put those 4 bytes after the remaining 5 from
--   the first action.
pushBack :: Connection -> B.ByteString -> STM ()
pushBack conn bs
  | B.null bs = return ()
  | otherwise = do
      pushback <- readTVar $ connpushback conn
      writeTVar (connpushback conn) $ bs Seq.<| pushback

-- | Atomically take elements from the end of the given sequence and write them
--   to the given socket. Throw an exception when the write fails
seqToSocket :: TVar (Seq.Seq B.ByteString)  -- ^ data is removed from the end
            -> (B.ByteString -> IO Int)  -- ^ the write function
            -> IO ()
seqToSocket q write = do
  -- Atomically remove an element from the end of the sequence
  bs <- atomically (do q' <- readTVar q
                       (bs, rest) <-
                         case Seq.viewr q' of
                              Seq.EmptyR -> retry
                              rest Seq.:> head -> return (head, rest)
                       writeTVar q rest
                       return bs)
  -- Write the data to the socket
  writea write bs
  seqToSocket q write

-- | Write a given number of bytes to a socket. This wraps a write function
--   which may write less than the requested number of bytes so that the whole
--   of the given ByteString is written out.
writea :: (B.ByteString -> IO Int)  -- ^ the write function
       -> B.ByteString  -- ^ the data to write
       -> IO ()
writea write bytes
  | B.null bytes = return ()
  | otherwise = do
    n <- write bytes
    if n == B.length bytes
       then return ()
       else writea write $ B.drop n bytes