module Aws.SignatureV4
(
  GeneralVersion(..)
, generalVersionToText
, parseGeneralVersion
, signatureVersion
, SignatureV4Credentials(..)
, newCredentials
, signPostRequest
, signGetRequest
, signPostRequestIO
, signGetRequestIO
, AuthorizationInfo(..)
, authorizationInfo
, authorizationInfoQuery
, authorizationInfoHeader
, dateNormalizationEnabled
, signingAlgorithm
, UriPath
, UriQuery
, normalizeUriPath
, normalizeUriQuery
, CanonicalUri(..)
, canonicalUri
, CanonicalHeaders(..)
, canonicalHeaders
, SignedHeaders
, signedHeaders
, CanonicalRequest(..)
, canonicalRequest
, HashedCanonicalRequest
, hashedCanonicalRequest
, CredentialScope(..)
, credentialScopeToText
, StringToSign(..)
, stringToSign
, SigningKey(..)
, signingKey
, Signature(..)
, requestSignature
) where
import Aws.General
import Control.Applicative
import Control.Arrow hiding (left)
import Control.Monad.IO.Class
import Crypto.Hash
import Data.Byteable
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Blaze.ByteString.Builder as BB
import qualified Blaze.ByteString.Builder.Char8 as BB8
import qualified Data.ByteString.Base16 as B16
import Data.Char
import qualified Data.CaseInsensitive as CI
import Data.IORef
import qualified Data.List as L
import Data.Monoid
import Data.String
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.Time.Clock (UTCTime, getCurrentTime, utctDay)
import Data.Time.Format (formatTime, parseTime)
import Data.Typeable
import qualified Test.QuickCheck as Q
import Test.QuickCheck.Instances ()
import qualified Text.Parser.Char as P
import qualified Text.Parser.Combinators as P
#if MIN_VERSION_time(1,5,0)
import Data.Time.Format
#else
import System.Locale
#endif
import qualified Network.HTTP.Types as HTTP
signingAlgorithm :: IsString a => a
signingAlgorithm = "AWS4-HMAC-SHA256"
signingHash :: B.ByteString -> B.ByteString
signingHash i = toBytes (hash i :: Digest SHA256)
signingHash16 :: B.ByteString -> B8.ByteString
signingHash16 = B16.encode . signingHash
signingHmac :: B.ByteString -> B.ByteString -> B.ByteString
signingHmac k i = toBytes (hmac k i :: HMAC SHA256)
signatureVersion :: IsString a => a
signatureVersion = "4"
type SigV4Key = ((B.ByteString,B.ByteString),(B.ByteString,B.ByteString))
data SignatureV4Credentials = SignatureV4Credentials
    { sigV4AccessKeyId :: B.ByteString
    , sigV4SecretAccessKey :: B.ByteString
    , sigV4SigningKeys :: IORef [SigV4Key]
    
    , sigV4SecurityToken :: Maybe B.ByteString
    }
    deriving (Typeable)
newCredentials
    :: (Functor m, MonadIO m)
    => B.ByteString 
    -> B.ByteString 
    -> Maybe B.ByteString 
    -> m SignatureV4Credentials
newCredentials accessKeyId secretAccessKey securityToken = do
    signingKeysRef <- liftIO $ newIORef []
    return $ SignatureV4Credentials accessKeyId secretAccessKey signingKeysRef securityToken
type UriPath = [T.Text]
type UriQuery = HTTP.QueryText
newtype CanonicalUri = CanonicalUri B8.ByteString
    deriving (Show, Read, Eq, Ord, Typeable)
canonicalUri
    :: UriPath
    -> UriQuery
    -> CanonicalUri
canonicalUri path query = CanonicalUri . BB.toByteString
    $ HTTP.encodePathSegments normalizedPath
    <> BB.copyByteString "\n"
    <> HTTP.renderQueryText False normalizedQuery
  where
    normalizedPath = case normalizeUriPath path of
        [] -> [""]
        a -> a
    normalizedQuery = L.sort
        . map (second $ maybe (Just "") Just)
        $ normalizeUriQuery query
