{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleContexts #-}

module Network.Xmpp.Sasl.Common where

import           Control.Applicative ((<$>))
import           Control.Monad.Except
import qualified Data.Attoparsec.ByteString.Char8 as AP
import           Data.Bits
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as B64
import           Data.Maybe (maybeToList)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import           Data.Word (Word8)
import           Data.XML.Pickle
import           Data.XML.Types
import           Network.Xmpp.Marshal
import           Network.Xmpp.Sasl.StringPrep
import           Network.Xmpp.Sasl.Types
import           Network.Xmpp.Stream
import           Network.Xmpp.Types

import qualified System.Random as Random

import           Control.Monad.State.Strict

--makeNonce :: ExceptT AuthFailure (StateT StreamState IO) BS.ByteString
makeNonce :: IO BS.ByteString
makeNonce :: IO ByteString
makeNonce = do
    StdGen
g <- IO StdGen -> IO StdGen
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
Random.newStdGen
    ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B64.encode (ByteString -> ByteString)
-> ([Int] -> ByteString) -> [Int] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
BS.pack ([Word8] -> ByteString)
-> ([Int] -> [Word8]) -> [Int] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Word8) -> [Int] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Word8
toWord8 ([Int] -> [Word8]) -> ([Int] -> [Int]) -> [Int] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
15 ([Int] -> ByteString) -> [Int] -> ByteString
forall a b. (a -> b) -> a -> b
$ StdGen -> [Int]
forall a g. (Random a, RandomGen g) => g -> [a]
Random.randoms StdGen
g
  where
    toWord8 :: Int -> Word8
    toWord8 :: Int -> Word8
toWord8 Int
x = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x :: Word8

-- The <auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl'/> element, with an
-- optional round-trip value.
saslInitE :: Text.Text -> Maybe Text.Text -> Element
saslInitE :: Text -> Maybe Text -> Element
saslInitE Text
mechanism Maybe Text
rt =
    Name -> [(Name, [Content])] -> [Node] -> Element
Element Name
"{urn:ietf:params:xml:ns:xmpp-sasl}auth"
        [(Name
"mechanism", [Text -> Content
ContentText Text
mechanism])]
        (Maybe Node -> [Node]
forall a. Maybe a -> [a]
maybeToList (Maybe Node -> [Node]) -> Maybe Node -> [Node]
forall a b. (a -> b) -> a -> b
$ Content -> Node
NodeContent (Content -> Node) -> (Text -> Content) -> Text -> Node
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Content
ContentText (Text -> Node) -> Maybe Text -> Maybe Node
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Text
rt)

-- SASL response with text payload.
saslResponseE :: Maybe Text.Text -> Element
saslResponseE :: Maybe Text -> Element
saslResponseE Maybe Text
resp =
    Name -> [(Name, [Content])] -> [Node] -> Element
Element Name
"{urn:ietf:params:xml:ns:xmpp-sasl}response"
    []
    (Maybe Node -> [Node]
forall a. Maybe a -> [a]
maybeToList (Maybe Node -> [Node]) -> Maybe Node -> [Node]
forall a b. (a -> b) -> a -> b
$ Content -> Node
NodeContent (Content -> Node) -> (Text -> Content) -> Text -> Node
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Content
ContentText (Text -> Node) -> Maybe Text -> Maybe Node
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Text
resp)

-- The <success xmlns='urn:ietf:params:xml:ns:xmpp-sasl'/> element.
xpSuccess :: PU [Node] (Maybe Text.Text)
xpSuccess :: PU [Node] (Maybe Text)
xpSuccess = Name -> PU [Node] (Maybe Text) -> PU [Node] (Maybe Text)
forall b. Name -> PU [Node] b -> PU [Node] b
xpElemNodes Name
"{urn:ietf:params:xml:ns:xmpp-sasl}success"
    (PU [Node] Text -> PU [Node] (Maybe Text)
forall t a. PU [t] a -> PU [t] (Maybe a)
xpOption (PU [Node] Text -> PU [Node] (Maybe Text))
-> PU [Node] Text -> PU [Node] (Maybe Text)
forall a b. (a -> b) -> a -> b
$ PU Text Text -> PU [Node] Text
forall a. PU Text a -> PU [Node] a
xpContent PU Text Text
forall a. PU a a
xpId)

