{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_HADDOCK prune #-}

-- | Parses encoded JWTs into data structures which can be handled

module Jose.Internal.Parser
    ( parseJwt
    , DecodableJwt (..)
    , EncryptedCEK (..)
    , Payload (..)
    , IV (..)
    , Tag (..)
    , AAD (..)
    , Sig (..)
    , SigTarget (..)
    )
where

import           Control.Applicative
import           Data.Bifunctor (first)
import Data.Aeson (eitherDecodeStrict')
import           Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as P
import qualified Data.Attoparsec.ByteString.Char8 as PC
import           Data.ByteArray.Encoding (convertFromBase, Base(..))
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B

import           Jose.Jwa
import           Jose.Types (JwtError(..), JwtHeader(..), JwsHeader(..), JweHeader(..))


data DecodableJwt
     = Unsecured ByteString
     | DecodableJws JwsHeader Payload Sig SigTarget
     | DecodableJwe JweHeader EncryptedCEK IV Payload Tag AAD


data Tag
    = Tag16 ByteString
    | Tag24 ByteString
    | Tag32 ByteString


data IV
    = IV12 ByteString
    | IV16 ByteString


newtype Sig = Sig ByteString
newtype SigTarget = SigTarget ByteString
newtype AAD = AAD ByteString
newtype Payload = Payload ByteString
newtype EncryptedCEK = EncryptedCEK ByteString


parseJwt :: ByteString -> Either JwtError DecodableJwt
parseJwt :: ByteString -> Either JwtError DecodableJwt
parseJwt ByteString
bs = (String -> JwtError)
-> Either String DecodableJwt -> Either JwtError DecodableJwt
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (JwtError -> String -> JwtError
forall a b. a -> b -> a
const JwtError
BadCrypto) (Either String DecodableJwt -> Either JwtError DecodableJwt)
-> Either String DecodableJwt -> Either JwtError DecodableJwt
forall a b. (a -> b) -> a -> b
$ Parser DecodableJwt -> ByteString -> Either String DecodableJwt
forall a. Parser a -> ByteString -> Either String a
P.parseOnly Parser DecodableJwt
jwt ByteString
bs


jwt :: Parser DecodableJwt
jwt :: Parser DecodableJwt
jwt = do
    (JwtHeader
hdr, ByteString
raw) <- Parser (JwtHeader, ByteString)
jwtHeader
    case JwtHeader
hdr of
        JwtHeader
UnsecuredH -> ByteString -> DecodableJwt
Unsecured (ByteString -> DecodableJwt)
-> Parser ByteString ByteString -> Parser DecodableJwt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString ByteString
base64Chunk
        JwsH JwsHeader
h -> do
            ByteString
payloadB64 <- (Char -> Bool) -> Parser ByteString ByteString
PC.takeWhile (Char
'.' Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/=) Parser ByteString ByteString
-> Parser ByteString Char -> Parser ByteString ByteString
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser ByteString Char
PC.char Char
'.'
            ByteString
payload <- ByteString -> Parser ByteString ByteString
b64Decode ByteString
payloadB64
            Sig
s <- JwsAlg -> Parser Sig
sig (JwsHeader -> JwsAlg
jwsAlg JwsHeader
h)
            DecodableJwt -> Parser DecodableJwt
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DecodableJwt -> Parser DecodableJwt)
-> DecodableJwt -> Parser DecodableJwt
forall a b. (a -> b) -> a -> b
$ JwsHeader -> Payload -> Sig -> SigTarget -> DecodableJwt
DecodableJws JwsHeader
h (ByteString -> Payload
Payload ByteString
payload) Sig
s (ByteString -> SigTarget
SigTarget ([ByteString] -> ByteString
B.concat [ByteString
raw, ByteString
".", ByteString
payloadB64]))
        JweH JweHeader
h ->
            JweHeader
-> EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt
DecodableJwe
                (JweHeader
 -> EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt)
-> Parser ByteString JweHeader
-> Parser
     ByteString
     (EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JweHeader -> Parser ByteString JweHeader
forall (f :: * -> *) a. Applicative f => a -> f a
pure JweHeader
h
                Parser
  ByteString
  (EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt)
-> Parser ByteString EncryptedCEK
-> Parser ByteString (IV -> Payload -> Tag -> AAD -> DecodableJwt)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString EncryptedCEK
encryptedCEK
                Parser ByteString (IV -> Payload -> Tag -> AAD -> DecodableJwt)
-> Parser ByteString IV
-> Parser ByteString (Payload -> Tag -> AAD -> DecodableJwt)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Enc -> Parser ByteString IV
iv (JweHeader -> Enc
jweEnc JweHeader
h)
                Parser ByteString (Payload -> Tag -> AAD -> DecodableJwt)
-> Parser ByteString Payload
-> Parser ByteString (Tag -> AAD -> DecodableJwt)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString Payload
encryptedPayload
                Parser ByteString (Tag -> AAD -> DecodableJwt)
-> Parser ByteString Tag -> Parser ByteString (AAD -> DecodableJwt)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Enc -> Parser ByteString Tag
authTag (JweHeader -> Enc
jweEnc JweHeader
h)
                Parser ByteString (AAD -> DecodableJwt)
-> Parser ByteString AAD -> Parser DecodableJwt
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> AAD -> Parser ByteString AAD
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> AAD
AAD ByteString
raw)


sig :: JwsAlg -> Parser Sig
sig :: JwsAlg -> Parser Sig
sig JwsAlg
_ = do
    ByteString
t <- Parser ByteString ByteString
P.takeByteString Parser ByteString ByteString
-> (ByteString -> Parser ByteString ByteString)
-> Parser ByteString ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Parser ByteString ByteString
b64Decode
    Sig -> Parser Sig
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Sig
Sig ByteString
t)


