module Happstack.Server.Internal.Types
    (Request(..), Response(..), RqBody(..), Input(..), HeaderPair(..),
     takeRequestBody, readInputsBody,
     rqURL, mkHeaders,
     getHeader, getHeaderBS, getHeaderUnsafe,
     hasHeader, hasHeaderBS, hasHeaderUnsafe,
     setHeader, setHeaderBS, setHeaderUnsafe,
     addHeader, addHeaderBS, addHeaderUnsafe,
     setRsCode, 
     LogAccess, logMAccess, Conf(..), nullConf, result, resultBS,
     redirect, 
     isHTTP1_0, isHTTP1_1,
     RsFlags(..), nullRsFlags, contentLength, chunked, noContentLength,
     HttpVersion(..), Length(..), Method(..), canHaveBody, Headers, continueHTTP,
     Host, ContentType(..),
     readDec', fromReadS, readM, FromReqURI(..),
     showRsValidator, EscapeHTTP(..)
    ) where
import Control.Exception (Exception, SomeException)
import Control.Monad.Error (Error(strMsg))
import Control.Monad.Trans (MonadIO(liftIO))
import qualified Control.Concurrent.Thread.Group as TG
import Control.Concurrent.MVar
import qualified Data.Map as M
import Data.Data (Data)
import Data.String (fromString)
import Data.Time.Format (FormatTime(..))
import Data.Typeable(Typeable)
import qualified Data.ByteString.Char8 as P
import Data.ByteString.Char8 (ByteString,pack)
import qualified Data.ByteString.Lazy.Char8 as L
import qualified Data.ByteString.Lazy.UTF8  as LU (fromString)
import Data.Int   (Int8, Int16, Int32, Int64)
import Data.Maybe
import Data.List
import Data.Word  (Word, Word8, Word16, Word32, Word64)
import qualified Data.Text as Text
import qualified Data.Text.Lazy as Lazy
import Happstack.Server.SURI
import Data.Char (toLower)
import Happstack.Server.Internal.RFC822Headers ( ContentType(..) )
import Happstack.Server.Internal.Cookie
import Happstack.Server.Internal.LogFormat (formatRequestCombined)
import Happstack.Server.Internal.TimeoutIO (TimeoutIO)
import Numeric (readDec, readSigned)
import System.Log.Logger (Priority(..), logM)
data HttpVersion = HttpVersion Int Int
             deriving(Read,Eq)
instance Show HttpVersion where
  show (HttpVersion x y) = (show x) ++ "." ++ (show y)
isHTTP1_1 :: Request -> Bool
isHTTP1_1 rq =
    case rqVersion rq of
      HttpVersion 1 1 -> True
      _               -> False
isHTTP1_0 :: Request -> Bool
isHTTP1_0 rq =
    case rqVersion rq of
      HttpVersion 1 0 -> True
      _               -> False
