{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
module Libjwt.Jwt
( Jwt(..)
, Encoded
, getToken
, sign
, sign'
, Decoded
, getDecoded
, decodeString
, decodeByteString
, Validated
, getValid
, validateJwt
, jwtFromString
, jwtFromByteString
)
where
import Libjwt.Algorithms
import Libjwt.Encoding
import Libjwt.Exceptions ( SomeDecodeException
, AlgorithmMismatch(..)
, DecodeException(..)
)
import Libjwt.Decoding
import Libjwt.FFI.Jwt
import Libjwt.FFI.Libjwt
import Libjwt.Header
import Libjwt.JwtValidation
import Libjwt.Keys
import Libjwt.Payload
import Libjwt.PrivateClaims
import Control.Monad.Catch
import Control.Monad.Extra ( unlessM )
import Control.Monad.Time
import Control.Monad ( (<=<)
, when
)
import Data.ByteString ( ByteString )
import qualified Data.ByteString.Char8 as C8
import GHC.IO.Exception ( IOErrorType(InvalidArgument) )
import System.IO.Error ( ioeGetErrorType )
data Jwt pc ns = Jwt { :: Header, Jwt pc ns -> Payload pc ns
payload :: Payload pc ns }
deriving stock instance Show (PrivateClaims pc ns) => Show (Jwt pc ns)
deriving stock instance Eq (PrivateClaims pc ns) => Eq (Jwt pc ns)
newtype Encoded t = MkEncoded { Encoded t -> ByteString
getToken :: ByteString
}
deriving stock (Int -> Encoded t -> ShowS
[Encoded t] -> ShowS
Encoded t -> String
(Int -> Encoded t -> ShowS)
-> (Encoded t -> String)
-> ([Encoded t] -> ShowS)
-> Show (Encoded t)
forall t. Int -> Encoded t -> ShowS
forall t. [Encoded t] -> ShowS
forall t. Encoded t -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Encoded t] -> ShowS
$cshowList :: forall t. [Encoded t] -> ShowS
show :: Encoded t -> String
$cshow :: forall t. Encoded t -> String
showsPrec :: Int -> Encoded t -> ShowS
$cshowsPrec :: forall t. Int -> Encoded t -> ShowS
Show, Encoded t -> Encoded t -> Bool
(Encoded t -> Encoded t -> Bool)
-> (Encoded t -> Encoded t -> Bool) -> Eq (Encoded t)
forall t. Encoded t -> Encoded t -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Encoded t -> Encoded t -> Bool
$c/= :: forall t. Encoded t -> Encoded t -> Bool
== :: Encoded t -> Encoded t -> Bool
$c== :: forall t. Encoded t -> Encoded t -> Bool
Eq)
sign
:: (Encode (PrivateClaims pc ns), SigningKey k)
=> Algorithm k
-> Payload pc ns
-> Encoded (Jwt pc ns)
sign :: Algorithm k -> Payload pc ns -> Encoded (Jwt pc ns)
sign = Typ -> Algorithm k -> Payload pc ns -> Encoded (Jwt pc ns)
forall (pc :: [Claim *]) (ns :: Namespace) k.
(Encode (PrivateClaims pc ns), SigningKey k) =>
Typ -> Algorithm k -> Payload pc ns -> Encoded (Jwt pc ns)
sign' Typ
JWT
sign'
:: (Encode (PrivateClaims pc ns), SigningKey k)
=> Typ
-> Algorithm k
-> Payload pc ns
-> Encoded (Jwt pc ns)
sign' :: Typ -> Algorithm k -> Payload pc ns -> Encoded (Jwt pc ns)
sign' Typ
ty Algorithm k
algorithm = JwtAlgT
-> ByteString -> Typ -> Payload pc ns -> Encoded (Jwt pc ns)
forall (pc :: [Claim *]) (ns :: Namespace).
Encode (PrivateClaims pc ns) =>
JwtAlgT
-> ByteString -> Typ -> Payload pc ns -> Encoded (Jwt pc ns)
signJwt JwtAlgT
jwtAlg (k -> ByteString
forall k. SigningKey k => k -> ByteString
getSigningKey k
key) Typ
ty
where (JwtAlgT
jwtAlg, k
key) = Algorithm k -> (JwtAlgT, k)
forall k. Algorithm k -> (JwtAlgT, k)
jwtAlgWithKey Algorithm k
algorithm
signJwt
:: Encode (PrivateClaims pc ns)
=> JwtAlgT
-> ByteString
-> Typ
-> Payload pc ns
-> Encoded (Jwt pc ns)
signJwt :: JwtAlgT
-> ByteString -> Typ -> Payload pc ns -> Encoded (Jwt pc ns)
signJwt JwtAlgT
jwtAlg ByteString
key Typ
ty Payload pc ns
it = ByteString -> Encoded (Jwt pc ns)
forall t. ByteString -> Encoded t
MkEncoded (ByteString -> Encoded (Jwt pc ns))
-> ByteString -> Encoded (Jwt pc ns)
forall a b. (a -> b) -> a -> b
$ JwtIO ByteString -> ByteString
forall a. JwtIO a -> a
unsafePerformJwtIO JwtIO ByteString
signTokenJwtIo
where
signTokenJwtIo :: JwtIO ByteString
signTokenJwtIo = do
JwtT
jwt <- JwtIO JwtT
mkJwtT
Payload pc ns -> JwtT -> EncodeResult
forall c. Encode c => c -> JwtT -> EncodeResult
encode Payload pc ns
it JwtT
jwt
Typ -> JwtT -> EncodeResult
forall c. Encode c => c -> JwtT -> EncodeResult
encode Typ
ty JwtT
jwt
JwtAlgT -> ByteString -> JwtT -> EncodeResult
jwtSetAlg JwtAlgT
jwtAlg ByteString
key JwtT
jwt
Bool -> EncodeResult -> EncodeResult
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (JwtAlgT
jwtAlg JwtAlgT -> JwtAlgT -> Bool
forall a. Eq a => a -> a -> Bool
== JwtAlgT
jwtAlgNone Bool -> Bool -> Bool
&& Typ
ty Typ -> Typ -> Bool
forall a. Eq a => a -> a -> Bool
== Typ
JWT) (EncodeResult -> EncodeResult) -> EncodeResult -> EncodeResult
forall a b. (a -> b) -> a -> b
$ String -> ByteString -> JwtT -> EncodeResult
addHeader String
"typ" ByteString
"JWT" JwtT
jwt
JwtT -> JwtIO ByteString
jwtEncode JwtT
jwt
{-# NOINLINE signJwt #-}
newtype Decoded t = MkDecoded { Decoded t -> t
getDecoded :: t }
deriving stock (Int -> Decoded t -> ShowS
[Decoded t] -> ShowS
Decoded t -> String
(Int -> Decoded t -> ShowS)
-> (Decoded t -> String)
-> ([Decoded t] -> ShowS)
-> Show (Decoded t)
forall t. Show t => Int -> Decoded t -> ShowS
forall t. Show t => [Decoded t] -> ShowS
forall t. Show t => Decoded t -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Decoded t] -> ShowS
$cshowList :: forall t. Show t => [Decoded t] -> ShowS
show :: Decoded t -> String
$cshow :: forall t. Show t => Decoded t -> String
showsPrec :: Int -> Decoded t -> ShowS
$cshowsPrec :: forall t. Show t => Int -> Decoded t -> ShowS
Show, Decoded t -> Decoded t -> Bool
(Decoded t -> Decoded t -> Bool)
-> (Decoded t -> Decoded t -> Bool) -> Eq (Decoded t)
forall t. Eq t => Decoded t -> Decoded t -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Decoded t -> Decoded t -> Bool
$c/= :: forall t. Eq t => Decoded t -> Decoded t -> Bool
== :: Decoded t -> Decoded t -> Bool
$c== :: forall t. Eq t => Decoded t -> Decoded t -> Bool
Eq)
decodeString
:: (MonadThrow m, Decode (PrivateClaims pc ns), DecodingKey k)
=> Algorithm k
-> String
-> m (Decoded (Jwt pc ns))
decodeString :: Algorithm k -> String -> m (Decoded (Jwt pc ns))
decodeString Algorithm k
algorithm = Algorithm k -> ByteString -> m (Decoded (Jwt pc ns))
forall (ns :: Namespace) (pc :: [Claim *]) (m :: * -> *) k.
(MonadThrow m, Decode (PrivateClaims pc ns), DecodingKey k) =>
Algorithm k -> ByteString -> m (Decoded (Jwt pc ns))
decodeByteString Algorithm k
algorithm (ByteString -> m (Decoded (Jwt pc ns)))
-> (String -> ByteString) -> String -> m (Decoded (Jwt pc ns))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
C8.pack
decodeByteString
:: forall ns pc m k
. (MonadThrow m, Decode (PrivateClaims pc ns), DecodingKey k)
=> Algorithm k
-> ByteString
-> m (Decoded (Jwt pc ns))
decodeByteString :: Algorithm k -> ByteString -> m (Decoded (Jwt pc ns))
decodeByteString Algorithm k
algorithm ByteString
token = (SomeDecodeException -> m (Decoded (Jwt pc ns)))
-> (Jwt pc ns -> m (Decoded (Jwt pc ns)))
-> Either SomeDecodeException (Jwt pc ns)
-> m (Decoded (Jwt pc ns))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeDecodeException -> m (Decoded (Jwt pc ns))
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (Decoded (Jwt pc ns) -> m (Decoded (Jwt pc ns))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Decoded (Jwt pc ns) -> m (Decoded (Jwt pc ns)))
-> (Jwt pc ns -> Decoded (Jwt pc ns))
-> Jwt pc ns
-> m (Decoded (Jwt pc ns))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Jwt pc ns -> Decoded (Jwt pc ns)
forall t. t -> Decoded t
MkDecoded)
(Either SomeDecodeException (Jwt pc ns) -> m (Decoded (Jwt pc ns)))
-> Either SomeDecodeException (Jwt pc ns)
-> m (Decoded (Jwt pc ns))
forall a b. (a -> b) -> a -> b
$ JwtIO (Either SomeDecodeException (Jwt pc ns))
-> Either SomeDecodeException (Jwt pc ns)
forall a. JwtIO a -> a
unsafePerformJwtIO JwtIO (Either SomeDecodeException (Jwt pc ns))
decodeTokenJwtIo
where
decodeTokenJwtIo :: JwtIO (Either SomeDecodeException (Jwt pc ns))
decodeTokenJwtIo :: JwtIO (Either SomeDecodeException (Jwt pc ns))
decodeTokenJwtIo =
let (JwtAlgT
jwtAlg, k
key) = Algorithm k -> (JwtAlgT, k)
forall k. Algorithm k -> (JwtAlgT, k)
jwtAlgWithKey Algorithm k
algorithm
in JwtIO (Jwt pc ns) -> JwtIO (Either SomeDecodeException (Jwt pc ns))
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try (JwtIO (Jwt pc ns)
-> JwtIO (Either SomeDecodeException (Jwt pc ns)))
-> JwtIO (Jwt pc ns)
-> JwtIO (Either SomeDecodeException (Jwt pc ns))
forall a b. (a -> b) -> a -> b
$ do
JwtT
jwt <- ByteString -> ByteString -> JwtIO JwtT
safeJwtDecode (k -> ByteString
forall k. DecodingKey k => k -> ByteString
getDecodingKey k
key) ByteString
token
JwtIO Bool -> EncodeResult -> EncodeResult
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM ((JwtAlgT -> JwtAlgT -> Bool
forall a. Eq a => a -> a -> Bool
== JwtAlgT
jwtAlg) (JwtAlgT -> Bool) -> JwtIO JwtAlgT -> JwtIO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JwtT -> JwtIO JwtAlgT
jwtGetAlg JwtT
jwt) (EncodeResult -> EncodeResult) -> EncodeResult -> EncodeResult
forall a b. (a -> b) -> a -> b
$ AlgorithmMismatch -> EncodeResult
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM AlgorithmMismatch
AlgorithmMismatch
Header -> Payload pc ns -> Jwt pc ns
forall (pc :: [Claim *]) (ns :: Namespace).
Header -> Payload pc ns -> Jwt pc ns
Jwt (Header -> Payload pc ns -> Jwt pc ns)
-> JwtIO Header -> JwtIO (Payload pc ns -> Jwt pc ns)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JwtT -> JwtIO Header
forall c. Decode c => JwtT -> JwtIO c
decode JwtT
jwt JwtIO (Payload pc ns -> Jwt pc ns)
-> JwtIO (Payload pc ns) -> JwtIO (Jwt pc ns)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> JwtT -> JwtIO (Payload pc ns)
forall c. Decode c => JwtT -> JwtIO c
decode JwtT
jwt
{-# NOINLINE decodeByteString #-}
safeJwtDecode :: ByteString -> ByteString -> JwtIO JwtT
safeJwtDecode :: ByteString -> ByteString -> JwtIO JwtT
safeJwtDecode ByteString
key ByteString
token =
(IOError -> Bool)
-> JwtIO JwtT -> (IOError -> JwtIO JwtT) -> JwtIO JwtT
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
(e -> Bool) -> m a -> (e -> m a) -> m a
catchIf (\IOError
e -> IOError -> IOErrorType
ioeGetErrorType IOError
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
InvalidArgument) (ByteString -> ByteString -> JwtIO JwtT
jwtDecode ByteString
key ByteString
token)
((IOError -> JwtIO JwtT) -> JwtIO JwtT)
-> (IOError -> JwtIO JwtT) -> JwtIO JwtT
forall a b. (a -> b) -> a -> b
$ JwtIO JwtT -> IOError -> JwtIO JwtT
forall a b. a -> b -> a
const
(JwtIO JwtT -> IOError -> JwtIO JwtT)
-> JwtIO JwtT -> IOError -> JwtIO JwtT
forall a b. (a -> b) -> a -> b
$ DecodeException -> JwtIO JwtT
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM
(DecodeException -> JwtIO JwtT) -> DecodeException -> JwtIO JwtT
forall a b. (a -> b) -> a -> b
$ String -> DecodeException
DecodeException
(String -> DecodeException) -> String -> DecodeException
forall a b. (a -> b) -> a -> b
$ ByteString -> String
C8.unpack ByteString
token
newtype Validated t = MkValid { Validated t -> t
getValid :: t }
deriving stock (Int -> Validated t -> ShowS
[Validated t] -> ShowS
Validated t -> String
(Int -> Validated t -> ShowS)
-> (Validated t -> String)
-> ([Validated t] -> ShowS)
-> Show (Validated t)
forall t. Show t => Int -> Validated t -> ShowS
forall t. Show t => [Validated t] -> ShowS
forall t. Show t => Validated t -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Validated t] -> ShowS
$cshowList :: forall t. Show t => [Validated t] -> ShowS
show :: Validated t -> String
$cshow :: forall t. Show t => Validated t -> String
showsPrec :: Int -> Validated t -> ShowS
$cshowsPrec :: forall t. Show t => Int -> Validated t -> ShowS
Show, Validated t -> Validated t -> Bool
(Validated t -> Validated t -> Bool)
-> (Validated t -> Validated t -> Bool) -> Eq (Validated t)
forall t. Eq t => Validated t -> Validated t -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Validated t -> Validated t -> Bool
$c/= :: forall t. Eq t => Validated t -> Validated t -> Bool
== :: Validated t -> Validated t -> Bool
$c== :: forall t. Eq t => Validated t -> Validated t -> Bool
Eq)
validateJwt
:: MonadTime m
=> ValidationSettings
-> JwtValidation pc ns
-> Decoded (Jwt pc ns)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
validateJwt :: ValidationSettings
-> JwtValidation pc ns
-> Decoded (Jwt pc ns)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
validateJwt ValidationSettings
settings JwtValidation pc ns
v (MkDecoded Jwt pc ns
jwt) =
(Validation (NonEmpty ValidationFailure) Valid
-> ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
-> m (Validation (NonEmpty ValidationFailure) Valid)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Jwt pc ns -> Validated (Jwt pc ns)
forall t. t -> Validated t
MkValid Jwt pc ns
jwt Validated (Jwt pc ns)
-> Validation (NonEmpty ValidationFailure) Valid
-> ValidationNEL ValidationFailure (Validated (Jwt pc ns))
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$) (m (Validation (NonEmpty ValidationFailure) Valid)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns))))
-> m (Validation (NonEmpty ValidationFailure) Valid)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
forall a b. (a -> b) -> a -> b
$ ValidationSettings
-> JwtValidation pc ns
-> Payload pc ns
-> m (Validation (NonEmpty ValidationFailure) Valid)
forall (m :: * -> *) (pc :: [Claim *]) (any :: Namespace).
MonadTime m =>
ValidationSettings
-> JwtValidation pc any
-> Payload pc any
-> m (Validation (NonEmpty ValidationFailure) Valid)
runValidation ValidationSettings
settings JwtValidation pc ns
v (Payload pc ns
-> m (Validation (NonEmpty ValidationFailure) Valid))
-> Payload pc ns
-> m (Validation (NonEmpty ValidationFailure) Valid)
forall a b. (a -> b) -> a -> b
$ Jwt pc ns -> Payload pc ns
forall (pc :: [Claim *]) (ns :: Namespace).
Jwt pc ns -> Payload pc ns
payload Jwt pc ns
jwt
jwtFromString
:: (Decode (PrivateClaims pc ns), MonadTime m, MonadThrow m, DecodingKey k)
=> ValidationSettings
-> JwtValidation pc ns
-> Algorithm k
-> String
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
jwtFromString :: ValidationSettings
-> JwtValidation pc ns
-> Algorithm k
-> String
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
jwtFromString ValidationSettings
settings JwtValidation pc ns
v Algorithm k
algorithm =
ValidationSettings
-> JwtValidation pc ns
-> Decoded (Jwt pc ns)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
forall (m :: * -> *) (pc :: [Claim *]) (ns :: Namespace).
MonadTime m =>
ValidationSettings
-> JwtValidation pc ns
-> Decoded (Jwt pc ns)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
validateJwt ValidationSettings
settings JwtValidation pc ns
v (Decoded (Jwt pc ns)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns))))
-> (String -> m (Decoded (Jwt pc ns)))
-> String
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Algorithm k -> String -> m (Decoded (Jwt pc ns))
forall (m :: * -> *) (pc :: [Claim *]) (ns :: Namespace) k.
(MonadThrow m, Decode (PrivateClaims pc ns), DecodingKey k) =>
Algorithm k -> String -> m (Decoded (Jwt pc ns))
decodeString Algorithm k
algorithm
jwtFromByteString
:: (Decode (PrivateClaims pc ns), MonadTime m, MonadThrow m, DecodingKey k)
=> ValidationSettings
-> JwtValidation pc ns
-> Algorithm k
-> ByteString
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
jwtFromByteString :: ValidationSettings
-> JwtValidation pc ns
-> Algorithm k
-> ByteString
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
jwtFromByteString ValidationSettings
settings JwtValidation pc ns
v Algorithm k
algorithm =
ValidationSettings
-> JwtValidation pc ns
-> Decoded (Jwt pc ns)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
forall (m :: * -> *) (pc :: [Claim *]) (ns :: Namespace).
MonadTime m =>
ValidationSettings
-> JwtValidation pc ns
-> Decoded (Jwt pc ns)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
validateJwt ValidationSettings
settings JwtValidation pc ns
v (Decoded (Jwt pc ns)
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns))))
-> (ByteString -> m (Decoded (Jwt pc ns)))
-> ByteString
-> m (ValidationNEL ValidationFailure (Validated (Jwt pc ns)))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Algorithm k -> ByteString -> m (Decoded (Jwt pc ns))
forall (ns :: Namespace) (pc :: [Claim *]) (m :: * -> *) k.
(MonadThrow m, Decode (PrivateClaims pc ns), DecodingKey k) =>
Algorithm k -> ByteString -> m (Decoded (Jwt pc ns))
decodeByteString Algorithm k
algorithm