{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE CPP                      #-}
{-# LANGUAGE DeriveDataTypeable       #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE OverloadedStrings        #-}
{-# LANGUAGE PackageImports           #-}
{-# LANGUAGE Rank2Types               #-}
{-# LANGUAGE ScopedTypeVariables      #-}

module Snap.Internal.Http.Server.SimpleBackend
  ( simpleEventLoop
  ) where

import           Control.Monad.Trans

import           Control.Concurrent hiding (yield)
import           Control.Exception
import           Control.Monad
import           Data.ByteString (ByteString)
import qualified Data.ByteString as S
import           Data.ByteString.Internal (c2w)
import           Data.Maybe
import           Foreign hiding (new)
import           Foreign.C
import           GHC.Conc (labelThread, forkOnIO)
import           Network.Socket
import           Prelude hiding (catch)
import           Snap.Internal.Debug
import           Snap.Internal.Http.Server.Date
import qualified Snap.Internal.Http.Server.TimeoutManager as TM
import           Snap.Internal.Http.Server.TimeoutManager (TimeoutManager)
import           Snap.Internal.Http.Server.Backend
import           Snap.Internal.Http.Server.Address
import qualified Snap.Internal.Http.Server.ListenHelpers as Listen
import           Snap.Iteratee hiding (map)

#if defined(HAS_SENDFILE)
import qualified System.SendFile as SF
import           System.Posix.IO
import           System.Posix.Types (Fd(..))

-- | For each cpu, we store:
--    * A list of accept threads, one per port.
--    * A TimeoutManager
--    * An mvar to signal when the timeout thread is shutdown
data EventLoopCpu = EventLoopCpu
    { _boundCpu        :: Int
    , _acceptThreads   :: [ThreadId]
    , _timeoutManager  :: TimeoutManager
    , _exitMVar        :: !(MVar ())

simpleEventLoop :: EventLoop
simpleEventLoop defaultTimeout sockets cap elog handler = do
    loops <- Prelude.mapM (newLoop defaultTimeout sockets handler elog)

    debug "simpleEventLoop: waiting for mvars"

    --wait for all threads to exit
    Prelude.mapM_ (takeMVar . _exitMVar) loops `finally` do
        debug "simpleEventLoop: killing all threads"
        _ <- mapM_ stopLoop loops
        mapM_ Listen.closeSocket sockets

newLoop :: Int
        -> [ListenSocket]
        -> SessionHandler
        -> (S.ByteString -> IO ())
        -> Int
        -> IO EventLoopCpu
newLoop defaultTimeout sockets handler elog cpu = do
    tmgr       <- TM.initialize defaultTimeout getCurrentDateTime
    exit       <- newEmptyMVar
    accThreads <- forM sockets $ \p -> forkOnIO cpu $
                  acceptThread defaultTimeout handler tmgr elog cpu p exit

    return $! EventLoopCpu cpu accThreads tmgr exit

stopLoop :: EventLoopCpu -> IO ()
stopLoop loop = block $ do
    TM.stop $ _timeoutManager loop
    Prelude.mapM_ killThread $ _acceptThreads loop

acceptThread :: Int
             -> SessionHandler
             -> TimeoutManager
             -> (S.ByteString -> IO ())
             -> Int
             -> ListenSocket
             -> MVar ()
             -> IO ()
acceptThread defaultTimeout handler tmgr elog cpu sock exitMVar =
    loop `finally` (tryPutMVar exitMVar () >> return ())
    acceptAndFork = do
        debug $ "acceptThread: calling accept() on socket " ++ show sock
        (s,addr) <- accept $ Listen.listenSocket sock
        setSocketOption s NoDelay 1
        debug $ "acceptThread: accepted connection from remote: " ++ show addr
        _ <- forkOnIO cpu (go s addr `catches` cleanup)
        return ()

    loop = do
        acceptAndFork `catches` acceptHandler

    go = runSession defaultTimeout handler tmgr sock

    acceptHandler =
        [ Handler $ \(e :: AsyncException) -> throwIO e
        , Handler $ \(e :: SomeException) -> do
              elog $ S.concat [ "SimpleBackend.acceptThread: accept threw: "
                              , S.pack . map c2w $ show e ]
              -- we're out of file descriptors, and it isn't likely to get
              -- better immediately; sleep for 10ms to avoid spamming the error
              -- log.
              threadDelay $ 10000

    cleanup =
          Handler $ \(_ :: AsyncException) -> return ()
        , Handler $ \(e :: SomeException) -> elog
                  $ S.concat [ "SimpleBackend.acceptThread: "
                             , S.pack . map c2w $ show e]

runSession :: Int
           -> SessionHandler
           -> TimeoutManager
           -> ListenSocket
           -> Socket
           -> SockAddr -> IO ()
runSession defaultTimeout handler tmgr lsock sock addr = do
    let fd = fdSocket sock
    curId <- myThreadId

    debug $ "Backend.withConnection: running session: " ++ show addr
    labelThread curId $ "connHndl " ++ show fd

    (rport,rhost) <- getAddress addr
    (lport,lhost) <- getSocketName sock >>= getAddress

    let sinfo = SessionInfo lhost lport rhost rport $ Listen.isSecure lsock

    timeoutHandle <- TM.register (killThread curId) tmgr
    let modifyTimeout = TM.modify timeoutHandle
    let tickleTimeout = modifyTimeout . max

    bracket (Listen.createSession lsock 8192 fd
              (threadWaitRead $ fromIntegral fd))
            (\session -> block $ do
                 debug "thread killed, closing socket"

                 -- cancel thread timeout
                 TM.cancel timeoutHandle

                 eatException $ Listen.endSession lsock session
                 eatException $ shutdown sock ShutdownBoth
                 eatException $ sClose sock
            (\s -> let writeEnd = writeOut lsock s sock
                                           (tickleTimeout defaultTimeout)
                   in handler sinfo
                              (enumerate lsock s sock)
                              (sendFile lsock (tickleTimeout defaultTimeout)
                                        fd writeEnd)

eatException :: IO a -> IO ()
eatException act = (act >> return ()) `catch` \(_::SomeException) -> return ()

sendFile :: ListenSocket
         -> IO ()
         -> CInt
         -> Iteratee ByteString IO ()
         -> FilePath
         -> Int64
         -> Int64
         -> IO ()
#if defined(HAS_SENDFILE)
sendFile lsock tickle sock writeEnd fp start sz =
    case lsock of
        ListenHttp _ -> bracket (openFd fp ReadOnly Nothing defaultFileFlags)
                                (go start sz)
        _            -> do
                   step <- runIteratee writeEnd
                   run_ $ enumFilePartial fp (start,start+sz) step
    go off bytes fd
      | bytes == 0 = return ()
      | otherwise  = do
            sent <- SF.sendFile (threadWaitWrite $ fromIntegral sock)
                                sfd fd off bytes
            if sent < bytes
              then tickle >> go (off+sent) (bytes-sent) fd
              else return ()

    sfd = Fd sock
sendFile _ _ _ writeEnd fp start sz = do
    -- no need to count bytes
    step <- runIteratee writeEnd
    run_ $ enumFilePartial fp (start,start+sz) step
    return ()

enumerate :: (MonadIO m)
          => ListenSocket
          -> NetworkSession
          -> Socket
          -> Enumerator ByteString m a
enumerate port session sock = loop
    dbg s = debug $ "SimpleBackend.enumerate(" ++ show (_socket session)
            ++ "): " ++ s

    loop (Continue k) = do
        dbg "reading from socket"
        s <- liftIO $ timeoutRecv
        case s of
            Nothing -> do
                   dbg "got EOF from socket"
                   sendOne k ""
            Just s' -> do
                   dbg $ "got " ++ Prelude.show (S.length s')
                           ++ " bytes from read end"
                   sendOne k s'

    loop x = returnI x

    sendOne k s | S.null s  = do
        dbg "sending EOF to continuation"
        enumEOF $ Continue k

                | otherwise = do
        dbg $ "sending " ++ show s ++ " to continuation"
        step <- lift $ runIteratee $ k $ Chunks [s]
        case step of
          (Yield x st)   -> do
                      dbg $ "got yield, remainder is " ++ show st
                      yield x st
          r@(Continue _) -> do
                      dbg $ "got continue"
                      loop r
          (Error e)      -> throwError e

    fd = fdSocket sock
    timeoutRecv = Listen.recv port sock (threadWaitRead $
                  fromIntegral fd) session
    timeoutRecv = Listen.recv port (threadWaitRead $
                  fromIntegral fd) session

writeOut :: (MonadIO m)
         => ListenSocket
         -> NetworkSession
         -> Socket
         -> (IO ())
         -> Iteratee ByteString m ()
writeOut port session sock tickle = loop
    dbg s = debug $ "SimpleBackend.writeOut(" ++ show (_socket session)
            ++ "): " ++ s

    loop = continue k

    k EOF = yield () EOF
    k (Chunks xs) = do
        let s = S.concat xs
        let n = S.length s
        dbg $ "got chunk with " ++ show n ++ " bytes"
        ee <- liftIO $ try $ timeoutSend s
        case ee of
          (Left (e::SomeException)) -> do
              dbg $ "timeoutSend got error " ++ show e
              throwError e
          (Right _) -> do
              let last10 = S.drop (n-10) s
              dbg $ "wrote " ++ show n ++ " bytes, last 10=" ++ show last10

    fd = fdSocket sock
    timeoutSend = Listen.send port sock tickle
                              (threadWaitWrite $ fromIntegral fd) session
    timeoutSend = Listen.send port tickle
                              (threadWaitWrite $ fromIntegral fd) session