-- Copyright © 2009 Jamey Sharp and Josh Triplett
-- See the file COPYING for license details.
--
-- HTTP Digest implementation based on RFC2617, with modifications to handle
-- the ways that various browsers violate that specification.
module Network.HTTP.Digest (
    authenticateDigest, DigestResult(..), makeHA1, setDigestRequired
) where

import Control.Monad
import Data.Bits
import Data.Char
import qualified Data.Digest.MD5 as MD5
import Data.Maybe
import Data.Time.Clock.POSIX
import Network.CGI
import Network.URI
import System.IO.Unsafe
import System.Random
import Text.ParserCombinators.Parsec hiding (token)

makeHA1 :: String -> String -> String -> String
makeHA1 username realm password = md5Hash $ username ++ ':' : realm ++ ':' : password

digestSecret :: String
digestSecret = show (unsafePerformIO randomIO :: Int)

hashTime :: Integer -> String
hashTime time = md5Hash (show time ++ ':' : digestSecret)

getNonceFor :: POSIXTime -> String
getNonceFor time = let secs = truncate time in show secs ++ ':' : hashTime secs

getTimeOf :: String -> Maybe POSIXTime
getTimeOf nonce = case reads nonce of
    [(time, ':' : hash)] | hashTime time == hash -> Just (fromInteger time)
    _ -> Nothing

makeNonce :: IO String
makeNonce = liftM getNonceFor getPOSIXTime

requiredAttribute :: Monad m => String -> [(String, String)] -> (String -> Bool) -> m String
requiredAttribute name attrs check = case lookup name attrs of
    Nothing -> fail $ "Digest authorization missing \"" ++ name ++ "\" attribute"
    Just value -> if check value then return value else
        fail $ "Digest authorization has invalid \"" ++ name ++ "\" attribute, \"" ++ value ++ "\""

checkAttributes :: String -> URI -> CharParser st ([(String, String)], String, String)
checkAttributes realm uri = do
    attrs <- parseDigestAuthorization
    requiredAttribute "realm" attrs (== realm)
    requiredAttribute "uri" attrs (urimatch . parseRelativeReference)
    username <- requiredAttribute "username" attrs (const True)
    nonce <- requiredAttribute "nonce" attrs (const True)
    return (attrs, username, nonce)
    where
    dom u = (uriScheme u, uriAuthority u)
    urimatch Nothing = False
    urimatch (Just authuri) = null (uriFragment authuri)
        && uriPath uri == uriPath authuri
        && (dom authuri == ("", Nothing) || dom uri == dom authuri)
        && (uriQuery uri == uriQuery authuri || null (uriQuery authuri)) -- work around IE bug

data DigestResult u = DigestSuccess u | DigestBadRequest String | DigestStale | DigestIncorrect | DigestMissing

-- Note that this implementation does not check for reuse of nonce-count
-- values, as that requires persistent state.
authenticateDigest :: (MonadCGI m, MonadIO m) => String -> (String -> m (Maybe (u, String))) -> m (DigestResult u)
authenticateDigest realm userfunc = do
    uri <- requestURI
    method <- requestMethod
    let header = "Authorization"
    maybeAuth <- requestHeader header
    case maybeAuth of
        Just auth -> case parse (checkAttributes realm uri) header auth of
            Left msg -> return (DigestBadRequest $ show msg)
            Right (attrs, username, nonce) -> do
                maybeUser <- userfunc username
                case maybeUser of
                    Just (user, ha1) | checkDigestAuthorization method attrs ha1 -> do
                        now <- liftIO getPOSIXTime
                        case getTimeOf nonce of
                            Just time | time >= now - 300 ->
                                setAuthenticationInfo >> return (DigestSuccess user)
                            _ -> return DigestStale
                    _ -> return DigestIncorrect
        _ -> return DigestMissing

-- XXX: setDigestRequired does not set the domain parameter, which according
-- to RFC2617 means "the protection space consists of all URIs on the
-- responding server."
setDigestRequired :: (MonadCGI m, MonadIO m) => String -> Bool -> m ()
setDigestRequired realm stale = do
    setStatus 401 "Unauthorized"
    nonce <- liftIO makeNonce
    setHeader "WWW-Authenticate" $
        "Digest realm=\"" ++ realm ++ "\", nonce=\"" ++ nonce ++ "\", qop=\"auth\"" ++
        if stale then ", stale=TRUE" else ""

