{-# LANGUAGE TypeSynonymInstances #-}
---------------------------------------------------------
-- |
-- Module        : Web.Encodings
-- Copyright     : Michael Snoyman
-- License       : BSD3
--
-- Maintainer    : Michael Snoyman <michael@snoyman.com>
-- Stability     : Unstable
-- Portability   : portable
--
-- Various web encodings.
--
---------------------------------------------------------
module Web.Encodings
    (
      -- * Simple encodings.
      -- ** URL (percentage encoding)
      encodeUrl
    , decodeUrl
      -- ** HTML (entity encoding)
    , encodeHtml
-- FIXME    , decodeHtml
      -- ** JSON
    , encodeJson
      -- * HTTP level encoding.
      -- ** Query string- pairs of percentage encoding
    , encodeUrlPairs
    , decodeUrlPairs
      -- ** Post parameters
    , FileInfo (..)
    , parseMultipart
    , parsePost
      -- ** Cookies
    , decodeCookies
      -- * Date/time encoding
    , formatW3
    ) where

import Data.ByteString.Class
import qualified Data.ByteString.Lazy as BS
import Text.Printf (printf)
import Data.Word (Word8)
import Numeric (showHex)
import Data.List (isPrefixOf)
import Data.ByteString.Lazy.Util hiding (ord)
import Data.Mime.Header
import Data.Maybe (fromMaybe)

import Data.Time.Clock
import System.Locale
import Data.Time.Format

-- | Encode all but unreserved characters with percentage encoding.
--
-- Assumes use of UTF-8 character encoding.
encodeUrl :: (LazyByteString x, LazyByteString y) => x -> y
encodeUrl = fromLazyByteString
          . BS.concatMap encodeUrlByte
          . toLazyByteString

ord :: Integral i => Char -> i
ord = fromIntegral . fromEnum

encodeUrlByte :: Word8 -> BS.ByteString
encodeUrlByte w
    -- List of unreserved characters per RFC 3986
    -- Gleaned from http://en.wikipedia.org/wiki/Percent-encoding
    | ord 'A' <= w && w <= ord 'Z' = BS.singleton w
    | ord 'a' <= w && w <= ord 'z' = BS.singleton w
    | ord '0' <= w && w <= ord '9' = BS.singleton w
    | ord '-' == w = BS.singleton w
    | ord '_' == w = BS.singleton w
    | ord '.' == w = BS.singleton w
    | ord '~' == w = BS.singleton w
    | otherwise = toLazyByteString $ (printf "%%%02x" w :: String)

-- | Decode percentage encoding. Assumes use of UTF-8 character encoding.
decodeUrl :: (LazyByteString x, LazyByteString y) => x -> y
decodeUrl = fromLazyByteString . BS.pack . decodeUrlList . BS.unpack . toLazyByteString

decodeUrlList :: [Word8] -> [Word8]
-- note: percent sign is 37
decodeUrlList (37:x:y:rest) = (fromHex x) * 16 + (fromHex y)
                            : decodeUrlList rest
decodeUrlList (x:rest)
    | x == 43 = 32 : decodeUrlList rest -- convert plus to space
    | otherwise = x : decodeUrlList rest
decodeUrlList [] = []

fromHex :: Word8 -> Word8
fromHex x
    | 48 <= x && x <= 57 = x - 48 -- 0 - 9
    | 65 <= x && x <= 70 = x - 65 + 10 -- A - F
    | 97 <= x && x <= 102 = x - 97 + 10 -- a - f
    | otherwise = 0 -- FIXME

-- | Escape special HTML characters.
encodeHtml :: String -> String
encodeHtml = concatMap encodeHtmlChar

encodeHtmlChar :: Char -> String
encodeHtmlChar '<' = "&lt;"
encodeHtmlChar '>' = "&gt;"
encodeHtmlChar '&' = "&amp;"
encodeHtmlChar '"' = "&quot;"
encodeHtmlChar '\'' = "&#39;"
encodeHtmlChar c = [c]

-- | Convert into key-value pairs. Strips the leading ? if necesary.
decodeUrlPairs :: (LazyByteString x, LazyByteString y, LazyByteString z)
               => x
               -> [(y, z)]
decodeUrlPairs = map decodeUrlPair
               . BS.split (ord '&')
               . BS.dropWhile (== ord '?')
               . toLazyByteString

decodeUrlPair :: (LazyByteString a, LazyByteString b)
              => BS.ByteString
              -> (a, b)
decodeUrlPair b =
    let (x, y) = BS.break (== ord '=') b
        y' = BS.dropWhile (== ord '=') y
     in (decodeUrl x, decodeUrl y')

-- | Convert a list of key-values pairs into a query string.
-- Does not include the question mark at the beginning.
encodeUrlPairs :: (LazyByteString x, LazyByteString y, LazyByteString z)
               => [(x, y)]
               -> z
encodeUrlPairs = fromLazyByteString
               . BS.intercalate (BS.singleton $ ord '&')
               . map encodeUrlPair

encodeUrlPair :: (LazyByteString x, LazyByteString y)
               => (x, y)
               -> BS.ByteString
encodeUrlPair (x, y) = BS.concat
                     [ encodeUrl x
                     , BS.singleton $ ord '='
                     , encodeUrl y
                     ]

-- | Perform JSON-encoding on a string. Does not wrap in quotation marks.
encodeJson :: LazyByteString x => String -> x
encodeJson = fromLazyByteString . toLazyByteString . encJSString

-- | Taken from json package by Sigbjorn Finne.
encJSString :: String -> String
encJSString jss = go jss
  where
  go s1 =
    case s1 of
      (x   :xs) | x < '\x20' -> '\\' : encControl x (go xs)
      ('"' :xs)              -> '\\' : '"'  : go xs
      ('\\':xs)              -> '\\' : '\\' : go xs
      (x   :xs)              -> x    : go xs
      ""                     -> ""

  encControl x xs = case x of
    '\b' -> 'b' : xs
    '\f' -> 'f' : xs
    '\n' -> 'n' : xs
    '\r' -> 'r' : xs
    '\t' -> 't' : xs
    _ | x < '\x10'   -> 'u' : '0' : '0' : '0' : hexxs
      | x < '\x100'  -> 'u' : '0' : '0' : hexxs
      | x < '\x1000' -> 'u' : '0' : hexxs
      | otherwise    -> 'u' : hexxs
      where hexxs = showHex (fromEnum x) xs

-- | Information on an uploaded file.
data FileInfo = FileInfo
    { fileName :: String
    , fileContentType :: String
    , fileContent :: BS.ByteString
    }
instance Show FileInfo where
    show (FileInfo fn ct _) = "FileInfo: " ++ fn ++ " (" ++ ct ++ ")"

-- | Parse a multipart form into parameters and files.
parseMultipart :: LazyByteString lbs
               => lbs -- ^ boundary
               -> BS.ByteString -- ^ content
               -> ([(String, String)], [(String, FileInfo)])
parseMultipart boundary' content =
    let boundary :: String
        boundary = fromLazyByteString $ toLazyByteString boundary'
        pieces = getPieces boundary content
        getJusts [] = []
        getJusts (Nothing:rest) = getJusts rest
        getJusts ((Just x):rest) = x : getJusts rest
        getLefts [] = []
        getLefts (Left x:rest) = x : getLefts rest
        getLefts (Right _:rest) = getLefts rest
        getRights [] = []
        getRights (Left _:rest) = getRights rest
        getRights (Right x:rest) = x : getRights rest
        pieces' = getJusts $ map parsePiece pieces
     in (getLefts pieces', getRights pieces')

-- | Parse a single segment of a multipart/form-data POST.
parsePiece :: Monad m
           => BS.ByteString
           -> m (Either (String, String) (String, FileInfo))
parsePiece b = do
    let (headers', content) = takeUntilBlank b
        headers = map parseHeader headers'
    name <- lookupHeaderAttr "Content-Disposition" "name" headers
    let filename = lookupHeaderAttr "Content-Disposition" "filename" headers
    let ctype = fromMaybe "" $ lookupHeader "Content-Type" headers
    -- charset = lookupHeaderAttr "Content-Type" "charset" headers
    return $ case filename of
        Nothing -> Left (name, (fromLazyByteString $ chompBS content))
        Just f -> Right (name, FileInfo f ctype content)

-- | Split up a bytestring along the given boundary.
getPieces :: String -- ^ boundary
          -> BS.ByteString -- ^ content
          -> [BS.ByteString]
{- FIXME this would be nice...
getPieces b c =
    let fullBound = ord '-' `BS.cons'` (ord '-' `BS.cons'` b)
        pieces = fullBound `BS.split` c
     in filter (/= toLazyByteString "--") $
        filter (not . BS.null) $
        map chompBS pieces
-}
getPieces b c
    | BS.null c = []
    | otherwise =
        let fullBound = toLazyByteString ('-':'-':b)
            (next, rest) = breakAtString fullBound c
            rest' = checkRest rest
            rest'' = getPieces b rest'
         in if BS.null next then rest'' else chompBS next : rest''
    where
        br = ord '\r'
        bn = ord '\n'
        dash = ord '-'
        checkRest bs
            | BS.length bs < 2 = BS.empty
            | BS.head bs == bn = BS.tail bs
            | BS.head bs == br && BS.head (BS.tail bs) == bn =
                BS.tail $ BS.tail bs
            | BS.head bs == dash && BS.head (BS.tail bs) == dash = BS.empty
            | otherwise = BS.empty -- FIXME

-- | Parse a post request. This function determines the correct decoding
-- function to use.
parsePost :: String -- ^ content type
          -> String -- ^ content length
          -> BS.ByteString -- ^ body of the post
          -> ([(String, String)], [(String, FileInfo)])
parsePost ctype clength body
    | urlenc `isPrefixOf` ctype = (decodeUrlPairs content, [])
    | formBound `isPrefixOf` ctype = parseMultipart boundProcessed content
    | otherwise = ([], [])
    where
        len = case reads clength of
            ((x, _):_) -> x
            [] -> 0
        content = BS.take len body
        urlenc = "application/x-www-form-urlencoded"
        formBound = "multipart/form-data; boundary="
        boundProcessed = drop (length formBound) ctype

-- | Decode the value of an HTTP_COOKIE header into key/value pairs.
decodeCookies :: String -> [(String, String)]
decodeCookies [] = []
decodeCookies s =
    let (first, rest) = break (== ';') s
     in decodeCookie first : decodeCookies (dropWhile (== ';') rest)

decodeCookie :: String -> (String, String)
decodeCookie s =
    let (key, value) = break (== '=') s
        key' = dropWhile (== ' ') key
        value' =
          case value of
            ('=':rest) -> rest
            x -> x
     in (key', value')

-- | Format a 'UTCTime' in W3 format; useful for setting cookies.
formatW3 :: UTCTime -> String
formatW3 = formatTime defaultTimeLocale "%FT%X-00:00"