normalizeUriPath :: UriPath -> UriPath
normalizeUriPath =
    
    HTTP.decodePathSegments . BB.toByteString . HTTP.encodePathSegments
    
    
    . reverse . L.foldl' f []
  where
    f [] ".." = [".."]
    f (_:t) ".." = t
    f l "." = l
    f ("":t) a = a:t
    f l a = a:l
normalizeUriQuery :: UriQuery -> UriQuery
normalizeUriQuery =
    
    HTTP.parseQueryText . BB.toByteString . HTTP.renderQueryText False
newtype CanonicalHeaders = CanonicalHeaders B8.ByteString
    deriving (Show, Read, Eq, Ord, Typeable)
canonicalHeaders :: HTTP.RequestHeaders -> CanonicalHeaders
canonicalHeaders = CanonicalHeaders
    . foldHeaders
    . L.sort 
    
    
    
#ifdef SIGN_V4_NORMALIZE_DATE
    . map canonicalDate
#endif
  where
#ifdef SIGN_V4_NORMALIZE_DATE
    canonicalDate :: HTTP.Header -> HTTP.Header
    canonicalDate ("date", d) = ("date", formatDate)
      where
        formatDate = case parseHttpDate (B8.unpack d) of
            Nothing -> d
            Just utc -> B8.pack $ fTime canonicalDateHeaderFormat utc
    canonicalDate a = a
#endif
    
    
    
    foldHeaders :: HTTP.RequestHeaders -> B8.ByteString
    foldHeaders [] = ""
    foldHeaders ((h0,v0):t) = BB.toByteString $ snd run <> bChar '\n'
      where
        run = L.foldl' f (h0, bBS (CI.foldedCase h0) <> bChar ':' <> trimWs v0) t
        f (ch, a) (h,v) = if ch == h
            then (h, a <> bChar ',' <> trimWs v)
            else (h, a <> bChar '\n' <> bBS (CI.foldedCase h) <> bChar ':' <> trimWs v)
    trimWs = (\(_,_,c) -> c)
        
        . B8.foldl' f (False, ' ', bBS "")
        
        . fst . B8.spanEnd isSpace
      where
        
        f (s, '\\', b) x = (s, x, b <> bChar x)
        
        f (s, _, b) '"' = (not s, '"', b <> bChar '"')
        
        f (False, ' ', b) x
            | isSpace x = (False, ' ', b)
        f (False, _, b) x
            | isSpace x = (False, ' ', b <> bChar ' ')
        
        f (s, _, b) x = (s, x, b <> bChar x)
    bChar = BB8.fromChar
    bBS = BB.copyByteString
#ifdef SIGN_V4_NORMALIZE_DATE
canonicalDateHeaderFormat :: String
canonicalDateHeaderFormat = "%a, %d %b %Y %H:%M:%S GMT"
parseHttpDate :: String -> Maybe UTCTime
parseHttpDate s =
        p "%a, %d %b %Y %H:%M:%S GMT" s 
    <|> p "%A, %d-%b-%y %H:%M:%S GMT" s 
    <|> p "%a %b %_d %H:%M:%S %Y" s     
    <|> p "%Y-%m-%dT%H:%M:%S%QZ" s      
    <|> p "%Y-%m-%dT%H:%M:%S%Q%Z" s     
  where
    p = parseTime defaultTimeLocale
#endif
dateNormalizationEnabled :: Bool
#ifdef SIGN_V4_NORMALIZE_DATE
dateNormalizationEnabled = True
#else
dateNormalizationEnabled = False
#endif
newtype SignedHeaders = SignedHeaders B8.ByteString
    deriving (Show, Read, Eq, Ord, Typeable)
signedHeaders :: HTTP.RequestHeaders -> SignedHeaders
signedHeaders = SignedHeaders
    . B8.intercalate ";"
    . L.nub
    . L.sort
    . map (CI.foldedCase . fst)
newtype CanonicalRequest = CanonicalRequest B8.ByteString
    deriving (Show, Read, Eq, Ord, Typeable)
canonicalRequest
    :: HTTP.Method 
    -> UriPath 
    -> UriQuery 
    -> HTTP.RequestHeaders 
    -> B.ByteString 
    -> CanonicalRequest