authTag :: Enc -> Parser Tag
authTag :: Enc -> Parser ByteString Tag
authTag Enc
e = do
    ByteString
t <- Parser ByteString ByteString
P.takeByteString Parser ByteString ByteString
-> (ByteString -> Parser ByteString ByteString)
-> Parser ByteString ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Parser ByteString ByteString
b64Decode
    case Enc
e of
        Enc
A128GCM -> ByteString -> Parser ByteString Tag
forall (m :: * -> *). MonadFail m => ByteString -> m Tag
tag16 ByteString
t
        Enc
A192GCM -> ByteString -> Parser ByteString Tag
forall (m :: * -> *). MonadFail m => ByteString -> m Tag
tag16 ByteString
t
        Enc
A256GCM -> ByteString -> Parser ByteString Tag
forall (m :: * -> *). MonadFail m => ByteString -> m Tag
tag16 ByteString
t
        Enc
A128CBC_HS256 -> ByteString -> Parser ByteString Tag
forall (m :: * -> *). MonadFail m => ByteString -> m Tag
tag16 ByteString
t
        Enc
A192CBC_HS384 -> ByteString -> Parser ByteString Tag
forall (m :: * -> *). MonadFail m => ByteString -> m Tag
tag24 ByteString
t
        Enc
A256CBC_HS512 -> ByteString -> Parser ByteString Tag
forall (m :: * -> *). MonadFail m => ByteString -> m Tag
tag32 ByteString
t
  where
    badTag :: String
badTag = String
"invalid auth tag"
    tag16 :: ByteString -> m Tag
tag16 ByteString
t = if ByteString -> Int
B.length ByteString
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16 then String -> m Tag
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else Tag -> m Tag
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag16 ByteString
t)
    tag24 :: ByteString -> m Tag
tag24 ByteString
t = if ByteString -> Int
B.length ByteString
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
24 then String -> m Tag
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else Tag -> m Tag
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag24 ByteString
t)
    tag32 :: ByteString -> m Tag
tag32 ByteString
t = if ByteString -> Int
B.length ByteString
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 then String -> m Tag
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else Tag -> m Tag
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag32 ByteString
t)


iv :: Enc -> Parser IV
iv :: Enc -> Parser ByteString IV
iv Enc
e = do
    ByteString
bs <- Parser ByteString ByteString
base64Chunk
    case Enc
e of
        Enc
