{-# LANGUAGE TypeSynonymInstances #-}
---------------------------------------------------------
-- |
-- Module        : Web.Encodings
-- Copyright     : Michael Snoyman
-- License       : BSD3
--
-- Maintainer    : Michael Snoyman <michael@snoyman.com>
-- Stability     : Unstable
-- Portability   : portable
--
-- Various web encodings.
--
---------------------------------------------------------
module Web.Encodings
    (
      -- * Simple encodings.
      -- ** URL (percentage encoding)
      encodeUrl
    , decodeUrl
      -- ** HTML (entity encoding)
    , encodeHtml
-- FIXME    , decodeHtml
      -- ** JSON
    , encodeJson
      -- * HTTP level encoding.
      -- ** Query string- pairs of percentage encoding
    , encodeUrlPairs
    , decodeUrlPairs
      -- ** Post parameters
    , FileInfo (..)
    , parseMultipart
    , parsePost
      -- ** Cookies
    , decodeCookies
    ) where

import Data.ByteString.Class
import qualified Data.ByteString.Lazy as BS
import Text.Printf (printf)
import Data.Word (Word8)
import Numeric (showHex)
import Data.List (isPrefixOf)
import Data.ByteString.Lazy.Util hiding (ord)
import Data.Mime.Header
import Data.Maybe (fromMaybe)

-- | Encode all but unreserved characters with percentage encoding.
--
-- Assumes use of UTF-8 character encoding.
encodeUrl :: (LazyByteString x, LazyByteString y) => x -> y
encodeUrl = fromLazyByteString
          . BS.concatMap encodeUrlByte
          . toLazyByteString

ord :: Integral i => Char -> i
ord = fromIntegral . fromEnum

encodeUrlByte :: Word8 -> BS.ByteString
encodeUrlByte w
    -- List of unreserved characters per RFC 3986
    -- Gleaned from http://en.wikipedia.org/wiki/Percent-encoding
    | ord 'A' <= w && w <= ord 'Z' = BS.singleton w
    | ord 'a' <= w && w <= ord 'z' = BS.singleton w
    | ord '0' <= w && w <= ord '9' = BS.singleton w
    | ord '-' == w = BS.singleton w
    | ord '_' == w = BS.singleton w
    | ord '.' == w = BS.singleton w
    | ord '~' == w = BS.singleton w
    | otherwise = toLazyByteString $ (printf "%%%02x" w :: String)

-- | Decode percentage encoding. Assumes use of UTF-8 character encoding.
decodeUrl :: (LazyByteString x, LazyByteString y) => x -> y
decodeUrl = fromLazyByteString . BS.pack . decodeUrlList . BS.unpack . toLazyByteString

decodeUrlList :: [Word8] -> [Word8]
-- note: percent sign is 37
decodeUrlList (37:x:y:rest) = (fromHex x) * 16 + (fromHex y)
                            : decodeUrlList rest
decodeUrlList (x:rest)
    | x == 43 = 32 : decodeUrlList rest -- convert plus to space
    | otherwise = x : decodeUrlList rest
decodeUrlList [] = []

fromHex :: Word8 -> Word8
fromHex x
    | 48 <= x && x <= 57 = x - 48 -- 0 - 9
    | 65 <= x && x <= 70 = x - 65 + 10 -- A - F
    | 97 <= x && x <= 102 = x - 97 + 10 -- a - f
    | otherwise = 0 -- FIXME

-- | Escape special HTML characters.
encodeHtml :: String -> String
encodeHtml = concatMap encodeHtmlChar

encodeHtmlChar :: Char -> String
encodeHtmlChar '<' = "&lt;"
encodeHtmlChar '>' = "&gt;"
encodeHtmlChar '&' = "&amp;"
encodeHtmlChar '"' = "&quot;"
encodeHtmlChar '\'' = "&#39;"
encodeHtmlChar c = [c]

-- | Convert into key-value pairs. Strips the leading ? if necesary.
decodeUrlPairs :: (LazyByteString x, LazyByteString y, LazyByteString z)
               => x
               -> [(y, z)]
decodeUrlPairs = map decodeUrlPair
               . BS.split (ord '&')
               . BS.dropWhile (== ord '?')
               . toLazyByteString

decodeUrlPair :: (LazyByteString a, LazyByteString b)
              => BS.ByteString
              -> (a, b)
decodeUrlPair b =
    let (x, y) = BS.break (== ord '=') b
        y' = BS.dropWhile (== ord '=') y
     in (decodeUrl x, decodeUrl y')

-- | Convert a list of key-values pairs into a query string.
-- Does not include the question mark at the beginning.
encodeUrlPairs :: (LazyByteString x, LazyByteString y, LazyByteString z)
               => [(x, y)]
               -> z
encodeUrlPairs = fromLazyByteString
               . BS.intercalate (BS.singleton $ ord '&')
               . map encodeUrlPair

encodeUrlPair :: (LazyByteString x, LazyByteString y)
               => (x, y)
               -> BS.ByteString
encodeUrlPair (x, y) = BS.concat
                     [ encodeUrl x
                     , BS.singleton $ ord '='
                     , encodeUrl y
                     ]

-- | Perform JSON-encoding on a string. Does not wrap in quotation marks.
encodeJson :: LazyByteString x => String -> x
encodeJson = fromLazyByteString . toLazyByteString . encJSString

-- | Taken from json package by Sigbjorn Finne.
encJSString :: String -> String
encJSString jss = go jss
  where
  go s1 =
    case s1 of
      (x   :xs) | x < '\x20' -> '\\' : encControl x (go xs)
      ('"' :xs)              -> '\\' : '"'  : go xs
      ('\\':xs)              -> '\\' : '\\' : go xs
      (x   :xs)              -> x    : go xs
      ""                     -> ""

  encControl x xs = case x of
    '\b' -> 'b' : xs
    '\f' -> 'f' : xs
    '\n' -> 'n' : xs
    '\r' -> 'r' : xs
    '\t' -> 't' : xs
    _ | x < '\x10'   -> 'u' : '0' : '0' : '0' : hexxs
      | x < '\x100'  -> 'u' : '0' : '0' : hexxs
      | x < '\x1000' -> 'u' : '0' : hexxs
      | otherwise    -> 'u' : hexxs
      where hexxs = showHex (fromEnum x) xs

-- | Information on an uploaded file.
data FileInfo = FileInfo
    { fileName :: String
    , fileContentType :: String
    , fileContent :: BS.ByteString
    }

-- | Parse a multipart form into parameters and files.
parseMultipart :: LazyByteString lbs
               => lbs -- ^ boundary
               -> BS.ByteString -- ^ content
               -> ([(String, String)], [(String, FileInfo)])
parseMultipart boundary' content =
    let boundary :: String
        boundary = fromLazyByteString $ toLazyByteString boundary'
        pieces = getPieces boundary content
        getJusts [] = []
        getJusts (Nothing:rest) = getJusts rest
        getJusts ((Just x):rest) = x : getJusts rest
        getLefts [] = []
        getLefts (Left x:rest) = x : getLefts rest
        getLefts (Right _:rest) = getLefts rest
        getRights [] = []
        getRights (Left _:rest) = getRights rest
        getRights (Right x:rest) = x : getRights rest
        pieces' = getJusts $ map parsePiece pieces
     in (getLefts pieces', getRights pieces')

-- | Parse a single segment of a multipart/form-data POST.
parsePiece :: Monad m
           => BS.ByteString
           -> m (Either (String, String) (String, FileInfo))
parsePiece b = do
    let (headers', content) = takeUntilBlank b
        headers = map parseHeader headers'
    name <- lookupHeaderAttr "Content-Disposition" "name" headers
    let filename = lookupHeaderAttr "Content-Disposition" "filename" headers
    let ctype = fromMaybe "" $ lookupHeader "Content-Type" headers
    -- charset = lookupHeaderAttr "Content-Type" "charset" headers
    return $ case filename of
        Nothing -> Left (name, (fromLazyByteString $ chompBS content))
        Just f -> Right (name, FileInfo f ctype content)

-- | Split up a bytestring along the given boundary.
getPieces :: String -- ^ boundary
          -> BS.ByteString -- ^ content
          -> [BS.ByteString]
{- FIXME this would be nice...
getPieces b c =
    let fullBound = ord '-' `BS.cons'` (ord '-' `BS.cons'` b)
        pieces = fullBound `BS.split` c
     in filter (/= toLazyByteString "--") $
        filter (not . BS.null) $
        map chompBS pieces
-}
getPieces b c
    | BS.null c = []
    | otherwise =
        let fullBound = toLazyByteString ('-':'-':b)
            (next, rest) = breakAtString fullBound c
            rest' = checkRest rest
            rest'' = getPieces b rest'
         in if BS.null next then rest'' else chompBS next : rest''
    where
        br = ord '\r'
        bn = ord '\n'
        dash = ord '-'
        checkRest bs
            | BS.length bs < 2 = BS.empty
            | BS.head bs == bn = BS.tail bs
            | BS.head bs == br && BS.head (BS.tail bs) == bn =
                BS.tail $ BS.tail bs
            | BS.head bs == dash && BS.head (BS.tail bs) == dash = BS.empty
            | otherwise = BS.empty -- FIXME

-- | Parse a post request. This function determines the correct decoding
-- function to use.
parsePost :: String -- ^ content type
          -> String -- ^ content length
          -> BS.ByteString -- ^ body of the post
          -> ([(String, String)], [(String, FileInfo)])
parsePost ctype clength body
    | urlenc `isPrefixOf` ctype = (decodeUrlPairs content, [])
    | formBound `isPrefixOf` ctype = parseMultipart boundProcessed content
    | otherwise = ([], [])
    where
        len = case reads clength of
            ((x, _):_) -> x
            [] -> 0
        content = BS.take len body
        urlenc = "application/x-www-form-urlencoded"
        formBound = "multipart/form-data; boundary="
        boundProcessed = drop (length formBound) ctype

-- | Decode the value of an HTTP_COOKIE header into key/value pairs.
decodeCookies :: String -> [(String, String)]
decodeCookies [] = []
decodeCookies s =
    let (first, rest) = break (== ';') s
     in decodeCookie first : decodeCookies (dropWhile (== ';') rest)

decodeCookie :: String -> (String, String)
decodeCookie s =
    let (key, value) = break (== '=') s
        key' = dropWhile (== ' ') key
        value' =
          case value of
            ('=':rest) -> rest
            x -> x
     in (key', value')