continueHTTP :: Request -> Response -> Bool
continueHTTP rq rs =
    (isHTTP1_0 rq && checkHeaderBS connectionC keepaliveC rq   &&
       (rsfLength (rsFlags rs) == ContentLength || isNoMessageBodyResponse rs)) ||
    (isHTTP1_1 rq && not (checkHeaderBS connectionC closeC rq) &&
       (rsfLength (rsFlags rs) /= NoContentLength || isNoMessageBodyResponse rs))
  where
    isNoMessageBodyCode code = (code >= 100 && code <= 199) || code == 204 || code == 304
    isNoMessageBodyResponse rs' = isNoMessageBodyCode (rsCode rs') && L.null (rsBody rs')
type LogAccess time =
    (   String
     -> String
     -> time
     -> String
     -> Int
     -> Integer
     -> String
     -> String
     -> IO ())
data Conf = Conf
    { port        :: Int             
    , validator   :: Maybe (Response -> IO Response) 
    , logAccess   :: forall t. FormatTime t => Maybe (LogAccess t) 
    , timeout     :: Int             
    , threadGroup :: Maybe TG.ThreadGroup 
    }
nullConf :: Conf
nullConf =
    Conf { port        = 8000
         , validator   = Nothing
         , logAccess   = Just logMAccess
         , timeout     = 30
         , threadGroup = Nothing
         }
logMAccess :: forall t. FormatTime t => LogAccess t
logMAccess host user time requestLine responseCode size referer userAgent =
    logM "Happstack.Server.AccessLog.Combined" INFO $ formatRequestCombined host user time requestLine responseCode size referer userAgent
data Method = GET | HEAD | POST | PUT | DELETE | TRACE | OPTIONS | CONNECT | PATCH | EXTENSION ByteString
    deriving (Show,Read,Eq,Ord,Typeable,Data)
canHaveBody :: Method
            -> Bool
canHaveBody POST          = True
canHaveBody PUT           = True
canHaveBody PATCH         = True
canHaveBody DELETE        = True
canHaveBody (EXTENSION _) = True
canHaveBody _             = False
data HeaderPair = HeaderPair
    { hName :: ByteString     
    , hValue :: [ByteString]  
    }
    deriving (Read,Show)
type Headers = M.Map ByteString HeaderPair 
data Length
    = ContentLength             
    | TransferEncodingChunked   
    | NoContentLength           
      deriving (Eq, Ord, Read, Show, Enum)
data RsFlags = RsFlags
    { rsfLength :: Length
    } deriving (Show,Read,Typeable)
nullRsFlags :: RsFlags
nullRsFlags = RsFlags { rsfLength = TransferEncodingChunked }
noContentLength :: Response -> Response
noContentLength res = res { rsFlags = flags } where flags = (rsFlags res) { rsfLength = NoContentLength }
chunked :: Response -> Response
chunked res         = res { rsFlags = flags } where flags = (rsFlags res) { rsfLength = TransferEncodingChunked }
contentLength :: Response -> Response
contentLength res   = res { rsFlags = flags } where flags = (rsFlags res) { rsfLength = ContentLength }
data Input = Input
    { inputValue       :: Either FilePath L.ByteString
    , inputFilename    :: Maybe FilePath
    , inputContentType :: ContentType
    } deriving (Show, Read, Typeable)
type Host = (String, Int) 
data Response
    = Response  { rsCode      :: Int
                , rsHeaders   :: Headers
                , rsFlags     :: RsFlags
                , rsBody      :: L.ByteString
                , rsValidator :: Maybe (Response -> IO Response)
                }
    | SendFile  { rsCode      :: Int
                , rsHeaders   :: Headers
                , rsFlags     :: RsFlags
                , rsValidator :: Maybe (Response -> IO Response)
                , sfFilePath  :: FilePath  
                , sfOffset    :: Integer   
                , sfCount     :: Integer    
                }
      deriving (Typeable)
instance Show Response where
    showsPrec _ res@Response{}  =
        showString   "================== Response ================"                    .
        showString "\nrsCode      = " . shows      (rsCode res)                        .
        showString "\nrsHeaders   = " . shows      (rsHeaders res)                     .
        showString "\nrsFlags     = " . shows      (rsFlags res)                       .
        showString "\nrsBody      = " . shows      (rsBody res)                        .
        showString "\nrsValidator = " . shows      (showRsValidator (rsValidator res))
    showsPrec _ res@SendFile{}  =
        showString   "================== Response ================"                    .
        showString "\nrsCode      = " . shows      (rsCode res)                        .
        showString "\nrsHeaders   = " . shows      (rsHeaders res)                     .
        showString "\nrsFlags     = " . shows      (rsFlags res)                       .
        showString "\nrsValidator = " . shows      (showRsValidator (rsValidator res)) .
        showString "\nsfFilePath  = " . shows      (sfFilePath res)                    .
        showString "\nsfOffset    = " . shows      (sfOffset res)                      .
        showString "\nsfCount     = " . shows      (sfCount res)
showRsValidator :: Maybe (Response -> IO Response) -> String
showRsValidator = maybe "Nothing" (const "Just <function>")
instance Error Response where
  strMsg str =
      setHeader "Content-Type" "text/plain; charset=UTF-8" $
       result 500 str
data Request = Request
    { rqSecure      :: Bool                  
    , rqMethod      :: Method                
    , rqPaths       :: [String]              
    , rqUri         :: String                
    , rqQuery       :: String                
    , rqInputsQuery :: [(String,Input)]      
    , rqInputsBody  :: MVar [(String,Input)] 
    , rqCookies     :: [(String,Cookie)]     
    , rqVersion     :: HttpVersion           
    , rqHeaders     :: Headers               
    , rqBody        :: MVar RqBody           
    , rqPeer        :: Host                  
    } deriving (Typeable)
instance Show Request where
    showsPrec _ rq =
        showString   "================== Request =================" .
        showString "\nrqSecure      = " . shows      (rqSecure rq) .
        showString "\nrqMethod      = " . shows      (rqMethod rq) .
        showString "\nrqPaths       = " . shows      (rqPaths rq) .
        showString "\nrqUri         = " . showString (rqUri rq) .
        showString "\nrqQuery       = " . showString (rqQuery rq) .
        showString "\nrqInputsQuery = " . shows      (rqInputsQuery rq) .
        showString "\nrqInputsBody  = " . showString "<<mvar>>" .
        showString "\nrqCookies     = " . shows      (rqCookies rq) .
        showString "\nrqVersion     = " . shows      (rqVersion rq) .
        showString "\nrqHeaders     = " . shows      (rqHeaders rq) .
        showString "\nrqBody        = " . showString "<<mvar>>" .
        showString "\nrqPeer        = " . shows      (rqPeer rq)
takeRequestBody :: (MonadIO m) => Request -> m (Maybe RqBody)
takeRequestBody rq = liftIO $ tryTakeMVar (rqBody rq)
readInputsBody :: Request -> IO (Maybe [(String, Input)])
readInputsBody req =
    do mbi <- tryTakeMVar (rqInputsBody req)
       case mbi of
         (Just bi) ->
                do putMVar (rqInputsBody req) bi
                   return (Just bi)
         Nothing -> return Nothing
rqURL :: Request -> String
rqURL rq = '/':intercalate "/" (rqPaths rq) ++ (rqQuery rq)
class HasHeaders a where
    updateHeaders :: (Headers->Headers) -> a -> a 
    headers       :: a -> Headers                 
instance HasHeaders Response where
    updateHeaders f rs = rs {rsHeaders=f $ rsHeaders rs }
    headers            = rsHeaders
instance HasHeaders Request where
    updateHeaders f rq = rq {rqHeaders = f $ rqHeaders rq }
    headers            = rqHeaders
instance HasHeaders Headers where
    updateHeaders f = f
    headers         = id
newtype RqBody = Body { unBody :: L.ByteString } deriving (Read,Show,Typeable)
setRsCode :: (Monad m) => Int -> Response -> m Response
setRsCode code rs = return rs { rsCode = code }
mkHeaders :: [(String,String)] -> Headers
mkHeaders hdrs
    = M.fromListWith join [ (P.pack (map toLower key), HeaderPair (P.pack key) [P.pack value]) | (key,value) <- hdrs ]
    where join (HeaderPair key vs1) (HeaderPair _ vs2) = HeaderPair key (vs2++vs1)
getHeader :: HasHeaders r => String -> r -> Maybe ByteString
getHeader = getHeaderBS . pack
getHeaderBS :: HasHeaders r => ByteString -> r -> Maybe ByteString
getHeaderBS = getHeaderUnsafe . P.map toLower
getHeaderUnsafe :: HasHeaders r => ByteString -> r -> Maybe ByteString
getHeaderUnsafe key var = listToMaybe =<< fmap hValue (getHeaderUnsafe' key var)
getHeaderUnsafe' :: HasHeaders r => ByteString -> r -> Maybe HeaderPair
getHeaderUnsafe' key = M.lookup key . headers
hasHeader :: HasHeaders r => String -> r -> Bool
hasHeader key r = isJust (getHeader key r)
hasHeaderBS :: HasHeaders r => ByteString -> r -> Bool
hasHeaderBS key r = isJust (getHeaderBS key r)
hasHeaderUnsafe :: HasHeaders r => ByteString -> r -> Bool
hasHeaderUnsafe key r = isJust (getHeaderUnsafe' key r)
checkHeaderBS :: HasHeaders r => ByteString -> ByteString -> r -> Bool
checkHeaderBS key val = checkHeaderUnsafe (P.map toLower key) (P.map toLower val)
checkHeaderUnsafe :: HasHeaders r => ByteString -> ByteString -> r -> Bool
checkHeaderUnsafe key val r
    = case getHeaderUnsafe key r of
        Just val' | P.map toLower val' == val -> True
        _ -> False
setHeader :: HasHeaders r => String -> String -> r -> r
setHeader key val = setHeaderBS (pack key) (pack val)
setHeaderBS :: HasHeaders r => ByteString -> ByteString -> r -> r
setHeaderBS key val = setHeaderUnsafe (P.map toLower key) (HeaderPair key [val])
setHeaderUnsafe :: HasHeaders r => ByteString -> HeaderPair -> r -> r
setHeaderUnsafe key val = updateHeaders (M.insert key val)
addHeader :: HasHeaders r => String -> String -> r -> r
addHeader key val = addHeaderBS (pack key) (pack val)
addHeaderBS :: HasHeaders r => ByteString -> ByteString -> r -> r
addHeaderBS key val = addHeaderUnsafe (P.map toLower key) (HeaderPair key [val])
addHeaderUnsafe :: HasHeaders r => ByteString -> HeaderPair -> r -> r
addHeaderUnsafe key val = updateHeaders (M.insertWith join key val)
    where join (HeaderPair k vs1) (HeaderPair _ vs2) = HeaderPair k (vs2++vs1)
result :: Int -> String -> Response
result code = resultBS code . LU.fromString
resultBS :: Int -> L.ByteString -> Response
resultBS code s = Response code M.empty nullRsFlags s Nothing
redirect :: (ToSURI s) => Int -> s -> Response -> Response
redirect c s resp = setHeaderBS locationC (pack (render (toSURI s))) resp{rsCode = c}
locationC :: ByteString
locationC   = P.pack "Location"
closeC :: ByteString
closeC      = P.pack "close"
connectionC :: ByteString
connectionC = P.pack "Connection"
keepaliveC :: ByteString
keepaliveC  = P.pack "Keep-Alive"
readDec' :: (Num a, Eq a) => String -> a
readDec' s =
  case readDec s of
    [(n,[])] -> n
    _    -> error "readDec' failed."
readM :: (Monad m, Read t) => String -> m t
readM s = case reads s of
            [(v,"")] -> return v
            _        -> fail "readM: parse error"
fromReadS :: [(a, String)] -> Maybe a
fromReadS [(n,[])] = Just n
fromReadS _        = Nothing
class FromReqURI a where
    fromReqURI :: String -> Maybe a
instance FromReqURI String  where fromReqURI = Just
instance FromReqURI Text.Text where fromReqURI = fmap fromString . fromReqURI
instance FromReqURI Lazy.Text where fromReqURI = fmap fromString . fromReqURI
instance FromReqURI Char    where fromReqURI s = case s of [c] -> Just c ; _ -> Nothing
instance FromReqURI Int     where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Int8    where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Int16   where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Int32   where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Int64   where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Integer where fromReqURI = fromReadS . readSigned readDec
instance FromReqURI Word    where fromReqURI = fromReadS . readDec
instance FromReqURI Word8   where fromReqURI = fromReadS . readDec
instance FromReqURI Word16  where fromReqURI = fromReadS . readDec
instance FromReqURI Word32  where fromReqURI = fromReadS . readDec
instance FromReqURI Word64  where fromReqURI = fromReadS . readDec
instance FromReqURI Float   where fromReqURI = readM
instance FromReqURI Double  where fromReqURI = readM
instance FromReqURI Bool    where
  fromReqURI s =
    let s' = map toLower s in
    case s' of
      "0"     -> Just False
      "false" -> Just False
      "1"     -> Just True
      "true"  -> Just True
      _       -> Nothing
data EscapeHTTP
  = EscapeHTTP (TimeoutIO -> IO ())
    deriving (Typeable)
instance Exception EscapeHTTP
instance Show EscapeHTTP where
  show (EscapeHTTP {})         = "<EscapeHTTP _>"