{-# Language BlockArguments #-}
{-# Language ImportQualifiedPost #-}
{-# Language LambdaCase #-}
{-# Language OverloadedStrings #-}
{-# Language RecordWildCards #-}
{-# Language ViewPatterns #-}
module Client.Authentication.Scram (
  -- * Transaction state types
  Phase1,
  Phase2,
  -- * Transaction step functions
  initiateScram,
  addServerFirst,
  addServerFinal,
  -- * Digests
  ScramDigest(..),
  mechanismName,
  ) where

import Control.Monad (guard)
import Data.Bits (xor)
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Base64 qualified as B64
import Data.ByteString.Char8 qualified as B8
import Data.List (foldl1')
import Data.Text (Text)
import Irc.Commands (AuthenticatePayload (AuthenticatePayload))
import OpenSSL.EVP.Digest (Digest, digestBS, hmacBS, getDigestByName)
import System.IO.Unsafe (unsafePerformIO)

data ScramDigest
  = ScramDigestSha1
  | ScramDigestSha2_256
  | ScramDigestSha2_512
  deriving Int -> ScramDigest -> ShowS
[ScramDigest] -> ShowS
ScramDigest -> String
(Int -> ScramDigest -> ShowS)
-> (ScramDigest -> String)
-> ([ScramDigest] -> ShowS)
-> Show ScramDigest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ScramDigest -> ShowS
showsPrec :: Int -> ScramDigest -> ShowS
$cshow :: ScramDigest -> String
show :: ScramDigest -> String
$cshowList :: [ScramDigest] -> ShowS
showList :: [ScramDigest] -> ShowS
Show

mechanismName :: ScramDigest -> Text
mechanismName :: ScramDigest -> Text
mechanismName ScramDigest
digest =
  case ScramDigest
digest of
    ScramDigest
ScramDigestSha1     -> Text
"SCRAM-SHA-1"
    ScramDigest
ScramDigestSha2_256 -> Text
"SCRAM-SHA-256"
    ScramDigest
ScramDigestSha2_512 -> Text
"SCRAM-SHA-512"

-- | SCRAM state waiting for server-first-message
data Phase1 = Phase1
  { Phase1 -> ScramDigest
phase1Digest          :: ScramDigest -- ^ underlying cryptographic hash function
  , Phase1 -> ByteString
phase1Password        :: ByteString -- ^ password
  , Phase1 -> ByteString
phase1CbindInput      :: ByteString -- ^ cbind-input
  , Phase1 -> ByteString
phase1Nonce           :: ByteString -- ^ c-nonce
  , Phase1 -> ByteString
phase1ClientFirstBare :: ByteString -- ^ client-first-bare
  }

-- | Construct client-first-message and extra parameters
-- needed for 'addServerFirst'.
initiateScram ::
  ScramDigest ->
  ByteString {- ^ authentication ID -} ->
  ByteString {- ^ authorization ID  -} ->
  ByteString {- ^ password          -} ->
  ByteString {- ^ nonce             -} ->
  (AuthenticatePayload, Phase1)
initiateScram :: ScramDigest
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> (AuthenticatePayload, Phase1)
initiateScram ScramDigest
digest ByteString
user ByteString
authzid ByteString
pass ByteString
nonce =
  (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
clientFirstMessage, Phase1
    { phase1Digest :: ScramDigest
phase1Digest = ScramDigest
digest
    , phase1Password :: ByteString
phase1Password = ByteString
pass
    , phase1CbindInput :: ByteString
phase1CbindInput = ByteString -> ByteString
B64.encode ByteString
gs2Header
    , phase1Nonce :: ByteString
phase1Nonce = ByteString
nonce
    , phase1ClientFirstBare :: ByteString
phase1ClientFirstBare = ByteString
clientFirstMessageBare
    })
  where
    clientFirstMessage :: ByteString
clientFirstMessage = ByteString
gs2Header ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
clientFirstMessageBare
    gs2Header :: ByteString
gs2Header = ByteString
"n," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encodeUsername ByteString
authzid ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
","
    clientFirstMessageBare :: ByteString
clientFirstMessageBare = ByteString
"n=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encodeUsername ByteString
user ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
",r=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
nonce

-- | SCRAM state waiting for server-final-message
newtype Phase2 = Phase2
  { Phase2 -> ByteString
phase2ServerSignature :: ByteString -- ^ base64 encoded expected value
  }

-- | Add server-first-message to current SCRAM transaction,
-- compute client-final-message and next state for 'addServerFinal'.
addServerFirst ::
  Phase1     {- ^ output of 'initiateScram' -} ->
  ByteString {- ^ server first message -} ->
  Maybe (AuthenticatePayload, Phase2)
addServerFirst :: Phase1 -> ByteString -> Maybe (AuthenticatePayload, Phase2)
addServerFirst Phase1{ByteString
ScramDigest
phase1Digest :: Phase1 -> ScramDigest
phase1Password :: Phase1 -> ByteString
phase1CbindInput :: Phase1 -> ByteString
phase1Nonce :: Phase1 -> ByteString
phase1ClientFirstBare :: Phase1 -> ByteString
phase1Digest :: ScramDigest
phase1Password :: ByteString
phase1CbindInput :: ByteString
phase1Nonce :: ByteString
phase1ClientFirstBare :: ByteString
..} ByteString
serverFirstMessage =

  do -- Parse server-first-message
     (ByteString
"r", ByteString
nonce) :
       (ByteString
"s", ByteString -> Either String ByteString
B64.decode -> Right ByteString
salt) :
       (ByteString
"i", ByteString -> Maybe (Int, ByteString)
B8.readInt -> Just (Int
iterations, ByteString
"")) :
       [(ByteString, ByteString)]
_extensions
       <- [(ByteString, ByteString)] -> Maybe [(ByteString, ByteString)]
forall a. a -> Maybe a
Just (ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
serverFirstMessage)

     -- validate nonce given by server includes ours and isn't empty
     Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> ByteString -> Bool
B.isPrefixOf ByteString
phase1Nonce ByteString
nonce Bool -> Bool -> Bool
&& ByteString
phase1Nonce ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
nonce)

     let clientFinalWithoutProof :: ByteString
clientFinalWithoutProof = ByteString
"c=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
phase1CbindInput ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
",r=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
nonce

     let authMessage :: ByteString
authMessage =
           ByteString
phase1ClientFirstBare ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>
           ByteString
serverFirstMessage ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>
           ByteString
clientFinalWithoutProof

     let (ByteString
clientProof, ByteString
serverSignature) =
           ScramDigest
-> ByteString
-> ByteString
-> Int
-> ByteString
-> (ByteString, ByteString)
crypto ScramDigest
phase1Digest ByteString
phase1Password ByteString
salt Int
iterations ByteString
authMessage

     let proof :: ByteString
proof = ByteString
"p=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
B64.encode ByteString
clientProof
     let clientFinalMessage :: ByteString
clientFinalMessage = ByteString
clientFinalWithoutProof ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
proof

     let phase2 :: Phase2
phase2 = Phase2 { phase2ServerSignature :: ByteString
phase2ServerSignature = ByteString -> ByteString
B64.encode ByteString
serverSignature }
     (AuthenticatePayload, Phase2)
-> Maybe (AuthenticatePayload, Phase2)
forall a. a -> Maybe a
Just (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
clientFinalMessage, Phase2
phase2)

-- | Add server-final-message to transaction and compute validatity of
-- the whole transaction.
addServerFinal ::
  Phase2     {- ^ output of 'addServerFirst' -} ->
  ByteString {- ^ server-final-message   -} ->
  Bool       {- ^ transaction succeeded? -}
addServerFinal :: Phase2 -> ByteString -> Bool
addServerFinal Phase2{ByteString
phase2ServerSignature :: Phase2 -> ByteString
phase2ServerSignature :: ByteString
..} ByteString
serverFinalMessage =
  case ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
serverFinalMessage of
    (ByteString
"v", ByteString
sig) : [(ByteString, ByteString)]
_extensions -> ByteString
sig ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
phase2ServerSignature
    [(ByteString, ByteString)]
_ -> Bool
False

-- | Big endian encoding of a 32-bit number 1.
int1 :: ByteString
int1 :: ByteString
int1 = [Word8] -> ByteString
B.pack [Word8
0,Word8
0,Word8
0,Word8
1]

xorBS :: ByteString -> ByteString -> ByteString
xorBS :: ByteString -> ByteString -> ByteString
xorBS ByteString
x ByteString
y = [Word8] -> ByteString
B.pack ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
x ByteString
y)

-- | Iterated, password-based, key-derivation function.
hi ::
  Digest     {- ^ underlying cryptographic hash function -} ->
  ByteString {- ^ secret -} ->
  ByteString {- ^ salt -} ->
  Int        {- ^ iterations -} ->
  ByteString {- ^ salted, iterated hash of secret -}
hi :: Digest -> ByteString -> ByteString -> Int -> ByteString
hi Digest
digest ByteString
str ByteString
salt Int
n = (ByteString -> ByteString -> ByteString)
-> [ByteString] -> ByteString
forall a. HasCallStack => (a -> a -> a) -> [a] -> a
foldl1' ByteString -> ByteString -> ByteString
xorBS (Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
n [ByteString]
us)
  where
    u1 :: ByteString
u1 = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
digest ByteString
str (ByteString
salt ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
int1)
    us :: [ByteString]
us = (ByteString -> ByteString) -> ByteString -> [ByteString]
forall a. (a -> a) -> a -> [a]
iterate (Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
digest ByteString
str) ByteString
u1

-- | Break up a SCRAM message into its underlying key-value association list.
parseMessage :: ByteString -> [(ByteString, ByteString)]
parseMessage :: ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
msg =
  [case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
B8.break (Char
'='Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
==) ByteString
entry of
     (ByteString
key, ByteString
value) -> (ByteString
key, Int -> ByteString -> ByteString
B.drop Int
1 ByteString
value)
  | ByteString
entry <- Char -> ByteString -> [ByteString]
B8.split Char
',' ByteString
msg]

-- | Tranform all the SCRAM parameters into a @ClientProof@
-- and @ServerSignature@.
crypto ::
  ScramDigest {- ^ digest       -} ->
  ByteString  {- ^ password     -} ->
  ByteString  {- ^ salt         -} ->
  Int         {- ^ iterations   -} ->
  ByteString  {- ^ auth message -} ->
  (ByteString, ByteString) {- ^ client-proof, server-signature -}
crypto :: ScramDigest
-> ByteString
-> ByteString
-> Int
-> ByteString
-> (ByteString, ByteString)
crypto ScramDigest
digest ByteString
password ByteString
salt Int
iterations ByteString
authMessage =
  (ByteString
clientProof, ByteString
serverSignature)
  where
    saltedPassword :: ByteString
saltedPassword  = Digest -> ByteString -> ByteString -> Int -> ByteString
hi       Digest
d ByteString
password ByteString
salt Int
iterations
    clientKey :: ByteString
clientKey       = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
saltedPassword ByteString
"Client Key"
    storedKey :: ByteString
storedKey       = Digest -> ByteString -> ByteString
digestBS Digest
d ByteString
clientKey
    clientSignature :: ByteString
clientSignature = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
storedKey ByteString
authMessage
    clientProof :: ByteString
clientProof     = ByteString -> ByteString -> ByteString
xorBS ByteString
clientKey ByteString
clientSignature
    serverKey :: ByteString
serverKey       = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
saltedPassword ByteString
"Server Key"
    serverSignature :: ByteString
serverSignature = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
serverKey ByteString
authMessage
    digestName :: String
digestName =
      case ScramDigest
digest of
        ScramDigest
ScramDigestSha1     -> String
"SHA1"
        ScramDigest
ScramDigestSha2_256 -> String
"SHA256"
        ScramDigest
ScramDigestSha2_512 -> String
"SHA512"
    Just Digest
d = IO (Maybe Digest) -> Maybe Digest
forall a. IO a -> a
unsafePerformIO (String -> IO (Maybe Digest)
getDigestByName String
digestName)

-- | Encode usersnames so they fit in the comma/equals delimited
-- SCRAM message format.
encodeUsername :: ByteString -> ByteString
encodeUsername :: ByteString -> ByteString
encodeUsername = (Char -> ByteString) -> ByteString -> ByteString
B8.concatMap \case
    Char
',' -> ByteString
"=2C"
    Char
'=' -> ByteString
"=3D"
    Char
x   -> Char -> ByteString
B8.singleton Char
x