-- Parses the incoming SASL data to a mapped list of pairs.
pairs :: BS.ByteString -> Either String Pairs
pairs :: ByteString -> Either String Pairs
pairs = Parser Pairs -> ByteString -> Either String Pairs
forall a. Parser a -> ByteString -> Either String a
AP.parseOnly (Parser Pairs -> ByteString -> Either String Pairs)
-> (Parser ByteString (ByteString, ByteString) -> Parser Pairs)
-> Parser ByteString (ByteString, ByteString)
-> ByteString
-> Either String Pairs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Parser ByteString (ByteString, ByteString)
 -> Parser ByteString () -> Parser Pairs)
-> Parser ByteString ()
-> Parser ByteString (ByteString, ByteString)
-> Parser Pairs
forall a b c. (a -> b -> c) -> b -> a -> c
flip Parser ByteString (ByteString, ByteString)
-> Parser ByteString () -> Parser Pairs
forall (f :: * -> *) a s. Alternative f => f a -> f s -> f [a]
AP.sepBy1 (Parser ByteString Char -> Parser ByteString ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Parser ByteString Char -> Parser ByteString ())
-> Parser ByteString Char -> Parser ByteString ()
forall a b. (a -> b) -> a -> b
$ Char -> Parser ByteString Char
AP.char Char
',') (Parser ByteString (ByteString, ByteString)
 -> ByteString -> Either String Pairs)
-> Parser ByteString (ByteString, ByteString)
-> ByteString
-> Either String Pairs
forall a b. (a -> b) -> a -> b
$ do
    Parser ByteString ()
AP.skipSpace
    ByteString
name <- (Char -> Bool) -> Parser ByteString
AP.takeWhile1 (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'=')
    Char
_ <- Char -> Parser ByteString Char
AP.char Char
'='
    Bool
qt <- ((Char -> Parser ByteString Char
AP.char Char
'"' Parser ByteString Char
-> Parser ByteString Bool -> Parser ByteString Bool
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> Parser ByteString Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True) Parser ByteString Bool
-> Parser ByteString Bool -> Parser ByteString Bool
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Bool -> Parser ByteString Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False)
    ByteString
content <- (Char -> Bool) -> Parser ByteString
AP.takeWhile1 (String -> Char -> Bool
AP.notInClass [Char
',', Char
'"'])
    Bool -> Parser ByteString () -> Parser ByteString ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
qt (Parser ByteString () -> Parser ByteString ())
-> (Parser ByteString Char -> Parser ByteString ())
-> Parser ByteString Char
-> Parser ByteString ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parser ByteString Char -> Parser ByteString ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Parser ByteString Char -> Parser ByteString ())
-> Parser ByteString Char -> Parser ByteString ()
forall a b. (a -> b) -> a -> b
$ Char -> Parser ByteString Char
AP.char Char
'"'
    (ByteString, ByteString)
-> Parser ByteString (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
name, ByteString
content)

-- Failure element pickler.
xpFailure :: PU [Node] SaslFailure
xpFailure :: PU [Node] SaslFailure
xpFailure = ((Maybe (Maybe LangTag, Text), (SaslError, (), ())) -> SaslFailure)
-> (SaslFailure
    -> (Maybe (Maybe LangTag, Text), (SaslError, (), ())))
