{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.HttpSpec
    (ReqSpec, ResSpec, HttpSpec
    ,WebComm(..), WebIn(..), WebOut(..), WebExc(..), WebErr(..)
    ,HasReqSpec(..), HasResSpec(..), TextEncoding
    ,rsHeader, rsHeaderFixed, rsParam, rsMeth, rsStatus
    ,rsXmlString, rsXml, rsValidXml
    ,rsPath, rsPathFixed, rsWithBody, rsBody, rsContentType
    ,rsPathSegment, rsXmlEncoding, rsTextEncoding, rsEncodingFixed
    ,genReqOut, genResOut, parseReqIn, parseResIn, rsXmlBody
    ,webExcSetReqIn, webExcSetResIn, webExcSetReqOut, webExcSetResOut

import Prelude hiding (exp)

import Control.Monad (liftM)
import Control.Monad.Reader (asks,local)
import Control.Monad.Error (MonadError(..), Error(..))
import Data.List (isPrefixOf)

import qualified Network.HTTP as Http
import qualified Network.URI as Uri

import qualified Data.ByteString.Lazy as BSL

import Data.Encoding (Encoding, DynEncoding
                     ,encodeLazyByteString, decodeLazyByteStringExplicit

import Text.XML.HXT.Arrow (PU)

import Data.BidiSpec

import Data.HttpSpec.EncodingHelper (encodingName)
import Data.HttpSpec.HttpTypes
    ,HttpUrl, HttpParamName, HttpParamValue
    ,IsHttp(..), IsReq(..), IsRes(..)
    ,urlParams, urlMatchPrefix, urlSplit)
import Text.XML.HXT.Helper
    (XmlValidator,XmlEncoding, pickleStr, pickleWithEnc, unpickle, unpickleStr
    ,xmlEncodingFromString, xmlEncodingToString, _UTF8_)
import Data.HttpSpec.Pretty (Pretty(..))

-- ----------------------------------------------------------------------------
--  Spec types for request and response
-- ----------------------------------------------------------------------------

class HasReqSpec a where
    reqSpec :: ReqSpec a

class HasResSpec a where
    resSpec :: ResSpec a

type TextEncoding = DynEncoding

data WebComm
    = WebCommIn WebIn
    | WebCommOut WebOut
      deriving (Show)

data WebIn
    = WebIn
      { webIn_req :: Maybe ReqIn
      , webIn_res :: Maybe ResOut
      } deriving (Show)

data WebOut
    = WebOut
      { webOut_req :: Maybe ReqOut
      , webOut_res :: Maybe ResIn
      } deriving (Show)

data WebExc
    = WebExc
      { webExc_comm :: Maybe WebComm
      , webExc_err :: WebErr
      } deriving (Show)

data WebErr
    = WebErrMissingParam String
    | WebErrMissingHeader HttpHeaderName
    | WebErrInvalidHeaderValue HttpHeaderName HttpHeaderValue String
    | WebErrInvalidMethod HttpMethod String
    | WebErrInvalidStatus Int String
    | WebErrInvalidUrl {- expected: -} String {- actual: -} String
    | WebErrMissingContentType
    | WebErrUnexpectedContentType {- expected: -} String {- actual: -} String
    | WebErrEmptyContent
    | WebErrNoMatch ReqIn
    | WebErrNotImplemented String
    | WebErrCustomMsg String
      deriving (Show)

instance Error WebErr where
    noMsg = WebErrCustomMsg "HttpSpec: unknown error."
    strMsg = WebErrCustomMsg

instance Error WebExc where
    noMsg = mkErr (WebErrCustomMsg "HttpSpec: unknown error.")
    strMsg s = mkErr (WebErrCustomMsg s)

instance Pretty WebExc where
    ppr exc = ppr (webExc_err exc)

instance Pretty WebErr where
    pprString err =
        case err of
          WebErrMissingParam p ->
              "Missing parameter `" ++ p ++ "'"
          WebErrMissingHeader h ->
              "Missing parameter `" ++ show h ++ "'"
          WebErrInvalidHeaderValue n v s ->
              "Invalid value " ++ show v ++ " for header `" ++ show n ++
              "': " ++ s
          WebErrInvalidMethod method s ->
              "Invalid HTTP method " ++ show method ++ ": " ++ s
          WebErrInvalidStatus stat s ->
              "Invalid HTTP status " ++ show stat ++ ": " ++ s
          WebErrInvalidUrl exp act ->
              "Invalid URL, expected " ++ exp ++ ", given " ++ act
          WebErrMissingContentType ->
              "Content type missing"
          WebErrUnexpectedContentType exp act ->
              "Unexpected content type, expected " ++ exp ++ ", given " ++ act
          WebErrEmptyContent ->
              "No content given"
          WebErrNoMatch req ->
              "No matching URL for " ++ show (reqIn_fullUrl req) ++
              ", method " ++ show (reqIn_method req)
          WebErrNotImplemented s ->
              "Functionality not yet implemented: " ++ s
          WebErrCustomMsg s ->

type ReqErr = WebExc
type ResErr = WebExc
type HttpErr = WebExc

type HttpSpecParser i a = SpecParser i HttpErr a

type ReqSpec = Spec ReqErr ReqIn ReqOut
type ResSpec = Spec ResErr ResIn ResOut
type HttpSpec = Spec HttpErr

-- ----------------------------------------------------------------------------
--  helper functions
-- ----------------------------------------------------------------------------
spGetHeader :: IsHttp h => HttpHeaderName -> HttpSpecParser h HttpHeaderValue
spGetHeader name = asks (httpGetHeader name)
                   >>= spFromMaybe (mkErr $ WebErrMissingHeader name)

mkErr :: WebErr -> WebExc
mkErr err = WebExc Nothing err

-- ----------------------------------------------------------------------------
--  HttpSpec combinators
-- ----------------------------------------------------------------------------

rsWithBody :: (IsHttp i, IsHttp o) =>
              (HttpSpec i o BSL.ByteString -> HttpSpec i o a)
           -> HttpSpec i o a
rsWithBody f = rsWith f rsBody

rsBody :: (IsHttp i, IsHttp o) => HttpSpec i o BSL.ByteString
rsBody = rsGetSet httpBody (flip httpSetBody)

rsHeader :: (IsHttp i, IsHttp o) =>
         -> HttpSpec i o HttpHeaderValue
rsHeader n = mkSpec (spGetHeader n) (flip $ httpSetHeader n)

rsHeaderFixed :: (IsHttp i, IsHttp o) =>
              -> HttpSpec i o a
              -> HttpSpec i o a
rsHeaderFixed (n,v) = rsCheckSet check (httpSetHeader n v)
    where check = spGetHeader n >>= spCheck (==v) err
          err v' = mkErr $ WebErrInvalidHeaderValue n v' ("Expected `"++v++"'.")

rsContentType :: (IsHttp i, IsHttp o) =>
              -> HttpSpec i o a
              -> HttpSpec i o a
rsContentType v = rsCheckSet check (httpSetHeader n v)
    where check = spGetHeader n >>= spCheck checkfun err
          checkfun v' = v `isPrefixOf` v'
          err v' = mkErr $ WebErrInvalidHeaderValue n v' ("Expected `"++v++"'.")
          n = Http.HdrContentType

-- ----------------------------------------------------------------------------
--  ReqSpec combinators
-- ----------------------------------------------------------------------------

rsParam :: HttpParamName -> ReqSpec HttpParamValue
rsParam name = mkSpec rsParseDef rsGenDef
      rsGenDef req val = reqAddUrlParam name val req
      rsParseDef = spGets (urlParams . reqIn_fullUrl)
                >>= spFromMaybe err . lookup name
      err = mkErr $ WebErrMissingParam name

rsMeth :: HttpMethod -> ReqSpec a -> ReqSpec a
rsMeth meth = rsCheckSet check (reqSetMethod meth)
    where check = spGets reqMethod >>= spCheck (==meth) err
          err m = mkErr $ WebErrInvalidMethod m ("Expected method `"++show meth++"'.")

rsPathSegment :: ReqSpec a -> ReqSpec (String, a)
rsPathSegment rs = mkSpec rsParseDef rsGenDef
      rsParseDef =
          do req <- spGet
             let msg = "URL too short."
                 url = reqUrl req
             case urlSplit url of
               Just (hd,tl) ->
                   do a <- local (reqSetUrl tl) (rsParse rs)
                      return (hd, a)
               Nothing -> throwError $ mkErr $ WebErrInvalidUrl msg (show url)
      rsGenDef r (path, a) = rsGen rs (reqAppendUrlPath path r) a

rsPath :: ReqSpec String
rsPath = mkSpec rsParseDef rsGenDef
    where rsParseDef = liftM (Uri.uriPath . reqUrl) spGet
          rsGenDef r path = reqSetUrlPath path r

rsPathFixed :: String -> ReqSpec a -> ReqSpec a
rsPathFixed path rs = mkSpec rsParseDef rsGenDef
      rsParseDef =
          do req <- spGet
             let msg = "Expected URL prefix: `"++path++"'"
                 url = reqUrl req
             case urlMatchPrefix path url of
               Just url' -> local (reqSetUrl url') (rsParse rs)
               Nothing -> throwError $ mkErr $ WebErrInvalidUrl msg (show url)
      rsGenDef r = rsGen rs (reqAppendUrlPath path r)

-- ----------------------------------------------------------------------------
--  ResSpec combinators
-- ----------------------------------------------------------------------------
rsStatus :: Int -> ResSpec a -> ResSpec a
rsStatus c = rsCheckSet check (resSetStatus c Nothing)
    where check = spGets resCode >>= spCheck (==c) err
          err i = mkErr $ WebErrInvalidStatus i ("Expected status code `"++show c++"'.")

-- ----------------------------------------------------------------------------
--  other specific combinators
-- ----------------------------------------------------------------------------

rsXmlString :: Error e => PU a -> Spec e i o String -> Spec e i o a
rsXmlString xp rs = rsWrapMaybe msg (unpickleStr xp, pickleStr xp) rs
    where msg = "Failed to unpickle XML."

rsXml :: Error e =>
      -> PU a
      -> Spec e i o BSL.ByteString
      -> Spec e i o a
rsXml enc xp rs = rsWrapMaybe msg (unpickle xp, pickleWithEnc enc xp) rs
    where msg = "Failed to unpickle XML."

rsValidXml :: Error e =>
           -> XmlValidator
           -> PU a
           -> Spec e i o BSL.ByteString
           -> Spec e i o a
rsValidXml enc val xp rs =
    flip rsWrapEither rs ( mapLeft strMsg . validateAndUnpickle val xp
                         , pickleWithEnc enc xp)
    where mapLeft f (Left a) = Left (f a)
          mapLeft _f (Right c) = Right c

rsXmlBody :: (IsHttp i, IsHttp o) => PU a -> HttpSpec i o a
rsXmlBody xp = rsWithBody (rsXml _UTF8_ xp)

rsEncodingFixed :: (Error e, Encoding enc) =>
           -> Spec e i o BSL.ByteString
           -> Spec e i o String
rsEncodingFixed enc = rsWrapEither' (decode, encode)
    where decode = decodeLazyByteStringExplicit enc
          encode = encodeLazyByteString enc

rsXmlEncoding :: Error e => Spec e i o String -> Spec e i o XmlEncoding
rsXmlEncoding = rsWrapEither' (decode, xmlEncodingToString)
    where decode = xmlEncodingFromString

rsTextEncoding :: Error e => Spec e i o String -> Spec e i o TextEncoding
rsTextEncoding = rsWrapMaybe msg (encodingFromStringExplicit, encodingName)
    where msg = "rsTextEncoding: unknown encoding"

-- ----------------------------------------------------------------------------
--  Spec runners
-- ----------------------------------------------------------------------------

genReqOut :: Monad m => ReqSpec a -> HttpUrl -> a -> m ReqOut
genReqOut rs base = genBySpec rs (ReqOut base Http.GET (HttpData [] BSL.empty))

parseReqIn :: MonadError ReqErr m => ReqSpec a -> ReqIn -> m a
parseReqIn rs reqIn = catchError (parseBySpec rs reqIn) handler
    where handler = throwError . webExcSetReqIn reqIn

genResOut :: Monad m => ResSpec a -> a -> m ResOut
genResOut rs = genBySpec rs (ResOut 200 Nothing (HttpData [] BSL.empty))

parseResIn :: MonadError ReqErr m => ResSpec a -> ResIn -> m a
parseResIn rs resIn = catchError (parseBySpec rs resIn) handler
    where handler = throwError . webExcSetResIn resIn

webExcSetReqIn :: ReqIn -> WebExc -> WebExc
webExcSetReqIn reqIn exc =
    case exc of
      WebExc (Just (WebCommIn win)) err -> WebExc (Just $ WebCommIn $ updWin win) err
      WebExc Nothing err -> WebExc (Just $ WebCommIn $ updWin $ WebIn Nothing Nothing) err
      _ -> exc
    where updWin win = win { webIn_req = Just reqIn}

webExcSetResOut :: ResOut -> WebExc -> WebExc
webExcSetResOut resOut exc =
    case exc of
      WebExc (Just (WebCommIn win)) err -> WebExc (Just $ WebCommIn $ updWin win) err
      WebExc Nothing err -> WebExc (Just $ WebCommIn $ updWin $ WebIn Nothing Nothing) err
      _ -> exc
    where updWin win = win { webIn_res = Just resOut}

webExcSetResIn :: ResIn -> WebExc -> WebExc
webExcSetResIn resIn exc =
    case exc of
      WebExc (Just (WebCommOut wout)) err -> WebExc (Just $ WebCommOut $ updWout wout) err
      WebExc Nothing err -> WebExc (Just $ WebCommOut $ updWout $ WebOut Nothing Nothing) err
      _ -> exc
    where updWout wout = wout { webOut_res = Just resIn}

webExcSetReqOut :: ReqOut -> WebExc -> WebExc
webExcSetReqOut reqOut exc =
    case exc of
      WebExc (Just (WebCommOut wout)) err -> WebExc (Just $ WebCommOut $ updWout wout) err
      WebExc Nothing err -> WebExc (Just $ WebCommOut $ updWout $ WebOut Nothing Nothing) err
      _ -> exc
    where updWout wout = wout { webOut_req = Just reqOut}