setAuthenticationInfo :: (MonadCGI m, MonadIO m) => m ()
setAuthenticationInfo = do
    nextnonce <- liftIO makeNonce
    setHeader "Authentication-Info" $ "nextnonce=\"" ++ nextnonce ++ "\""

parseDigestAuthorization :: CharParser st [(String, String)]
parseDigestAuthorization = do
    let stringLower s = (mapM_ (\x-> satisfy (\y-> x == toLower y)) s >> return s) <?> show s
    let lws = (optional (string "\r\n") >> skipMany1 (satisfy (\x-> x == ' ' || x == '\t')) >> return ' ') <?> "LWS"
    let isHTTPTokenChar c = c > ' ' && c < '\DEL' && c `notElem` "()<>@,;:\\\"/[]?={}"
    let token = many1 (satisfy isHTTPTokenChar) <?> "token"
    let tokenLower = fmap (map toLower) token
    let quoted = between (char '"') (char '"')
    let quotedPair = (char '\\' >> satisfy (<= '\DEL')) <?> "quoted-pair"
    let qdtext = (lws <|> satisfy (\x-> x > ' ' && x <= '\DEL' && x /= '"')) <?> "qdtext"
    let quotedString = (quoted $ many $ quotedPair <|> qdtext) <?> "quoted-string"
    let lhex = satisfy (`elem` "0123456789abcdef") <?> "LHEX"
    -- RFC2617 requires exactly 8 characters for nc, but some browsers don't
    -- provide the leading 0s.
    let ncValue = many1 lhex <?> "nc-value"
    let list e = fmap msum $ (skipMany lws >> option mzero (fmap return e)) `sepBy` (skipMany lws >> char ',')
    let nameValue kinds = do
            name <- tokenLower
            skipMany lws >> char '=' >> skipMany lws
            value <- fromMaybe (quotedString <|> token) $ lookup name kinds
            return (name, value)
    -- RFC2617 does not allow quoting algorithm, qop, or nc; allow it anyway
    -- because some clients quote everything.
    between (stringLower "digest" >> lws) eof $ list $ nameValue [
            ("username", quotedString),
            ("realm", quotedString),
            ("nonce", quotedString),
            ("uri", quotedString <|> string "*"),
            ("response", quoted (count 32 lhex) <?> "request-digest"),
            -- ("algorithm", token),
            ("cnonce", quotedString),
            ("opaque", quotedString),
            -- ("qop", token),
            ("nc", ncValue <|> quoted ncValue)
        ]

md5Hash :: String -> String
md5Hash = foldr octetToHex "" . MD5.hash . map (toEnum . fromEnum) where
    octetToHex n rest = wordToDigit (n `shiftR` 4) : wordToDigit (n .&. 0xF) : rest
    wordToDigit = intToDigit . fromEnum

checkDigestAuthorization :: String -> [(String, String)] -> String -> Bool
checkDigestAuthorization method attrs ha1 = isJust $ do
    nonce <- lookup "nonce" attrs
    digestURI <- lookup "uri" attrs
    response <- lookup "response" attrs
    let algorithm = lookup "algorithm" attrs
    guard $ algorithm == Nothing || algorithm == Just "MD5" -- "MD5-sess" not implemented yet
    let kd secret contents = md5Hash (secret ++ ':' : contents)
    let ha2 = md5Hash (method ++ ':' : digestURI)
    expected <- case lookup "qop" attrs of
        Just qop -> do
            guard $ qop == "auth" -- "auth-int" not implemented yet
            cnonce <- lookup "cnonce" attrs
            nc <- lookup "nc" attrs
            return $ kd ha1 $ nonce ++ ':' : nc ++ ':' : cnonce ++ ':' : qop ++ ':' : ha2
        Nothing -> return $ kd ha1 $ nonce ++ ':' : ha2
    guard $ response == expected