{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} module Web.JWT.ASAP ( module Web.JWT.ASAP.Error , module Web.JWT.ASAP.Env , Expiry(..) , MaxAge(..) , defaultTokenExpiry , defaultTokenMaxAge , timedClaim , expiringClaim , maxAgeClaimGenerator' , maxAgeClaimGenerator , asapReadRsaSecret , asapAuthHeader , asapAuthHeaderFromEnv , laterThanMaxAge ) where import Control.Applicative (liftA2) import Control.Lens (view, ( # )) import Control.Monad.Except (MonadError (..)) import Data.ByteString (ByteString) import Data.ByteString.Base64 (decodeLenient) import Data.ByteString.Lens (packedChars) import Data.IORef (newIORef, readIORef, writeIORef) import Data.String (fromString) import qualified Data.Text as T import Data.Time (NominalDiffTime) import Data.Time.Clock.POSIX (getPOSIXTime) import Data.UUID (UUID) import qualified Data.UUID as UUID import qualified Data.UUID.V4 as UUID import qualified Web.JWT as JWT import Web.JWT.ASAP.Env (MonadEnv (..), asapLookupEnv) import Web.JWT.ASAP.Error (HasAsapError (..)) newtype Expiry = Expiry NominalDiffTime deriving (Show, Eq, Ord) newtype MaxAge = MaxAge NominalDiffTime deriving (Show, Eq, Ord) defaultTokenExpiry :: Expiry defaultTokenExpiry = Expiry $ 10 * 60 defaultTokenMaxAge :: MaxAge defaultTokenMaxAge = MaxAge $ 9 * 60 timedClaim :: Expiry -> NominalDiffTime -> UUID -> JWT.JWTClaimsSet timedClaim (Expiry expiryTime) time uuid = mempty { JWT.iat = JWT.numericDate time , JWT.exp = JWT.numericDate $ time + expiryTime , JWT.jti = JWT.stringOrURI $ UUID.toText uuid } expiringClaim :: Expiry -> IO JWT.JWTClaimsSet expiringClaim expiry = do liftA2 (timedClaim expiry) getPOSIXTime UUID.nextRandom maxAgeClaimGenerator' :: (Monad m) => MaxAge -> m NominalDiffTime -> m JWT.JWTClaimsSet -> (JWT.JWTClaimsSet -> m ()) -> m JWT.JWTClaimsSet -> m JWT.JWTClaimsSet maxAgeClaimGenerator' maxAge time newClaim = regenerateWhen predicate newClaim where predicate claim = maybe (pure False) (\iat -> laterThanMaxAge maxAge iat <$> time) $ JWT.iat claim maxAgeClaimGenerator :: MaxAge -> Expiry -> IO (IO JWT.JWTClaimsSet) maxAgeClaimGenerator maxAge expiry = do initialClaim <- newClaim ref <- newIORef initialClaim pure (maxAgeClaimGenerator' maxAge getPOSIXTime newClaim (writeIORef ref) (readIORef ref)) where newClaim = expiringClaim expiry asapReadRsaSecret :: (HasAsapError e, MonadError e m) => ByteString -> m JWT.Signer asapReadRsaSecret = maybe (throwError (asapInvalidSecret # ())) (pure . JWT.RSAPrivateKey) . JWT.readRsaSecret asapAuthHeader :: JWT.Signer -> JWT.JOSEHeader -> JWT.JWTClaimsSet -> T.Text asapAuthHeader signer header claim = "Bearer " <> JWT.encodeSigned signer header claim asapAuthHeaderFromEnv :: (HasAsapError e, MonadError e m, MonadEnv m) => JWT.JOSEHeader -> JWT.JWTClaimsSet -> m T.Text asapAuthHeaderFromEnv header claim = do issuer <- asapLookupEnv "ASAP_ISSUER" keyId <- asapLookupEnv "ASAP_KEY_ID" dataUri <- asapLookupEnv "ASAP_PRIVATE_KEY" let pem = decodeLenient . view packedChars $ dataUriData dataUri header' = header { JWT.kid = Just $ fromString keyId } claim' = claim { JWT.iss = JWT.stringOrURI $ fromString issuer } signer <- asapReadRsaSecret pem pure (asapAuthHeader signer header' claim') dataUriData :: String -> String dataUriData = snd . break (== ',') laterThanMaxAge :: MaxAge -> JWT.NumericDate -> NominalDiffTime -> Bool laterThanMaxAge (MaxAge maxAgeTime) iat time = do time - JWT.secondsSinceEpoch iat >= maxAgeTime regenerateWhen :: Monad m => (a -> m Bool) -> m a -> (a -> m ()) -> m a -> m a regenerateWhen predicate ma put get = do c <- get b <- predicate c if b then do c' <- ma put c' pure c' else pure c