{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_HADDOCK prune #-}

-- | Parses encoded JWTs into data structures which can be handled

module Jose.Internal.Parser
    ( parseJwt
    , DecodableJwt (..)
    , EncryptedCEK (..)
    , Payload (..)
    , IV (..)
    , Tag (..)
    , AAD (..)
    , Sig (..)
    , SigTarget (..)
    )
where

import           Data.Bifunctor (first)
import Data.Aeson (eitherDecodeStrict')
import           Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as P
import qualified Data.Attoparsec.ByteString.Char8 as PC
import           Data.ByteArray.Encoding (convertFromBase, Base(..))
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B

import           Jose.Jwa
import           Jose.Types (JwtError(..), JwtHeader(..), JwsHeader(..), JweHeader(..))


data DecodableJwt
     = Unsecured ByteString
     | DecodableJws JwsHeader Payload Sig SigTarget
     | DecodableJwe JweHeader EncryptedCEK IV Payload Tag AAD


data Tag
    = Tag16 ByteString
    | Tag24 ByteString
    | Tag32 ByteString


data IV
    = IV12 ByteString
    | IV16 ByteString


newtype Sig = Sig ByteString
newtype SigTarget = SigTarget ByteString
newtype AAD = AAD ByteString
newtype Payload = Payload ByteString
newtype EncryptedCEK = EncryptedCEK ByteString


parseJwt :: ByteString -> Either JwtError DecodableJwt
parseJwt :: ByteString -> Either JwtError DecodableJwt
parseJwt ByteString
bs = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a b. a -> b -> a
const JwtError
BadCrypto) forall a b. (a -> b) -> a -> b
$ forall a. Parser a -> ByteString -> Either String a
P.parseOnly Parser DecodableJwt
jwt ByteString
bs


jwt :: Parser DecodableJwt
jwt :: Parser DecodableJwt
jwt = do
    (JwtHeader
hdr, ByteString
raw) <- Parser (JwtHeader, ByteString)
jwtHeader
    case JwtHeader
hdr of
        JwtHeader
UnsecuredH -> ByteString -> DecodableJwt
Unsecured forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString
base64Chunk
        JwsH JwsHeader
h -> do
            ByteString
payloadB64 <- (Char -> Bool) -> Parser ByteString
PC.takeWhile (Char
'.' forall a. Eq a => a -> a -> Bool
/=) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
PC.char Char
'.'
            ByteString
payload <- ByteString -> Parser ByteString
b64Decode ByteString
payloadB64
            Sig
s <- JwsAlg -> Parser Sig
sig (JwsHeader -> JwsAlg
jwsAlg JwsHeader
h)
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ JwsHeader -> Payload -> Sig -> SigTarget -> DecodableJwt
DecodableJws JwsHeader
h (ByteString -> Payload
Payload ByteString
payload) Sig
s (ByteString -> SigTarget
SigTarget ([ByteString] -> ByteString
B.concat [ByteString
raw, ByteString
".", ByteString
payloadB64]))
        JweH JweHeader
h ->
            JweHeader
-> EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt
DecodableJwe
                forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a. Applicative f => a -> f a
pure JweHeader
h
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser EncryptedCEK
encryptedCEK
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Enc -> Parser IV
iv (JweHeader -> Enc
jweEnc JweHeader
h)
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Payload
encryptedPayload
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Enc -> Parser Tag
authTag (JweHeader -> Enc
jweEnc JweHeader
h)
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> AAD
AAD ByteString
raw)


sig :: JwsAlg -> Parser Sig
sig :: JwsAlg -> Parser Sig
sig JwsAlg
_ = do
    ByteString
t <- Parser ByteString
P.takeByteString forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Parser ByteString
b64Decode
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Sig
Sig ByteString
t)


authTag :: Enc -> Parser Tag
authTag :: Enc -> Parser Tag
authTag Enc
e = do
    ByteString
