module Network.OAuth2.Experiment.Pkce ( mkPkceParam, CodeChallenge (..), CodeVerifier (..), CodeChallengeMethod (..), PkceRequestParam (..), ) where import Control.Monad.IO.Class import Crypto.Hash qualified as H import Crypto.Random qualified as Crypto import Data.ByteArray qualified as ByteArray import Data.ByteString qualified as BS import Data.ByteString.Base64.URL qualified as B64 import Data.Text (Text) import Data.Text.Encoding qualified as T import Data.Word newtype CodeChallenge = CodeChallenge {unCodeChallenge :: Text} newtype CodeVerifier = CodeVerifier {unCodeVerifier :: Text} data CodeChallengeMethod = S256 deriving (Show) data PkceRequestParam = PkceRequestParam { codeVerifier :: CodeVerifier , codeChallenge :: CodeChallenge , codeChallengeMethod :: CodeChallengeMethod -- ^ spec says optional but in practice it is S256 -- https://datatracker.ietf.org/doc/html/rfc7636#section-4.3 } mkPkceParam :: MonadIO m => m PkceRequestParam mkPkceParam = do codeV <- genCodeVerifier pure PkceRequestParam { codeVerifier = CodeVerifier (T.decodeUtf8 codeV) , codeChallenge = CodeChallenge (encodeCodeVerifier codeV) , codeChallengeMethod = S256 } encodeCodeVerifier :: BS.ByteString -> Text encodeCodeVerifier = B64.encodeBase64Unpadded . BS.pack . ByteArray.unpack . hashSHA256 genCodeVerifier :: MonadIO m => m BS.ByteString genCodeVerifier = liftIO $ getBytesInternal BS.empty cvMaxLen :: Int cvMaxLen = 128 -- The default 'getRandomBytes' generates bytes out of unreverved characters scope. -- code-verifier = 43*128unreserved -- unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" -- ALPHA = %x41-5A / %x61-7A -- DIGIT = %x30-39 getBytesInternal :: BS.ByteString -> IO BS.ByteString getBytesInternal ba | BS.length ba >= cvMaxLen = pure (BS.take cvMaxLen ba) | otherwise = do bs <- Crypto.getRandomBytes cvMaxLen let bsUnreserved = ba `BS.append` BS.filter isUnreversed bs getBytesInternal bsUnreserved hashSHA256 :: BS.ByteString -> H.Digest H.SHA256 hashSHA256 = H.hash isUnreversed :: Word8 -> Bool isUnreversed w = w `BS.elem` unreverseBS {- a-z: 97-122 A-Z: 65-90 -: 45 .: 46 _: 95 ~: 126 -} unreverseBS :: BS.ByteString unreverseBS = BS.pack $ [97 .. 122] ++ [65 .. 90] ++ [45, 46, 95, 126]