canonicalRequest method path query headers payload =
    CanonicalRequest $ B8.intercalate "\n"
        [ method
        , cUri
        , cHeaders
        , sHeaders
        , signingHash16 payload
        ]
  where
    CanonicalUri cUri = canonicalUri path query
    CanonicalHeaders cHeaders = canonicalHeaders headers
    SignedHeaders sHeaders = signedHeaders headers
newtype HashedCanonicalRequest = HashedCanonicalRequest B8.ByteString
    deriving (Show, Read, Eq, Ord, Typeable)
hashedCanonicalRequest :: CanonicalRequest -> HashedCanonicalRequest
hashedCanonicalRequest (CanonicalRequest r) = HashedCanonicalRequest
    $ signingHash16 r
data CredentialScope = CredentialScope
    { credentialScopeDate :: !UTCTime
    , credentialScopeRegion :: !Region
    , credentialScopeService :: !ServiceNamespace
    }
    deriving (Show, Read, Typeable)
instance Eq CredentialScope where
    CredentialScope a0 b0 c0 == CredentialScope a1 b1 c1 =
        utctDay a0 == utctDay a1
        && b0 == b1
        && c0 == c1
credentialScopeToText :: (IsString a, Monoid a) => CredentialScope -> a
credentialScopeToText s =
    credentialScopeDateText s
    <> "/" <> toText (credentialScopeRegion s)
    <> "/" <> toText (credentialScopeService s)
    <> "/" <> terminationString
parseCredentialScope :: (Monad m, P.CharParsing m) => m CredentialScope
parseCredentialScope = CredentialScope
    <$> time
    <*> (P.char '/' *> parseRegion)
    <*> (P.char '/' *> parseServiceNamespace)
    <* (P.char '/' *> P.text terminationString)
  where
    time = do
        str <- P.count 8 P.digit
        case parseTime defaultTimeLocale credentialScopeDateFormat str of
            Nothing -> fail $ "failed to parse credential scope date: " <> str
            Just t -> return t
terminationString :: IsString a => a
terminationString = "aws4_request"
credentialScopeDateFormat :: IsString a => a
credentialScopeDateFormat = "%Y%m%d"
credentialScopeDateText :: IsString a => CredentialScope -> a
credentialScopeDateText s = fTime credentialScopeDateFormat (credentialScopeDate s)
instance AwsType CredentialScope where
    toText = credentialScopeToText
    parse = parseCredentialScope
instance Q.Arbitrary CredentialScope where
    arbitrary = CredentialScope
        <$> Q.arbitrary
        <*> Q.arbitrary
        <*> Q.arbitrary
newtype StringToSign = StringToSign B8.ByteString
    deriving (Show, Read, Eq, Ord, Typeable)
stringToSign
    :: UTCTime 
    -> CredentialScope 
    -> CanonicalRequest 
    -> StringToSign
stringToSign date credentialScope request = StringToSign $ B8.intercalate "\n"
    [ signingAlgorithm
    , fTime signingStringDateFormat date
    , T.encodeUtf8 $ credentialScopeToText credentialScope
    , hashedRequest
    ]
  where
    HashedCanonicalRequest hashedRequest = hashedCanonicalRequest request
signingStringDateFormat :: IsString a => a
signingStringDateFormat = "%Y%m%dT%H%M%SZ"
newtype SigningKey = SigningKey B.ByteString
    deriving (Show, Read, Eq, Ord, Typeable)
signingKey :: SignatureV4Credentials -> CredentialScope -> SigningKey
signingKey credentials s = SigningKey kSigning
  where
    kSecret = sigV4SecretAccessKey credentials
    kDate = signingHmac (signingKeyPrefix <> kSecret) dateStr
    kRegion = signingHmac kDate regionStr
    kService = signingHmac kRegion serviceStr
    kSigning = signingHmac kService terminationString
    dateStr = T.encodeUtf8 $ credentialScopeDateText s
    regionStr = T.encodeUtf8 . toText $ credentialScopeRegion s
    serviceStr = T.encodeUtf8 . toText $ credentialScopeService s
