{-# Language OverloadedStrings, ImportQualifiedPost #-}
module Client.Authentication.Ecdh
  (
    -- * Phase type
    Phase1,
    -- * Mechanism details
    mechanismName,
    -- * Transition functions
    clientFirst,
    clientResponse,
  ) where

import Control.Monad (guard)
import Crypto.Curve25519.Pure qualified as Curve
import Data.Bits (xor)
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Base64 qualified as B64
import Data.Text (Text)
import Data.Text.Encoding qualified as Text
import Irc.Commands (AuthenticatePayload (AuthenticatePayload))
import OpenSSL.EVP.Digest ( digestBS, getDigestByName, hmacBS, Digest )
import System.IO.Unsafe ( unsafePerformIO )

newtype Phase1 = Phase1 Curve.PrivateKey

mechanismName :: Text
mechanismName :: Text
mechanismName = Text
"ECDH-X25519-CHALLENGE"

clientFirst :: Maybe Text -> Text -> Text -> Maybe (AuthenticatePayload, Phase1)
clientFirst :: Maybe Text -> Text -> Text -> Maybe (AuthenticatePayload, Phase1)
clientFirst Maybe Text
mbAuthz Text
authc Text
privateKeyText =
  case ByteString -> Maybe PrivateKey
Curve.importPrivate (ByteString -> Maybe PrivateKey)
-> Either String ByteString -> Either String (Maybe PrivateKey)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Either String ByteString
B64.decode (Text -> ByteString
Text.encodeUtf8 Text
privateKeyText) of
    Right (Just PrivateKey
private) -> (AuthenticatePayload, Phase1)
-> Maybe (AuthenticatePayload, Phase1)
forall a. a -> Maybe a
Just (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
payload, PrivateKey -> Phase1
Phase1 PrivateKey
private)
    Either String (Maybe PrivateKey)
_ -> Maybe (AuthenticatePayload, Phase1)
forall a. Maybe a
Nothing
  where
    payload :: ByteString
payload =
      case Maybe Text
mbAuthz of
        Maybe Text
Nothing    -> Text -> ByteString
Text.encodeUtf8 Text
authc
        Just Text
authz -> Text -> ByteString
Text.encodeUtf8 Text
authc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\0" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
Text.encodeUtf8 Text
authz

clientResponse ::
  Phase1 ->
  ByteString                {- ^ server response  -} ->
  Maybe AuthenticatePayload {- ^ client response  -}
clientResponse :: Phase1 -> ByteString -> Maybe AuthenticatePayload
clientResponse (Phase1 PrivateKey
privateKey) ByteString
serverMessage = 
  do let (ByteString
serverPubBS, (ByteString
sessionSalt, ByteString
maskedChallenge)) =
           Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
32 (ByteString -> (ByteString, ByteString))
-> (ByteString, ByteString)
-> (ByteString, (ByteString, ByteString))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
32 ByteString
serverMessage
     Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> Int
B.length ByteString
maskedChallenge Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32)
     PublicKey
serverPublic <- ByteString -> Maybe PublicKey
Curve.importPublic ByteString
serverPubBS

     let sharedSecret :: ByteString
sharedSecret = PrivateKey -> PublicKey -> ByteString
Curve.makeShared PrivateKey
privateKey PublicKey
serverPublic
     let clientPublic :: PublicKey
clientPublic = PrivateKey -> PublicKey
Curve.generatePublic PrivateKey
privateKey
     let ikm :: ByteString
ikm = Digest -> ByteString -> ByteString
digestBS Digest
sha256
             (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
sharedSecret ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> PublicKey -> ByteString
Curve.exportPublic PublicKey
clientPublic ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
serverPubBS
     let prk :: ByteString
prk = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
sha256 ByteString
sessionSalt ByteString
ikm
     let betterSecret :: ByteString
betterSecret = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
sha256 ByteString
prk ByteString
"ECDH-X25519-CHALLENGE\1"
     let sessionChallenge :: ByteString
sessionChallenge = [Word8] -> ByteString
B.pack ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
maskedChallenge ByteString
betterSecret)
     AuthenticatePayload -> Maybe AuthenticatePayload
forall a. a -> Maybe a
Just (AuthenticatePayload -> Maybe AuthenticatePayload)
-> AuthenticatePayload -> Maybe AuthenticatePayload
forall a b. (a -> b) -> a -> b
$! ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
sessionChallenge

sha256 :: Digest
Just Digest
sha256 = IO (Maybe Digest) -> Maybe Digest
forall a. IO a -> a
unsafePerformIO (String -> IO (Maybe Digest)
getDigestByName String
"SHA256")