-> PU [Node] (Maybe (Maybe LangTag, Text), (SaslError, (), ()))
-> PU [Node] SaslFailure
forall a b t. (a -> b) -> (b -> a) -> PU t a -> PU t b
xpWrap
    (\(Maybe (Maybe LangTag, Text)
txt, (SaslError
failure, ()
_, ()
_)) -> SaslError -> Maybe (Maybe LangTag, Text) -> SaslFailure
SaslFailure SaslError
failure Maybe (Maybe LangTag, Text)
txt)
    (\(SaslFailure SaslError
failure Maybe (Maybe LangTag, Text)
txt) -> (Maybe (Maybe LangTag, Text)
txt,(SaslError
failure,(),())))
    (Name
-> PU [Node] (Maybe (Maybe LangTag, Text), (SaslError, (), ()))
-> PU [Node] (Maybe (Maybe LangTag, Text), (SaslError, (), ()))
forall b. Name -> PU [Node] b -> PU [Node] b
xpElemNodes
        Name
"{urn:ietf:params:xml:ns:xmpp-sasl}failure"
        (PU [Node] (Maybe (Maybe LangTag, Text))
-> PU [Node] (SaslError, (), ())
-> PU [Node] (Maybe (Maybe LangTag, Text), (SaslError, (), ()))
forall a b1 b2. PU [a] b1 -> PU [a] b2 -> PU [a] (b1, b2)
xp2Tuple
             (PU [Node] (Maybe LangTag, Text)
-> PU [Node] (Maybe (Maybe LangTag, Text))
forall t a. PU [t] a -> PU [t] (Maybe a)
xpOption (PU [Node] (Maybe LangTag, Text)
 -> PU [Node] (Maybe (Maybe LangTag, Text)))
-> PU [Node] (Maybe LangTag, Text)
-> PU [Node] (Maybe (Maybe LangTag, Text))
forall a b. (a -> b) -> a -> b
$ Name
-> PU [(Name, [Content])] (Maybe LangTag)
-> PU [Node] Text
-> PU [Node] (Maybe LangTag, Text)
forall a n.
Name -> PU [(Name, [Content])] a -> PU [Node] n -> PU [Node] (a, n)
xpElem
                  Name
"{urn:ietf:params:xml:ns:xmpp-sasl}text"
                  PU [(Name, [Content])] (Maybe LangTag)
xpLangTag
                  (PU Text Text -> PU [Node] Text
forall a. PU Text a -> PU [Node] a
xpContent PU Text Text
forall a. PU a a
xpId))
        (Text
-> PU Text SaslError
-> PU [(Name, [Content])] ()
-> PU [Node] ()
-> PU [Node] (SaslError, (), ())
forall name a n.
Text
-> PU Text name
-> PU [(Name, [Content])] a
-> PU [Node] n
-> PU [Node] (name, a, n)
xpElemByNamespace
             Text
"urn:ietf:params:xml:ns:xmpp-sasl"
             PU Text SaslError
xpSaslError
             (PU [(Name, [Content])] ()
forall a. PU [a] ()
xpUnit)
             (PU [Node] ()
forall a. PU [a] ()
xpUnit))))

xpSaslError :: PU Text.Text SaslError
xpSaslError :: PU Text SaslError
xpSaslError = (Text
"xpSaslError", Text
"") (Text, Text) -> PU Text SaslError -> PU Text SaslError
forall t a. (Text, Text) -> PU t a -> PU t a
<?>
        (Text -> SaslError) -> (SaslError -> Text) -> PU Text SaslError
forall a b. (a -> b) -> (b -> a) -> PU a b
xpIso Text -> SaslError
forall a. (Eq a, IsString a) => a -> SaslError
saslErrorFromText SaslError -> Text
forall p. IsString p => SaslError -> p
saslErrorToText
  where
    saslErrorToText :: SaslError -> p
saslErrorToText SaslError
SaslAborted              = p
"aborted"
    saslErrorToText SaslError
SaslAccountDisabled      = p
"account-disabled"
    saslErrorToText SaslError
SaslCredentialsExpired   = p
"credentials-expired"
    saslErrorToText SaslError
