{-# Language BlockArguments #-}
{-# Language ImportQualifiedPost #-}
{-# Language LambdaCase #-}
{-# Language OverloadedStrings #-}
{-# Language RecordWildCards #-}
{-# Language ViewPatterns #-}
module Client.Authentication.Scram (
  -- * Transaction state types
  Phase1,
  Phase2,
  -- * Transaction step functions
  initiateScram,
  addServerFirst,
  addServerFinal,
  -- * Digests
  ScramDigest(..),
  mechanismName,
  ) where

import Control.Monad (guard)
import Data.Bits (xor)
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Base64 qualified as B64
import Data.ByteString.Char8 qualified as B8
import Data.List (foldl1')
import Data.Text (Text)
import OpenSSL.EVP.Digest ( Digest, digestBS, hmacBS, getDigestByName)
import Irc.Commands (AuthenticatePayload (AuthenticatePayload))
import System.IO.Unsafe (unsafePerformIO)

data ScramDigest
  = ScramDigestSha1
  | ScramDigestSha2_256
  | ScramDigestSha2_512
  deriving Int -> ScramDigest -> ShowS
[ScramDigest] -> ShowS
ScramDigest -> String
(Int -> ScramDigest -> ShowS)
-> (ScramDigest -> String)
-> ([ScramDigest] -> ShowS)
-> Show ScramDigest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScramDigest] -> ShowS
$cshowList :: [ScramDigest] -> ShowS
show :: ScramDigest -> String
$cshow :: ScramDigest -> String
showsPrec :: Int -> ScramDigest -> ShowS
$cshowsPrec :: Int -> ScramDigest -> ShowS
Show

mechanismName :: ScramDigest -> Text
mechanismName :: ScramDigest -> Text
mechanismName ScramDigest
digest =
  case ScramDigest
digest of
    ScramDigest
ScramDigestSha1     -> Text
"SCRAM-SHA-1"
    ScramDigest
ScramDigestSha2_256 -> Text
"SCRAM-SHA-256"
    ScramDigest
ScramDigestSha2_512 -> Text
"SCRAM-SHA-512"

-- | SCRAM state waiting for server-first-message
data Phase1 = Phase1
  { Phase1 -> ScramDigest
phase1Digest          :: ScramDigest -- ^ underlying cryptographic hash function
  , Phase1 -> ByteString
phase1Password        :: ByteString -- ^ password
  , Phase1 -> ByteString
phase1CbindInput      :: ByteString -- ^ cbind-input
  , Phase1 -> ByteString
phase1Nonce           :: ByteString -- ^ c-nonce
  , Phase1 -> ByteString
phase1ClientFirstBare :: ByteString -- ^ client-first-bare
  }

-- | Construct client-first-message and extra parameters
-- needed for 'addServerFirst'.
initiateScram ::
  ScramDigest ->
  ByteString {- ^ authentication ID -} ->
  ByteString {- ^ authorization ID  -} ->
  ByteString {- ^ password          -} ->
  ByteString {- ^ nonce             -} ->
  (AuthenticatePayload, Phase1)
initiateScram :: ScramDigest
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> (AuthenticatePayload, Phase1)
initiateScram ScramDigest
digest ByteString
user ByteString
authzid ByteString
pass ByteString
nonce =
  (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
clientFirstMessage, Phase1 :: ScramDigest
-> ByteString -> ByteString -> ByteString -> ByteString -> Phase1
Phase1
    { phase1Digest :: ScramDigest
phase1Digest = ScramDigest
digest
    , phase1Password :: ByteString
phase1Password = ByteString
pass
    , phase1CbindInput :: ByteString
phase1CbindInput = ByteString -> ByteString
B64.encode ByteString
gs2Header
    , phase1Nonce :: ByteString
phase1Nonce = ByteString
nonce
    , phase1ClientFirstBare :: ByteString
phase1ClientFirstBare = ByteString
clientFirstMessageBare
    })
  where
    clientFirstMessage :: ByteString
clientFirstMessage = ByteString
gs2Header ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
clientFirstMessageBare
    gs2Header :: ByteString
gs2Header = ByteString
"n," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encodeUsername ByteString
authzid ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
","
    clientFirstMessageBare :: ByteString
clientFirstMessageBare = ByteString
"n=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encodeUsername ByteString
user ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
",r=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
nonce

-- | SCRAM state waiting for server-final-message
newtype Phase2 = Phase2
  { Phase2 -> ByteString
phase2ServerSignature :: ByteString -- ^ base64 encoded expected value
  }

-- | Add server-first-message to current SCRAM transaction,
-- compute client-final-message and next state for 'addServerFinal'.
addServerFirst ::
  Phase1     {- ^ output of 'initiateScram' -} ->
  ByteString {- ^ server first message -} ->
  Maybe (AuthenticatePayload, Phase2)
addServerFirst :: Phase1 -> ByteString -> Maybe (AuthenticatePayload, Phase2)
addServerFirst Phase1{ByteString
ScramDigest
phase1ClientFirstBare :: ByteString
phase1Nonce :: ByteString
phase1CbindInput :: ByteString
phase1Password :: ByteString
phase1Digest :: ScramDigest
phase1ClientFirstBare :: Phase1 -> ByteString
phase1Nonce :: Phase1 -> ByteString
phase1CbindInput :: Phase1 -> ByteString
phase1Password :: Phase1 -> ByteString
phase1Digest :: Phase1 -> ScramDigest
..} ByteString
serverFirstMessage =

  do -- Parse server-first-message
     (ByteString
"r", ByteString
nonce) :
       (ByteString
"s", ByteString -> Either String ByteString
B64.decode -> Right ByteString
salt) :
       (ByteString
"i", ByteString -> Maybe (Int, ByteString)
B8.readInt -> Just (Int
iterations, ByteString
"")) :
       [(ByteString, ByteString)]
_extensions
       <- [(ByteString, ByteString)] -> Maybe [(ByteString, ByteString)]
forall a. a -> Maybe a
Just (ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
serverFirstMessage)

     -- validate nonce given by server includes ours and isn't empty
     Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> ByteString -> Bool
B.isPrefixOf ByteString
phase1Nonce ByteString
nonce Bool -> Bool -> Bool
&& ByteString
phase1Nonce ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
nonce)

     let clientFinalWithoutProof :: ByteString
clientFinalWithoutProof = ByteString
"c=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
phase1CbindInput ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
",r=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
nonce

     let authMessage :: ByteString
authMessage =
           ByteString
phase1ClientFirstBare ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>
           ByteString
serverFirstMessage ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>
           ByteString
clientFinalWithoutProof

     let (ByteString
clientProof, ByteString
serverSignature) =
           ScramDigest
-> ByteString
-> ByteString
-> Int
-> ByteString
-> (ByteString, ByteString)
crypto ScramDigest
phase1Digest ByteString
phase1Password ByteString
salt Int
iterations ByteString
authMessage

     let proof :: ByteString
proof = ByteString
"p=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
B64.encode ByteString
clientProof
     let clientFinalMessage :: ByteString
clientFinalMessage = ByteString
clientFinalWithoutProof ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
proof

     let phase2 :: Phase2
phase2 = Phase2 :: ByteString -> Phase2
Phase2 { phase2ServerSignature :: ByteString
phase2ServerSignature = ByteString -> ByteString
B64.encode ByteString
serverSignature }
     (AuthenticatePayload, Phase2)
-> Maybe (AuthenticatePayload, Phase2)
forall a. a -> Maybe a
Just (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
clientFinalMessage, Phase2
phase2)

-- | Add server-final-message to transaction and compute validatity of
-- the whole transaction.
addServerFinal ::
  Phase2     {- ^ output of 'addServerFirst' -} ->
  ByteString {- ^ server-final-message   -} ->
  Bool       {- ^ transaction succeeded? -}
addServerFinal :: Phase2 -> ByteString -> Bool
addServerFinal Phase2{ByteString
phase2ServerSignature :: ByteString
phase2ServerSignature :: Phase2 -> ByteString
..} ByteString
serverFinalMessage =
  case ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
serverFinalMessage of
    (ByteString
"v", ByteString
sig) : [(ByteString, ByteString)]
_extensions -> ByteString
sig ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
phase2ServerSignature
    [(ByteString, ByteString)]
_ -> Bool
False

-- | Big endian encoding of a 32-bit number 1.
int1 :: ByteString
int1 :: ByteString
int1 = [Word8] -> ByteString
B.pack [Word8
0,Word8
0,Word8
0,Word8
1]

xorBS :: ByteString -> ByteString -> ByteString
xorBS :: ByteString -> ByteString -> ByteString
xorBS ByteString
x ByteString
y = [Word8] -> ByteString
B.pack ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
x ByteString
y)

-- | Iterated, password-based, key-derivation function.
hi ::
  Digest     {- ^ underlying cryptographic hash function -} ->
  ByteString {- ^ secret -} ->
  ByteString {- ^ salt -} ->
  Int        {- ^ iterations -} ->
  ByteString {- ^ salted, iterated hash of secret -}
hi :: Digest -> ByteString -> ByteString -> Int -> ByteString
hi Digest
digest ByteString
str ByteString
salt Int
n = (ByteString -> ByteString -> ByteString)
-> [ByteString] -> ByteString
forall a. (a -> a -> a) -> [a] -> a
foldl1' ByteString -> ByteString -> ByteString
xorBS (Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
n [ByteString]
us)
  where
    u1 :: ByteString
u1 = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
digest ByteString
str (ByteString
salt ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
int1)
    us :: [ByteString]
us = (ByteString -> ByteString) -> ByteString -> [ByteString]
forall a. (a -> a) -> a -> [a]
iterate (Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
digest ByteString
str) ByteString
u1

-- | Break up a SCRAM message into its underlying key-value association list.
parseMessage :: ByteString -> [(ByteString, ByteString)]
parseMessage :: ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
msg =
  [case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
B8.break (Char
'='Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
==) ByteString
entry of
     (ByteString
key, ByteString
value) -> (ByteString
key, Int -> ByteString -> ByteString
B.drop Int
1 ByteString
value)
  | ByteString
entry <- Char -> ByteString -> [ByteString]
B8.split Char
',' ByteString
msg]

-- | Tranform all the SCRAM parameters into a @ClientProof@
-- and @ServerSignature@.
crypto ::
  ScramDigest {- ^ digest       -} ->
  ByteString  {- ^ password     -} ->
  ByteString  {- ^ salt         -} ->
  Int         {- ^ iterations   -} ->
  ByteString  {- ^ auth message -} ->
  (ByteString, ByteString) {- ^ client-proof, server-signature -}
crypto :: ScramDigest
-> ByteString
-> ByteString
-> Int
-> ByteString
-> (ByteString, ByteString)
crypto ScramDigest
digest ByteString
password ByteString
salt Int
iterations ByteString
authMessage =
  (ByteString
clientProof, ByteString
serverSignature)
  where
    saltedPassword :: ByteString
saltedPassword  = Digest -> ByteString -> ByteString -> Int -> ByteString
hi       Digest
d ByteString
password ByteString
salt Int
iterations
    clientKey :: ByteString
clientKey       = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
saltedPassword ByteString
"Client Key"
    storedKey :: ByteString
storedKey       = Digest -> ByteString -> ByteString
digestBS Digest
d ByteString
clientKey
    clientSignature :: ByteString
clientSignature = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
storedKey ByteString
authMessage
    clientProof :: ByteString
clientProof     = ByteString -> ByteString -> ByteString
xorBS ByteString
clientKey ByteString
clientSignature
    serverKey :: ByteString
serverKey       = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
saltedPassword ByteString
"Server Key"
    serverSignature :: ByteString
serverSignature = Digest -> ByteString -> ByteString -> ByteString
hmacBS   Digest
d ByteString
serverKey ByteString
authMessage
    digestName :: String
digestName =
      case ScramDigest
digest of
        ScramDigest
ScramDigestSha1     -> String
"SHA1"
        ScramDigest
ScramDigestSha2_256 -> String
"SHA256"
        ScramDigest
ScramDigestSha2_512 -> String
"SHA512"
    Just Digest
d = IO (Maybe Digest) -> Maybe Digest
forall a. IO a -> a
unsafePerformIO (String -> IO (Maybe Digest)
getDigestByName String
digestName)

-- | Encode usersnames so they fit in the comma/equals delimited
-- SCRAM message format.
encodeUsername :: ByteString -> ByteString
encodeUsername :: ByteString -> ByteString
encodeUsername = (Char -> ByteString) -> ByteString -> ByteString
B8.concatMap \case
    Char
',' -> ByteString
"=2C"
    Char
'=' -> ByteString
"=3D"
    Char
x   -> Char -> ByteString
B8.singleton Char
x