{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
---------------------------------------------------------
-- |
-- Module        : Web.Encodings
-- Copyright     : Michael Snoyman
-- License       : BSD3
--
-- Maintainer    : Michael Snoyman <michael@snoyman.com>
-- Stability     : Stable
-- Portability   : portable
--
-- Various web encodings.
--
---------------------------------------------------------
module Web.Encodings
    (
      -- * Simple encodings.
      -- ** URL (percentage encoding)
      encodeUrl
    , decodeUrl
    , decodeUrlFailure
    , DecodeUrlException (..)
      -- ** HTML (entity encoding)
    , encodeHtml
    , decodeHtml
      -- ** JSON
    , encodeJson
    , decodeJson
      -- * HTTP level encoding.
      -- ** Query string- pairs of percentage encoding
    , encodeUrlPairs
    , decodeUrlPairs
    , decodeUrlPairsFailure
      -- ** Post parameters
    , FileInfo (..)
    , parseMultipart
    , parsePost
      -- ** Specific HTTP headers
    , decodeCookies
    , parseCookies
    , parseHttpAccept
      -- * Date/time encoding
    , formatW3
      -- * WAI-specific decodings
    , Sink (..)
    , lbsSink
    , tempFileSink
    ) where

import Numeric (showHex)
import Data.List (isPrefixOf, sortBy)
import Web.Encodings.MimeHeader
import Data.Maybe (fromMaybe)

import Data.Time.Clock
import System.Locale
import Data.Time.Format

import Web.Encodings.StringLike (StringLike)
import qualified Web.Encodings.StringLike as SL
import Control.Failure
import Data.Char (ord, isControl)
import Data.Function (on)
import Data.Typeable (Typeable)
import Control.Exception (Exception)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy as BL
import Data.Maybe (catMaybes)
import Data.Either (partitionEithers)
import System.Directory (getTemporaryDirectory, removeFile)
import System.IO

-- | Encode all but unreserved characters with percentage encoding.
--
-- This function implicitly converts the given value to a UTF-8 bytestream.
encodeUrl :: StringLike a => a -> a
encodeUrl = SL.concatMapUtf8 encodeUrlChar

encodeUrlChar :: Char -> String
encodeUrlChar c
    -- List of unreserved characters per RFC 3986
    -- Gleaned from http://en.wikipedia.org/wiki/Percent-encoding
    | 'A' <= c && c <= 'Z' = [c]
    | 'a' <= c && c <= 'z' = [c]
    | '0' <= c && c <= '9' = [c]
encodeUrlChar c@'-' = [c]
encodeUrlChar c@'_' = [c]
encodeUrlChar c@'.' = [c]
encodeUrlChar c@'~' = [c]
encodeUrlChar ' ' = "+"
encodeUrlChar y =
    let (a, c) = fromEnum y `divMod` 16
        b = a `mod` 16
        showHex' x -- FIXME just use Numeric version?
            | x < 10 = toEnum $ x + (fromEnum '0')
            | x < 16 = toEnum $ x - 10 + (fromEnum 'A')
            | otherwise = error $ "Invalid argument to showHex: " ++ show x
     in ['%', showHex' b, showHex' c]

-- | Decode percentage encoding. Assumes use of UTF-8 character encoding.
--
-- If there are any parse errors, this returns the original input. If you would
-- like to be alerted more directly of errors, use 'decodeUrlFailure'.
decodeUrl :: StringLike s => s -> s
decodeUrl s = fromMaybe s $ decodeUrlFailure s

data DecodeUrlException = InvalidPercentEncoding | InvalidUtf8Encoding
    deriving (Show, Typeable)
instance Exception DecodeUrlException

-- | Same as 'decodeUrl', but 'failure's on either invalid percent or UTF8
-- encoding.
decodeUrlFailure :: (Failure DecodeUrlException m, StringLike s,
                     Monad m)
                 => s -> m s
decodeUrlFailure s = do
    bs <- decodeUrlFailure' s
    case SL.unpackUtf8 bs of
        Nothing -> failure InvalidUtf8Encoding
        Just x -> return x

decodeUrlFailure' :: (Failure DecodeUrlException m, StringLike s,
                      Monad m)
                  => s -> m BS.ByteString