SaslEncryptionRequired   = p
"encryption-required"
    saslErrorToText SaslError
SaslIncorrectEncoding    = p
"incorrect-encoding"
    saslErrorToText SaslError
SaslInvalidAuthzid       = p
"invalid-authzid"
    saslErrorToText SaslError
SaslInvalidMechanism     = p
"invalid-mechanism"
    saslErrorToText SaslError
SaslMalformedRequest     = p
"malformed-request"
    saslErrorToText SaslError
SaslMechanismTooWeak     = p
"mechanism-too-weak"
    saslErrorToText SaslError
SaslNotAuthorized        = p
"not-authorized"
    saslErrorToText SaslError
SaslTemporaryAuthFailure = p
"temporary-auth-failure"
    saslErrorFromText :: a -> SaslError
saslErrorFromText a
"aborted" = SaslError
SaslAborted
    saslErrorFromText a
"account-disabled" = SaslError
SaslAccountDisabled
    saslErrorFromText a
"credentials-expired" = SaslError
SaslCredentialsExpired
    saslErrorFromText a
"encryption-required" = SaslError
SaslEncryptionRequired
    saslErrorFromText a
"incorrect-encoding" = SaslError
SaslIncorrectEncoding
    saslErrorFromText a
"invalid-authzid" = SaslError
SaslInvalidAuthzid
    saslErrorFromText a
"invalid-mechanism" = SaslError
SaslInvalidMechanism
    saslErrorFromText a
"malformed-request" = SaslError
SaslMalformedRequest
    saslErrorFromText a
"mechanism-too-weak" = SaslError
SaslMechanismTooWeak
    saslErrorFromText a
"not-authorized" = SaslError
SaslNotAuthorized
    saslErrorFromText a
"temporary-auth-failure" = SaslError
SaslTemporaryAuthFailure
    saslErrorFromText a
_ = SaslError
SaslNotAuthorized

-- Challenge element pickler.
xpChallenge :: PU [Node] (Maybe Text.Text)
xpChallenge :: PU [Node] (Maybe Text)
xpChallenge = Name -> PU [Node] (Maybe Text) -> PU [Node] (Maybe Text)
forall b. Name -> PU [Node] b -> PU [Node] b
xpElemNodes Name
"{urn:ietf:params:xml:ns:xmpp-sasl}challenge"
                      (PU [Node] Text -> PU [Node] (Maybe Text)
forall t a. PU [t] a -> PU [t] (Maybe a)
xpOption (PU [Node] Text -> PU [Node] (Maybe Text))
-> PU [Node] Text -> PU [Node] (Maybe Text)
forall a b. (a -> b) -> a -> b
$ PU Text Text -> PU [Node] Text
forall a. PU Text a -> PU [Node] a
xpContent PU Text Text
forall a. PU a a
xpId)

-- | Pickler for SaslElement.
xpSaslElement :: PU [Node] SaslElement
xpSaslElement :: PU [Node] SaslElement
xpSaslElement = (SaslElement -> Int)
-> [PU [Node] SaslElement] -> PU [Node] SaslElement
forall a t. (a -> Int) -> [PU t a] -> PU t a
xpAlt SaslElement -> Int
forall p. Num p => SaslElement -> p
saslSel
                [ (Maybe Text -> SaslElement)
-> (SaslElement -> Maybe Text)
-> PU [Node] (Maybe Text)
-> PU [Node] SaslElement
forall a b t. (a -> b) -> (b -> a) -> PU t a -> PU t b
xpWrap Maybe Text -> SaslElement
SaslSuccess   (\(SaslSuccess Maybe Text
x)   -> Maybe Text
x) PU [Node] (Maybe Text)
xpSuccess
                , (Maybe Text -> SaslElement)
-> (SaslElement -> Maybe Text)
-> PU [Node] (Maybe Text)
-> PU [Node] SaslElement
forall a b t. (a -> b) -> (b -> a) -> PU t a -> PU t b
xpWrap Maybe Text -> SaslElement
SaslChallenge (\(SaslChallenge Maybe Text
c) -> Maybe Text
c) PU [Node] (Maybe Text)
xpChallenge
                ]
  where
    saslSel :: SaslElement -> p
