module Network.OAuth2.Experiment.Pkce (
  mkPkceParam,
  CodeChallenge (..),
  CodeVerifier (..),
  CodeChallengeMethod (..),
  PkceRequestParam (..),
) where

import Control.Monad.IO.Class
import Crypto.Hash qualified as H
import Crypto.Random qualified as Crypto
import Data.ByteArray qualified as ByteArray
import Data.ByteString qualified as BS
import Data.ByteString.Base64.URL qualified as B64
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Data.Word

newtype CodeChallenge = CodeChallenge {CodeChallenge -> Text
unCodeChallenge :: Text}

newtype CodeVerifier = CodeVerifier {CodeVerifier -> Text
unCodeVerifier :: Text}

data CodeChallengeMethod = S256
  deriving (Int -> CodeChallengeMethod -> ShowS
[CodeChallengeMethod] -> ShowS
CodeChallengeMethod -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CodeChallengeMethod] -> ShowS
$cshowList :: [CodeChallengeMethod] -> ShowS
show :: CodeChallengeMethod -> String
$cshow :: CodeChallengeMethod -> String
showsPrec :: Int -> CodeChallengeMethod -> ShowS
$cshowsPrec :: Int -> CodeChallengeMethod -> ShowS
Show)

data PkceRequestParam = PkceRequestParam
  { PkceRequestParam -> CodeVerifier
codeVerifier :: CodeVerifier
  , PkceRequestParam -> CodeChallenge
codeChallenge :: CodeChallenge
  , PkceRequestParam -> CodeChallengeMethod
codeChallengeMethod :: CodeChallengeMethod
  -- ^ spec says optional but in practice it is S256
  -- https://datatracker.ietf.org/doc/html/rfc7636#section-4.3
  }

mkPkceParam :: MonadIO m => m PkceRequestParam
mkPkceParam :: forall (m :: * -> *). MonadIO m => m PkceRequestParam
mkPkceParam = do
  ByteString
codeV <- forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    PkceRequestParam
      { codeVerifier :: CodeVerifier
codeVerifier = Text -> CodeVerifier
CodeVerifier (ByteString -> Text
T.decodeUtf8 ByteString
codeV)
      , codeChallenge :: CodeChallenge
codeChallenge = Text -> CodeChallenge
CodeChallenge (ByteString -> Text
encodeCodeVerifier ByteString
codeV)
      , codeChallengeMethod :: CodeChallengeMethod
codeChallengeMethod = CodeChallengeMethod
S256
      }

encodeCodeVerifier :: BS.ByteString -> Text
encodeCodeVerifier :: ByteString -> Text
encodeCodeVerifier = ByteString -> Text
B64.encodeBase64Unpadded forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
BS.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ByteArrayAccess a => a -> [Word8]
ByteArray.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Digest SHA256
hashSHA256

genCodeVerifier :: MonadIO m => m BS.ByteString
genCodeVerifier :: forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ ByteString -> IO ByteString
getBytesInternal ByteString
BS.empty

cvMaxLen :: Int
cvMaxLen :: Int
cvMaxLen = Int
128

-- The default 'getRandomBytes' generates bytes out of unreverved characters scope.
-- code-verifier = 43*128unreserved
--   unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
--   ALPHA = %x41-5A / %x61-7A
--   DIGIT = %x30-39
getBytesInternal :: BS.ByteString -> IO BS.ByteString
getBytesInternal :: ByteString -> IO ByteString
getBytesInternal ByteString
ba
  | ByteString -> Int
BS.length ByteString
ba forall a. Ord a => a -> a -> Bool
>= Int
cvMaxLen = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ByteString -> ByteString
BS.take Int
cvMaxLen ByteString
ba)
  | Bool
otherwise = do
      ByteString
bs <- forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
Crypto.getRandomBytes Int
cvMaxLen
      let bsUnreserved :: ByteString
bsUnreserved = ByteString
ba ByteString -> ByteString -> ByteString
`BS.append` (Word8 -> Bool) -> ByteString -> ByteString
BS.filter Word8 -> Bool
isUnreversed ByteString
bs
      ByteString -> IO ByteString
getBytesInternal ByteString
bsUnreserved

hashSHA256 :: BS.ByteString -> H.Digest H.SHA256
hashSHA256 :: ByteString -> Digest SHA256
hashSHA256 = forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
H.hash

isUnreversed :: Word8 -> Bool
isUnreversed :: Word8 -> Bool
isUnreversed Word8
w = Word8
w Word8 -> ByteString -> Bool
`BS.elem` ByteString
unreverseBS

{-
a-z: 97-122
A-Z: 65-90
-: 45
.: 46
_: 95
~: 126
-}
unreverseBS :: BS.ByteString
unreverseBS :: ByteString
unreverseBS = [Word8] -> ByteString
BS.pack forall a b. (a -> b) -> a -> b
$ [Word8
97 .. Word8
122] forall a. [a] -> [a] -> [a]
++ [Word8
65 .. Word8
90] forall a. [a] -> [a] -> [a]
++ [Word8
45, Word8
46, Word8
95, Word8
126]