t <- Parser ByteString
P.takeByteString forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Parser ByteString
b64Decode
    case Enc
e of
        Enc
A128GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
        Enc
A192GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
        Enc
A256GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
        Enc
A128CBC_HS256 -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
        Enc
A192CBC_HS384 -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag24 ByteString
t
        Enc
A256CBC_HS512 -> forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag32 ByteString
t
  where
    badTag :: String
badTag = String
"invalid auth tag"
    tag16 :: ByteString -> m Tag
tag16 ByteString
t = if ByteString -> Int
B.length ByteString
t forall a. Eq a => a -> a -> Bool
/= Int
16 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag16 ByteString
t)
    tag24 :: ByteString -> m Tag
tag24 ByteString
t = if ByteString -> Int
B.length ByteString
t forall a. Eq a => a -> a -> Bool
/= Int
24 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag24 ByteString
t)
    tag32 :: ByteString -> m Tag
tag32 ByteString
t = if ByteString -> Int
B.length ByteString
t forall a. Eq a => a -> a -> Bool
/= Int
32 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag32 ByteString
t)


iv :: Enc -> Parser IV
iv :: Enc -> Parser IV
iv Enc
e = do
    ByteString
bs <- Parser ByteString
base64Chunk
    case Enc
e of
        Enc
A128GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
        Enc
A192GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
        Enc
A256GCM -> forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
        Enc
_ -> forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv16 ByteString
bs
  where
    iv12 :: ByteString -> m IV
iv12 ByteString
bs = if ByteString -> Int
B.length ByteString
bs forall a. Eq a => a -> a -> Bool
/= Int
12 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid iv" else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IV
IV12 ByteString
bs)
    iv16 :: ByteString -> m IV
iv16 ByteString
bs = if ByteString -> Int
B.length ByteString
bs forall a. Eq a => a -> a -> Bool
/= Int
16 then forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid iv" else forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IV
IV16 ByteString
bs)


encryptedCEK :: Parser EncryptedCEK
encryptedCEK :: Parser EncryptedCEK
encryptedCEK = ByteString -> EncryptedCEK
EncryptedCEK forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString
base64Chunk


encryptedPayload :: Parser Payload
encryptedPayload :: Parser Payload
encryptedPayload = ByteString -> Payload
Payload forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString
base64Chunk


jwtHeader :: P.Parser (JwtHeader, ByteString)
jwtHeader :: Parser (JwtHeader, ByteString)
jwtHeader = do
    ByteString
hdrB64 <- (Char -> Bool) -> Parser ByteString
PC.takeWhile (Char
'.' forall a. Eq a => a -> a -> Bool
/=) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
PC.char Char
'.'
    ByteString
hdrBytes <- ByteString -> Parser ByteString
b64Decode ByteString
hdrB64 :: P.Parser ByteString
    JwtHeader
hdr <- forall {m :: * -> *} {a}.
(MonadFail m, FromJSON a) =>
ByteString -> m a
parseHdr ByteString
hdrBytes
    forall (m :: * -> *) a. Monad m => a -> m a
return (JwtHeader
hdr, ByteString
hdrB64)
  where
    parseHdr :: ByteString -> m a
parseHdr ByteString
bs = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. FromJSON a => ByteString -> Either String a
eitherDecodeStrict' ByteString
bs)


base64Chunk :: P.Parser ByteString
base64Chunk :: Parser ByteString
base64Chunk = do
    ByteString
bs <- (Char -> Bool) -> Parser ByteString
PC.takeWhile (Char
'.' forall a. Eq a => a -> a -> Bool
/=) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
PC.char Char
'.'
    ByteString -> Parser ByteString
b64Decode ByteString
bs


b64Decode :: ByteString -> P.Parser ByteString
b64Decode :: ByteString -> Parser ByteString
b64Decode ByteString
bs = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const (forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid Base64")) forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64URLUnpadded ByteString
bs