saslSel (SaslSuccess   Maybe Text
_) = p
0
    saslSel (SaslChallenge Maybe Text
_) = p
1

-- | Add quotationmarks around a byte string.
quote :: BS.ByteString -> BS.ByteString
quote :: ByteString -> ByteString
quote ByteString
x = [ByteString] -> ByteString
BS.concat [ByteString
"\"",ByteString
x,ByteString
"\""]

saslInit :: Text.Text -> Maybe BS.ByteString -> ExceptT AuthFailure (StateT StreamState IO) ()
saslInit :: Text
-> Maybe ByteString
-> ExceptT AuthFailure (StateT StreamState IO) ()
saslInit Text
mechanism Maybe ByteString
payload = do
    Either XmppFailure ()
r <- StateT StreamState IO (Either XmppFailure ())
-> ExceptT
     AuthFailure (StateT StreamState IO) (Either XmppFailure ())
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StreamState IO (Either XmppFailure ())
 -> ExceptT
      AuthFailure (StateT StreamState IO) (Either XmppFailure ()))
-> (Maybe Text -> StateT StreamState IO (Either XmppFailure ()))
-> Maybe Text
-> ExceptT
     AuthFailure (StateT StreamState IO) (Either XmppFailure ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> StateT StreamState IO (Either XmppFailure ())
pushElement (Element -> StateT StreamState IO (Either XmppFailure ()))
-> (Maybe Text -> Element)
-> Maybe Text
-> StateT StreamState IO (Either XmppFailure ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Maybe Text -> Element
saslInitE Text
mechanism (Maybe Text
 -> ExceptT
      AuthFailure (StateT StreamState IO) (Either XmppFailure ()))
-> Maybe Text
-> ExceptT
     AuthFailure (StateT StreamState IO) (Either XmppFailure ())
forall a b. (a -> b) -> a -> b
$
        ByteString -> Text
Text.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
forall p. (Eq p, IsString p) => p -> p
encodeEmpty (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B64.encode (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ByteString
payload
    case Either XmppFailure ()
r of
        Right () -> () -> ExceptT AuthFailure (StateT StreamState IO) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Left XmppFailure
e  -> AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) ())
-> AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) ()
forall a b. (a -> b) -> a -> b
$ XmppFailure -> AuthFailure
AuthStreamFailure XmppFailure
e
  where
    -- §6.4.2
    encodeEmpty :: p -> p
encodeEmpty p
"" = p
"="
    encodeEmpty p
x = p
x

-- | Pull the next element.
pullSaslElement :: ExceptT AuthFailure (StateT StreamState IO) SaslElement
pullSaslElement :: ExceptT AuthFailure (StateT StreamState IO) SaslElement
pullSaslElement = do
    Either XmppFailure (Either SaslFailure SaslElement)
mbse <- StateT
  StreamState
  IO
  (Either XmppFailure (Either SaslFailure SaslElement))
-> ExceptT
     AuthFailure
     (StateT StreamState IO)
     (Either XmppFailure (Either SaslFailure SaslElement))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT
   StreamState
   IO
   (Either XmppFailure (Either SaslFailure SaslElement))
 -> ExceptT
      AuthFailure
      (StateT StreamState IO)
      (Either XmppFailure (Either SaslFailure SaslElement)))
-> StateT
     StreamState
     IO
     (Either XmppFailure (Either SaslFailure SaslElement))
-> ExceptT
     AuthFailure
     (StateT StreamState IO)
     (Either XmppFailure (Either SaslFailure SaslElement))
