{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-orphans #-}

-- | Server implementation of the 'JWTAuth'' trait.
module WebGear.Server.Trait.Auth.JWT where

import Control.Arrow (arr, returnA, (>>>))
import Control.Monad.Except (MonadError (throwError), lift, runExceptT, withExceptT)
import Control.Monad.Time (MonadTime)
import qualified Crypto.JWT as JWT
import Data.ByteString.Lazy (fromStrict)
import Data.Void (Void)
import WebGear.Core.Handler (arrM)
import WebGear.Core.Modifiers
import WebGear.Core.Request (Request)
import WebGear.Core.Trait (Get (..), Linked)
import WebGear.Core.Trait.Auth.Common (
  AuthToken (..),
  AuthorizationHeader,
  getAuthorizationHeaderTrait,
 )
import WebGear.Core.Trait.Auth.JWT (JWTAuth' (..), JWTAuthError (..))
import WebGear.Server.Handler (ServerHandler)

instance (MonadTime m, Get (ServerHandler m) (AuthorizationHeader scheme) Request) => Get (ServerHandler m) (JWTAuth' Required scheme m e a) Request where
  {-# INLINEABLE getTrait #-}
  getTrait ::
    JWTAuth' Required scheme m e a ->
    ServerHandler m (Linked ts Request) (Either (JWTAuthError e) a)
  getTrait :: JWTAuth' 'Required scheme m e a
-> ServerHandler m (Linked ts Request) (Either (JWTAuthError e) a)
getTrait JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
$sel:jwtValidationSettings:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWTValidationSettings
$sel:jwkSet:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWKSet
$sel:toJWTAttribute:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
..} = proc Linked ts Request
request -> do
    Maybe (Either Text (AuthToken scheme))
result <- forall (scheme :: Symbol) (h :: * -> * -> *) (ts :: [*]).
Get h (AuthorizationHeader scheme) Request =>
h (Linked ts Request) (Maybe (Either Text (AuthToken scheme)))
forall (h :: * -> * -> *) (ts :: [*]).
Get h (AuthorizationHeader scheme) Request =>
h (Linked ts Request) (Maybe (Either Text (AuthToken scheme)))
getAuthorizationHeaderTrait @scheme -< Linked ts Request
request
    case Maybe (Either Text (AuthToken scheme))
result of
      Maybe (Either Text (AuthToken scheme))
Nothing -> ServerHandler
  m (Either (JWTAuthError e) a) (Either (JWTAuthError e) a)
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< JWTAuthError e -> Either (JWTAuthError e) a
forall a b. a -> Either a b
Left JWTAuthError e
forall e. JWTAuthError e
JWTAuthHeaderMissing
      (Just (Left Text
_)) -> ServerHandler
  m (Either (JWTAuthError e) a) (Either (JWTAuthError e) a)
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< JWTAuthError e -> Either (JWTAuthError e) a
forall a b. a -> Either a b
Left JWTAuthError e
forall e. JWTAuthError e
JWTAuthSchemeMismatch
      (Just (Right AuthToken scheme
token)) ->
        case AuthToken scheme -> Either JWTError SignedJWT
parseJWT AuthToken scheme
token of
          Left JWTError
e -> ServerHandler
  m (Either (JWTAuthError e) a) (Either (JWTAuthError e) a)
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< JWTAuthError e -> Either (JWTAuthError e) a
forall a b. a -> Either a b
Left (JWTError -> JWTAuthError e
forall e. JWTError -> JWTAuthError e
JWTAuthTokenBadFormat JWTError
e)
          Right SignedJWT
jwt -> ServerHandler m SignedJWT (Either (JWTAuthError e) a)
validateJWT -< SignedJWT
jwt
    where
      parseJWT :: AuthToken scheme -> Either JWT.JWTError JWT.SignedJWT
      parseJWT :: AuthToken scheme -> Either JWTError SignedJWT
parseJWT AuthToken{ByteString
CI ByteString
authScheme :: forall (scheme :: Symbol). AuthToken scheme -> CI ByteString
authToken :: forall (scheme :: Symbol). AuthToken scheme -> ByteString
authToken :: ByteString
authScheme :: CI ByteString
..} = ByteString -> Either JWTError SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
JWT.decodeCompact (ByteString -> Either JWTError SignedJWT)
-> ByteString -> Either JWTError SignedJWT
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
fromStrict ByteString
authToken

      validateJWT :: ServerHandler m JWT.SignedJWT (Either (JWTAuthError e) a)
      validateJWT :: ServerHandler m SignedJWT (Either (JWTAuthError e) a)
validateJWT = (SignedJWT -> m (Either (JWTAuthError e) a))
-> ServerHandler m SignedJWT (Either (JWTAuthError e) a)
forall (h :: * -> * -> *) (m :: * -> *) a b.
Handler h m =>
(a -> m b) -> h a b
arrM ((SignedJWT -> m (Either (JWTAuthError e) a))
 -> ServerHandler m SignedJWT (Either (JWTAuthError e) a))
-> (SignedJWT -> m (Either (JWTAuthError e) a))
-> ServerHandler m SignedJWT (Either (JWTAuthError e) a)
forall a b. (a -> b) -> a -> b
$ \SignedJWT
jwt -> ExceptT (JWTAuthError e) m a -> m (Either (JWTAuthError e) a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT (JWTAuthError e) m a -> m (Either (JWTAuthError e) a))
-> ExceptT (JWTAuthError e) m a -> m (Either (JWTAuthError e) a)
forall a b. (a -> b) -> a -> b
$ do
        ClaimsSet
claims <- (JWTError -> JWTAuthError e)
-> ExceptT JWTError m ClaimsSet
-> ExceptT (JWTAuthError e) m ClaimsSet
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT JWTError -> JWTAuthError e
forall e. JWTError -> JWTAuthError e
JWTAuthTokenBadFormat (ExceptT JWTError m ClaimsSet
 -> ExceptT (JWTAuthError e) m ClaimsSet)
-> ExceptT JWTError m ClaimsSet
-> ExceptT (JWTAuthError e) m ClaimsSet
forall a b. (a -> b) -> a -> b
$ JWTValidationSettings
-> JWKSet -> SignedJWT -> ExceptT JWTError m ClaimsSet
forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
JWT.verifyClaims JWTValidationSettings
jwtValidationSettings JWKSet
jwkSet SignedJWT
jwt
        m (Either e a) -> ExceptT (JWTAuthError e) m (Either e a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ClaimsSet -> m (Either e a)
toJWTAttribute ClaimsSet
claims) ExceptT (JWTAuthError e) m (Either e a)
-> (Either e a -> ExceptT (JWTAuthError e) m a)
-> ExceptT (JWTAuthError e) m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (e -> ExceptT (JWTAuthError e) m a)
-> (a -> ExceptT (JWTAuthError e) m a)
-> Either e a
-> ExceptT (JWTAuthError e) m a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (JWTAuthError e -> ExceptT (JWTAuthError e) m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (JWTAuthError e -> ExceptT (JWTAuthError e) m a)
-> (e -> JWTAuthError e) -> e -> ExceptT (JWTAuthError e) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> JWTAuthError e
forall e. e -> JWTAuthError e
JWTAuthAttributeError) a -> ExceptT (JWTAuthError e) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance (MonadTime m, Get (ServerHandler m) (AuthorizationHeader scheme) Request) => Get (ServerHandler m) (JWTAuth' Optional scheme m e a) Request where
  {-# INLINEABLE getTrait #-}
  getTrait ::
    JWTAuth' Optional scheme m e a ->
    ServerHandler m (Linked ts Request) (Either Void (Either (JWTAuthError e) a))
  getTrait :: JWTAuth' 'Optional scheme m e a
-> ServerHandler
     m (Linked ts Request) (Either Void (Either (JWTAuthError e) a))
getTrait JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
$sel:jwtValidationSettings:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWTValidationSettings
$sel:jwkSet:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWKSet
$sel:toJWTAttribute:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> ClaimsSet -> m (Either e a)
..} = JWTAuth' 'Required scheme m e a
-> ServerHandler
     m
     (Linked ts Request)
     (Either
        (Absence (JWTAuth' 'Required scheme m e a) Request)
        (Attribute (JWTAuth' 'Required scheme m e a) Request))
forall (h :: * -> * -> *) t a (ts :: [*]).
Get h t a =>
t -> h (Linked ts a) (Either (Absence t a) (Attribute t a))
getTrait (JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTValidationSettings
-> JWKSet
-> (ClaimsSet -> m (Either e a))
-> JWTAuth' x scheme m e a
JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
$sel:jwtValidationSettings:JWTAuth' :: JWTValidationSettings
$sel:jwkSet:JWTAuth' :: JWKSet
$sel:toJWTAttribute:JWTAuth' :: ClaimsSet -> m (Either e a)
..} :: JWTAuth' Required scheme m e a) ServerHandler m (Linked ts Request) (Either (JWTAuthError e) a)
-> ServerHandler
     m
     (Either (JWTAuthError e) a)
     (Either Void (Either (JWTAuthError e) a))
-> ServerHandler
     m (Linked ts Request) (Either Void (Either (JWTAuthError e) a))
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (Either (JWTAuthError e) a
 -> Either Void (Either (JWTAuthError e) a))
-> ServerHandler
     m
     (Either (JWTAuthError e) a)
     (Either Void (Either (JWTAuthError e) a))
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr Either (JWTAuthError e) a
-> Either Void (Either (JWTAuthError e) a)
forall a b. b -> Either a b
Right