{-# LANGUAGE OverloadedStrings, ForeignFunctionInterface #-}
-- | This module contains a structure for representing web URLs. We don't try
--   to be a fully general URI parser (so no @mailto:@ etc), but it's a lot
--   better than "Network.URI" for dealing with HTTP(S)
module Network.MiniHTTP.URL
  ( URL(..)
  , RelativeURL(..)
  , Scheme(..)
  , Host(..)
  , toRelative
  , parse
  , parseRelative
  , parseArguments
  , serialise
  , serialiseRelative
  , serialiseArguments
  ) where

import           Control.Applicative(Alternative(..))
import           Control.Exception (handle)
import           Control.Monad (when, liftM)

import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import           Data.ByteString.Char8 ()
import           Data.ByteString.Internal (c2w, w2c)
import qualified Data.Binary.Put as P
import qualified Data.Binary.Strict.Class as C
import qualified Data.Binary.Strict.Get as G
import qualified Data.Binary.Strict.ByteSet as BSet
import           Data.Char (toLower)
import           Data.List (intersperse)
import qualified Data.Map as Map
import           Data.Maybe (fromMaybe, isJust, catMaybes)
import           Data.String (fromString)

import           Foreign
import           Foreign.C.Types

import           System.IO.Unsafe (unsafePerformIO)

import           Text.Printf (printf)

import qualified Network.Socket as S


-- | A web URL
data URL = URL { urlScheme :: Scheme
               , urlUser :: Maybe B.ByteString
               , urlPassword :: Maybe B.ByteString
               , urlHost :: Host
               , urlPort :: Int  -- ^ defaults based on the scheme
               , urlPath :: B.ByteString  -- ^ does not include leading '/'
               , urlArguments :: Map.Map B.ByteString B.ByteString
               , urlFragment :: Maybe B.ByteString  -- ^ doesn't include '#'
               } deriving (Eq)

instance Show URL where
  show = map w2c . B.unpack . serialise

-- | This is a relative URL. It just copies (and renames) the last three
--   members of URL. However, it's good to keep these two different types of
--   URL apart in the type system.
data RelativeURL = RelativeURL { rurlPath :: B.ByteString  -- ^ does not include leading '/'
                               , rurlArguments :: Map.Map B.ByteString B.ByteString
                               , rurlFragment :: Maybe B.ByteString
                               } deriving (Show, Eq)

-- | The transport layer to be used
data Scheme = HTTP | HTTPS deriving (Show, Eq)

-- | The host where the resource can be found
data Host = IPv4Literal S.HostAddress
          | IPv6Literal S.HostAddress6
          | Hostname B.ByteString
          deriving (Show, Eq)

wrapGet :: G.Get a -> B.ByteString -> Maybe a
wrapGet f bs =
  case G.runGet f bs of
       (Left _, _) -> Nothing
       (Right u, _) -> Just u

-- | Parse a URL
parse :: B.ByteString -> Maybe URL
parse = wrapGet parseURL

-- | Parse just an arguments map. Can be useful for POST requests. Warning: it
--   occurs to be that the arguments in a POST request might include unescaped
--   '#' symbols. In a URL that would be illegal, but the parser may need to be
--   reworked for that.
parseArguments :: B.ByteString -> Maybe (Map.Map B.ByteString B.ByteString)
parseArguments = fmap Map.fromList . wrapGet parseKVs

-- | Parse a relative URL
parseRelative :: B.ByteString -> Maybe RelativeURL
parseRelative = wrapGet parseRelativeURL

-- | Extract a relative URL from a URL
toRelative :: URL -> RelativeURL
toRelative (URL { urlPath = path, urlArguments = args, urlFragment = frag }) =
  RelativeURL path args frag

foreign import ccall unsafe "inet_ntop"
  inet_ntop :: CInt           -> Ptr Word32 -> Ptr CChar -> CSize -> IO ()

foreign import ccall unsafe "htonl"
  htonl :: Word32 -> IO Word32

serialiseIPv4 :: S.HostAddress -> B.ByteString
serialiseIPv4 v4 = unsafePerformIO $ do
  alloca $ \ptr -> do
  poke ptr v4
  allocaBytes 17 $ \str -> do
  inet_ntop (S.packFamily S.AF_INET) ptr str 17
  B.packCString str

serialiseIPv6 :: S.HostAddress6 -> B.ByteString
serialiseIPv6 (a, b, c, d) = unsafePerformIO $ do
  allocaArray 4 $ \ptr -> do
  mapM htonl [a,b,c,d] >>= pokeArray ptr
  allocaBytes 47 $ \str -> do
  inet_ntop (S.packFamily S.AF_INET6) ptr str 47
  B.packCString str

defaultPort :: Scheme -> Int
defaultPort HTTP = 80
defaultPort HTTPS = 443

doPut :: (a -> P.Put) -> a -> B.ByteString
doPut p = B.concat . BL.toChunks . P.runPut . p

-- | Convert a URL to a ByteString. It's the same as "show", except for the
--   type of the return.
serialise :: URL -> B.ByteString
serialise = doPut putURL

serialiseRelative :: RelativeURL -> B.ByteString
serialiseRelative = doPut putRelativeURL

-- | Serialise just an arguments map. Can be useful for POST requests.
serialiseArguments :: Map.Map B.ByteString B.ByteString -> B.ByteString
serialiseArguments = doPut putArguments

putURL :: URL -> P.Put
putURL url = do
  case urlScheme url of
       HTTP -> P.putByteString "http"
       HTTPS -> P.putByteString "https"
  P.putByteString "://"

  maybe (return ()) P.putByteString $ urlUser url
  maybe (return ()) (\x -> P.putWord8 (c2w ':') >> P.putByteString x) $ urlPassword url
  when (isJust (urlUser url) || isJust (urlPassword url)) $ P.putWord8 $ c2w '@'

  case urlHost url of
       IPv4Literal v4 -> P.putByteString $ serialiseIPv4 v4
       IPv6Literal v6 -> do
         P.putWord8 $ c2w '['
         P.putByteString $ serialiseIPv6 v6
         P.putWord8 $ c2w ']'
       Hostname h -> P.putByteString h

  when (urlPort url /= defaultPort (urlScheme url)) $ do
    P.putWord8 $ c2w ':'
    P.putByteString $ fromString $ show $ urlPort url

  putRelativeURL $ toRelative url

putRelativeURL :: RelativeURL -> P.Put
putRelativeURL rurl = do
  P.putWord8 $ c2w '/'
  P.putByteString $ encodeString pathSafeChars $ rurlPath rurl
  when (not $ Map.null $ rurlArguments rurl) $ do
    P.putWord8 $ c2w '?'
    putArguments $ rurlArguments rurl
    maybe (return ()) P.putByteString $ rurlFragment rurl

putArguments :: Map.Map B.ByteString B.ByteString -> P.Put
putArguments args = do
  let f ("", "") = Nothing
      f (k, "") = Just $ P.putByteString $ encodeString safeChars k
      f (k, v) = Just $ do
        P.putByteString $ encodeString safeChars k
        P.putWord8 $ c2w '='
        P.putByteString $ encodeString safeChars v
  let vs = flip map (Map.toList args) f
  sequence_ $ intersperse (P.putWord8 $ c2w '&') $ catMaybes vs

digits :: BSet.ByteSet
digits = BSet.range (c2w '0') (c2w '9')

hexChars :: BSet.ByteSet
hexChars = BSet.fromList $ map c2w "0123456789abcdefABCDEF"

isHexChar :: Word8 -> Bool
isHexChar = BSet.member hexChars

-- | The list of unsafe charactors in a URL, from RFC 1738
unsafeChars :: BSet.ByteSet
unsafeChars = BSet.range 0 0x1f `BSet.union`
              BSet.range 0x7f 0xff `BSet.union`
              BSet.fromList (map c2w "$&+,/:;=?@ \"<>#%{}[]|\\^~`")

-- | Bytes which aren't unsafe, are safe.
safeChars :: BSet.ByteSet
safeChars = BSet.complement unsafeChars

-- | These are the charactors which are safe in a path
pathSafeChars :: BSet.ByteSet
pathSafeChars = safeChars `BSet.union` BSet.singleton (c2w '/')

-- | Parse a single key, value pair
parseKV :: G.Get (B.ByteString, B.ByteString)
parseKV = do
  key <- C.spanOf (\x -> x /= c2w '#' && x /= c2w '=' && x /= c2w '&')
  value <- C.optional $ do
    C.word8 $ c2w '='
    C.spanOf (\x -> x /= c2w '#' && x /= c2w '&')

  return (decodeString key, decodeString $ fromMaybe "" value)

-- | Parse a set of URL encoded key, value pairs.
parseKVs :: G.Get [(B.ByteString, B.ByteString)]
parseKVs = do
  first <- parseKV
  rest <- C.many $ (C.word8 $ c2w '&') >> parseKV

  return $ first : rest

parseIPv6 :: G.Get B.ByteString
parseIPv6 = do
  C.word8 $ c2w '['
  s <- C.spanOf1 (/= c2w ']')
  C.word8 $ c2w ']'

  return s

toString :: B.ByteString -> String
toString = map w2c . B.unpack

parseRelativeURL :: G.Get RelativeURL
parseRelativeURL = do
  emptyp <- C.isEmpty
  if emptyp
     then return $ RelativeURL "" Map.empty Nothing
     else do
       C.word8 $ c2w '/'
       path <- C.spanOf (\x -> x /= c2w '?' && x /= c2w '#')
       margs <- C.optional $ do
         C.word8 $ c2w '?'
         liftM Map.fromList parseKVs
       mfrag <- C.optional $ do
         C.word8 $ c2w '#'
         rem <- C.remaining
         C.getByteString rem

       emptyp <- C.isEmpty
       when (not emptyp) $ fail "Trailing garbage"

       return $ RelativeURL path (fromMaybe Map.empty margs) mfrag

parseURL :: G.Get URL
parseURL = do
  scheme' <- C.spanOf1 (/= c2w ':')
  scheme <- case map (toLower . w2c) $ B.unpack scheme' of
                 "http" -> return HTTP
                 "https" -> return HTTPS
                 _ -> fail "Unknown scheme"
  C.string "://"

  muserpw <- C.optional $ do
    user <- C.spanOf (\x -> x /= c2w '@' && x /= c2w ':')
    pw <- C.optional $ do
      C.word8 $ c2w ':'
      C.spanOf (/= c2w '@')
    C.word8 $ c2w '@'
    return (user, pw)

  host' <- parseIPv6 <|> C.spanOf1 (\x -> x /= c2w ':' && x /= c2w '/')

  mhost <- return $ unsafePerformIO $ handle (const $ return Nothing) $ do
    ai <- S.getAddrInfo (Just (S.defaultHints { S.addrFlags = [S.AI_NUMERICHOST] }))
                        (Just $ toString host') Nothing
    case ai of
         [] -> return Nothing
         ai:_ ->
           case S.addrFamily ai of
                S.AF_INET ->
                  case S.addrAddress ai of
                       (S.SockAddrInet _ host) -> return $ Just $ IPv4Literal host
                       _ -> return Nothing
                S.AF_INET6 ->
                  case S.addrAddress ai of
                       (S.SockAddrInet6 _ _ host _) -> return $ Just $ IPv6Literal host
                       _ -> return Nothing
                _ -> return Nothing

  mport <- C.optional $ do
    C.word8 $ c2w ':'
    s <- C.spanOf1 $ BSet.member digits
    case reads $ map w2c $ B.unpack s of
         [(x, "")] ->
           if x > 0 && x < 65536
              then return x
              else fail "Port number out of range"
         _ -> fail "Invalid port number"

  rurl <- parseRelativeURL

  let url = URL { urlScheme = scheme
                , urlHost = case mhost of
                                 Just h -> h
                                 Nothing -> Hostname host'
                , urlPort = case mport of
                                 Just p -> p
                                 Nothing -> defaultPort scheme
                , urlPath = rurlPath rurl
                , urlArguments = rurlArguments rurl
                , urlFragment = rurlFragment rurl
                , urlUser = Nothing
                , urlPassword = Nothing }
      url'' = case muserpw of
                   Just (user, mpw) -> url { urlUser = Just user
                                           , urlPassword = mpw }
                   Nothing -> url

  return url''

toHexNibble :: Word8 -> Word8
toHexNibble 0x30 = 0
toHexNibble 0x31 = 1
toHexNibble 0x32 = 2
toHexNibble 0x33 = 3
toHexNibble 0x34 = 4
toHexNibble 0x35 = 5
toHexNibble 0x36 = 6
toHexNibble 0x37 = 7
toHexNibble 0x38 = 8
toHexNibble 0x39 = 9
toHexNibble 0x41 = 10
toHexNibble 0x42 = 11
toHexNibble 0x43 = 12
toHexNibble 0x44 = 13
toHexNibble 0x45 = 14
toHexNibble 0x46 = 15
toHexNibble 0x61 = 10
toHexNibble 0x62 = 11
toHexNibble 0x63 = 12
toHexNibble 0x64 = 13
toHexNibble 0x65 = 14
toHexNibble 0x66 = 15
toHexNibble _ = error "toHexNibble passed non-hex char"

toHexByte :: B.ByteString -> Word8
toHexByte bs = (toHexNibble (bs `B.index` 0) `shiftL` 4) .|.
               (toHexNibble (bs `B.index` 1))

-- | Replace any % escaped bytes in the given string with their Word8 values.
decodePercents :: B.ByteString -> B.ByteString
decodePercents bs = f (left, right) where
  (left, right) = B.span (/= c2w '%') bs
  f (left, right)
    | B.null right = left
    | B.length right >= 3 &&
      isHexChar (right `B.index` 1) &&
      isHexChar (right `B.index` 2) =
        left `B.append` B.singleton (toHexByte $ B.tail $ B.take 3 right) `B.append` decodeString (B.drop 3 right)
    | otherwise = bs

-- | Replace pluses in the given string with spaces and then perform percent
--   decoding.
decodeString :: B.ByteString -> B.ByteString
decodeString = decodePercents . B.map f where
  f 0x2b = 0x20
  f x = x

-- | Percent encode any unsafe bytes in the given string
encodeString :: BSet.ByteSet -> B.ByteString -> B.ByteString
encodeString safeChars bs = f (left, right) where
  (left, right) = B.span (BSet.member safeChars) bs
  f (left, right)
    | B.null right = left
    | otherwise = left `B.append` escaped `B.append` encodeString safeChars right' where
        right' = B.tail right
        unsafe = B.head right
        escaped = B.pack $ map c2w $ printf "%%%X" ((fromIntegral unsafe) :: Int)