forall a b. (a -> b) -> a -> b
$ PU [Node] (Either SaslFailure SaslElement)
-> StateT
     StreamState
     IO
     (Either XmppFailure (Either SaslFailure SaslElement))
forall a.
PU [Node] a -> StateT StreamState IO (Either XmppFailure a)
pullUnpickle (PU [Node] SaslFailure
-> PU [Node] SaslElement
-> PU [Node] (Either SaslFailure SaslElement)
forall n t1 t2. PU n t1 -> PU n t2 -> PU n (Either t1 t2)
xpEither PU [Node] SaslFailure
xpFailure PU [Node] SaslElement
xpSaslElement)
    case Either XmppFailure (Either SaslFailure SaslElement)
mbse of
        Left XmppFailure
e -> AuthFailure
-> ExceptT AuthFailure (StateT StreamState IO) SaslElement
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (AuthFailure
 -> ExceptT AuthFailure (StateT StreamState IO) SaslElement)
-> AuthFailure
-> ExceptT AuthFailure (StateT StreamState IO) SaslElement
forall a b. (a -> b) -> a -> b
$ XmppFailure -> AuthFailure
AuthStreamFailure XmppFailure
e
        Right (Left SaslFailure
e) -> AuthFailure
-> ExceptT AuthFailure (StateT StreamState IO) SaslElement
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (AuthFailure
 -> ExceptT AuthFailure (StateT StreamState IO) SaslElement)
-> AuthFailure
-> ExceptT AuthFailure (StateT StreamState IO) SaslElement
forall a b. (a -> b) -> a -> b
$ SaslFailure -> AuthFailure
AuthSaslFailure SaslFailure
e
        Right (Right SaslElement
r) -> SaslElement
-> ExceptT AuthFailure (StateT StreamState IO) SaslElement
forall (m :: * -> *) a. Monad m => a -> m a
return SaslElement
r

-- | Pull the next element, checking that it is a challenge.
pullChallenge :: ExceptT AuthFailure (StateT StreamState IO) (Maybe BS.ByteString)
pullChallenge :: ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
pullChallenge = do
  SaslElement
e <- ExceptT AuthFailure (StateT StreamState IO) SaslElement
pullSaslElement
  case SaslElement
e of
      SaslChallenge Maybe Text
Nothing -> Maybe ByteString
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
      SaslChallenge (Just Text
scb64)
          | Right ByteString
sc <- ByteString -> Either String ByteString
B64.decode (ByteString -> Either String ByteString)
-> (Text -> ByteString) -> Text -> Either String ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
Text.encodeUtf8 (Text -> Either String ByteString)
-> Text -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ Text
scb64
             -> Maybe ByteString
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString
 -> ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString))
-> Maybe ByteString
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
sc
      SaslElement
_ -> AuthFailure
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError AuthFailure
AuthOtherFailure -- TODO: Log

-- | Extract value from Just, failing with AuthOtherFailure on Nothing.
saslFromJust :: Maybe a -> ExceptT AuthFailure (StateT StreamState IO) a
saslFromJust :: Maybe a -> ExceptT AuthFailure (StateT StreamState IO) a
saslFromJust Maybe a
Nothing = AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) a)
-> AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) a
forall a b. (a -> b) -> a -> b
$ AuthFailure
AuthOtherFailure -- TODO: Log
saslFromJust (Just a
d) = a -> ExceptT AuthFailure (StateT StreamState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
d

-- | Pull the next element and check that it is success.
pullSuccess :: ExceptT AuthFailure (StateT StreamState IO) (Maybe Text.Text)
pullSuccess :: ExceptT AuthFailure (StateT StreamState IO) (Maybe Text)
pullSuccess = do
    SaslElement
e <- ExceptT AuthFailure (StateT StreamState IO) SaslElement
pullSaslElement
    case SaslElement
e of
        SaslSuccess Maybe Text
x -> Maybe Text
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe Text)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Text
x
        SaslElement
