-- Copyright (C) 2009 Diego Souza <dsouza at bitforest dot org>
-- 
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Lesser General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- (at your option) any later version.
-- 
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU Lesser General Public License for more details.
-- You should have received a copy of the GNU Lesser General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.

module Network.Protocol.OAuth.Request (Request(..),HTTPMethod(..),Parameter,PercentEncoding,encode,encodes,decode,decodes,append_param,apply,show_url,show_oauthurl,show_oauthheader,show_urlencoded,read_urlencoded,(>>+),(>>|)) where

import Data.Bits as B
import qualified Data.ByteString.Lazy as B1
import qualified Data.ByteString.Lazy.UTF8 as B2
import qualified Data.ByteString.Lazy.Char8 as B3
import qualified Data.Word as W
import qualified Data.Char as C
import qualified Data.List as L

-- | A pair which represents a parameter (key,value).
type Parameter  = (String,Maybe String)

-- | The possible HTTP methods
data HTTPMethod =   GET
                  | POST
                  | DELETE
                  | PUT
  deriving (Show,Read,Eq)

-- | Refer to <http://en.wikipedia.org/wiki/Percent-encoding> for more information
class PercentEncoding a where
  -- | Encodes an /a/ type to bytestring.
  encode :: a -> B1.ByteString
  
  -- | Encodes a list of /a/ types into bytestring.
  encodes :: [a] -> B1.ByteString
  encodes = B1.concat . map encode
  
  -- | Decodes a single /a/ type out of an encoded string.
  decode :: B1.ByteString -> (a,B1.ByteString)
  
  -- | Decodes the whole string into a list of /a/ types.
  decodes :: B1.ByteString -> [a]
  decodes = L.unfoldr decode'
    where
      decode' bs | B1.null bs = Nothing
                 | otherwise  = (Just . decode) bs

-- | The HTTP request which must be properly authenticated with oauth. It is not meant to represent the full HTTP request, instead the data which matters for oauth authentication.
data Request = HTTP { ssl     :: Bool,       -- ^ True means /HTTPS/ and false means /HTTP/
                      method  :: HTTPMethod,
                      host    :: String,     -- ^ The hostname or ip address (e.g. bitforest.org)
                      port    :: Int,        -- ^ The tcp port (e.g. 80)
                      path    :: String,     -- ^ The request path (e.g. \/foo\/bar\/)
                      params  :: [Parameter] -- ^ The request parameters (both GET and POST)
                    }
  deriving (Show,Read,Eq)

-- | Convenience function to append an item in request's parameters list
append_param :: Request -> String -> Maybe String -> Request
append_param r k v = let o_params = params r
                         n_params = (k,v) : o_params
                     in r { params = n_params }

-- | Parses a urlencoded string.
read_urlencoded :: B1.ByteString -> [Parameter]
read_urlencoded u | B1.null u = []
                  | otherwise = (map param' . map keyval' . B1.split 0x26) u
  where
    keyval' s = let (k,v) = B1.break (==0x3d) s
                in (k, B1.drop 1 v)

    param' (k,v) | B1.null v = (decodes k,Nothing)
                 | otherwise = (decodes k,(Just . decodes) v)

-- | Show the entire url, including possibly any oauth parameter which may be present.
show_url :: Request -> B1.ByteString
show_url (HTTP s m h p0 p1 ps) = B1.concat [endpoint', path', query']
  where
    endpoint' | s && p0==443  = B3.pack $ "https://" ++ h
              | s               = B3.pack $ "https://" ++ h ++ (':':(show p0))
              | not s && p0==80 = B3.pack $ "http://" ++ h
              | otherwise       = B3.pack $ "http://" ++ h ++ (':':(show p0))

    path' = (B1.cons 0x2f . B1.concat . L.intersperse (B1.singleton 0x2f) . map encodes . _path_comp) p1

    query' | m/=GET || null ps = B1.empty
           | otherwise         = (B1.cons 0x3f . show_urlencoded) ps

-- | The URL to perform the oauth request
show_oauthurl :: Request -> B1.ByteString
show_oauthurl req = let params' = params req
                        req'    = req { params = filter (not . L.isPrefixOf "oauth_" . fst) params' }
                    in show_url req'

-- | The Authorization or WWW-Authenticated headers to perform oauth authentication. 
show_oauthheader :: String           -- ^ The realm
                    -> Request
                    -> B1.ByteString -- ^ The Authorization\/WWW-Authenticate header
show_oauthheader realm (HTTP _ _ _ _ _ p) | B1.null params' = realm'
                                          | otherwise       = B1.concat [realm', B1.singleton 0x2c, params']
  where
    encodes' s = B1.concat [B1.singleton 0x22, encodes s, B1.singleton 0x22]

    params' = (_urlencode encodes' 0x2c . filter (L.isPrefixOf "oauth_" . fst)) p

    realm'  = B3.pack ("OAuth realm=\"" ++ realm ++ "\"")

-- | Produces a urlencoded string.
-- For convenience, it sorts the parameters first, as demands the oauth protocol.
show_urlencoded :: [Parameter] -> B1.ByteString
show_urlencoded = _urlencode encodes 0x26

-- | Applies a function to the request
apply :: Request -> (Request -> Request) -> Request
apply r f = f r

-- | Convenience operator to append an item in request's parameters list
(>>+) :: Request -> (String,Maybe String) -> Request
(>>+) r = uncurry (append_param r)

-- | Applies a function to the request
(>>|) :: Request -> (Request -> Request) -> Request
(>>|) = apply

instance PercentEncoding Char where
  encode = B1.pack . concat . map enc' . B1.unpack . B2.fromString . (:[])
    where
      enc' b | elem b whitelist' = [b]
             | otherwise          = let b0 = b .&. 0x0F
                                        b1 = B.shiftR (b .&. 0xF0) 4
                                    in ((37:) . map (fromIntegral . C.ord . C.toUpper . C.intToDigit . fromIntegral)) [b1,b0]
      whitelist' = [0x61..0x7a] ++ [0x41..0x5a] ++ [0x30..0x39] ++ [0x2d,0x2e,0x5f,0x7e]

  decode bytes = let c0 = (head . decodes) bytes
                     b0 = encode c0
                 in (c0, B1.drop (B1.length b0) bytes)
  
  decodes = B2.toString . B1.pack . fold' . B1.unpack
    where
      fold' (37:b1:b0:bs) = let b1' = (fromIntegral . C.digitToInt . C.chr . fromIntegral) b1 
                                b0' = (fromIntegral . C.digitToInt . C.chr . fromIntegral) b0
                                bl  = (B.shiftL b1' 4) .&. 0xF0
                                br  = b0' .&. 0x0F
                            in (bl .|. br) : fold' bs
      fold' (b:bs)        = b : fold' bs
      fold' []            = []

_urlencode :: (String -> B1.ByteString) -> W.Word8 -> [Parameter] -> B1.ByteString
_urlencode ve s p | null p    = B1.empty
                  | otherwise = (B1.init . foldr fold' B1.empty . L.sort) p
  where 
    fold' (k,Nothing) = B1.append (B1.concat [encodes k, B1.singleton 0x3d, B1.singleton s])
    fold' (k,Just v)  = B1.append (B1.concat [encodes k, B1.singleton 0x3d, ve v, B1.singleton s])

_path_comp :: String -> [String]
_path_comp p = (filter (not . null) . L.unfoldr unfold') p ++ trailing'
  where
    unfold' p1 = case (break (=='/') p1)
                 of ([],[]) -> Nothing
                    (l,r)   -> Just (l,drop 1 r)
    
    trailing' | last p=='/' = [[]]
              | otherwise   = []