{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE CPP                      #-}
{-# LANGUAGE DeriveDataTypeable       #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE OverloadedStrings        #-}
{-# LANGUAGE PackageImports           #-}
{-# LANGUAGE RankNTypes               #-}
{-# LANGUAGE ScopedTypeVariables      #-}

module Snap.Internal.Http.Server.SimpleBackend
  ( Backend
  , BackendTerminatedException(..)
  , Connection
  , TimeoutException(..)
  , name
  , debug
  , bindIt
  , new
  , stop
  , withConnection
  , sendFile
  , tickleTimeout
  , getReadEnd
  , getWriteEnd
  , getRemoteAddr
  , getRemotePort
  , getLocalAddr
  , getLocalPort
  ) where

------------------------------------------------------------------------------
import "monads-fd" Control.Monad.Trans

import           Control.Concurrent
import           Control.Exception
import           Control.Monad
import           Data.ByteString (ByteString)
import           Data.ByteString.Internal (c2w, w2c)
import qualified Data.ByteString as B
import           Data.Iteratee.WrappedByteString
import           Data.Typeable
import           Data.Word
import           Foreign hiding (new)
import           GHC.Conc (labelThread, forkOnIO)
import           Network.Socket
import qualified Network.Socket.ByteString as SB
import           Prelude hiding (catch)
------------------------------------------------------------------------------
import           Data.Concurrent.HashMap (hashString)
import           Snap.Internal.Debug
import           Snap.Internal.Http.Server.Date
import qualified Snap.Internal.Http.Server.TimeoutTable as TT
import           Snap.Internal.Http.Server.TimeoutTable (TimeoutTable)
import           Snap.Iteratee hiding (foldl')

#if defined(HAS_SENDFILE)
import qualified System.SendFile as SF
import           System.Posix.IO
import           System.Posix.Types (Fd(..))
#endif


data BackendTerminatedException = BackendTerminatedException
   deriving (Typeable)

instance Show BackendTerminatedException where
    show (BackendTerminatedException) = "Backend terminated"

instance Exception BackendTerminatedException


------------------------------------------------------------------------------
type QueueElem = Maybe (Socket,SockAddr)

data Backend = Backend
    { _acceptSocket    :: !Socket
    , _acceptThread    :: !ThreadId
    , _timeoutTable    :: TimeoutTable
    , _timeoutThread   :: !(MVar ThreadId)
    , _connectionQueue :: !(Chan QueueElem)
    }

data Connection = Connection
    { _backend     :: Backend
    , _socket      :: Socket
    , _remoteAddr  :: ByteString
    , _remotePort  :: Int
    , _localAddr   :: ByteString
    , _localPort   :: Int
    , _connTid     :: MVar ThreadId
    , _threadHash  :: MVar Word
    }


{-# INLINE name #-}
name :: ByteString
name = "simple"


sendFile :: Connection -> FilePath -> Int64 -> Int64 -> IO ()
#if defined(HAS_SENDFILE)
sendFile c fp start sz = do
    bracket (openFd fp ReadOnly Nothing defaultFileFlags)
            (closeFd)
            (go start sz)
  where
    go off bytes fd
      | bytes == 0 = return ()
      | otherwise  = do
            sent <- SF.sendFile sfd fd off bytes
            if sent < bytes
              then tickleTimeout c >> go (off+sent) (bytes-sent) fd
              else return ()

    sfd = Fd . fdSocket $ _socket c
#else
sendFile c fp start sz = do
    -- no need to count bytes
    enumFilePartial fp (start,start+sz) (getWriteEnd c) >>= run
    return ()
#endif


bindIt :: ByteString         -- ^ bind address, or \"*\" for all
       -> Int                -- ^ port to bind to
       -> IO Socket
bindIt bindAddress bindPort = do
    sock <- socket AF_INET Stream 0
    addr <- getHostAddr bindPort bindAddress
    setSocketOption sock ReuseAddr 1
    bindSocket sock addr
    listen sock 150
    return sock


acceptThread :: Socket -> Chan QueueElem -> IO ()
acceptThread sock connq = loop `finally` cleanup
  where
    loop = do
        debug $ "acceptThread: calling accept()"
        s@(_,addr) <- accept sock
        debug $ "acceptThread: accepted connection from remote: " ++ show addr
        debug $ "acceptThread: queueing"
        writeChan connq $ Just s
        loop

    cleanup = block $ do
        debug $ "acceptThread: cleanup, closing socket and notifying "
                  ++ "chan listeners"
        sClose sock
        replicateM 10 $ writeChan connq Nothing


new :: Socket   -- ^ value you got from bindIt
    -> Int
    -> IO Backend
new sock cpu = do
    debug $ "Backend.new: listening"

    tt        <- TT.new
    t         <- newEmptyMVar
    connq     <- newChan
    accThread <- forkOnIO cpu $ acceptThread sock connq

    let b = Backend sock accThread tt t connq

    tid <- forkIO $ timeoutThread b
    putMVar t tid

    return b


timeoutThread :: Backend -> IO ()
timeoutThread backend = do
    loop `catch` (\(_::SomeException) -> killAll)

  where
    table = _timeoutTable backend

    loop = do
        debug "timeoutThread: waiting for activity on thread table"
        TT.waitForActivity table
        debug "timeoutThread: woke up, killing old connections"
        killTooOld
        loop


    killTooOld = do
        now    <- getCurrentDateTime
        TT.killOlderThan (now - tIMEOUT) table

    -- timeout = 30 seconds
    tIMEOUT = 30

    killAll = do
        debug "Backend.timeoutThread: shutdown, killing all connections"
        TT.killAll table


stop :: Backend -> IO ()
stop backend = do
    debug $ "Backend.stop: killing accept thread"
    killThread acthr

    debug $ "Backend.stop: killing timeout thread"

    -- kill timeout thread; timeout thread handler will stop all of the running
    -- connection threads
    readMVar tthr >>= killThread
    debug $ "Backend.stop: exiting.."

  where
    acthr = _acceptThread  backend
    tthr  = _timeoutThread backend


data AddressNotSupportedException = AddressNotSupportedException String
   deriving (Typeable)

instance Show AddressNotSupportedException where
    show (AddressNotSupportedException x) = "Address not supported: " ++ x

instance Exception AddressNotSupportedException


withConnection :: Backend -> Int -> (Connection -> IO ()) -> IO ()
withConnection backend cpu proc = do
    debug $ "Backend.withConnection: reading from chan"

    qelem <- readChan $ _connectionQueue backend
    when (qelem == Nothing) $ do
        debug $ "Backend.withConnection: channel terminated, throwing "
                  ++ "BackendTerminatedException"
        throwIO BackendTerminatedException

    let (Just (sock,addr)) = qelem
    let fd = fdSocket sock

    debug $ "Backend.withConnection: dequeued connection from remote: "
              ++ show addr

    (port,host) <-
        case addr of
          SockAddrInet p h -> do
             h' <- inet_ntoa h
             return (fromIntegral p, B.pack $ map c2w h')
          x -> throwIO $ AddressNotSupportedException $ show x

    laddr <- getSocketName sock

    (lport,lhost) <-
        case laddr of
          SockAddrInet p h -> do
             h' <- inet_ntoa h
             return (fromIntegral p, B.pack $ map c2w h')
          x -> throwIO $ AddressNotSupportedException $ show x

    tmvar   <- newEmptyMVar
    thrhash <- newEmptyMVar

    let c = Connection backend sock host port lhost lport tmvar thrhash

    tid <- forkOnIO cpu $ do
        labelMe $ "connHndl " ++ show fd
        bracket (return c)
                (\_ -> block $ do
                     debug "thread killed, closing socket"
                     thr   <- readMVar tmvar
                     thash <- readMVar thrhash

                     -- remove thread from timeout table
                     TT.delete thash thr $ _timeoutTable backend

                     eatException $ shutdown sock ShutdownBoth
                     eatException $ sClose sock
                )
                proc

    putMVar tmvar tid
    putMVar thrhash $ hashString $ show tid
    tickleTimeout c
    return ()


labelMe :: String -> IO ()
labelMe s = do
    tid <- myThreadId
    labelThread tid s


eatException :: IO a -> IO ()
eatException act = (act >> return ()) `catch` \(_::SomeException) -> return ()

getReadEnd :: Connection -> Enumerator IO a
getReadEnd = enumerate


getWriteEnd :: Connection -> Iteratee IO ()
getWriteEnd = writeOut


getRemoteAddr :: Connection -> ByteString
getRemoteAddr = _remoteAddr

getRemotePort :: Connection -> Int
getRemotePort = _remotePort

getLocalAddr :: Connection -> ByteString
getLocalAddr = _localAddr

getLocalPort :: Connection -> Int
getLocalPort = _localPort

------------------------------------------------------------------------------
getHostAddr :: Int
            -> ByteString
            -> IO SockAddr
getHostAddr p s = do
    h <- if s == "*"
          then return iNADDR_ANY
          else inet_addr (map w2c . B.unpack $ s)

    return $ SockAddrInet (fromIntegral p) h



data TimeoutException = TimeoutException
   deriving (Typeable)

instance Show TimeoutException where
    show TimeoutException = "timeout"

instance Exception TimeoutException


tickleTimeout :: Connection -> IO ()
tickleTimeout conn = do
    debug "Backend.tickleTimeout"
    now   <- getCurrentDateTime
    tid   <- readMVar $ _connTid conn
    thash <- readMVar $ _threadHash conn

    TT.insert thash tid now table

  where
    table = _timeoutTable $ _backend conn


_cancelTimeout :: Connection -> IO ()
_cancelTimeout conn = do
    debug "Backend.cancelTimeout"

    tid   <- readMVar $ _connTid conn
    thash <- readMVar $ _threadHash conn

    TT.delete thash tid table

  where
    table = _timeoutTable $ _backend conn


timeoutRecv :: Connection -> Int -> IO ByteString
timeoutRecv conn n = do
    let sock = _socket conn
    SB.recv sock n


timeoutSend :: Connection -> ByteString -> IO ()
timeoutSend conn s = do
    let len = B.length s
    debug $ "Backend.timeoutSend: entered w/ " ++ show len ++ " bytes"
    let sock = _socket conn
    SB.sendAll sock s
    debug $ "Backend.timeoutSend: sent all"
    tickleTimeout conn


bLOCKSIZE :: Int
bLOCKSIZE = 8192


enumerate :: (MonadIO m) => Connection -> Enumerator m a
enumerate = loop
  where
    loop conn f = do
        debug $ "Backend.enumerate: reading from socket"
        s <- liftIO $ timeoutRecv conn bLOCKSIZE
        debug $ "Backend.enumerate: got " ++ Prelude.show (B.length s)
                ++ " bytes from read end"
        sendOne conn f s

    sendOne conn f s = do
        v <- runIter f (if B.null s
                         then EOF Nothing
                         else Chunk $ WrapBS s)
        case v of
          r@(Done _ _)      -> return $ liftI r
          (Cont k Nothing)  -> loop conn k
          (Cont _ (Just e)) -> return $ throwErr e


writeOut :: (MonadIO m) => Connection -> Iteratee m ()
writeOut conn = IterateeG out
  where
    out c@(EOF _)   = return $ Done () c

    out (Chunk s) = do
        let x = unWrap s

        ee <- liftIO $ ((try $ timeoutSend conn x)
                            :: IO (Either SomeException ()))

        case ee of
          (Left e)  -> return $ Done () (EOF $ Just $ Err $ show e)
          (Right _) -> return $ Cont (writeOut conn) Nothing