{-# LANGUAGE OverloadedStrings #-}
-- | This module contains functions for writing webservers. These servers
--   process requests in a state monad pipeline and several useful actions are
--   provided herein.
--
--   See @examples/test.hs@ for an example of how to use this module.
module Network.MiniHTTP.Server
  ( -- * The processing monad
    WebMonad
  , WebState(..)
  , getRequest
  , getPayload
  , getPOST
  , getGET
  , getReply
  , setReply
  , setHeader
  , setCookie
  , errorPage

  -- * WebMonad actions
  , handleConditionalRequest
  , handleHandleToSource
  , handleRangeRequests
  , handleDecoration
  , handleFromFilesystem

  -- * Running the server
  , serveHTTP
  , serveHTTPS
  , DispatchMatch(..)
  , dispatchOnURL
  ) where

import           Prelude hiding (foldl, catch)

import           Control.Concurrent.STM
import           Control.Exception (catch)
import           Control.Monad.State.Strict

import qualified Data.Binary.Put as P
import qualified Data.ByteString as B
import           Data.ByteString.Char8 ()
import qualified Data.ByteString.Lazy as BL
import           Data.ByteString.Internal (c2w, w2c)
import           Data.Char (chr)
import           Data.Int (Int64)
import qualified Data.Map as Map
import           Data.Maybe (isNothing, isJust, fromJust, catMaybes, maybe)
import           Data.String (fromString)
import           Data.Time.Clock.POSIX

import           System.FilePath (combine, splitDirectories, joinPath, takeExtension)
import           System.IO
import           System.IO.Unsafe (unsafePerformIO)
import           System.Posix
import qualified System.Posix.Signals as Signal

import qualified OpenSSL.Session as SSL

import qualified Network.Connection as C
import           Network.Socket hiding (send, sendTo, recv, recvFrom)
import           Network.MiniHTTP.Marshal
import           Network.MiniHTTP.MimeTypesParse
import           Network.MiniHTTP.HTTPConnection
import qualified Network.MiniHTTP.URL as URL

-- | Processing a request involves running a number of actions in a StateT monad
--   where the state for that monad is this record. This contains both a
--   @Source@ and a "Handle" element. Often something will fill in the "Handle"
--   and expect later processing to convert it to a "Source". Somehow, you have
--   to end up with a "Source", however.
data WebState =
  WebState { wsRequest :: Request  -- ^ the original request
           , wsBody :: Maybe Source  -- ^ the client's payload
           , wsMimeTypes :: Map.Map B.ByteString MediaType
             -- ^ the system mime types db, mapping file extensions
           , wsReply :: Reply   -- ^ the current reply
           , wsSource :: Maybe Source  -- ^ the current source
           , wsHandle :: Maybe Handle  -- ^ the current handle
           , wsAction :: Maybe (IO ())
             -- ^ an action to be performed before sending the reply
           }

-- | The processing monad
type WebMonad = StateT WebState IO

-- | Return the request
getRequest :: WebMonad Request
getRequest = get >>= return . wsRequest

-- | Return the current reply
getReply :: WebMonad Reply
getReply = get >>= return . wsReply

-- | Return the client's request payload (if any)
getPayload :: WebMonad (Maybe Source)
getPayload = get >>= return . wsBody

-- | Get the arguments to a POST request
getPOST :: Int  -- ^ max number of bytes to read
        -> WebMonad (Map.Map B.ByteString B.ByteString)
getPOST maxBytes = do
  -- My kingdom for a MaybeT
  msource <- getPayload
  maybe (return Map.empty) (\source -> do
    mbs <- liftIO $ sourceToBS maxBytes source
    maybe (return Map.empty) (\bs -> do
      maybe (return Map.empty) return $ URL.parseArguments bs) mbs) msource

-- | Get the arguments to a GET request
getGET :: WebMonad (Map.Map B.ByteString B.ByteString)
getGET = liftM (URL.rurlArguments . reqUrl) getRequest

-- | Set the current reply to be a reply with the given status code, the
--   default message for that status code, an empty body and an empty set of
--   headers.
setReply :: Int -> WebMonad ()
setReply code = do
  s <- get
  put $ s { wsAction = Nothing, wsSource = Nothing, wsHandle = Nothing,
            wsReply = Reply 1 1 code (statusToMessage code) $
              emptyHeaders {httpContentLength = Just 0} }

-- | Set a header in the current reply. Because of the way records work, you use
--   this function like this:
--
--   > setHeader $ \h -> h { httpSomeHeader = Just value }
setHeader :: (Headers -> Headers) -> WebMonad ()
setHeader f = do
  reply <- getReply
  let h = replyHeaders reply
  s <- get
  put $ s { wsReply = reply { replyHeaders = f h } }

setCookie :: Cookie -> WebMonad ()
setCookie newcookie@(Cookie { cookieName = n }) = do
  reply <- getReply
  let h = replyHeaders reply
      sets = httpSetCookie h
      sets' = if any (\cookie -> cookieName cookie == n) sets
                 then map (\cookie -> if cookieName cookie == n then newcookie else cookie) sets
                 else newcookie : sets
  s <- get
  put $ s { wsReply = reply { replyHeaders = h { httpSetCookie = sets' } } }

-- | This handles the If-*Matches and If-*Modified conditional headers. It takes
--   its information from the Last-Modified and ETag headers of the current
--   reply. Note that, for the purposes of ETag matching, a reply without
--   an ETag header is considered not to exist from the point of view of,
--   say, If-Matches: *.
handleConditionalRequest :: WebMonad ()
handleConditionalRequest = do
  req <- getRequest
  reply <- getReply
  let metag = httpETag $ replyHeaders reply
      mmtime = httpLastModified $ replyHeaders reply

  case httpIfMatch $ reqHeaders req of
       Just (Left ()) -> when (isNothing $ metag) $ setReply 412
       Just (Right tags) ->
         case metag of
              Nothing -> setReply 412
              Just (False, etag) -> when (not $ elem etag tags) $ setReply 412
              Just (True, _) -> setReply 412
       Nothing -> return ()

  case httpIfNoneMatch $ reqHeaders req of
       Just (Left ()) -> when (isJust $ metag) $ setReply 412
       Just (Right tags) ->
         case metag of
              Nothing -> return ()
              Just tag -> when (elem tag tags) $ setReply 412
       Nothing -> return ()

  case httpIfModifiedSince $ reqHeaders req of
       Just rmtime -> case mmtime of
                           Just mtime -> when (mtime <= rmtime) $ setReply 304
                           Nothing -> return ()
       Nothing -> return ()

  case httpIfUnmodifiedSince $ reqHeaders req of
       Just rmtime -> case mmtime of
                           Just mtime -> when (rmtime <= mtime) $ setReply 412
                           Nothing -> return ()
       Nothing -> return ()

-- | If the current state includes a Handle, this turns it into a Source
handleHandleToSource :: WebMonad ()
handleHandleToSource = do
  reply <- getReply
  mhandle <- liftM wsHandle get
  case mhandle of
       Just handle -> do
         source <- lift $ hSource (0, (fromJust $ httpContentLength $ replyHeaders reply) - 1) handle
         get >>= \s -> put $ s { wsHandle = Nothing, wsSource = Just source }
       Nothing -> return ()

-- | Given the length of the resource, filter any unsatisfiable ranges and
--   convert them all into RangeOf form.
satisfiableRanges :: Int64 -> [Range] -> [Range]
satisfiableRanges contentLength = catMaybes . map f where
  f (RangeFrom a)
    | a < contentLength = Just $ RangeOf a $ contentLength - 1
    | otherwise = Nothing
  f (RangeOf a b)
    | a < contentLength = Just $ RangeOf a $ min b contentLength
    | otherwise = Nothing
  f (RangeSuffix a)
    | a > 0 && contentLength > 0 = Just $ RangeOf (contentLength - a) (contentLength - 1)
    | otherwise = Nothing

-- | This handles Range requests and also translates from Handles to Sources.
--   If the WebMonad has a Handle at this point, then we can construct sources
--   from any subrange of the file. (We also assume that Content-Length is
--   correctly set.)
--
--   See RFC 2616, section 14.35
handleRangeRequests :: WebMonad ()
handleRangeRequests = do
  mhandle <- get >>= return . wsHandle
  req <- getRequest
  reply <- getReply
  case mhandle of
       Nothing -> return ()
       Just handle ->
         case httpContentLength $ replyHeaders reply of
              Nothing -> handleHandleToSource
              Just contentLength -> do
                setHeader (\h -> h { httpAcceptRanges = True })
                case httpRange $ reqHeaders req of
                     Nothing -> handleHandleToSource
                     Just ranges -> do
                       let ranges' = satisfiableRanges contentLength ranges
                       case ranges' of
                          [] -> do
                            setReply 416
                            setHeader (\h -> h { httpContentRange = Just (Nothing, Just contentLength) })
                          [RangeOf a b] -> do
                            s <- get
                            source <- lift $ hSource (a, b) handle
                            put $ s { wsReply = (wsReply s) { replyStatus = 206
                                                            , replyMessage = "Partial Content" }
                                    , wsHandle = Nothing
                                    , wsSource = Just source }
                            setHeader (\h -> h { httpContentRange = Just (Just (a, b), Just contentLength)})
                            setHeader (\h -> h { httpContentLength = Just ((b - a) + 1)})
                          -- We don't support multiple ranges
                          _ -> return ()

-- | At the moment, this just adds the header Server: Network.MiniHTTP
handleDecoration :: WebMonad ()
handleDecoration = setHeader (\h -> h { httpServer = Just "Network.MiniHTTP" })

-- | If a source is missing, install a null source. If this was a HEAD request,
--   remove the current source and set the content length to 0
handleFinal :: StateT WebState IO ()
handleFinal = do
  s <- get
  case wsSource s of
       Nothing -> do setHeader (\h -> h { httpContentLength = Just 0 })
                     s <- get
                     put $ s { wsSource = Just nullSource }
       _ -> return ()

  s <- get
  req <- getRequest
  if reqMethod req == HEAD
     then do
       setHeader $ \h -> h { httpContentLength = Just 0
                           , httpTransferEncoding = [] }
       put $ s { wsSource = Just nullSource }
     else return ()

-- | This is a very simple handler which deals with requests by returning the
--   requested file from the filesystem. It sets a Handle in the state and sets
--   the Content-Type, Content-Length and Last-Modified headers
handleFromFilesystem :: FilePath -- ^ the root of the filesystem to serve from
                     -> WebMonad ()
handleFromFilesystem docroot = do
  req <- getRequest
  when (not $ reqMethod req `elem` [GET, HEAD]) $
    fail "Can only handle GET and HEAD from the filesystem"

  -- stopping directory traversal needs to be done a little carefully.
  -- Hopefully this is all correct
  let path = map w2c $ B.unpack $ URL.rurlPath $ reqUrl req
      -- First, make sure that there aren't any NULs in the path
      path' = takeWhile (/= chr 0) path
      elems = splitDirectories path'
      -- Remove any '..'
      elems' = filter (\x -> x /= ".." && x /= "/") elems
      ext = takeExtension path'
      filepath = combine docroot $ joinPath elems'
  mimeTypes <- get >>= return . wsMimeTypes
  s <- get
  r <- lift $ catch
    (do fd <- openFd filepath ReadOnly Nothing (OpenFileFlags False False True False False)
        stat <- getFdStatus fd
        let size = fromIntegral $ fileSize stat
            mtime = posixSecondsToUTCTime $ fromRational $ toRational $ modificationTime stat
        handle <- fdToHandle fd
        return $ Just $
          s { wsHandle = Just handle
            , wsSource = Nothing
            , wsReply = Reply 1 1 200 "Ok" $ emptyHeaders
               { httpLastModified = Just mtime
               , httpContentLength = Just size
               , httpContentType = Map.lookup (B.pack $ map c2w ext) mimeTypes } } )
    (const $ return Nothing)
  case r of
       Just x -> put x
       Nothing -> errorPage "File not found"

pipeline :: Map.Map B.ByteString MediaType
         -> WebMonad ()
         -> Request
         -> Maybe Source
         -> IO (Reply, Source)
pipeline mimetypes action req msource = do
  let initState = (WebState req msource mimetypes (Reply 1 1 500 "Server error" emptyHeaders)
                   Nothing Nothing Nothing)
  (_, s) <- catch (
    runStateT (do
    action
    handleFinal) initState)
    (\e -> runStateT (do
             errorPage $ show e
             handleFinal) initState)

  return (wsReply s, fromJust $ wsSource s)

-- | Read a single request from a socket
readRequest :: C.Connection
            -> IO Request
readRequest conn = readIG conn 256 4096 parseRequest >>= return . fromJust

-- | Loop, reading and processing requests
readRequests :: (Request -> Maybe Source -> IO (Reply, IO SourceResult))
             -> C.Connection
             -> IO ()
readRequests handler conn = do
  result <- readRequest conn
  body <-
    case httpContentLength $ reqHeaders result of
         Nothing -> return Nothing
         Just n -> connSource n B.empty conn >>= return . Just
  (reply, source) <- handler result body
  let lowWater = 32 * 1024
  atomically $ C.writeAtLowWater lowWater conn $ B.concat $ BL.toChunks $ P.runPut $ putReply reply
  success <- if isNothing $ httpContentLength $ replyHeaders reply
                then streamSourceChunked lowWater conn source
                else streamSource lowWater conn source
  if not success
     then C.close conn
     else do case body of
                  Nothing -> return ()
                  Just source -> sourceDrain source
             readRequests handler conn

sslHandshake :: SSL.SSL -> IO () -> IO ()
sslHandshake ssl k = SSL.accept ssl >> k

acceptLoop :: (Request -> Maybe Source -> IO (Reply, Source)) -> Socket -> IO ()
acceptLoop handler acceptingSocket = do
  (newsock, addr) <- accept acceptingSocket
  setSocketOption newsock NoDelay 1
  putStrLn $ "Connection from " ++ show addr

  c <- C.new (return ()) $ C.baseConnectionFromSocket newsock
  C.forkInConnection c $ readRequests handler c
  acceptLoop handler acceptingSocket

acceptLoopHTTPS :: SSL.SSLContext
                -> (Request -> Maybe Source -> IO (Reply, Source))
                -> Socket
                -> IO ()
acceptLoopHTTPS ctx handler acceptingSocket = do
  (newsock, addr) <- accept acceptingSocket
  setSocketOption newsock NoDelay 1
  putStrLn $ "Connection from " ++ show addr

  ssl <- SSL.connection ctx newsock
  c <- C.new (return ()) $ sslToBaseConnection ssl
  C.forkInConnection c $ sslHandshake ssl $ readRequests handler c
  acceptLoopHTTPS ctx handler acceptingSocket

errorPage :: String -> WebMonad ()
errorPage error = (do
  s <- get
  source <- liftIO $ bsSource message
  put $ s { wsSource = Just source }
  setHeader $ \h -> h { httpContentLength = Just $ fromIntegral $ B.length message }
  handleDecoration) where
  message = head `B.append` errorbs `B.append` tail
  head = "<html> <head> <title>Network.MiniHTTP error page</title> <style language=\"text/css\"> #top { height: 1.5em; width: 100%; background-color: #BFD9FF; border-bottom: 3px solid #004FBF; margin-bottom: 2em; padding-left: 1em; font-variant: small-caps; font-size: 2em; padding-top: 0.5em; } body { margin: 0 0 0 0; } #main { margin-left: 4px; } .enbox { padding-left: 2em; background-color: \"#003786\" } h4 { color: #004FBF; } </style> </head> <body> <div id=\"top\">Network.MiniHTTP</div> <div id=\"main\"> <h4>An error occured while processing your request:</h4> <pre class=\"enbox\">"
  tail = "</pre> </div> </body> </html>"
  errorbs = fromString $ concatMap escape error
  escape '<' = "&lt;"
  escape '&' = "&amp;"
  escape '>' = "&gt;"
  escape x = [x]

data DispatchMatch = Exact B.ByteString
                   | Prefix B.ByteString
                   deriving (Show, Eq)

dispatchMatch :: B.ByteString -> DispatchMatch -> Bool
dispatchMatch b (Exact m) = b == m
dispatchMatch b (Prefix p) = p `B.isPrefixOf` b

-- | This is an, optional, helper function which you might find useful. The
--   serving fuctions both expect a "WebMonad" action which is called to
--   process each request. In general you have to write that and dispatch based
--   on the client's request.
--
--   This might save you some work: it tries each of the elements in the list
--   in turn. As soon as one matches it runs the given action to process the
--   request.
dispatchOnURL :: [(DispatchMatch, WebMonad ())]
                 -- ^ the list of URL prefixes (with '/'!) and their actions
              -> WebMonad ()
dispatchOnURL paths = do
  req <- getRequest
  let path = URL.rurlPath $ reqUrl req

  case map snd $ filter (dispatchMatch path . fst) paths of
       [] -> errorPage "No dispatchers matched requested URL"
       x:_ -> x

globalMimeTypes :: Map.Map B.ByteString MediaType
globalMimeTypes = unsafePerformIO $
  parseMimeTypesTotal "/etc/mime.types" >>= return . maybe Map.empty id

serve :: Int -- ^ port number
      -> (Socket -> IO ())  -- ^ accept loop
      -> IO ()
serve portno acceptLoop = do
  --  Switch these two lines to use IPv6 (which works for IPv4 clients too). Not
  --  all systems support this
  --acceptingSocket <- socket AF_INET6 Stream 0
  --let sockaddr = SockAddrInet6 (fromIntegral portno) 0 iN6ADDR_ANY 0
  acceptingSocket <- socket AF_INET Stream 0
  let sockaddr = SockAddrInet (fromIntegral portno) iNADDR_ANY

  setSocketOption acceptingSocket ReuseAddr 1
  bindSocket acceptingSocket sockaddr
  listen acceptingSocket 1

  -- Ignore SIGPIPE
  Signal.installHandler Signal.sigPIPE Signal.Ignore Nothing

  catch (acceptLoop acceptingSocket)
        (const $ sClose acceptingSocket)


-- | Start an IPv4 HTTP server
serveHTTP :: Int  -- ^ the port number to listen on
          -> WebMonad ()  -- ^ the processing action
          -> IO ()
serveHTTP portno action = do
  serve portno $ acceptLoop $ pipeline globalMimeTypes action

-- | Start an IPv4 HTTPS server. Plese remember to have wrapped your main
--   function in 'OpenSSL.withOpenSSL' otherwise you'll probably crash the
--   process.
serveHTTPS :: Int  -- ^ the port number to listen on
           -> FilePath  -- ^ path to public key (certificate)
           -> FilePath  -- ^ path to private key
           -> WebMonad ()  -- ^ the processing action
           -> IO ()
serveHTTPS portno public private action = do
  ctx <- SSL.context
  SSL.contextSetPrivateKeyFile ctx private
  SSL.contextSetCertificateFile ctx public
  SSL.contextSetDefaultCiphers ctx
  goodp <- SSL.contextCheckPrivateKey ctx
  when (not goodp) $ fail "Public/private key mismatch"

  serve portno $ acceptLoopHTTPS ctx $ pipeline globalMimeTypes action