signingKeyPrefix :: IsString a => a
signingKeyPrefix = "AWS4"
newtype Signature = Signature B8.ByteString
    deriving (Show, Read, Eq, Ord, Typeable)
requestSignature
    :: SigningKey
    -> StringToSign
    -> Signature
requestSignature (SigningKey key) (StringToSign str) =
    Signature . B16.encode $ signingHmac key str
authorizationCredential
    :: (IsString a, Monoid a)
    => SignatureV4Credentials
    -> CredentialScope
    -> a
authorizationCredential creds credScope =
    (fromString . B8.unpack . sigV4AccessKeyId) creds <> "/" <> toText credScope
data AuthorizationInfo = AuthorizationInfo
    { authzInfoAlgorithm :: !B8.ByteString
    , authzInfoCredential :: !B8.ByteString
    , authzInfoSignedHeaders :: !B8.ByteString
    , authzInfoDate :: !UTCTime
    , authzInfoSignature :: !B8.ByteString
    }
authorizationInfo
    :: SignatureV4Credentials
    -> CredentialScope
    -> SignedHeaders
    -> UTCTime
    -> Signature
    -> AuthorizationInfo
authorizationInfo creds credScope (SignedHeaders hdrs) date (Signature sig) = AuthorizationInfo
    { authzInfoAlgorithm = signingAlgorithm
    , authzInfoCredential = authorizationCredential creds credScope
    , authzInfoSignedHeaders = hdrs
    , authzInfoDate = date
    , authzInfoSignature = sig
    }
authorizationInfoQuery :: AuthorizationInfo -> UriQuery
authorizationInfoQuery authz =
    [ ("X-Amz-Signature", Just . T.decodeUtf8 $ authzInfoSignature authz)
    ]
authorizationInfoHeader
    :: AuthorizationInfo
    -> HTTP.RequestHeaders
authorizationInfoHeader authz = [ ("Authorization", authzInfo) ]
  where
    authzInfo = authzInfoAlgorithm authz
        <> " Credential=" <> authzInfoCredential authz
        <> ", SignedHeaders=" <> authzInfoSignedHeaders authz
        <> ", Signature=" <> authzInfoSignature authz
signGetRequest
    :: SignatureV4Credentials 
    -> Region 
    -> ServiceNamespace 
    -> UTCTime 
    -> HTTP.Method 
    -> UriPath 
    -> UriQuery 
    -> HTTP.RequestHeaders 
    -> B.ByteString 
    -> Either String UriQuery
signGetRequest credentials region service date = signGetRequest_ key credentials region service date
  where
    key = signingKey credentials CredentialScope
        { credentialScopeDate = date
        , credentialScopeRegion = region
        , credentialScopeService = service
        }
signGetRequest_
    :: SigningKey
    -> SignatureV4Credentials 
    -> Region 
    -> ServiceNamespace 
    -> UTCTime 
    -> HTTP.Method 
    -> UriPath 
    -> UriQuery 
    -> HTTP.RequestHeaders 
    -> B.ByteString 
    -> Either String UriQuery
signGetRequest_ key credentials region service date method path query headers payload = do
    case lookup "host" headers of
        Nothing -> Left "Failed to sign request with Signature V4: host header is missing"
        Just _ -> return ()
    case lookup "Action" query of
        Nothing -> Left "Failed to sign request with Signature V4: Action parameter is missing"
        Just _ -> return ()
    return $ queryToSign <> authorizationInfoQuery authz
  where
    queryToSign = query <>
        [ ("X-Amz-Algorithm", Just signingAlgorithm)
        , ("X-Amz-Credential", Just $ authorizationCredential credentials credentialScope)
        , ("X-Amz-Date", Just $ fTime signingStringDateFormat date)
        , ("X-Amz-SignedHeaders", let SignedHeaders h = shdrs in Just (T.decodeUtf8 h))
        , ("X-Amz-Security-Token", T.decodeUtf8 <$> sigV4SecurityToken credentials)
        ]
    authz = authorizationInfo credentials credentialScope shdrs date sig
    sig = requestSignature key str
    shdrs = signedHeaders headers
    request = canonicalRequest method path queryToSign headers payload
    str = stringToSign date credentialScope request
    credentialScope = CredentialScope
        { credentialScopeDate = date
        , credentialScopeRegion = region
        , credentialScopeService = service
        }
