module Hyena.Server
    ( serve,
      serveWithConfig
    ) where
import Control.Concurrent (ThreadId, forkIO)
import Control.Exception.Extensible
import Control.Monad (unless, when)
import Control.Monad.Reader (MonadIO, MonadReader, ReaderT, ask, asks,
                             liftIO, runReaderT)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C (elemIndex, pack)
import Network.BSD (getProtocolNumber)
import Network.Socket (Family(..), HostAddress, SockAddr(..), Socket,
                       SocketOption(..), SocketType(..), accept, bindSocket,
                       listen, inet_addr, maxListenQueue, sClose,
                       setSocketOption, socket, withSocketsDo)
import Network.Wai
import Prelude hiding (catch, log)
import System.Exit (exitFailure, ExitCode(..))
import System.IO (Handle, stderr, hPutStrLn)
#ifndef mingw32_HOST_OS
import System.Posix.Signals (Handler(..), installHandler, sigPIPE)
#endif
import Hyena.Config
import Hyena.Http
import Hyena.Logging
newtype Server a = Server (ReaderT ServerConfig IO a)
    deriving (Monad, MonadIO, MonadReader ServerConfig)
data ServerConfig = ServerConfig
    { config       :: Config  
    , accessLogger :: AccessLogger  
    , errorLogger  :: ErrorLogger  
    }
runServer :: ServerConfig -> Server a -> IO a
runServer conf (Server a) = runReaderT a conf
catchServer ::Server a -> (forall e. (Exception e) => e ->Server a) -> Server a
catchServer m k = do
  conf <- ask
  io $ runServer conf m `catches` handlers conf
    where handlers c 
              = [ Handler $ \(e::ExitCode)      ->throw e
                , Handler $ \(e::SomeException) ->runServer c $ k e ]
finallyServer :: Server a -> Server b -> Server a
finallyServer m k = do
  conf <- ask
  io $ runServer conf m `finally` runServer conf k
forkServer :: Server () -> Server ThreadId
forkServer m = do
  conf <- ask
  io $ forkIO $ runServer conf m
serve :: Application -> IO ()
serve application = do
  conf <- configFromFlags
  serveWithConfig conf application
serveWithConfig :: Config -> Application -> IO ()
serveWithConfig conf application = do
#ifndef mingw32_HOST_OS
  installHandler sigPIPE Ignore Nothing
#endif
  when (daemonize conf) $ do
    hPutStrLn stderr "Daemonized mode not supported at the moment."
    hPutStrLn stderr $ "If you need this feature please say so in " ++
                  "GHC ticket #1185."
    exitFailure
  bracketLoggers (logHandle conf) $ \accessLog errorLog ->
      let serverConf = ServerConfig
                       { config       = conf
                       , accessLogger = accessLog
                       , errorLogger  = errorLog
                       }
      in runServer serverConf $ serve' application
bracketLoggers :: Handle -> (AccessLogger -> ErrorLogger -> IO ()) -> IO ()
bracketLoggers h =
    bracket (do accessLog <- startAccessLogger h
                errorLog <- startErrorLogger stderr
                return (accessLog, errorLog))
                (\(accessLog, errorLog) -> do
                   stopErrorLogger errorLog
                   stopAccessLogger accessLog)
                . uncurry
serve' :: Application -> Server ()
serve' application = do
  conf <- ask
  port' <- asks (fromIntegral . port . config)
  address' <- asks (address . config)
  io $ withSocketsDo $
     do proto <- getProtocolNumber "tcp"
        addr <- inet_addr address'
        bracket (socket AF_INET Stream proto)
                 sClose
                 (\sock -> do
                    setSocketOption sock ReuseAddr 1
                    bindSocket sock (SockAddrInet port' addr)
                    listen sock maxListenQueue
                    runServer conf $ acceptConnections application sock)
acceptConnections :: Application -> Socket -> Server ()
acceptConnections application serverSock = do
  (sock, SockAddrInet _ haddr) <- io $ accept serverSock
  forkServer ((talk sock haddr application `finallyServer`
               (io $ sClose sock))
              `catchServer`
              (\e -> do logger <- asks errorLogger
                        io $ logError logger $ show e))
  acceptConnections application serverSock
talk :: Socket -> HostAddress -> Application -> Server ()
talk sock haddr application = do
  req <- io $ receiveRequest sock
  case req of
    Nothing  -> io $ sendResponse sock $ errorResponse 400
    Just req' ->
        
        
        do errorLogger' <- asks errorLogger
           let environ = requestToEnvironment (logError errorLogger') req'
           resp <- run environ application
           accessLogger' <- asks accessLogger
           io $ logAccess accessLogger' req' resp haddr
           io $ sendResponse sock resp
           unless (closeConnection req' resp) $
                  talk sock haddr application
run :: Environment -> Application -> Server Response
run environ application = io $ do
  
  
  (status, reason, headers', output) <- application environ
  return Response
           { statusCode      = status
           , reasonPhrase    = reason
           , responseHeaders = headers'
           , responseBody    = output
           }
closeConnection :: Request -> Response -> Bool
closeConnection req resp =
    let reqHdr       = lookup hdrName (requestHeaders req)
        respHdr      = lookup hdrName (responseHeaders resp)
        closeSet     =
            case (reqHdr, respHdr) of
              (Just v, _) | v == closeVal -> True
              (_, Just v) | v == closeVal -> True
              _                           -> False
        keepAliveSet =
            case reqHdr of
              Just v | v == keepAliveVal -> True
              _                          -> False
    in closeSet || (httpVersion req < (1,1) && not keepAliveSet)
    where
      hdrName      = C.pack "Connection"
      closeVal     = C.pack "close"
      keepAliveVal = C.pack "keep-alive"
io :: MonadIO m => IO a -> m a
io = liftIO
requestToEnvironment :: (String -> IO ()) -> Request -> Environment
requestToEnvironment err req =
    Environment
    { requestMethod   = method req
    , scriptName      = S.empty
    , pathInfo        = path
    , queryString     = query
    , requestProtocol = httpVersion req
    , headers         = requestHeaders req
    , input           = requestBody req
    , errors          = err
    }
    where
      (path, query) = splitRequestUri $ requestUri req
      splitRequestUri uri =
          let index = C.elemIndex '?' uri
          in case index of
               Nothing -> (uri, Nothing)
               Just i  -> (S.take i uri, Just $ S.drop (i + 1) uri)