--   This Source Code Form is subject to the terms of the Mozilla Public
--   License, v. 2.0. If a copy of the MPL was not distributed with this
--   file, You can obtain one at http://mozilla.org/MPL/2.0/.

{-# OPTIONS_HADDOCK show-extensions #-}

{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE OverloadedStrings #-}

-- | JWT header representation
module Libjwt.Header
  ( Alg(..)
  , Typ(..)
  , Header(..)
  )
where

import           Libjwt.Encoding
import           Libjwt.Keys
import           Libjwt.FFI.Libjwt
import           Libjwt.FFI.Jwt

import           Control.Monad                  ( when )

import           Data.ByteString                ( ByteString )
import qualified Data.ByteString               as ByteString

-- | @"alg"@ header parameter
data Alg = None
         -- | HMAC SHA-256 (secret key must be __at least 256 bits in size__)
         | HS256 Secret
         -- | HMAC SHA-384 (secret key must be __at least 384 bits in size__)
         | HS384 Secret
         -- | HMAC SHA-512 (secret key must be __at least 512 bits in size__)
         | HS512 Secret
         -- | RSASSA-PKCS1-v1_5 SHA-256 (a key of size __2048 bits or larger__ must be used with this algorithm)
         | RS256 RsaKeyPair
         -- | RSASSA-PKCS1-v1_5 SHA-384 (a key of size __2048 bits or larger__ must be used with this algorithm)
         | RS384 RsaKeyPair
         -- | RSASSA-PKCS1-v1_5 SHA-512 (a key of size __2048 bits or larger__ must be used with this algorithm)
         | RS512 RsaKeyPair
         -- | ECDSA with P-256 curve and SHA-256
         | ES256 EcKeyPair
         -- | ECDSA with P-384 curve and SHA-384
         | ES384 EcKeyPair
         -- | ECDSA with P-521 curve and SHA-512
         | ES512 EcKeyPair
  deriving stock (Int -> Alg -> ShowS
[Alg] -> ShowS
Alg -> String
(Int -> Alg -> ShowS)
-> (Alg -> String) -> ([Alg] -> ShowS) -> Show Alg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Alg] -> ShowS
$cshowList :: [Alg] -> ShowS
show :: Alg -> String
$cshow :: Alg -> String
showsPrec :: Int -> Alg -> ShowS
$cshowsPrec :: Int -> Alg -> ShowS
Show, Alg -> Alg -> Bool
(Alg -> Alg -> Bool) -> (Alg -> Alg -> Bool) -> Eq Alg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Alg -> Alg -> Bool
$c/= :: Alg -> Alg -> Bool
== :: Alg -> Alg -> Bool
$c== :: Alg -> Alg -> Bool
Eq)

-- | @"typ"@ header parameter
data Typ = JWT | Typ (Maybe ByteString)
  deriving stock (Int -> Typ -> ShowS
[Typ] -> ShowS
Typ -> String
(Int -> Typ -> ShowS)
-> (Typ -> String) -> ([Typ] -> ShowS) -> Show Typ
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Typ] -> ShowS
$cshowList :: [Typ] -> ShowS
show :: Typ -> String
$cshow :: Typ -> String
showsPrec :: Int -> Typ -> ShowS
$cshowsPrec :: Int -> Typ -> ShowS
Show, Typ -> Typ -> Bool
(Typ -> Typ -> Bool) -> (Typ -> Typ -> Bool) -> Eq Typ
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Typ -> Typ -> Bool
$c/= :: Typ -> Typ -> Bool
== :: Typ -> Typ -> Bool
$c== :: Typ -> Typ -> Bool
Eq)

-- | JWT header representation
data Header = Header { Header -> Alg
alg :: Alg, Header -> Typ
typ :: Typ }
  deriving stock (Int -> Header -> ShowS
[Header] -> ShowS
Header -> String
(Int -> Header -> ShowS)
-> (Header -> String) -> ([Header] -> ShowS) -> Show Header
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Header] -> ShowS
$cshowList :: [Header] -> ShowS
show :: Header -> String
$cshow :: Header -> String
showsPrec :: Int -> Header -> ShowS
$cshowsPrec :: Int -> Header -> ShowS
Show, Header -> Header -> Bool
(Header -> Header -> Bool)
-> (Header -> Header -> Bool) -> Eq Header
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Header -> Header -> Bool
$c/= :: Header -> Header -> Bool
== :: Header -> Header -> Bool
$c== :: Header -> Header -> Bool
Eq)

instance Encode Header where
  encode :: Header -> JwtT -> EncodeResult
encode Header
header JwtT
jwt = Alg -> JwtT -> EncodeResult
encodeAlg (Header -> Alg
alg Header
header) JwtT
jwt EncodeResult -> EncodeResult -> EncodeResult
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Typ -> JwtT -> EncodeResult
encodeTyp (Header -> Typ
typ Header
header) JwtT
jwt
   where
    encodeAlg :: Alg -> JwtT -> EncodeResult
encodeAlg Alg
None           = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgNone ByteString
ByteString.empty (JwtT -> EncodeResult)
-> (JwtT -> EncodeResult) -> JwtT -> EncodeResult
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> JwtT -> EncodeResult
forceTyp
    encodeAlg (HS256 Secret
secret) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgHs256 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
reveal Secret
secret
    encodeAlg (HS384 Secret
secret) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgHs384 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
reveal Secret
secret
    encodeAlg (HS512 Secret
secret) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgHs512 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
reveal Secret
secret
    encodeAlg (RS256 RsaKeyPair
pem   ) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgRs256 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ RsaKeyPair -> ByteString
privKey RsaKeyPair
pem
    encodeAlg (RS384 RsaKeyPair
pem   ) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgRs384 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ RsaKeyPair -> ByteString
privKey RsaKeyPair
pem
    encodeAlg (RS512 RsaKeyPair
pem   ) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgRs512 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ RsaKeyPair -> ByteString
privKey RsaKeyPair
pem
    encodeAlg (ES256 EcKeyPair
pem   ) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgEs256 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ EcKeyPair -> ByteString
ecPrivKey EcKeyPair
pem
    encodeAlg (ES384 EcKeyPair
pem   ) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgEs384 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ EcKeyPair -> ByteString
ecPrivKey EcKeyPair
pem
    encodeAlg (ES512 EcKeyPair
pem   ) = JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlgEs512 (ByteString -> JwtT -> EncodeResult)
-> ByteString -> JwtT -> EncodeResult
forall a b. (a -> b) -> a -> b
$ EcKeyPair -> ByteString
ecPrivKey EcKeyPair
pem

    encodeTyp :: Typ -> JwtT -> EncodeResult
encodeTyp (Typ (Just ByteString
s)) = String -> ByteString -> JwtT -> EncodeResult
addHeader String
"typ" ByteString
s
    encodeTyp Typ
_              = JwtT -> EncodeResult
forall b. b -> EncodeResult
nullEncode

    forceTyp :: JwtT -> EncodeResult
forceTyp = Bool -> EncodeResult -> EncodeResult
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Header -> Typ
typ Header
header Typ -> Typ -> Bool
forall a. Eq a => a -> a -> Bool
== Typ
JWT) (EncodeResult -> EncodeResult)
-> (JwtT -> EncodeResult) -> JwtT -> EncodeResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString -> JwtT -> EncodeResult
addHeader String
"typ" ByteString
"JWT"