_ -> AuthFailure
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe Text)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (AuthFailure
 -> ExceptT AuthFailure (StateT StreamState IO) (Maybe Text))
-> AuthFailure
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe Text)
forall a b. (a -> b) -> a -> b
$ AuthFailure
AuthOtherFailure -- TODO: Log

-- | Pull the next element. When it's success, return it's payload.
-- If it's a challenge, send an empty response and pull success.
pullFinalMessage :: ExceptT AuthFailure (StateT StreamState IO) (Maybe BS.ByteString)
pullFinalMessage :: ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
pullFinalMessage = do
    SaslElement
challenge2 <- ExceptT AuthFailure (StateT StreamState IO) SaslElement
pullSaslElement
    case SaslElement
challenge2 of
        SaslSuccess   Maybe Text
x -> Maybe Text
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
forall (m :: * -> *).
MonadError AuthFailure m =>
Maybe Text -> m (Maybe ByteString)
decode Maybe Text
x
        SaslChallenge Maybe Text
x -> do
            ()
_b <- Maybe ByteString -> ExceptT AuthFailure (StateT StreamState IO) ()
respond Maybe ByteString
forall a. Maybe a
Nothing
            Maybe Text
_s <- ExceptT AuthFailure (StateT StreamState IO) (Maybe Text)
pullSuccess
            Maybe Text
-> ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
forall (m :: * -> *).
MonadError AuthFailure m =>
Maybe Text -> m (Maybe ByteString)
decode Maybe Text
x
  where
    decode :: Maybe Text -> m (Maybe ByteString)