signPostRequest
    :: SignatureV4Credentials 
    -> Region 
    -> ServiceNamespace 
    -> UTCTime 
    -> HTTP.Method 
    -> UriPath 
    -> UriQuery 
    -> HTTP.RequestHeaders 
    -> B.ByteString 
    -> Either String HTTP.RequestHeaders
signPostRequest credentials region service date = signPostRequest_ key credentials region service date
  where
    key = signingKey credentials CredentialScope
        { credentialScopeDate = date
        , credentialScopeRegion = region
        , credentialScopeService = service
        }
signPostRequest_
    :: SigningKey
    -> SignatureV4Credentials 
    -> Region 
    -> ServiceNamespace 
    -> UTCTime 
    -> HTTP.Method 
    -> UriPath 
    -> UriQuery 
    -> HTTP.RequestHeaders 
    -> B.ByteString 
    -> Either String HTTP.RequestHeaders 
signPostRequest_ key credentials region service date method path query headers payload = do
    case lookup "host" headers of
        Nothing -> Left "Failed to sign request with Signature V4: host header is missing"
        Just _ -> return ()
    
    
    
    
    return $ headersWithDate <> authorizationInfoHeader authz
  where
    authz = authorizationInfo credentials credentialScope shdrs date sig
    sig = requestSignature key str
    shdrs = signedHeaders headersWithDate
    request = canonicalRequest method path query headersWithDate payload
    str = stringToSign date credentialScope request
    credentialScope = CredentialScope
        { credentialScopeDate = date
        , credentialScopeRegion = region
        , credentialScopeService = service
        }
    headersWithDate = ("x-amz-date", fTime signingStringDateFormat date)
        : filter (\x -> fst x /= "date" && fst x /= "x-amz-date") headers
        ++ maybe [] ((:[]) . ("X-Amz-Security-Token",)) (sigV4SecurityToken credentials)
signGetRequestIO
    :: SignatureV4Credentials 
    -> Region
    -> ServiceNamespace
    -> UTCTime
    -> HTTP.Method 
    -> UriPath 
    -> UriQuery 
    -> HTTP.RequestHeaders 
    -> B.ByteString 
    -> IO (Either String UriQuery)
signGetRequestIO credentials region service date method path query headers payload = do
    key <- getSigningKey credentials region service
    return $ signGetRequest_ key credentials region service date method path query headers payload
signPostRequestIO
    :: SignatureV4Credentials 
    -> Region
    -> ServiceNamespace
    -> UTCTime
    -> HTTP.Method 
    -> UriPath 
    -> UriQuery 
    -> HTTP.RequestHeaders 
    -> B.ByteString 
    -> IO (Either String HTTP.RequestHeaders)
signPostRequestIO credentials region service date method path query headers payload = do
    key <- getSigningKey credentials region service
    return $ signPostRequest_ key credentials region service date method path query headers payload
getSigningKey
    :: SignatureV4Credentials 
    -> Region
    -> ServiceNamespace
    -> IO SigningKey
getSigningKey credentials region service = do
    date <- getCurrentTime
    let dateStr = fTime credentialScopeDateFormat date
    k <- atomicModifyIORef' (sigV4SigningKeys credentials) $ \cache ->
        case L.lookup idx cache of
            Just (d,k) -> if d /= dateStr
                then newKey date dateStr cache
                else (cache,k)
            Nothing -> newKey date dateStr cache
    return $ SigningKey k
  where
    idx = ((T.encodeUtf8 . toText) region, (T.encodeUtf8 . toText) service)
    newKey date dateStr c =
        let SigningKey key = signingKey credentials CredentialScope
                { credentialScopeDate = date
                , credentialScopeRegion = region
                , credentialScopeService =  service
                }
            c_ = (idx, (dateStr,key)):c
        in key `seq` c_ `seq` (c_, key)
fTime :: IsString a => String -> UTCTime -> a
fTime format time = fromString $ formatTime defaultTimeLocale format time