A128GCM -> ByteString -> Parser ByteString IV
forall (m :: * -> *). MonadFail m => ByteString -> m IV
iv12 ByteString
bs
        Enc
A192GCM -> ByteString -> Parser ByteString IV
forall (m :: * -> *). MonadFail m => ByteString -> m IV
iv12 ByteString
bs
        Enc
A256GCM -> ByteString -> Parser ByteString IV
forall (m :: * -> *). MonadFail m => ByteString -> m IV
iv12 ByteString
bs
        Enc
_ -> ByteString -> Parser ByteString IV
forall (m :: * -> *). MonadFail m => ByteString -> m IV
iv16 ByteString
bs
  where
    iv12 :: ByteString -> m IV
iv12 ByteString
bs = if ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 then String -> m IV
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid iv" else IV -> m IV
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IV
IV12 ByteString
bs)
    iv16 :: ByteString -> m IV
iv16 ByteString
bs = if ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16 then String -> m IV
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid iv" else IV -> m IV
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IV
IV16 ByteString
bs)


encryptedCEK :: Parser EncryptedCEK
encryptedCEK :: Parser ByteString EncryptedCEK
encryptedCEK = ByteString -> EncryptedCEK
EncryptedCEK (ByteString -> EncryptedCEK)
-> Parser ByteString ByteString -> Parser ByteString EncryptedCEK
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString ByteString
base64Chunk


encryptedPayload :: Parser Payload
encryptedPayload :: Parser ByteString Payload
encryptedPayload = ByteString -> Payload
Payload (ByteString -> Payload)
-> Parser ByteString ByteString -> Parser ByteString Payload
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString ByteString
base64Chunk


jwtHeader :: P.Parser (JwtHeader, ByteString)
jwtHeader :: Parser (JwtHeader, ByteString)
jwtHeader = do
    ByteString
hdrB64 <- (Char -> Bool) -> Parser ByteString ByteString
PC.takeWhile (Char
'.' Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/=) Parser ByteString ByteString
-> Parser ByteString Char -> Parser ByteString ByteString
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser ByteString Char
PC.char Char
'.'
    ByteString
hdrBytes <- ByteString -> Parser ByteString ByteString
b64Decode ByteString
hdrB64 :: P.Parser ByteString
    JwtHeader
hdr <- ByteString -> Parser ByteString JwtHeader
forall (m :: * -> *) a.
(MonadFail m, FromJSON a) =>
ByteString -> m a
parseHdr ByteString
hdrBytes
    (JwtHeader, ByteString) -> Parser (JwtHeader, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (JwtHeader
hdr, ByteString
hdrB64)
  where
    parseHdr :: ByteString -> m a
parseHdr ByteString
bs = (String -> m a) -> (a -> m a) -> Either String a -> m a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either String a
forall a. FromJSON a => ByteString -> Either String a
eitherDecodeStrict' ByteString
bs)


base64Chunk :: P.Parser ByteString
base64Chunk :: Parser ByteString ByteString
base64Chunk = do
    ByteString
bs <- (Char -> Bool) -> Parser ByteString ByteString
PC.takeWhile (Char
'.' Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/=) Parser ByteString ByteString
-> Parser ByteString Char -> Parser ByteString ByteString
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser ByteString Char
PC.char Char
'.'
    ByteString -> Parser ByteString ByteString
b64Decode ByteString
bs


b64Decode :: ByteString -> P.Parser ByteString
b64Decode :: ByteString -> Parser ByteString ByteString
b64Decode ByteString
bs = (String -> Parser ByteString ByteString)
-> (ByteString -> Parser ByteString ByteString)
-> Either String ByteString
-> Parser ByteString ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Parser ByteString ByteString
-> String -> Parser ByteString ByteString
forall a b. a -> b -> a
const (String -> Parser ByteString ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid Base64")) ByteString -> Parser ByteString ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String ByteString -> Parser ByteString ByteString)
-> Either String ByteString -> Parser ByteString ByteString
forall a b. (a -> b) -> a -> b
$ Base -> ByteString -> Either String ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64URLUnpadded ByteString
bs