{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
module Snap.AtlassianConnect.QueryStringHash
    ( createQueryStringHash
    , module Network.HTTP.Types
    ) where

import           Control.Applicative
import           Control.Monad         (guard)
import qualified Crypto.Hash           as SHA
import qualified Data.ByteString.Char8 as B
import           Data.Function
import           Data.List
import           Data.List.Split
import           Data.Maybe
import           Data.Monoid
import           Data.Ord
import qualified Data.Text             as T
import qualified Data.Text.Encoding    as TE
import           Network.HTTP.Types
import           Network.URI


-- $setup
-- >>> import Data.Maybe
-- >>> import Data.Text

-- | Create a query string hash (required for the JWT claim set for an
-- Atlassian connect application) based on the HTTP method, the base URL
-- and the url to be used. This function will canonicalize the given URL
-- based on the rules outlined on <https://developer.atlassian.com/static/connect/docs/concepts/understanding-jwt.html#qsh>
--
-- >>> :{
--    let
--        baseUrl = fromMaybe nullURI $ parseURI "http://localhost:2990"
--        input   = "http://localhost:2990/path/to/service?a=1&A=2&b=3&B=4" :: Text
--    in createQueryStringHash GET baseUrl input
-- :}
-- Just "70282c7cf82834bd5a3d6dacda1b4ccd5cf5860a63a1fa2fb86b64d576e6a1d5"
createQueryStringHash :: StdMethod -> URI -> T.Text -> Maybe T.Text
createQueryStringHash method baseUrl fullUrl =
   (TE.decodeUtf8 . SHA.digestToHexByteString . hsh) <$> toCanonicalUrl method baseUrl fullUrl

hsh :: T.Text -> SHA.Digest SHA.SHA256
hsh = SHA.hash . TE.encodeUtf8

-- TODO we ask for the method just so that we can run show method on it...I think we should give it
-- here in a different way and not mandate StdMethod.
toCanonicalUrl :: StdMethod -> URI -> T.Text -> Maybe T.Text
toCanonicalUrl method baseUrl' rawFullUrl = do
   fullUrl <- parseURI (T.unpack rawFullUrl)
   guard (comparing uriScheme baseUrl' fullUrl == EQ)
   guard (comparing uriAuthority baseUrl' fullUrl == EQ)
   path' <- uriPath <$> stripBaseUrl baseUrl' fullUrl
   let sqs = sortedQueryString fullUrl
   return . T.pack $ intercalate "&" [show method, path', sqs]

sortedQueryString :: URI -> String
sortedQueryString = toCanonicalQueryString . parseQueryText . B.pack . uriQuery

stripBaseUrl :: URI -> URI -> Maybe URI
stripBaseUrl baseUrl' fullUrl = do
    strippedPath <- stripPrefix (uriPath baseUrl') (uriPath fullUrl)
    return fullUrl
        { uriScheme = ""
        , uriAuthority = Nothing
        , uriPath = strippedPath
        }

{-
 - See step 5 of Creating a Query Hash:
 - https://developer.atlassian.com/static/connect/docs/concepts/understanding-jwt.html
* Sort the query parameters primarily by their percent-encoded names and secondarily by their percent-encoded values
* Sorting is by codepoint: sort(["a", "A", "b", "B"]) => ["A", "B", "a", "b"]
* For each parameter append its percent-encoded name, the '=' character and then its percent-encoded value.
* In the case of repeated parameters append the ',' character and subsequent percent-encoded values.
* Ignore the jwt parameter, if present.
* Some particular values to be aware of:
    "+" is encoded as "%20",
    "*" as "%2A" and
    "~" as "~".
    (These values used for consistency with OAuth1.)
 This method needs to be functionally equivalent to: com.atlassian.jwt.core.HttpRequestCanonicalizer#canonicalizeQueryParameters from atlassian-jwt
-}
toCanonicalQueryString :: QueryText -> String
toCanonicalQueryString = T.unpack . render . joinQueryParams . groupAndSortQueryParams . ignoreJWTParam

type QueryParam = (T.Text, Maybe T.Text)

ignoreJWTParam :: [(T.Text, a)] -> [(T.Text, a)]
ignoreJWTParam = filter ((/= "jwt") . fst)

sortParamKeys :: Ord a => [(a, b)] -> [(a, b)]
sortParamKeys = sortBy (comparing fst)

sortParamValues :: Ord b => [(a, b)] -> [(a, b)]
sortParamValues = sortBy (comparing snd)

groupAndSortQueryParams :: [QueryParam] -> [[QueryParam]]
groupAndSortQueryParams = fmap sortParamValues . groupBy ((==) `on` fst) . sortParamKeys

joinQueryParams :: [[QueryParam]] -> [(T.Text, T.Text)]
joinQueryParams = catMaybes . fmap joinQueryParam

joinQueryParam :: [QueryParam] -> Maybe (T.Text, T.Text)
joinQueryParam [] = Nothing
joinQueryParam xs@(x : _) = return (fst x, T.intercalate sep . catMaybes . fmap snd $ xs)
   where
      sep = T.singleton ','

queryParamToString :: (T.Text, T.Text) -> T.Text
queryParamToString (key, value) = encode key <> T.singleton '=' <> encode value

render :: [(T.Text, T.Text)] -> T.Text
render = T.intercalate "&" . fmap queryParamToString

encode :: T.Text -> T.Text
encode = TE.decodeUtf8 . urlEncode True . TE.encodeUtf8