decode Maybe Text
Nothing  = Maybe ByteString -> m (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
    decode (Just Text
d) = case ByteString -> Either String ByteString
B64.decode (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
Text.encodeUtf8 Text
d of
        Left String
_e -> AuthFailure -> m (Maybe ByteString)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (AuthFailure -> m (Maybe ByteString))
-> AuthFailure -> m (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ AuthFailure
AuthOtherFailure -- TODO: Log
        Right ByteString
x -> Maybe ByteString -> m (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> m (Maybe ByteString))
-> Maybe ByteString -> m (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
x

-- | Extract p=q pairs from a challenge.
toPairs :: BS.ByteString -> ExceptT AuthFailure (StateT StreamState IO) Pairs
toPairs :: ByteString -> ExceptT AuthFailure (StateT StreamState IO) Pairs
toPairs ByteString
ctext = case ByteString -> Either String Pairs
pairs ByteString
ctext of
    Left String
_e -> AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) Pairs
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError AuthFailure
AuthOtherFailure -- TODO: Log
    Right Pairs
r -> Pairs -> ExceptT AuthFailure (StateT StreamState IO) Pairs
forall (m :: * -> *) a. Monad m => a -> m a
return Pairs
r

-- | Send a SASL response element. The content will be base64-encoded.
respond :: Maybe BS.ByteString -> ExceptT AuthFailure (StateT StreamState IO) ()
respond :: Maybe ByteString -> ExceptT AuthFailure (StateT StreamState IO) ()
respond Maybe ByteString
m = do
    Either XmppFailure ()
r <- StateT StreamState IO (Either XmppFailure ())
-> ExceptT
     AuthFailure (StateT StreamState IO) (Either XmppFailure ())
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StreamState IO (Either XmppFailure ())
 -> ExceptT
      AuthFailure (StateT StreamState IO) (Either XmppFailure ()))
-> (Maybe ByteString
    -> StateT StreamState IO (Either XmppFailure ()))
-> Maybe ByteString
-> ExceptT
     AuthFailure (StateT StreamState IO) (Either XmppFailure ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> StateT StreamState IO (Either XmppFailure ())
pushElement (Element -> StateT StreamState IO (Either XmppFailure ()))
-> (Maybe ByteString -> Element)
-> Maybe ByteString
-> StateT StreamState IO (Either XmppFailure ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Text -> Element
saslResponseE (Maybe Text -> Element)
-> (Maybe ByteString -> Maybe Text) -> Maybe ByteString -> Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ByteString -> Text
Text.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B64.encode) (Maybe ByteString
 -> ExceptT
      AuthFailure (StateT StreamState IO) (Either XmppFailure ()))
-> Maybe ByteString
-> ExceptT
     AuthFailure (StateT StreamState IO) (Either XmppFailure ())
forall a b. (a -> b) -> a -> b
$ Maybe ByteString
m
    case Either XmppFailure ()
r of
        Left XmppFailure
e -> AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) ())
-> AuthFailure -> ExceptT AuthFailure (StateT StreamState IO) ()
forall a b. (a -> b) -> a -> b
$ XmppFailure -> AuthFailure
AuthStreamFailure XmppFailure
e
        Right () -> () -> ExceptT AuthFailure (StateT StreamState IO) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Run the appropriate stringprep profiles on the credentials.
-- May fail with 'AuthStringPrepFailure'
prepCredentials :: Text.Text -> Maybe Text.Text -> Text.Text
                -> ExceptT AuthFailure (StateT StreamState IO) (Text.Text, Maybe Text.Text, Text.Text)
prepCredentials :: Text
-> Maybe Text
-> Text
-> ExceptT
     AuthFailure (StateT StreamState IO) (Text, Maybe Text, Text)
prepCredentials Text
authcid Maybe Text
authzid Text
password = case Maybe (Text, Maybe Text, Text)
credentials of
    Maybe (Text, Maybe Text, Text)
Nothing -> AuthFailure
-> ExceptT
     AuthFailure (StateT StreamState IO) (Text, Maybe Text, Text)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (AuthFailure
 -> ExceptT
      AuthFailure (StateT StreamState IO) (Text, Maybe Text, Text))
-> AuthFailure
-> ExceptT
     AuthFailure (StateT StreamState IO) (Text, Maybe Text, Text)
forall a b. (a -> b) -> a -> b
$ AuthFailure
AuthIllegalCredentials
    Just (Text, Maybe Text, Text)
creds -> (Text, Maybe Text, Text)
-> ExceptT
     AuthFailure (StateT StreamState IO) (Text, Maybe Text, Text)
forall (m :: * -> *) a. Monad m => a -> m a
return (Text, Maybe Text, Text)
creds
  where
    credentials :: Maybe (Text, Maybe Text, Text)
credentials = do
        Text
ac <- Text -> Maybe Text
normalizeUsername Text
authcid
        Maybe Text
az <- case Maybe Text
authzid of
          Maybe Text
Nothing -> Maybe Text -> Maybe (Maybe Text)
forall a. a -> Maybe a
Just Maybe Text
forall a. Maybe a
Nothing
          Just Text
az' -> Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Maybe Text -> Maybe (Maybe Text)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Text -> Maybe Text
normalizeUsername Text
az'
        Text
pw <- Text -> Maybe Text
normalizePassword Text
password
        (Text, Maybe Text, Text) -> Maybe (Text, Maybe Text, Text)
forall (m :: * -> *) a. Monad m => a -> m a
return (Text
ac, Maybe Text
az, Text
pw)

-- | Bit-wise xor of byte strings
xorBS :: BS.ByteString -> BS.ByteString -> BS.ByteString
xorBS :: ByteString -> ByteString -> ByteString
xorBS ByteString
x ByteString
y = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
BS.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
x ByteString
y

-- | Join byte strings with ","
merge :: [BS.ByteString] -> BS.ByteString
merge :: [ByteString] -> ByteString
merge = ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
","

-- | Infix concatenation of byte strings
(+++) :: BS.ByteString -> BS.ByteString -> BS.ByteString
+++ :: ByteString -> ByteString -> ByteString
(+++) = ByteString -> ByteString -> ByteString
BS.append