{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} module Web.JWT.ASAP ( module Web.JWT.ASAP.Error , module Web.JWT.ASAP.Env , Expiry(..) , MaxAge(..) , defaultTokenExpiry , defaultTokenMaxAge , timedClaim , expiringClaim , maxAgeClaimGenerator' , maxAgeClaimGenerator , asapReadRsaSecret , asapAuthHeader , asapAuthHeaderFromEnv , asapSignerFromEnv , laterThanMaxAge ) where import Control.Applicative (liftA2) import Control.Lens (view, ( # ), _tail) import Control.Monad.Except (MonadError (..)) import Data.ByteString (ByteString) import Data.ByteString.Char8 as BS (unlines) import Data.ByteString.Lens (packedChars) import Data.IORef (newIORef, readIORef, writeIORef) import Data.String (fromString) import qualified Data.Text as T import Data.Time (NominalDiffTime) import Data.Time.Clock.POSIX (getPOSIXTime) import Data.UUID (UUID) import qualified Data.UUID as UUID import qualified Data.UUID.V4 as UUID import qualified Web.JWT as JWT import Web.JWT.ASAP.Env import Web.JWT.ASAP.Error newtype Expiry = Expiry NominalDiffTime deriving (Show, Eq, Ord) newtype MaxAge = MaxAge NominalDiffTime deriving (Show, Eq, Ord) defaultTokenExpiry :: Expiry defaultTokenExpiry = Expiry $ 10 * 60 defaultTokenMaxAge :: MaxAge defaultTokenMaxAge = MaxAge $ 9 * 60 timedClaim :: Expiry -> NominalDiffTime -> UUID -> JWT.JWTClaimsSet timedClaim (Expiry expiryTime) time uuid = mempty { JWT.iat = JWT.numericDate time , JWT.exp = JWT.numericDate $ time + expiryTime , JWT.jti = JWT.stringOrURI $ UUID.toText uuid } expiringClaim :: Expiry -> IO JWT.JWTClaimsSet expiringClaim expiry = liftA2 (timedClaim expiry) getPOSIXTime UUID.nextRandom maxAgeClaimGenerator' :: (Monad m) => MaxAge -> m NominalDiffTime -> m JWT.JWTClaimsSet -> (JWT.JWTClaimsSet -> m ()) -> m JWT.JWTClaimsSet -> m JWT.JWTClaimsSet maxAgeClaimGenerator' maxAge time = regenerateWhen predicate where predicate claim = maybe (pure False) (\iat -> laterThanMaxAge maxAge iat <$> time) $ JWT.iat claim maxAgeClaimGenerator :: MaxAge -> Expiry -> IO (IO JWT.JWTClaimsSet) maxAgeClaimGenerator maxAge expiry = do initialClaim <- newClaim ref <- newIORef initialClaim pure (maxAgeClaimGenerator' maxAge getPOSIXTime newClaim (writeIORef ref) (readIORef ref)) where newClaim = expiringClaim expiry asapReadRsaSecret :: (HasAsapError e, MonadError e m) => ByteString -> m JWT.Signer asapReadRsaSecret = maybe (throwError (asapInvalidSecret # ())) (pure . JWT.RSAPrivateKey) . JWT.readRsaSecret asapAuthHeader :: JWT.Signer -> JWT.JOSEHeader -> JWT.JWTClaimsSet -> T.Text asapAuthHeader signer header claim = "Bearer " <> JWT.encodeSigned signer header claim asapAuthHeaderFromEnv :: (HasAsapError e, MonadError e m, MonadEnv m) => JWT.JOSEHeader -> JWT.JWTClaimsSet -> m T.Text asapAuthHeaderFromEnv header claim = do issuer <- asapLookupEnv "ASAP_ISSUER" keyId <- asapLookupEnv "ASAP_KEY_ID" let header' = header { JWT.kid = Just $ fromString keyId } claim' = claim { JWT.iss = JWT.stringOrURI $ fromString issuer } signer <- asapSignerFromEnv pure (asapAuthHeader signer header' claim') asapSignerFromEnv :: (HasAsapError e, MonadError e m, MonadEnv m) => m JWT.Signer asapSignerFromEnv = do pem <- toPem . view packedChars . dataUriData <$> asapLookupEnv "ASAP_PRIVATE_KEY" asapReadRsaSecret pem where toPem c = BS.unlines [ "-----BEGIN RSA PRIVATE KEY-----", c, "-----END RSA PRIVATE KEY-----" ] dataUriData :: String -> String dataUriData = view _tail . dropWhile (/= ',') laterThanMaxAge :: MaxAge -> JWT.NumericDate -> NominalDiffTime -> Bool laterThanMaxAge (MaxAge maxAgeTime) iat time = time - JWT.secondsSinceEpoch iat >= maxAgeTime regenerateWhen :: Monad m => (a -> m Bool) -> m a -> (a -> m ()) -> m a -> m a regenerateWhen predicate ma put get = do c <- get b <- predicate c if b then do c' <- ma put c' pure c' else pure c