decodeUrlFailure' s = do
  case SL.uncons s of
    Nothing -> return SL.empty
    Just (a, s') ->
      case a of
        '%' -> do
          case SL.uncons s' of
            Nothing -> failure InvalidPercentEncoding
            Just (b, s'') ->
              case SL.uncons s'' of
                Nothing -> failure InvalidPercentEncoding
                Just (c, s''') ->
                  case getHex b c of
                    Nothing -> failure InvalidPercentEncoding
                    Just h -> do
                      s'''' <- decodeUrlFailure' s'''
                      return $ h `SL.cons` s''''
        '+' -> do
            s'' <- decodeUrlFailure' s'
            return $ ' ' `SL.cons` s''
        _ -> do
            s'' <- decodeUrlFailure' s'
            return $ a `SL.cons` s''

getHex :: Char -> Char -> Maybe Char
getHex x y = do
    x' <- hexVal x
    y' <- hexVal y
    return $ toEnum $ x' * 16 + y'

-- | Escape special HTML characters.
encodeHtml :: StringLike s => s -> s
encodeHtml = SL.concatMap encodeHtmlChar

encodeHtmlChar :: Char -> String
encodeHtmlChar '<' = "&lt;"
encodeHtmlChar '>' = "&gt;"
encodeHtmlChar '&' = "&amp;"
encodeHtmlChar '"' = "&quot;"
encodeHtmlChar '\'' = "&#39;"
encodeHtmlChar c = [c]

-- | Decode HTML-encoded content into plain content.
--
-- Note: this does not support all HTML entities available. It also swallows
-- all failures.
decodeHtml :: StringLike s => s -> s
decodeHtml s = case SL.uncons s of
    Nothing -> SL.empty
    Just ('&', xs) -> fromMaybe ('&' `SL.cons` decodeHtml xs) $ do
        (before, after) <- SL.breakCharMaybe ';' xs
        c <- case SL.unpack before of -- this are small enough that unpack is ok
            "lt" -> return '<'
            "gt" -> return '>'
            "amp" -> return '&'
            "quot" -> return '"'
            '#' : 'x' : hex -> readHexChar hex
            '#' : 'X' : hex -> readHexChar hex
            '#' : dec -> readDecChar dec
            _ -> Nothing -- just to shut up a warning
        return $ c `SL.cons` decodeHtml after
    Just (x, xs) -> x `SL.cons` decodeHtml xs

readHexChar :: String -> Maybe Char
readHexChar s = helper 0 s where
    helper i "" = return $ toEnum i
    helper i (c:cs) = do
        c' <- hexVal c
        helper (i * 16 + c') cs

hexVal :: Char -> Maybe Int
hexVal c
    | '0' <= c && c <= '9' = Just $ ord c - ord '0'
    | 'A' <= c && c <= 'F' = Just $ ord c - ord 'A' + 10
    | 'a' <= c && c <= 'f' = Just $ ord c - ord 'a' + 10
    | otherwise = Nothing

readDecChar :: String -> Maybe Char
readDecChar s = do
    case reads s of
        (i, _):_ -> Just $ toEnum (i :: Int)
        _ -> Nothing

-- | Convert into key-value pairs. Strips the leading ? if necesary.
decodeUrlPairs :: StringLike s
               => s
               -> [(s, s)]
decodeUrlPairs = map decodeUrlPair
               . SL.split '&'
               . SL.dropWhile (== '?')

-- | Convert into key-value pairs. Strips the leading ? if necesary. 'failure's
-- as necesary for invalid encodings.
decodeUrlPairsFailure :: (StringLike s, Failure DecodeUrlException m,
                          Monad m)
                      => s
                      -> m [(s, s)]
decodeUrlPairsFailure = mapM decodeUrlPairFailure
                      . SL.split '&'
                      . SL.dropWhile (== '?')

decodeUrlPairFailure :: (StringLike s, Failure DecodeUrlException m,
                         Monad m)
                     => s
                     -> m (s, s)
decodeUrlPairFailure b = do
    let (x, y) = SL.breakChar '=' b
    x' <- decodeUrlFailure x
    y' <- decodeUrlFailure y
    return (x', y')

decodeUrlPair :: StringLike s
              => s
              -> (s, s)
decodeUrlPair b =
    let (x, y) = SL.breakChar '=' b
     in (decodeUrl x, decodeUrl y)

-- | Convert a list of key-values pairs into a query string.
-- Does not include the question mark at the beginning.
encodeUrlPairs :: StringLike s
               => [(s, s)]
               -> s
encodeUrlPairs = SL.intercalate (SL.pack "&")
               . map encodeUrlPair

encodeUrlPair :: StringLike s
               => (s, s)
               -> s
encodeUrlPair (x, y) = encodeUrl x `SL.append` ('=' `SL.cons` encodeUrl y)

-- | Perform JSON-encoding on a string. Does not wrap in quotation marks.
-- Taken from json package by Sigbjorn Finne.
encodeJson :: StringLike s => s -> s
encodeJson = SL.concatMap encodeJsonChar

encodeJsonChar :: Char -> String
encodeJsonChar '\b' = "\\b"
encodeJsonChar '\f' = "\\f"
encodeJsonChar '\n' = "\\n"
encodeJsonChar '\r' = "\\r"
encodeJsonChar '\t' = "\\t"
encodeJsonChar '"' = "\\\""
encodeJsonChar '\\' = "\\\\"
encodeJsonChar c
    | not $ isControl c = [c]
    | c < '\x10'   = '\\' : 'u' : '0' : '0' : '0' : hexxs
    | c < '\x100'  = '\\' : 'u' : '0' : '0' : hexxs
    | c < '\x1000' = '\\' : 'u' : '0' : hexxs
    where hexxs = showHex (fromEnum c) "" -- FIXME
encodeJsonChar c = [c]

decodeJson :: StringLike s => s -> s
decodeJson s = case SL.uncons s of
    Nothing -> SL.empty
    Just ('\\', xs) -> fromMaybe ('\\' `SL.cons` decodeJson xs) $ do
        (x, xs') <- SL.uncons xs
        if x == 'u'
            then do
                (a, e) <- SL.uncons xs'
                (b, f) <- SL.uncons e
                (c, g) <- SL.uncons f
                (d, h) <- SL.uncons g
                res <- readHexChar [a, b, c, d]
                return $ res `SL.cons` decodeJson h
            else do
                c <- case x of
                        'b' -> return '\b'
                        'f' -> return '\f'
                        'n' -> return '\n'
                        'r' -> return '\r'
                        't' -> return '\t'
                        '"' -> return '"'
                        '\'' -> return '\''
                        '\\' -> return '\\'
                        _ -> Nothing
                return $ c `SL.cons` decodeJson xs'
    Just (x, xs) -> x `SL.cons` decodeJson xs

-- | Information on an uploaded file.
data FileInfo s c = FileInfo
    { fileName :: s
    , fileContentType :: s
    , fileContent :: c
    }
    deriving (Eq, Show)

-- | Parse a multipart form into parameters and files.
parseMultipart :: StringLike s
               => String -- ^ boundary
               -> s -- ^ content
               -> ([(s, s)], [(s, FileInfo s s)])
parseMultipart boundary =
    partitionEithers . catMaybes . map parsePiece . getPieces boundary

-- | Parse a single segment of a multipart/form-data POST.
parsePiece :: (StringLike s, Failure (AttributeNotFound s) m,
               Monad m)
           => s
           -> m (Either (s, s) (s, FileInfo s s))
parsePiece b = do
    let (headers', content) = SL.takeUntilBlank b
        headers = map parseHeader headers'
    name <- lookupHeaderAttr (SL.pack "Content-Disposition")
                             (SL.pack "name")
                             headers
    let filename = lookupHeaderAttr (SL.pack "Content-Disposition")
                                    (SL.pack "filename")
                                    headers
    let ctype = fromMaybe SL.empty $ lookupHeader (SL.pack "Content-Type")
                                                  headers
    -- charset = lookupHeaderAttr "Content-Type" "charset" headers
    return $ case filename of
        Nothing -> Left (name, SL.chomp content)
        Just f -> Right (name, FileInfo f ctype content)

-- | Split up a bytestring along the given boundary.
getPieces :: StringLike s
          => String -- ^ boundary
          -> s -- ^ content
          -> [s]
getPieces _ c | SL.null c = []
getPieces b c =
        let fullBound = SL.pack $ '-' `SL.cons` ('-' `SL.cons` b)
            (next, rest) = SL.breakString fullBound c
            rest' = checkRest rest
            rest'' = getPieces b rest'
         in if SL.null next then rest'' else SL.chomp next : rest''
    where
        br = '\r'
        bn = '\n'
        dash = '-'
        checkRest bs
            | SL.lengthLT 2 bs = SL.empty
            | SL.head bs == bn = SL.tail bs
            | SL.head bs == br && SL.head (SL.tail bs) == bn =
                SL.tail $ SL.tail bs
            | SL.head bs == dash && SL.head (SL.tail bs) == dash = SL.empty
            | otherwise = SL.empty -- FIXME

-- | Parse a post request. This function determines the correct decoding
-- function to use.
parsePost :: StringLike s
          => String -- ^ content type
          -> String -- ^ content length
          -> s -- ^ body of the post
          -> ([(s, s)], [(s, FileInfo s s)])
parsePost ctype clength body
    | urlenc `SL.isPrefixOf` ctype = (decodeUrlPairs content, [])
    | formBound `isPrefixOf` ctype = parseMultipart boundProcessed content
    | otherwise = ([], [])
    where
        len = case reads clength of
            ((x, _):_) -> x
            [] -> 0
        content = SL.take len body
        boundProcessed = drop (length formBound) ctype

urlenc :: String
urlenc = "application/x-www-form-urlencoded"

formBound :: String
formBound = "multipart/form-data; boundary="

{-# DEPRECATED decodeCookies "Please use parseCookies instead" #-}
-- | Deprecate alias for 'parseCookies'.
decodeCookies :: StringLike s => s -> [(s, s)]
decodeCookies = parseCookies

-- | Decode the value of an HTTP_COOKIE header into key/value pairs.
parseCookies :: StringLike s => s -> [(s, s)]
parseCookies s
  | SL.null s = []
  | otherwise =
    let (first, rest) = SL.break (== ';') s
     in parseCookie first : parseCookies (SL.dropWhile (== ';') rest)

parseCookie :: StringLike s => s -> (s, s)
parseCookie s =
    let (key, value) = SL.break (== '=') s
        key' = SL.dropWhile (== ' ') key
        value' =
          case SL.uncons value of
            Just ('=', rest) -> rest
            _ -> value
     in (key', value')

-- | Parse the HTTP accept string to determine supported content types.
parseHttpAccept :: StringLike s => s -> [s]
parseHttpAccept = map fst
                . sortBy (rcompare `on` snd)
                . map grabQ
                . SL.split ','

rcompare :: Ord a => a -> a -> Ordering
rcompare x y = case compare x y of
    LT -> GT
    GT -> LT
    EQ -> EQ

grabQ :: StringLike s => s -> (s, Double)
grabQ s =
    let (s', q) = SL.breakChar ';' s
        (_, q') = SL.breakChar '=' q
     in (trimWhite s', readQ $ trimWhite q')

readQ :: StringLike s => s -> Double
readQ s = case reads $ SL.unpack s of
            (x, _):_ -> x
            _ -> 1.0

trimWhite :: StringLike s => s -> s
trimWhite = SL.dropWhile (== ' ')

-- | Format a 'UTCTime' in W3 format; useful for setting cookies.
formatW3 :: UTCTime -> String
formatW3 = formatTime defaultTimeLocale "%FT%X-00:00"

-- | A destination for data, the opposite of a 'Source'.
data Sink x y = Sink
    { sinkInit :: IO x
    , sinkAppend :: x -> BS.ByteString -> IO x
    , sinkClose :: x -> IO y
    , sinkFinalize :: y -> IO ()
    }

lbsSink :: Sink ([BS.ByteString] -> [BS.ByteString]) BL.ByteString
lbsSink = Sink
    { sinkInit = return id
    , sinkAppend = \front bs -> return $ front . (:) bs
    , sinkClose = \front -> return $ BL.fromChunks $ front []
    , sinkFinalize = \_ -> return ()
    }

tempFileSink :: Sink (FilePath, Handle) FilePath
tempFileSink = Sink
    { sinkInit = do
        tempDir <- getTemporaryDirectory
        openBinaryTempFile tempDir "webenc.buf"
    , sinkAppend = \(fp, h) bs -> BS.hPut h bs >> return (fp, h)
    , sinkClose = \(fp, h) -> do
        hClose h
        return fp
    , sinkFinalize = \fp -> removeFile fp
    }

data ParseState seed
    = PSBegin ([BS.ByteString] -> [BS.ByteString])
    | PSParam BL.ByteString BL.ByteString
    | PSFile BL.ByteString BL.ByteString BL.ByteString seed BS.ByteString
    | PSNothing
instance Show (ParseState x) where
    show (PSBegin x) = show ("PSBegin" :: String, B8.unpack $ B8.concat $ x [])
    show (PSParam x y) = show ("PSParam" :: String, x, y)
    show (PSFile x y z _ _) = show ("PSFile" :: String, x, y, z)
    show PSNothing = "PSNothing"