{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.Xmpp.Sasl.Mechanisms.Scram
where
import Control.Applicative ((<$>))
import Control.Monad
import Control.Monad.Except
import Control.Monad.State.Strict
import qualified Crypto.Classes as Crypto
import qualified Crypto.HMAC as Crypto
import qualified Crypto.Hash.CryptoAPI as Crypto
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Char8 as BS8 (unpack)
import Data.List (foldl1', genericTake)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import Network.Xmpp.Sasl.Common
import Network.Xmpp.Sasl.Types
import Network.Xmpp.Types
hashToken :: (Crypto.Hash ctx hash) => hash
hashToken :: forall ctx hash. Hash ctx hash => hash
hashToken = forall a. HasCallStack => a
undefined
scram :: (Crypto.Hash ctx hash)
=> hash
-> Text.Text
-> Maybe Text.Text
-> Text.Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scram :: forall ctx hash.
Hash ctx hash =>
hash
-> Text
-> Maybe Text
-> Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scram hash
hToken Text
authcid Maybe Text
authzid Text
password = do
(Text
ac, Maybe Text
az, Text
pw) <- Text
-> Maybe Text
-> Text
-> ExceptT
AuthFailure (StateT StreamState IO) (Text, Maybe Text, Text)
prepCredentials Text
authcid Maybe Text
authzid Text
password
Text
-> Maybe Text
-> Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scramhelper Text
ac Maybe Text
az Text
pw
where
scramhelper :: Text
-> Maybe Text
-> Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scramhelper Text
authcid' Maybe Text
authzid' Text
pwd = do
ByteString
cnonce <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ByteString
makeNonce
()
_ <- Text
-> Maybe ByteString
-> ExceptT AuthFailure (StateT StreamState IO) ()
saslInit Text
"SCRAM-SHA-1" (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
cFirstMessage ByteString
cnonce)
ByteString
sFirstMessage <- forall a. Maybe a -> ExceptT AuthFailure (StateT StreamState IO) a
saslFromJust forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
pullChallenge
Pairs
prs <- ByteString -> ExceptT AuthFailure (StateT StreamState IO) Pairs
toPairs ByteString
sFirstMessage
(ByteString
nonce, ByteString
salt, Integer
ic) <- Pairs
-> ByteString
-> ExceptT
AuthFailure
(StateT StreamState IO)
(ByteString, ByteString, Integer)
fromPairs Pairs
prs ByteString
cnonce
let (ByteString
cfm, ByteString
v) = ByteString
-> ByteString
-> Integer
-> ByteString
-> ByteString
-> (ByteString, ByteString)
cFinalMessageAndVerifier ByteString
nonce ByteString
salt Integer
ic ByteString
sFirstMessage ByteString
cnonce
()
_ <- Maybe ByteString -> ExceptT AuthFailure (StateT StreamState IO) ()
respond forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just ByteString
cfm
Pairs
finalPairs <- ByteString -> ExceptT AuthFailure (StateT StreamState IO) Pairs
toPairs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. Maybe a -> ExceptT AuthFailure (StateT StreamState IO) a
saslFromJust forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
pullFinalMessage
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"v" Pairs
finalPairs forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just ByteString
v) forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError AuthFailure
AuthOtherFailure
forall (m :: * -> *) a. Monad m => a -> m a
return ()
where
encode :: Crypto.Hash ctx hash => hash -> hash -> BS.ByteString
encode :: forall ctx hash. Hash ctx hash => hash -> hash -> ByteString
encode hash
_hashtoken = forall a. Serialize a => a -> ByteString
Crypto.encode
hash :: BS.ByteString -> BS.ByteString
hash :: ByteString -> ByteString
hash ByteString
str = forall ctx hash. Hash ctx hash => hash -> hash -> ByteString
encode hash
hToken forall a b. (a -> b) -> a -> b
$ forall ctx d. (Hash ctx d, Hash ctx d) => ByteString -> d
Crypto.hash' ByteString
str
hmac :: BS.ByteString -> BS.ByteString -> BS.ByteString
hmac :: ByteString -> ByteString -> ByteString
hmac ByteString
key ByteString
str = forall ctx hash. Hash ctx hash => hash -> hash -> ByteString
encode hash
hToken forall a b. (a -> b) -> a -> b
$ forall c d. Hash c d => MacKey c d -> ByteString -> d
Crypto.hmac' (forall c d. ByteString -> MacKey c d
Crypto.MacKey ByteString
key) ByteString
str
authzid'' :: Maybe BS.ByteString
authzid'' :: Maybe ByteString
authzid'' = (\Text
z -> ByteString
"a=" ByteString -> ByteString -> ByteString
+++ Text -> ByteString
Text.encodeUtf8 Text
z) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Text
authzid'
gs2CbindFlag :: BS.ByteString
gs2CbindFlag :: ByteString
gs2CbindFlag = ByteString
"n"
gs2Header :: BS.ByteString
gs2Header :: ByteString
gs2Header = [ByteString] -> ByteString
merge forall a b. (a -> b) -> a -> b
$ [ ByteString
gs2CbindFlag
, forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
"" forall a. a -> a
id Maybe ByteString
authzid''
, ByteString
""
]
cFirstMessageBare :: BS.ByteString -> BS.ByteString
cFirstMessageBare :: ByteString -> ByteString
cFirstMessageBare ByteString
cnonce = [ByteString] -> ByteString
merge [ ByteString
"n=" ByteString -> ByteString -> ByteString
+++ Text -> ByteString
Text.encodeUtf8 Text
authcid'
, ByteString
"r=" ByteString -> ByteString -> ByteString
+++ ByteString
cnonce]
cFirstMessage :: BS.ByteString -> BS.ByteString
cFirstMessage :: ByteString -> ByteString
cFirstMessage ByteString
cnonce = ByteString
gs2Header ByteString -> ByteString -> ByteString
+++ ByteString -> ByteString
cFirstMessageBare ByteString
cnonce
fromPairs :: Pairs
-> BS.ByteString
-> ExceptT AuthFailure (StateT StreamState IO) (BS.ByteString, BS.ByteString, Integer)
fromPairs :: Pairs
-> ByteString
-> ExceptT
AuthFailure
(StateT StreamState IO)
(ByteString, ByteString, Integer)
fromPairs Pairs
prs ByteString
cnonce | Just ByteString
nonce <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"r" Pairs
prs
, ByteString
cnonce ByteString -> ByteString -> Bool
`BS.isPrefixOf` ByteString
nonce
, Just ByteString
salt' <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"s" Pairs
prs
, Right ByteString
salt <- ByteString -> Either String ByteString
B64.decode ByteString
salt'
, Just ByteString
ic <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"i" Pairs
prs
, [(Integer
i,String
"")] <- forall a. Read a => ReadS a
reads forall a b. (a -> b) -> a -> b
$ ByteString -> String
BS8.unpack ByteString
ic
= forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
nonce, ByteString
salt, Integer
i)
fromPairs Pairs
_ ByteString
_ = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ AuthFailure
AuthOtherFailure
cFinalMessageAndVerifier :: BS.ByteString
-> BS.ByteString
-> Integer
-> BS.ByteString
-> BS.ByteString
-> (BS.ByteString, BS.ByteString)
cFinalMessageAndVerifier :: ByteString
-> ByteString
-> Integer
-> ByteString
-> ByteString
-> (ByteString, ByteString)
cFinalMessageAndVerifier ByteString
nonce ByteString
salt Integer
ic ByteString
sfm ByteString
cnonce
= ([ByteString] -> ByteString
merge [ ByteString
cFinalMessageWOProof
, ByteString
"p=" ByteString -> ByteString -> ByteString
+++ ByteString -> ByteString
B64.encode ByteString
clientProof
]
, ByteString -> ByteString
B64.encode ByteString
serverSignature
)
where
cFinalMessageWOProof :: BS.ByteString
cFinalMessageWOProof :: ByteString
cFinalMessageWOProof = [ByteString] -> ByteString
merge [ ByteString
"c=" ByteString -> ByteString -> ByteString
+++ ByteString -> ByteString
B64.encode ByteString
gs2Header
, ByteString
"r=" ByteString -> ByteString -> ByteString
+++ ByteString
nonce]
saltedPassword :: BS.ByteString
saltedPassword :: ByteString
saltedPassword = ByteString -> ByteString -> Integer -> ByteString
hi (Text -> ByteString
Text.encodeUtf8 Text
pwd) ByteString
salt Integer
ic
clientKey :: BS.ByteString
clientKey :: ByteString
clientKey = ByteString -> ByteString -> ByteString
hmac ByteString
saltedPassword ByteString
"Client Key"
storedKey :: BS.ByteString
storedKey :: ByteString
storedKey = ByteString -> ByteString
hash ByteString
clientKey
authMessage :: BS.ByteString
authMessage :: ByteString
authMessage = [ByteString] -> ByteString
merge [ ByteString -> ByteString
cFirstMessageBare ByteString
cnonce
, ByteString
sfm
, ByteString
cFinalMessageWOProof
]
clientSignature :: BS.ByteString
clientSignature :: ByteString
clientSignature = ByteString -> ByteString -> ByteString
hmac ByteString
storedKey ByteString
authMessage
clientProof :: BS.ByteString
clientProof :: ByteString
clientProof = ByteString
clientKey ByteString -> ByteString -> ByteString
`xorBS` ByteString
clientSignature
serverKey :: BS.ByteString
serverKey :: ByteString
serverKey = ByteString -> ByteString -> ByteString
hmac ByteString
saltedPassword ByteString
"Server Key"
serverSignature :: BS.ByteString
serverSignature :: ByteString
serverSignature = ByteString -> ByteString -> ByteString
hmac ByteString
serverKey ByteString
authMessage
hi :: BS.ByteString -> BS.ByteString -> Integer -> BS.ByteString
hi :: ByteString -> ByteString -> Integer -> ByteString
hi ByteString
str ByteString
slt Integer
ic' = forall a. (a -> a -> a) -> [a] -> a
foldl1' ByteString -> ByteString -> ByteString
xorBS (forall i a. Integral i => i -> [a] -> [a]
genericTake Integer
ic' [ByteString]
us)
where
u1 :: ByteString
u1 = ByteString -> ByteString -> ByteString
hmac ByteString
str (ByteString
slt ByteString -> ByteString -> ByteString
+++ ([Word8] -> ByteString
BS.pack [Word8
0,Word8
0,Word8
0,Word8
1]))
us :: [ByteString]
us = forall a. (a -> a) -> a -> [a]
iterate (ByteString -> ByteString -> ByteString
hmac ByteString
str) ByteString
u1
scramSha1 :: Username
-> Maybe AuthZID
-> Password
-> SaslHandler
scramSha1 :: Text -> Maybe Text -> Text -> SaslHandler
scramSha1 Text
authcid Maybe Text
authzid Text
passwd =
( Text
"SCRAM-SHA-1"
, do
Either AuthFailure ()
r <- forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ forall ctx hash.
Hash ctx hash =>
hash
-> Text
-> Maybe Text
-> Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scram (forall ctx hash. Hash ctx hash => hash
hashToken :: Crypto.SHA1) Text
authcid Maybe Text
authzid Text
passwd
case Either AuthFailure ()
r of
Left (AuthStreamFailure XmppFailure
e) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left XmppFailure
e
Left AuthFailure
e -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just AuthFailure
e
Right () -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a. Maybe a
Nothing
)