{-# LANGUAGE TypeSynonymInstances, DeriveDataTypeable, FlexibleInstances, RankNTypes #-}

module Happstack.Server.Internal.Types
    (Request(..), Response(..), RqBody(..), Input(..), HeaderPair(..),
     takeRequestBody, readInputsBody,
     rqURL, mkHeaders,
     getHeader, getHeaderBS, getHeaderUnsafe,
     hasHeader, hasHeaderBS, hasHeaderUnsafe,
     setHeader, setHeaderBS, setHeaderUnsafe,
     addHeader, addHeaderBS, addHeaderUnsafe,
     setRsCode, -- setCookie, setCookies,
     Conf(..), nullConf, TLSConf(..), HTTPS(..), result, resultBS,
     redirect, -- redirect_, redirect', redirect'_,
     isHTTP1_0, isHTTP1_1,
     RsFlags(..), nullRsFlags, contentLength, chunked, noContentLength,
     HttpVersion(..), Length(..), Method(..), Headers, continueHTTP,
     Host, ContentType(..),
     readDec', fromReadS, readM, FromReqURI(..)
    ) where

import Control.Monad.Error (Error(strMsg))
import Control.Monad.Trans (MonadIO(liftIO))
import Control.Concurrent.MVar
import qualified Data.Map as M
import Data.Data (Data)
import Data.Time.Format (FormatTime(..))
import Data.Typeable(Typeable)
import qualified Data.ByteString.Char8 as P
import Data.ByteString.Char8 (ByteString,pack)
import qualified Data.ByteString.Lazy.Char8 as L
import qualified Data.ByteString.Lazy.UTF8  as LU (fromString)
import Data.Int   (Int8, Int16, Int32, Int64)
import Data.Maybe
import Data.List
import Data.Word  (Word, Word8, Word16, Word32, Word64)
import Happstack.Server.SURI
import Data.Char (toLower)
import Happstack.Server.Internal.RFC822Headers ( ContentType(..) )
import Happstack.Server.Internal.Cookie
import Happstack.Server.Internal.LogFormat (formatRequestCombined)
import Happstack.Server.Internal.TLS
import Numeric (readDec, readSigned)
import System.Log.Logger (Priority(..), logM)
import Text.Show.Functions ()

-- | HTTP version
data HttpVersion = HttpVersion Int Int

instance Show HttpVersion where
  show (HttpVersion x y) = (show x) ++ "." ++ (show y)

-- | 'True' if 'Request' is HTTP version @1.1@
isHTTP1_1 :: Request -> Bool
isHTTP1_1 rq =
    case rqVersion rq of
      HttpVersion 1 1 -> True
      _               -> False

-- | 'True' if 'Request' is HTTP version @1.0@
isHTTP1_0 :: Request -> Bool
isHTTP1_0 rq =
    case rqVersion rq of
      HttpVersion 1 0 -> True
      _               -> False

-- | Should the connection be used for further messages after this.
-- | isHTTP1_0 && hasKeepAlive || isHTTP1_1 && hasNotConnectionClose
continueHTTP :: Request -> Response -> Bool
continueHTTP rq res =
    (isHTTP1_0 rq && checkHeaderBS connectionC keepaliveC rq   && rsfLength (rsFlags res) == ContentLength) ||
    (isHTTP1_1 rq && not (checkHeaderBS connectionC closeC rq) && rsfLength (rsFlags res) /= NoContentLength)

-- | HTTP configuration
data Conf = Conf
    { port       :: Int -- ^ Port for the server to listen on.
    , tls        :: Maybe TLSConf
    , validator  :: Maybe (Response -> IO Response) -- ^ a function to validate the output on-the-fly
    , logAccess  :: forall t. FormatTime t => Maybe (String -> String -> t -> String -> Int -> Integer -> String -> String -> IO ()) -- ^ function to log access requests (see also: 'logMAccess')
    , timeout    :: Int -- ^ number of seconds to wait before killing an inactive thread

-- | Default configuration contains no validator and the port is set to 8000
nullConf :: Conf
nullConf =
    Conf { port      = 8000
         , tls       = Nothing
         , validator = Nothing
         , logAccess = Just logMAccess
         , timeout   = 30

-- | log access requests using hslogger and apache-style log formatting
-- see also: 'Conf'
logMAccess :: forall t. FormatTime t =>
              String  -- ^ host
           -> String  -- ^ user
           -> t       -- ^ time
           -> String  -- ^ requestLine
           -> Int     -- ^ responseCode
           -> Integer -- ^ size
           -> String  -- ^ referer
           -> String  -- ^ userAgent
           -> IO ()
logMAccess host user time requestLine responseCode size referer userAgent =
    logM "Happstack.Server.AccessLog.Combined" INFO $ formatRequestCombined host user time requestLine responseCode size referer userAgent

-- | HTTP request method

-- | an HTTP header
data HeaderPair = HeaderPair
    { hName :: ByteString     -- ^ header name
    , hValue :: [ByteString]  -- ^ header value (or values if multiple occurances of the header are present)
    deriving (Read,Show)

-- | a Map of HTTP headers
-- the Map key is the header converted to lowercase
type Headers = M.Map ByteString HeaderPair -- ^ lowercased name -> (realname, value)

-- | A flag value set in the 'Response' which controls how the
-- @Content-Length@ header is set, and whether *chunked* output
-- encoding is used.
-- see also: 'nullRsFlags', 'notContentLength', and 'chunked'
data Length
    = ContentLength             -- ^ automatically add a @Content-Length@ header to the 'Response'
    | TransferEncodingChunked   -- ^ do not add a @Content-Length@ header. Do use @chunked@ output encoding
    | NoContentLength           -- ^ do not set @Content-Length@ or @chunked@ output encoding.
      deriving (Eq, Ord, Read, Show, Enum)

-- | Result flags
data RsFlags = RsFlags
    { rsfLength :: Length
    } deriving (Show,Read,Typeable)

-- | Default RsFlags: automatically use @Transfer-Encoding: Chunked@.
nullRsFlags :: RsFlags
nullRsFlags = RsFlags { rsfLength = TransferEncodingChunked }

-- | Do not automatically add a Content-Length field to the 'Response'
noContentLength :: Response -> Response
noContentLength res = res { rsFlags = flags } where flags = (rsFlags res) { rsfLength = NoContentLength }

-- | Do not automatically add a Content-Length header. Do automatically use Transfer-Encoding: Chunked
chunked :: Response -> Response
chunked res         = res { rsFlags = flags } where flags = (rsFlags res) { rsfLength = TransferEncodingChunked }

-- | Automatically add a Content-Length header. Do not use Transfer-Encoding: Chunked
contentLength :: Response -> Response
contentLength res   = res { rsFlags = flags } where flags = (rsFlags res) { rsfLength = ContentLength }

-- | a value extract from the @QUERY_STRING@ or 'Request' body
-- If the input value was a file, then it will be saved to a temporary file on disk and 'inputValue' will contain @Left pathToTempFile@.
data Input = Input
    { inputValue       :: Either FilePath L.ByteString
    , inputFilename    :: Maybe FilePath
    , inputContentType :: ContentType
    } deriving (Show, Read, Typeable)

-- | hostname & port
type Host = (String, Int) -- ^ (hostname, port)

-- | an HTTP Response
data Response
    = Response  { rsCode      :: Int
                , rsHeaders   :: Headers
                , rsFlags     :: RsFlags
                , rsBody      :: L.ByteString
                , rsValidator :: Maybe (Response -> IO Response)
    | SendFile  { rsCode      :: Int
                , rsHeaders   :: Headers
                , rsFlags     :: RsFlags
                , rsValidator :: Maybe (Response -> IO Response)
                , sfFilePath  :: FilePath  -- ^ file handle to send from
                , sfOffset    :: Integer   -- ^ offset to start at
                , sfCount     :: Integer    -- ^ number of bytes to send
      deriving (Typeable)

instance Show Response where
    showsPrec _ res@Response{}  =
        showString   "================== Response ================" .
        showString "\nrsCode      = " . shows      (rsCode res)     .
        showString "\nrsHeaders   = " . shows      (rsHeaders res)  .
        showString "\nrsFlags     = " . shows      (rsFlags res)    .
        showString "\nrsBody      = " . shows      (rsBody res)     .
        showString "\nrsValidator = " . shows      (rsValidator res)
    showsPrec _ res@SendFile{}  =
        showString   "================== Response ================" .
        showString "\nrsCode      = " . shows      (rsCode res)     .
        showString "\nrsHeaders   = " . shows      (rsHeaders res)  .
        showString "\nrsFlags     = " . shows      (rsFlags res)    .
        showString "\nrsValidator = " . shows      (rsValidator res).
        showString "\nsfFilePath  = " . shows      (sfFilePath res) .
        showString "\nsfOffset    = " . shows      (sfOffset res)   .
        showString "\nsfCount     = " . shows      (sfCount res)

-- what should the status code be ?
instance Error Response where
  strMsg str =
      setHeader "Content-Type" "text/plain; charset=UTF-8" $
       result 500 str

-- | an HTTP request
data Request = Request
    { rqSecure        :: Bool                  -- ^ request uses https:\/\/
      , rqMethod      :: Method                -- ^ request method
      , rqPaths       :: [String]              -- ^ the uri, split on /, and then decoded
      , rqUri         :: String                -- ^ the raw rqUri
      , rqQuery       :: String                -- ^ the QUERY_STRING
      , rqInputsQuery :: [(String,Input)]      -- ^ the QUERY_STRING decoded as key/value pairs
      , rqInputsBody  :: MVar [(String,Input)] -- ^ the request body decoded as key/value pairs (when appropriate)
      , rqCookies     :: [(String,Cookie)]     -- ^ cookies
      , rqVersion     :: HttpVersion           -- ^ HTTP version
      , rqHeaders     :: Headers               -- ^ the HTTP request headers
      , rqBody        :: MVar RqBody           -- ^ the raw, undecoded request body
      , rqPeer        :: Host                  -- ^ (hostname, port) of the client making the request
    } deriving (Typeable)

instance Show Request where
    showsPrec _ rq =
        showString   "================== Request =================" .
        showString "\nrqSecure      = " . shows      (rqMethod rq) .
        showString "\nrqMethod      = " . shows      (rqMethod rq) .
        showString "\nrqPaths       = " . shows      (rqPaths rq) .
        showString "\nrqUri         = " . showString (rqUri rq) .
        showString "\nrqQuery       = " . showString (rqQuery rq) .
        showString "\nrqInputsQuery = " . shows      (rqInputsQuery rq) .
        showString "\nrqInputsBody  = " . showString "<<mvar>>" .
        showString "\nrqCookies     = " . shows      (rqCookies rq) .
        showString "\nrqVersion     = " . shows      (rqVersion rq) .
        showString "\nrqHeaders     = " . shows      (rqHeaders rq) .
        showString "\nrqBody        = " . showString "<<mvar>>" .
        showString "\nrqPeer        = " . shows      (rqPeer rq)

-- | get the request body from the Request and replace it with Nothing
-- IMPORTANT: You can really only call this function once. Subsequent
-- calls will return 'Nothing'.
takeRequestBody :: (MonadIO m) => Request -> m (Maybe RqBody)
takeRequestBody rq = liftIO $ tryTakeMVar (rqBody rq)

-- | read the request body inputs
-- This will only work if the body inputs have already been decoded. Otherwise it will return Nothing.
readInputsBody :: Request -> IO (Maybe [(String, Input)])
readInputsBody req =
    do mbi <- tryTakeMVar (rqInputsBody req)
       case mbi of
         (Just bi) ->
                do putMVar (rqInputsBody req) bi
                   return (Just bi)
         Nothing -> return Nothing

-- | Converts a Request into a String representing the corresponding URL
rqURL :: Request -> String
rqURL rq = '/':intercalate "/" (rqPaths rq) ++ (rqQuery rq)

-- | a class for working with types that contain HTTP headers
class HasHeaders a where
    updateHeaders :: (Headers->Headers) -> a -> a -- ^ modify the headers
    headers       :: a -> Headers                 -- ^ extract the headers

instance HasHeaders Response where
    updateHeaders f rs = rs {rsHeaders=f $ rsHeaders rs }
    headers            = rsHeaders

instance HasHeaders Request where
    updateHeaders f rq = rq {rqHeaders = f $ rqHeaders rq }
    headers            = rqHeaders

instance HasHeaders Headers where
    updateHeaders f = f
    headers         = id

-- | The body of an HTTP 'Request'
newtype RqBody = Body { unBody :: L.ByteString } deriving (Read,Show,Typeable)

-- | Sets the Response status code to the provided Int and lifts the computation
-- into a Monad.
setRsCode :: (Monad m) => Int -> Response -> m Response
setRsCode code rs = return rs { rsCode = code }

-- | Takes a list of (key,val) pairs and converts it into Headers.  The
-- keys will be converted to lowercase
mkHeaders :: [(String,String)] -> Headers
mkHeaders hdrs
    = M.fromListWith join [ (P.pack (map toLower key), HeaderPair (P.pack key) [P.pack value]) | (key,value) <- hdrs ]
    where join (HeaderPair key vs1) (HeaderPair _ vs2) = HeaderPair key (vs1++vs2)

-- Retrieving header information

-- | Lookup header value. Key is case-insensitive.
getHeader :: HasHeaders r => String -> r -> Maybe ByteString
getHeader = getHeaderBS . pack

-- | Lookup header value. Key is a case-insensitive bytestring.
getHeaderBS :: HasHeaders r => ByteString -> r -> Maybe ByteString
getHeaderBS = getHeaderUnsafe . P.map toLower

-- | Lookup header value with a case-sensitive key. The key must be lowercase.
getHeaderUnsafe :: HasHeaders r => ByteString -> r -> Maybe ByteString
getHeaderUnsafe key var = listToMaybe =<< fmap hValue (getHeaderUnsafe' key var)

-- | Lookup header with a case-sensitive key. The key must be lowercase.
getHeaderUnsafe' :: HasHeaders r => ByteString -> r -> Maybe HeaderPair
getHeaderUnsafe' key = M.lookup key . headers

-- Querying header status

-- | Returns True if the associated key is found in the Headers.  The lookup
-- is case insensitive.
hasHeader :: HasHeaders r => String -> r -> Bool
hasHeader key r = isJust (getHeader key r)

-- | Acts as 'hasHeader' with ByteStrings
hasHeaderBS :: HasHeaders r => ByteString -> r -> Bool
hasHeaderBS key r = isJust (getHeaderBS key r)

-- | Acts as 'hasHeaderBS' but the key is case sensitive.  It should be
-- in lowercase.
hasHeaderUnsafe :: HasHeaders r => ByteString -> r -> Bool
hasHeaderUnsafe key r = isJust (getHeaderUnsafe' key r)

checkHeaderBS :: HasHeaders r => ByteString -> ByteString -> r -> Bool
checkHeaderBS key val = checkHeaderUnsafe (P.map toLower key) (P.map toLower val)

checkHeaderUnsafe :: HasHeaders r => ByteString -> ByteString -> r -> Bool
checkHeaderUnsafe key val r
    = case getHeaderUnsafe key r of
        Just val' | P.map toLower val' == val -> True
        _ -> False

-- Setting header status

-- | Associates the key/value pair in the headers.  Forces the key to be
-- lowercase.
setHeader :: HasHeaders r => String -> String -> r -> r
setHeader key val = setHeaderBS (pack key) (pack val)

-- | Acts as 'setHeader' but with ByteStrings.
setHeaderBS :: HasHeaders r => ByteString -> ByteString -> r -> r
setHeaderBS key val = setHeaderUnsafe (P.map toLower key) (HeaderPair key [val])

-- | Sets the key to the HeaderPair.  This is the only way to associate a key
-- with multiple values via the setHeader* functions.  Does not force the key
-- to be in lowercase or guarantee that the given key and the key in the HeaderPair will match.
setHeaderUnsafe :: HasHeaders r => ByteString -> HeaderPair -> r -> r
setHeaderUnsafe key val = updateHeaders (M.insert key val)

-- Adding headers

-- | Add a key/value pair to the header.  If the key already has a value
-- associated with it, then the value will be appended.
-- Forces the key to be lowercase.
addHeader :: HasHeaders r => String -> String -> r -> r
addHeader key val = addHeaderBS (pack key) (pack val)

-- | Acts as addHeader except for ByteStrings
addHeaderBS :: HasHeaders r => ByteString -> ByteString -> r -> r
addHeaderBS key val = addHeaderUnsafe (P.map toLower key) (HeaderPair key [val])

-- | Add a key/value pair to the header using the underlying HeaderPair data
-- type.  Does not force the key to be in lowercase or guarantee that the given key and the key in the HeaderPair will match.
addHeaderUnsafe :: HasHeaders r => ByteString -> HeaderPair -> r -> r
addHeaderUnsafe key val = updateHeaders (M.insertWith join key val)
    where join (HeaderPair k vs1) (HeaderPair _ vs2) = HeaderPair k (vs1++vs2)

-- | Creates a Response with the given Int as the status code and the provided
-- String as the body of the Response
result :: Int -> String -> Response
result code = resultBS code . LU.fromString

-- | Acts as 'result' but works with ByteStrings directly.
-- By default, Transfer-Encoding: chunked will be used
resultBS :: Int -> L.ByteString -> Response
resultBS code s = Response code M.empty nullRsFlags s Nothing

-- | Sets the Response's status code to the given Int and redirects to the given URI
redirect :: (ToSURI s) => Int -> s -> Response -> Response
redirect c s resp = setHeaderBS locationC (pack (render (toSURI s))) resp{rsCode = c}

-- constants here

-- | @Location@
locationC :: ByteString
locationC   = P.pack "Location"

-- | @close@
closeC :: ByteString
closeC      = P.pack "close"

-- | @Connection@
connectionC :: ByteString
connectionC = P.pack "Connection"

-- | @Keep-Alive@
keepaliveC :: ByteString
keepaliveC  = P.pack "Keep-Alive"

readDec' :: (Num a, Eq a) => String -> a
readDec' s =
  case readDec s of
    [(n,[])] -> n
    _    -> error "readDec' failed."

-- | Read in any monad.
readM :: (Monad m, Read t) => String -> m t
readM s = case reads s of
            [(v,"")] -> return v
            _        -> fail "readM: parse error"

-- |convert a 'ReadS a' result to 'Maybe a'
fromReadS :: [(a, String)] -> Maybe a
fromReadS [(n,[])] = Just n
fromReadS _        = Nothing

-- | This class is used by 'path' to parse a path component into a
-- value.
-- The instances for number types ('Int', 'Float', etc) use 'readM' to
-- parse the path component.
-- The instance for 'String', on the other hand, returns the
-- unmodified path component.
-- See the following section of the Happstack Crash Course for
-- detailed instructions using and extending 'FromReqURI':
--  <http://www.happstack.com/docs/crashcourse/RouteFilters.html#FromReqURI>

class FromReqURI a where
    fromReqURI :: String -> Maybe a

instance FromReqURI String  where fromReqURI = Just
instance FromReqURI Char    where fromReqURI s = case s of [c] -> Just c ; _ -> Nothing
instance FromReqURI Int     where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Int8    where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Int16   where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Int32   where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Int64   where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Integer where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Word    where fromReqURI = fromReadS . readDec
instance FromReqURI Word8   where fromReqURI = fromReadS . readDec
instance FromReqURI Word16  where fromReqURI = fromReadS . readDec
instance FromReqURI Word32  where fromReqURI = fromReadS . readDec
instance FromReqURI Word64  where fromReqURI = fromReadS . readDec
instance FromReqURI Float   where fromReqURI = readM
instance FromReqURI Double  where fromReqURI = readM
instance FromReqURI Bool    where
  fromReqURI s =
    let s' = map toLower s in
    case s' of
      "0"     -> Just False
      "false" -> Just False
      "1"     -> Just True
      "True"  -> Just True
      _       -> Nothing