{-# Language BlockArguments #-}
{-# Language ImportQualifiedPost #-}
{-# Language LambdaCase #-}
{-# Language OverloadedStrings #-}
{-# Language RecordWildCards #-}
{-# Language ViewPatterns #-}
module Client.Authentication.Scram (
  -- * Transaction state types
  Phase1,
  Phase2,
  -- * Transaction step functions
  initiateScram,
  addServerFirst,
  addServerFinal,
  -- * Digests
  ScramDigest(..),
  mechanismName,
  ) where

import Control.Monad (guard)
import Data.Bits (xor)
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Base64 qualified as B64
import Data.ByteString.Char8 qualified as B8
import Data.List (foldl1')
import Data.Text (Text)
import Irc.Commands (AuthenticatePayload (AuthenticatePayload))
import OpenSSL.EVP.Digest (Digest, digestBS, hmacBS, getDigestByName)
import System.IO.Unsafe (unsafePerformIO)

data ScramDigest
  = ScramDigestSha1
  | ScramDigestSha2_256
  | ScramDigestSha2_512
  deriving Int -> ScramDigest -> ShowS
[ScramDigest] -> ShowS
ScramDigest -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScramDigest] -> ShowS
$cshowList :: [ScramDigest] -> ShowS
show :: ScramDigest -> String
$cshow :: ScramDigest -> String
showsPrec :: Int -> ScramDigest -> ShowS
$cshowsPrec :: Int -> ScramDigest -> ShowS
Show

mechanismName :: ScramDigest -> Text
mechanismName :: ScramDigest -> Text
mechanismName ScramDigest
digest =
  case ScramDigest
digest of
    ScramDigest
ScramDigestSha1     -> Text
"SCRAM-SHA-1"
    ScramDigest
ScramDigestSha2_256 -> Text
"SCRAM-SHA-256"
    ScramDigest
ScramDigestSha2_512 -> Text
"SCRAM-SHA-512"

-- | SCRAM state waiting for server-first-message
data Phase1 = Phase1
  { Phase1 -> ScramDigest
phase1Digest          :: ScramDigest -- ^ underlying cryptographic hash function
  , Phase1 -> ByteString
phase1Password        :: ByteString -- ^ password
  , Phase1 -> ByteString
phase1CbindInput      :: ByteString -- ^ cbind-input
  , Phase1 -> ByteString
phase1Nonce           :: ByteString -- ^ c-nonce
  , Phase1 -> ByteString
phase1ClientFirstBare :: ByteString -- ^ client-first-bare
  }

-- | Construct client-first-message and extra parameters
-- needed for 'addServerFirst'.
initiateScram ::
  ScramDigest ->
  ByteString {- ^ authentication ID -} ->
  ByteString {- ^ authorization ID  -} ->
  ByteString {- ^ password          -} ->
  ByteString {- ^ nonce             -} ->
  (AuthenticatePayload, Phase1)
initiateScram :: ScramDigest
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> (AuthenticatePayload, Phase1)
initiateScram ScramDigest
digest ByteString
user ByteString
authzid ByteString
pass ByteString
nonce =
  (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
clientFirstMessage, Phase1
    { phase1Digest :: ScramDigest
phase1Digest = ScramDigest
digest
    , phase1Password :: ByteString
phase1Password = ByteString
pass
    , phase1CbindInput :: ByteString
phase1CbindInput = ByteString -> ByteString
B64.encode ByteString
gs2Header
    , phase1Nonce :: ByteString
phase1Nonce = ByteString
nonce
    , phase1ClientFirstBare :: ByteString
phase1ClientFirstBare = ByteString
clientFirstMessageBare
    })
  where
    clientFirstMessage :: ByteString
clientFirstMessage = ByteString
gs2Header forall a. Semigroup a => a -> a -> a
<> ByteString
clientFirstMessageBare
    gs2Header :: ByteString
gs2Header = ByteString
"n," forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encodeUsername ByteString
authzid forall a. Semigroup a => a -> a -> a
<> ByteString
","
    clientFirstMessageBare :: ByteString
clientFirstMessageBare = ByteString
"n=" forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encodeUsername ByteString
user forall a. Semigroup a => a -> a -> a
<> ByteString
",r=" forall a. Semigroup a => a -> a -> a
<> ByteString
nonce

-- | SCRAM state waiting for server-final-message
newtype Phase2 = Phase2
  { Phase2 -> ByteString
phase2ServerSignature :: ByteString -- ^ base64 encoded expected value
  }

-- | Add server-first-message to current SCRAM transaction,
-- compute client-final-message and next state for 'addServerFinal'.
addServerFirst ::
  Phase1     {- ^ output of 'initiateScram' -} ->
  ByteString {- ^ server first message -} ->
  Maybe (AuthenticatePayload, Phase2)
addServerFirst :: Phase1 -> ByteString -> Maybe (AuthenticatePayload, Phase2)
addServerFirst Phase1{ByteString
ScramDigest
phase1ClientFirstBare :: ByteString
phase1Nonce :: ByteString
phase1CbindInput :: ByteString
phase1Password :: ByteString
phase1Digest :: ScramDigest
phase1ClientFirstBare :: Phase1 -> ByteString
phase1Nonce :: Phase1 -> ByteString
phase1CbindInput :: Phase1 -> ByteString
phase1Password :: Phase1 -> ByteString
phase1Digest :: Phase1 -> ScramDigest
..} ByteString
serverFirstMessage =

  do -- Parse server-first-message
     (ByteString
"r", ByteString
nonce) :
       (ByteString
"s", ByteString -> Either String ByteString
B64.decode -> Right ByteString
salt) :
       (ByteString
"i", ByteString -> Maybe (Int, ByteString)
B8.readInt -> Just (Int
iterations, ByteString
"")) :
       [(ByteString, ByteString)]
_extensions
       <- forall a. a -> Maybe a
Just (ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
serverFirstMessage)

     -- validate nonce given by server includes ours and isn't empty
     forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> ByteString -> Bool
B.isPrefixOf ByteString
phase1Nonce ByteString
nonce Bool -> Bool -> Bool
&& ByteString
phase1Nonce forall a. Eq a => a -> a -> Bool
/= ByteString
nonce)

     let clientFinalWithoutProof :: ByteString
clientFinalWithoutProof = ByteString
"c=" forall a. Semigroup a => a -> a -> a
<> ByteString
phase1CbindInput forall a. Semigroup a => a -> a -> a
<> ByteString
",r=" forall a. Semigroup a => a -> a -> a
<> ByteString
nonce

     let authMessage :: ByteString
authMessage =
           ByteString
phase1ClientFirstBare forall a. Semigroup a => a -> a -> a
<> ByteString
"," forall a. Semigroup a => a -> a -> a
<>
           ByteString
serverFirstMessage forall a. Semigroup a => a -> a -> a
<> ByteString
"," forall a. Semigroup a => a -> a -> a
<>
           ByteString
clientFinalWithoutProof

     let (ByteString
clientProof, ByteString
serverSignature) =
           ScramDigest
-> ByteString
-> ByteString
-> Int
-> ByteString
-> (ByteString, ByteString)
crypto ScramDigest
phase1Digest ByteString
phase1Password ByteString
salt Int
iterations ByteString
authMessage

     let proof :: ByteString
proof = ByteString
"p=" forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
B64.encode ByteString
clientProof
     let clientFinalMessage :: ByteString
clientFinalMessage = ByteString
clientFinalWithoutProof forall a. Semigroup a => a -> a -> a
<> ByteString
"," forall a. Semigroup a => a -> a -> a
<> ByteString
proof

     let phase2 :: Phase2
phase2 = Phase2 { phase2ServerSignature :: ByteString
phase2ServerSignature = ByteString -> ByteString
B64.encode ByteString
serverSignature }
     forall a. a -> Maybe a
Just (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
clientFinalMessage, Phase2
phase2)

-- | Add server-final-message to transaction and compute validatity of
-- the whole transaction.
addServerFinal ::
  Phase2     {- ^ output of 'addServerFirst' -} ->
  ByteString {- ^ server-final-message   -} ->
  Bool       {- ^ transaction succeeded? -}
addServerFinal :: Phase2 -> ByteString -> Bool
addServerFinal Phase2{ByteString
phase2ServerSignature :: ByteString
phase2ServerSignature :: Phase2 -> ByteString
..} ByteString
serverFinalMessage =
  case ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
serverFinalMessage of
    (ByteString
"v", ByteString
sig) : [(ByteString, ByteString)]
_extensions -> ByteString
sig forall a. Eq a => a -> a -> Bool
== ByteString
phase2ServerSignature
    [(ByteString, ByteString)]
_ -> Bool
False

-- | Big endian encoding of a 32-bit number 1.
int1 :: ByteString
int1 :: ByteString
int1 = [Word8] -> ByteString
B.pack [Word8
0,Word8
0,Word8
0,Word8
1]

xorBS :: ByteString -> ByteString -> ByteString
xorBS :: ByteString -> ByteString -> ByteString
xorBS ByteString
x ByteString
y = [Word8] -> ByteString
B.pack (forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith forall a. Bits a => a -> a -> a
xor ByteString
x ByteString
y)

-- | Iterated, password-based, key-derivation function.
hi ::
  Digest     {- ^ underlying cryptographic hash function -} ->
  ByteString {- ^ secret -} ->
  ByteString {- ^ salt -} ->
  Int        {- ^ iterations -} ->
  ByteString {- ^ salted, iterated hash of secret -}
hi :: Digest -> ByteString -> ByteString -> Int -> ByteString
hi Digest
digest ByteString
str ByteString
salt Int
n = forall a. (a -> a -> a) -> [a] -> a
foldl1' ByteString -> ByteString -> ByteString
xorBS (forall a. Int -> [a] -> [a]
take Int
n [ByteString]
us)
  where
    u1 :: ByteString
u1 = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
digest ByteString
str (ByteString
salt forall a. Semigroup a => a -> a -> a
<> ByteString
int1)
    us :: [ByteString]
us = forall a. (a -> a) -> a -> [a]
iterate (Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
digest ByteString
str) ByteString
u1

-- | Break up a SCRAM message into its underlying key-value association list.
parseMessage :: ByteString -> [(ByteString, ByteString)]
parseMessage :: ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
msg =
  [case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
B8.break (Char
'='forall a. Eq a => a -> a -> Bool
==) ByteString
entry of
     (ByteString
key, ByteString
value) -> (ByteString
key, Int -> ByteString -> ByteString
B.drop Int
1 ByteString
value)
  | ByteString
entry <- Char -> ByteString -> [ByteString]
B8.split Char
',' ByteString
msg]

-- | Tranform all the SCRAM parameters into a @ClientProof@
-- and @ServerSignature@.
crypto ::
  ScramDigest {- ^ digest       -} ->
  ByteString  {- ^ password     -} ->
  ByteString  {- ^ salt         -} ->
  Int         {- ^ iterations   -} ->
  ByteString  {- ^ auth message -} ->
  (ByteString, ByteString) {- ^ client-proof, server-signature -}
crypto :: ScramDigest
-> ByteString
-> ByteString
-> Int
-> ByteString
-> (ByteString, ByteString)
crypto ScramDigest
digest ByteString
password ByteString
salt Int
iterations ByteString
authMessage =
  (ByteString
clientProof, ByteString
serverSignature)
  where
    saltedPassword :: ByteString
saltedPassword  = Digest -> ByteString -> ByteString -> Int -> ByteString
hi       Digest
d ByteString
password ByteString
salt Int
iterations
    clientKey :: ByteString
clientKey       = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
saltedPassword ByteString
"Client Key"
    storedKey :: ByteString
storedKey       = Digest -> ByteString -> ByteString
digestBS Digest
d ByteString
clientKey
    clientSignature :: ByteString
clientSignature = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
storedKey ByteString
authMessage
    clientProof :: ByteString
clientProof     = ByteString -> ByteString -> ByteString
xorBS ByteString
clientKey ByteString
clientSignature
    serverKey :: ByteString
serverKey       = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
saltedPassword ByteString
"Server Key"
    serverSignature :: ByteString
serverSignature = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
serverKey ByteString
authMessage
    digestName :: String
digestName =
      case ScramDigest
digest of
        ScramDigest
ScramDigestSha1     -> String
"SHA1"
        ScramDigest
ScramDigestSha2_256 -> String
"SHA256"
        ScramDigest
ScramDigestSha2_512 -> String
"SHA512"
    Just Digest
d = forall a. IO a -> a
unsafePerformIO (String -> IO (Maybe Digest)
getDigestByName String
digestName)

-- | Encode usersnames so they fit in the comma/equals delimited
-- SCRAM message format.
encodeUsername :: ByteString -> ByteString
encodeUsername :: ByteString -> ByteString
encodeUsername = (Char -> ByteString) -> ByteString -> ByteString
B8.concatMap \case
    Char
',' -> ByteString
"=2C"
    Char
'=' -> ByteString
"=3D"
    Char
x   -> Char -> ByteString
B8.singleton Char
x