module Network.AWS.Authentication (
runAction, isAmzHeader, preSignedURI,
S3Action(..),
mimeEncodeQP, mimeDecode
) where
import Network.AWS.AWSResult
import Network.AWS.AWSConnection
import Network.AWS.ArrowUtils
import Network.HTTP as HTTP hiding (simpleHTTP_)
import Network.HTTP.HandleStream (simpleHTTP_)
import Network.Stream (Result)
import Network.URI as URI
import qualified Data.ByteString.Lazy.Char8 as L
import Data.ByteString.Char8 (pack, unpack)
import Data.HMAC
import Codec.Binary.Base64 (encode, decode)
import Codec.Utils (Octet)
import Data.Char (intToDigit, digitToInt, ord, chr, toLower)
import Data.Bits ((.&.))
import qualified Codec.Binary.UTF8.String as US
import Data.List (sortBy, groupBy, intersperse)
import Data.Maybe
import System.Time
import System.Locale
import Text.Regex
import Text.XML.HXT.Arrow
data S3Action =
S3Action {
s3conn :: AWSConnection,
s3bucket :: String,
s3object :: String,
s3query :: String,
s3metadata :: [(String, String)],
s3body :: L.ByteString,
s3operation :: RequestMethod
} deriving (Show)
requestFromAction :: S3Action
-> HTTP.HTTPRequest L.ByteString
requestFromAction a =
Request { rqURI = URI { uriScheme = "",
uriAuthority = Nothing,
uriPath = qpath,
uriQuery = s3query a,
uriFragment = "" },
rqMethod = s3operation a,
rqHeaders = Header HdrHost (s3Hostname a) :
headersFromAction a,
rqBody = (s3body a)
}
where qpath = '/' : s3object a
headersFromAction :: S3Action
-> [Header]
headersFromAction = map (\(k,v) -> case k of
"Content-Type" -> Header HdrContentType v
"Content-Length" -> Header HdrContentLength v
otherwise -> Header (HdrCustom k) (mimeEncodeQP v))
. s3metadata
addContentLengthHeader :: HTTP.HTTPRequest L.ByteString -> HTTP.HTTPRequest L.ByteString
addContentLengthHeader req = insertHeader HdrContentLength (show (L.length (rqBody req))) req
addAuthenticationHeader :: S3Action
-> HTTP.HTTPRequest L.ByteString
-> HTTP.HTTPRequest L.ByteString
addAuthenticationHeader act req = insertHeader HdrAuthorization auth_string req
where auth_string = "AWS " ++ awsAccessKey conn ++ ":" ++ signature
signature = makeSignature conn (stringToSign act req)
conn = s3conn act
makeSignature :: AWSConnection
-> String
-> String
makeSignature c s =
encode (hmac_sha1 keyOctets msgOctets)
where keyOctets = string2words (awsSecretKey c)
msgOctets = string2words s
stringToSign :: S3Action -> HTTP.HTTPRequest L.ByteString -> String
stringToSign a r =
canonicalizeHeaders r ++
canonicalizeAmzHeaders r ++
canonicalizeResource a
canonicalizeHeaders :: HTTP.HTTPRequest L.ByteString -> String
canonicalizeHeaders r =
http_verb ++ "\n" ++
hdr_content_md5 ++ "\n" ++
hdr_content_type ++ "\n" ++
dateOrExpiration ++ "\n"
where http_verb = show (rqMethod r)
hdr_content_md5 = get_header HdrContentMD5
hdr_date = get_header HdrDate
hdr_content_type = get_header HdrContentType
get_header h = fromMaybe "" (findHeader h r)
dateOrExpiration = fromMaybe hdr_date (findHeader HdrExpires r)
canonicalizeAmzHeaders :: HTTP.HTTPRequest L.ByteString -> String
canonicalizeAmzHeaders r =
let amzHeaders = filter isAmzHeader (rqHeaders r)
amzHeaderKV = map headerToLCKeyValue amzHeaders
sortedGroupedHeaders = groupHeaders (sortHeaders amzHeaderKV)
uniqueHeaders = combineHeaders sortedGroupedHeaders
in concatMap (\a -> a ++ "\n") (map showHeader uniqueHeaders)
showHeader :: (String, String) -> String
showHeader (k,v) = k ++ ":" ++ removeLeadingTrailingWhitespace(fold_whitespace v)
fold_whitespace :: String -> String
fold_whitespace s = subRegex (mkRegex "\n\r( |\t)+") s " "
removeLeadingTrailingWhitespace :: String -> String
removeLeadingTrailingWhitespace s = subRegex (mkRegex "^\\s+") (subRegex (mkRegex "\\s+$") s "") ""
combineHeaders :: [[(String, String)]] -> [(String, String)]
combineHeaders = map mergeSameHeaders
mergeSameHeaders :: [(String, String)] -> (String, String)
mergeSameHeaders h@(x:_) = let values = map snd h
in ((fst x), (concat $ intersperse "," values))
groupHeaders :: [(String, String)] -> [[(String, String)]]
groupHeaders = groupBy (\a b -> fst a == fst b)
sortHeaders :: [(String, String)] -> [(String, String)]
sortHeaders = sortBy (\a b -> fst a `compare` fst b)
headerToLCKeyValue :: Header -> (String, String)
headerToLCKeyValue (Header k v) = (map toLower (show k), v)
isAmzHeader :: Header -> Bool
isAmzHeader h =
case h of
Header (HdrCustom k) _ -> isPrefix amzHeader k
otherwise -> False
isPrefix :: Eq a => [a] -> [a] -> Bool
isPrefix a b = a == take (length a) b
amzHeader :: String
amzHeader = "x-amz-"
canonicalizeResource :: S3Action -> String
canonicalizeResource a = bucket ++ uri
where uri = '/' : s3object a
bucket = case (s3bucket a) of
b@(_:_) -> '/' : map toLower b
otherwise -> ""
addDateToReq :: HTTP.HTTPRequest L.ByteString
-> String
-> HTTP.HTTPRequest L.ByteString
addDateToReq r d = r {HTTP.rqHeaders =
HTTP.Header HTTP.HdrDate d : HTTP.rqHeaders r}
addExpirationToReq :: HTTP.HTTPRequest L.ByteString -> String -> HTTP.HTTPRequest L.ByteString
addExpirationToReq r = addHeaderToReq r . HTTP.Header HTTP.HdrExpires
addHeaderToReq :: HTTP.HTTPRequest L.ByteString -> Header -> HTTP.HTTPRequest L.ByteString
addHeaderToReq r h = r {HTTP.rqHeaders = h : HTTP.rqHeaders r}
s3Hostname :: S3Action -> String
s3Hostname a =
let s3host = awsHost (s3conn a) in
case (s3bucket a) of
b@(_:_) -> b ++ "." ++ s3host
otherwise -> s3host
httpCurrentDate :: IO String
httpCurrentDate =
do c <- getClockTime
let utc_time = (toUTCTime c) {ctTZName = "GMT"}
return $ formatCalendarTime defaultTimeLocale rfc822DateFormat utc_time
string2words :: String -> [Octet]
string2words = US.encode
runAction :: S3Action -> IO (AWSResult (HTTPResponse L.ByteString))
runAction a = runAction' a (s3Hostname a)
runAction' :: S3Action -> String -> IO (AWSResult (HTTPResponse L.ByteString))
runAction' a hostname = do
c <- (openTCPConnection hostname (awsPort (s3conn a)))
cd <- httpCurrentDate
let aReq = addAuthenticationHeader a $
addContentLengthHeader $
addDateToReq (requestFromAction a) cd
result <- simpleHTTP_ c aReq
close c
createAWSResult a result
preSignedURI :: S3Action
-> Integer
-> URI
preSignedURI a e =
let c = (s3conn a)
srv = (awsHost c)
pt = (show (awsPort c))
accessKeyQuery = "AWSAccessKeyId=" ++ awsAccessKey c
beginQuery = case (s3query a) of
"" -> "?"
x -> x ++ "&"
expireQuery = "Expires=" ++ show e
toSign = "GET\n\n\n" ++ show e ++ "\n/" ++ s3bucket a ++ "/" ++ s3object a
sigQuery = "Signature=" ++ urlEncode (makeSignature c toSign)
q = beginQuery ++ accessKeyQuery ++ "&" ++
expireQuery ++ "&" ++ sigQuery
in URI { uriScheme = "http:",
uriAuthority = Just (URIAuth "" srv (':' : pt)),
uriPath = "/" ++ s3bucket a ++ "/" ++ s3object a,
uriQuery = q,
uriFragment = ""
}
createAWSResult :: S3Action -> Result (HTTPResponse L.ByteString) -> IO (AWSResult (HTTPResponse L.ByteString))
createAWSResult a b = either handleError handleSuccess b
where handleError = return . Left . NetworkError
handleSuccess s = case (rspCode s) of
(2,_,_) -> return (Right s)
(3,0,7) -> case (findHeader HdrLocation s) of
Just l -> runAction' a (getHostname l)
Nothing -> return (Left $ AWSError "Temporary Redirect" "Redirect without location header")
(4,0,4) -> return (Left $ AWSError "NotFound" "404 Not Found")
otherwise -> do e <- parseRestErrorXML (L.unpack (rspBody s))
return (Left e)
getHostname :: String -> String
getHostname h = case parseURI h of
Just u -> case (uriAuthority u) of
Just auth -> (uriRegName auth)
Nothing -> ""
Nothing -> ""
parseRestErrorXML :: String -> IO ReqError
parseRestErrorXML x =
do e <- runX (readString [(a_validate,v_0)] x
>>> processRestError)
case e of
[] -> return (AWSError "NoErrorInMsg"
("HTTP Error condition, but message body"
++ "did not contain error code."))
x:xs -> return x
processRestError = deep (isElem >>> hasName "Error") >>>
split >>> first (text <<< atTag "Code") >>>
second (text <<< atTag "Message") >>>
unsplit (\x y -> AWSError x y)
mimeEncodeQP, mimeDecode :: String -> String
mimeDecode a
| isPrefix utf8qp a =
mimeDecodeQP $ encoded_payload utf8qp a
| isPrefix utf8b64 a =
mimeDecodeB64 $ encoded_payload utf8b64 a
| otherwise =
a
where
utf8qp = "=?UTF-8?Q?"
utf8b64 = "=?UTF-8?B?"
encoded_payload prefix = reverse . drop 2 . reverse . drop (length prefix)
mimeDecodeQP :: String -> String
mimeDecodeQP =
US.decodeString . mimeDecodeQP'
mimeDecodeQP' :: String -> String
mimeDecodeQP' ('=':a:b:rest) =
chr (16 * digitToInt a + digitToInt b)
: mimeDecodeQP' rest
mimeDecodeQP' (h:t) =h : mimeDecodeQP' t
mimeDecodeQP' [] = []
mimeDecodeB64 :: String -> String
mimeDecodeB64 s =
case decode s of
Nothing -> ""
Just a -> US.decode a
mimeEncodeQP s =
if any reservedChar s
then "=?UTF-8?Q?" ++ (mimeEncodeQP' $ US.encodeString s) ++ "?="
else s
mimeEncodeQP' :: String -> String
mimeEncodeQP' [] = []
mimeEncodeQP' (h:t) =
let str = if reservedChar h then escape h else [h]
in str ++ mimeEncodeQP' t
where
escape x =
let y = ord x in
[ '=', intToDigit ((y `div` 16) .&. 0xf), intToDigit (y .&. 0xf) ]
reservedChar :: Char -> Bool
reservedChar x
| xi >= 0x20 && xi <= 0x7e = False
| otherwise = True
where xi = ord x