module Network.OAuth.Http.Request 
       ( 
         Request(..)
       , FieldList()
       , Version(..)
       , Method(..)
         
       , fromList
       , singleton
       , empty
       , toList
       , parseQString
       , find
       , findWithDefault
       , ifindWithDefault
       , change
       , insert
       , replace
       , replaces
       , union
       , unionAll
         
       , showURL
       , showQString
       , showProtocol
       , showAuthority
       , showPath
       , parseURL
       ) where
import Control.Monad.State
import Network.OAuth.Http.PercentEncoding
import Network.OAuth.Http.Util
import Data.List (intercalate,isPrefixOf)
import Data.Monoid
import Data.Char (toLower)
import qualified Data.ByteString.Lazy as B
import qualified Data.Binary as Bi
data Method =   GET
              | POST
              | PUT
              | DELETE
              | TRACE
              | CONNECT
              | HEAD
  deriving (Eq)
data Version =   Http10
               | Http11
  deriving (Eq)
newtype FieldList = FieldList { unFieldList :: [(String,String)] }
  deriving (Eq,Ord)
data Request = ReqHttp { version    :: Version      
                       , ssl        :: Bool         
                       , host       :: String       
                       , port       :: Int          
                       , method     :: Method       
                       , reqHeaders :: FieldList    
                       , pathComps  :: [String]     
                       , qString    :: FieldList    
                       , reqPayload :: B.ByteString 
                       }
  deriving (Eq,Show)
showProtocol :: Request -> String
showProtocol req 
  | ssl req   = "https"
  | otherwise = "http"
showAuthority :: Request -> String
showAuthority req 
  | ssl req && (port req)==443      = host req
  | not (ssl req) && (port req)==80 = host req
  | otherwise                       = host req ++":"++ show (port req)
