{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-orphans #-}

-- | Server implementation of the 'JWTAuth'' trait.
module WebGear.Server.Trait.Auth.JWT where

import Control.Arrow (arr, returnA, (>>>))
import Control.Monad.Except (MonadError (throwError), lift, runExceptT, withExceptT)
import Control.Monad.Time (MonadTime)
import qualified Crypto.JWT as JWT
import Data.ByteString.Lazy (fromStrict)
import Data.Void (Void)
import WebGear.Core.Handler (arrM)
import WebGear.Core.Modifiers
import WebGear.Core.Request (Request)
import WebGear.Core.Trait (Get (..), Linked)
import WebGear.Core.Trait.Auth.Common (
  AuthToken (..),
  AuthorizationHeader,
  getAuthorizationHeaderTrait,
 )
import WebGear.Core.Trait.Auth.JWT (JWTAuth' (..), JWTAuthError (..))
import WebGear.Server.Handler (ServerHandler)

instance (MonadTime m, Get (ServerHandler m) (AuthorizationHeader scheme) Request) => Get (ServerHandler m) (JWTAuth' Required scheme m e a) Request where
  {-# INLINEABLE getTrait #-}
  getTrait ::
    JWTAuth' Required scheme m e a ->
    ServerHandler m (Linked ts Request) (Either (JWTAuthError e) a)
  getTrait :: forall (ts :: [*]).
JWTAuth' 'Required scheme m e a
-> ServerHandler m (Linked ts Request) (Either (JWTAuthError e) a)
getTrait JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
$sel:jwtValidationSettings:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWTValidationSettings
$sel:jwkSet:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWKSet
$sel:toJWTAttribute:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
..} = proc Linked ts Request
request -> do
    Maybe (Either Text (AuthToken scheme))
result <- forall (scheme :: Symbol) (h :: * -> * -> *) (ts :: [*]).
Get h (AuthorizationHeader scheme) Request =>
h (Linked ts Request) (Maybe (Either Text (AuthToken scheme)))
getAuthorizationHeaderTrait @scheme -< Linked ts Request
request
    case Maybe (Either Text (AuthToken scheme))
result of
      Maybe (Either Text (AuthToken scheme))
Nothing -> forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< forall a b. a -> Either a b
Left forall e. JWTAuthError e
JWTAuthHeaderMissing
      (Just (Left Text
_)) -> forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< forall a b. a -> Either a b
Left forall e. JWTAuthError e
JWTAuthSchemeMismatch
      (Just (Right AuthToken scheme
token)) ->
        case AuthToken scheme -> Either JWTError SignedJWT
parseJWT AuthToken scheme
token of
          Left JWTError
e -> forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< forall a b. a -> Either a b
Left (forall e. JWTError -> JWTAuthError e
JWTAuthTokenBadFormat JWTError
e)
          Right SignedJWT
jwt -> ServerHandler m SignedJWT (Either (JWTAuthError e) a)
validateJWT -< SignedJWT
jwt
    where
      parseJWT :: AuthToken scheme -> Either JWT.JWTError JWT.SignedJWT
      parseJWT :: AuthToken scheme -> Either JWTError SignedJWT
parseJWT AuthToken{ByteString
CI ByteString
authScheme :: forall (scheme :: Symbol). AuthToken scheme -> CI ByteString
authToken :: forall (scheme :: Symbol). AuthToken scheme -> ByteString
authToken :: ByteString
authScheme :: CI ByteString
..} = forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
JWT.decodeCompact forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
fromStrict ByteString
authToken

      validateJWT :: ServerHandler m JWT.SignedJWT (Either (JWTAuthError e) a)
      validateJWT :: ServerHandler m SignedJWT (Either (JWTAuthError e) a)
validateJWT = forall (h :: * -> * -> *) (m :: * -> *) a b.
Handler h m =>
(a -> m b) -> h a b
arrM forall a b. (a -> b) -> a -> b
$ \SignedJWT
jwt -> forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
        ClaimsSet
claims <- forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT forall e. JWTError -> JWTAuthError e
JWTAuthTokenBadFormat forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
JWT.verifyClaims JWTValidationSettings
jwtValidationSettings JWKSet
jwkSet SignedJWT
jwt
        forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ClaimsSet -> m (Either e a)
toJWTAttribute ClaimsSet
claims) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. e -> JWTAuthError e
JWTAuthAttributeError) forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance (MonadTime m, Get (ServerHandler m) (AuthorizationHeader scheme) Request) => Get (ServerHandler m) (JWTAuth' Optional scheme m e a) Request where
  {-# INLINEABLE getTrait #-}
  getTrait ::
    JWTAuth' Optional scheme m e a ->
    ServerHandler m (Linked ts Request) (Either Void (Either (JWTAuthError e) a))
  getTrait :: forall (ts :: [*]).
JWTAuth' 'Optional scheme m e a
-> ServerHandler
     m (Linked ts Request) (Either Void (Either (JWTAuthError e) a))
getTrait JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
$sel:jwtValidationSettings:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWTValidationSettings
$sel:jwkSet:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWKSet
$sel:toJWTAttribute:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> ClaimsSet -> m (Either e a)
..} = forall (h :: * -> * -> *) t a (ts :: [*]).
Get h t a =>
t -> h (Linked ts a) (Either (Absence t a) (Attribute t a))
getTrait (JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
$sel:jwtValidationSettings:JWTAuth' :: JWTValidationSettings
$sel:jwkSet:JWTAuth' :: JWKSet
$sel:toJWTAttribute:JWTAuth' :: ClaimsSet -> m (Either e a)
..} :: JWTAuth' Required scheme m e a) forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr forall a b. b -> Either a b
Right