showPath :: Request -> String
showPath = intercalate "/" . map encode . pathComps
showQString :: Request -> String
showQString = show . qString
showURL :: Request -> String
showURL =   concat 
          . zipWith ($) [showProtocol,const "://",showAuthority,showPath,showQString'] 
          . repeat
  where showQString' :: Request -> String
        showQString' req 
          | null (unFieldList (qString req)) = ""
          | otherwise                        = '?' : showQString req
parseURL :: String -> Maybe Request
parseURL tape = evalState parser (tape,Just initial)
  where parser = do { _parseProtocol
                    ; _parseSymbol (':',True)
                    ; _parseSymbol ('/',True)
                    ; _parseSymbol ('/',True)
                    ; _parseHost
                    ; _parseSymbol (':',False)
                    ; _parsePort
                    ; _parseSymbol ('/',True)
                    ; _parsePath
                    ; _parseSymbol ('?',False)
                    ; _parseQString
                    ; fmap snd get
                    }
        initial = ReqHttp { version    = Http11
                          , ssl        = False
                          , method     = GET
                          , host       = "127.0.0.1"
                          , port       = 80
                          , reqHeaders = fromList []
                          , pathComps  = []
                          , qString    = fromList []
                          , reqPayload = B.empty
                          }
parseQString :: String -> FieldList
parseQString tape = evalState parser (tape,Just initial)
  where parser = do { _parseQString
                    ; fmap (qstring . snd) get
                    }
        qstring Nothing  = fromList []
        qstring (Just r) = qString r
        initial = ReqHttp { version    = Http11
                          , ssl        = False
                          , method     = GET
                          , host       = "127.0.0.1"
                          , port       = 80
                          , reqHeaders = fromList []
                          , pathComps  = []
                          , qString    = fromList []
                          , reqPayload = B.empty
                          }
fromList :: [(String,String)] -> FieldList
fromList = FieldList
toList :: FieldList -> [(String,String)]
toList = unFieldList
singleton :: (String,String) -> FieldList
singleton = fromList . (:[])
empty :: FieldList
empty = fromList []
change :: (String,String) -> FieldList -> FieldList
change kv (FieldList list) = FieldList (change' kv list)
  where change' (k,v) ((k0,v0):fs) 
          | k0==k     = (k0,v) : change' (k,v) fs
          | otherwise = (k0,v0) : change' (k,v) fs
        change' _ []  = []
insert :: (String,String) -> FieldList -> FieldList
insert kv = mappend (FieldList [kv])
replace :: (String,String) -> FieldList -> FieldList
replace (k,v) fs 
  | null $ find (==k) fs = insert (k,v) fs
  | otherwise            = change (k,v) fs
replaces :: [(String,String)] -> FieldList -> FieldList
replaces fs field = foldr (replace) field fs
find :: (String -> Bool) -> FieldList -> [String]
find p (FieldList list) = map snd (filter (p.fst) list)
union :: FieldList -> FieldList -> FieldList
union (FieldList as) bs = foldr replace bs as
unionAll :: FieldList -> FieldList -> FieldList
unionAll (FieldList as) bs = foldr insert bs as
findWithDefault :: (String,String) -> FieldList -> String
findWithDefault (key,def) fields 
  | null values = def
  | otherwise   = head values
    where values = find (==key) fields
ifindWithDefault :: (String,String) -> FieldList -> String
ifindWithDefault (key,def) fields 
  | null values = def
  | otherwise   = head values
    where values = find (\k -> lower k == lower key) fields
          lower  = map toLower
_parseProtocol :: State (String,Maybe Request) ()
_parseProtocol = do { (tape,req) <- get
                    ; if ("https" `isPrefixOf` tape)
                      then put (drop 5 tape,liftM (\r -> r {ssl=True,port=443}) req)
                      else if ("http" `isPrefixOf` tape) 
                           then put (drop 4 tape,liftM (\r -> r {ssl=False,port=80}) req)
                           else put ("",Nothing)
                    }
_parseHost :: State (String,Maybe Request) ()
_parseHost = do { (tape,req) <- get
                ; let (value,tape') = break (`elem` ":/") tape
                ; put (tape',liftM (\r -> r {host = value}) req)
                }
_parsePort :: State (String,Maybe Request) ()
_parsePort = do { (tape,req) <- get
                ; let (value,tape') = break (=='/') tape
                ; case (reads value)
                  of [(value',"")] -> put (tape',liftM (\r -> r {port = value'}) req)
                     _             -> put (tape',req)
                }
_parsePath :: State (String,Maybe Request) ()
_parsePath = do { (tape,req) <- get
                ; let (value,tape') = break (=='?') tape
                      value'        = "" : map (decodeWithDefault "") (splitBy (=='/') value)
                ; put (tape',liftM (\r -> r {pathComps=value'}) req)
                }
_parseQString :: State (String,Maybe Request) ()
_parseQString = do { (tape,req) <- get
                   ; let (value,tape') = break (=='#') tape
                         fields        = fromList $ filter (/=("","")) (map parseField (splitBy (=='&') value))
                   ; put (tape',liftM (\r -> r {qString=fields}) req)
                   }
  where parseField tape = let (k,v) = break (=='=') tape
                          in case (v)
                             of ('=':v') -> (decodeWithDefault "" k,decodeWithDefault "" v')
                                _        -> (decodeWithDefault "" k,"")
_parseSymbol :: (Char,Bool) -> State (String,Maybe Request) ()
_parseSymbol (c,required) = do { (tape,req) <- get
                               ; if ([c] `isPrefixOf` tape)
                                 then put (drop 1 tape,req)
                                 else if (required) 
                                      then put ("",Nothing)
                                      else put (tape,req)
                               }
instance Show Method where
  showsPrec _ m = case m 
                  of GET     -> showString "GET"
                     POST    -> showString "POST"
                     DELETE  -> showString "DELETE"
                     CONNECT -> showString "CONNECT"
                     HEAD    -> showString "HEAD"
                     TRACE   -> showString "TRACE"
                     PUT     -> showString "PUT"
instance Read Method where
  readsPrec _ "GET"     = [(GET,"")]
  readsPrec _ "POST"    = [(POST,"")]
  readsPrec _ "DELETE"  = [(DELETE,"")]
  readsPrec _ "CONNECT" = [(CONNECT,"")]
  readsPrec _ "HEAD"    = [(HEAD,"")]
  readsPrec _ "TRACE"   = [(TRACE,"")]
  readsPrec _ "PUT"     = [(PUT,"")]
  readsPrec _ _         = []
instance Read Version where
  readsPrec _ "HTTP/1.0" = [(Http10,"")]
  readsPrec _ "HTTP/1.1" = [(Http11,"")]
  readsPrec _ _          = []
instance Show Version where
  showsPrec _ v = case v 
                  of Http10 -> showString "HTTP/1.0"
                     Http11 -> showString "HTTP/1.1"
instance Show FieldList where
  showsPrec _ = showString . intercalate "&" . map showField . unFieldList
    where showField (k,v) = encode k ++"="++ encode v
instance Monoid FieldList where
  mempty  = FieldList []
  mappend (FieldList as) (FieldList bs) = FieldList (as `mappend` bs)
instance Bi.Binary FieldList where
  put = Bi.put . unFieldList
  get = fmap FieldList Bi.get