{-# LANGUAGE DataKinds         #-}
{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE NamedFieldPuns    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies      #-}
module Auth.Biscuit.Servant
  (
  -- Servant Auth Handler
    RequireBiscuit
  , authHandler
  , genBiscuitCtx
  , checkBiscuit
  -- Decorate regular handlers with composable verifiers
  , WithVerifier (..)
  , handleBiscuit
  , withVerifier
  , withVerifier_
  , noVerifier
  , noVerifier_
  , withFallbackVerifier
  , withPriorityVerifier
  ) where

import           Auth.Biscuit                     (Biscuit, PublicKey, Verifier,
                                                   checkBiscuitSignature,
                                                   parseB64, verifyBiscuit)
import           Control.Monad.Except             (MonadError, throwError)
import           Control.Monad.IO.Class           (MonadIO, liftIO)
import           Control.Monad.Reader             (ReaderT, lift, runReaderT)
import           Data.Bifunctor                   (first)
import qualified Data.ByteString                  as BS
import qualified Data.ByteString.Char8            as C8
import qualified Data.ByteString.Lazy             as LBS
import           Network.Wai
import           Servant                          (AuthProtect)
import           Servant.Server
import           Servant.Server.Experimental.Auth

-- | Type used to protect and API tree, requiring a biscuit token
-- to be attached to requests. The associated auth handler will
-- only check the biscuit signature. Checking the datalog part
-- usually requires endpoint-specific information, and has to
-- be performed separately with either 'checkBiscuit' (for simple
-- use-cases) or 'handleBiscuit' (for more complex use-cases).
type RequireBiscuit = AuthProtect "biscuit"
type instance AuthServerData RequireBiscuit = CheckedBiscuit

-- | A biscuit which signature has already been verified.
-- Since the biscuit lib checks the signature while verifying the datalog
-- part, the public key is needed. 'CheckedBiscuit' carries the public key
-- used for verifying the signature so that the datalog verification part
-- can use it.
data CheckedBiscuit = CheckedBiscuit PublicKey Biscuit

-- | Wrapper for a servant handler, equipped with a biscuit 'Verifier'
-- that will be used to authorize the request. If the authorization
-- succeeds, the handler is ran.
-- The handler itself is given access to the verified biscuit through
-- a 'ReaderT Biscuit'.
data WithVerifier m a
  = WithVerifier
  { WithVerifier m a -> ReaderT Biscuit m a
handler_  :: ReaderT Biscuit m a
  -- ^ the wrapped handler, in a 'ReaderT' to give easy access to the biscuit
  , WithVerifier m a -> Verifier
verifier_ :: Verifier
  -- ^ the 'Verifier' associated to the handler
  }

-- | Combines the provided 'Verifier' to the 'Verifier' attached to the wrapped
-- handler. _facts_, _rules_ and _checked_ are unordered, but _policies_ have a
-- specific order. 'withFallbackVerifier' puts the provided policies at the _bottom_
-- of the list (ie as _fallback_ policies).
-- If you want the policies to be tried before the ones of the wrapped handler, you
-- can use 'withPriorityVerifier'.
withFallbackVerifier :: Verifier
                     -> WithVerifier m a
                     -> WithVerifier m a
withFallbackVerifier :: Verifier -> WithVerifier m a -> WithVerifier m a
withFallbackVerifier Verifier
newV h :: WithVerifier m a
h@WithVerifier{Verifier
verifier_ :: Verifier
verifier_ :: forall (m :: * -> *) a. WithVerifier m a -> Verifier
verifier_} =
  WithVerifier m a
h { verifier_ :: Verifier
verifier_ = Verifier
verifier_ Verifier -> Verifier -> Verifier
forall a. Semigroup a => a -> a -> a
<> Verifier
newV }

-- | Combines the provided 'Verifier' to the 'Verifier' attached to the wrapped
-- handler. _facts_, _rules_ and _checked_ are unordered, but _policies_ have a
-- specific order. 'withFallbackVerifier' puts the provided policies at the _top_
-- of the list (ie as _priority_ policies).
-- If you want the policies to be tried after the ones of the wrapped handler, you
-- can use 'withFallbackVerifier'.
withPriorityVerifier :: Verifier
                     -> WithVerifier m a
                     -> WithVerifier m a
withPriorityVerifier :: Verifier -> WithVerifier m a -> WithVerifier m a
withPriorityVerifier Verifier
newV h :: WithVerifier m a
h@WithVerifier{Verifier
verifier_ :: Verifier
verifier_ :: forall (m :: * -> *) a. WithVerifier m a -> Verifier
verifier_} =
  WithVerifier m a
h { verifier_ :: Verifier
verifier_ = Verifier
newV Verifier -> Verifier -> Verifier
forall a. Semigroup a => a -> a -> a
<> Verifier
verifier_ }

-- | Wraps an existing handler block, attaching a 'Verifier'. The handler has
-- to be a 'ReaderT Biscuit' to be able to access the token. If you don't need
-- to access the token from the handler block, you can use 'withVerifier_'
-- instead.
withVerifier :: Monad m => Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifier :: Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifier Verifier
verifier_ ReaderT Biscuit m a
handler_ =
  WithVerifier :: forall (m :: * -> *) a.
ReaderT Biscuit m a -> Verifier -> WithVerifier m a
WithVerifier
    { ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
handler_
    , Verifier
verifier_ :: Verifier
verifier_ :: Verifier
verifier_
    }

-- | Wraps an existing handler block, attaching a 'Verifier'. The handler can be
-- any monad, but won't be able to access the 'Biscuit'. If you want to read the
-- biscuit token from the handler block, you can use 'withVerifier' instead.
withVerifier_ :: Monad m => Verifier -> m a -> WithVerifier m a
withVerifier_ :: Verifier -> m a -> WithVerifier m a
withVerifier_ Verifier
v = Verifier -> ReaderT Biscuit m a -> WithVerifier m a
forall (m :: * -> *) a.
Monad m =>
Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifier Verifier
v (ReaderT Biscuit m a -> WithVerifier m a)
-> (m a -> ReaderT Biscuit m a) -> m a -> WithVerifier m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ReaderT Biscuit m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | Wraps an existing handler block, attaching an empty 'Verifier'. The handler has
-- to be a 'ReaderT Biscuit' to be able to access the token. If you don't need
-- to access the token from the handler block, you can use 'noVerifier_'
-- instead.
--
-- This function can be used together with 'withFallbackVerifier' or 'withPriorityVerifier'
-- to apply policies on several handlers at the same time (with 'hoistServer' for instance).
noVerifier :: Monad m => ReaderT Biscuit m a -> WithVerifier m a
noVerifier :: ReaderT Biscuit m a -> WithVerifier m a
noVerifier = Verifier -> ReaderT Biscuit m a -> WithVerifier m a
forall (m :: * -> *) a.
Monad m =>
Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifier Verifier
forall a. Monoid a => a
mempty

-- | Wraps an existing handler block, attaching an empty 'Verifier'. The handler can be
-- any monad, but won't be able to access the 'Biscuit'. If you want to read the
-- biscuit token from the handler block, you can use 'noVerifier' instead.
--
-- This function can be used together with 'withFallbackVerifier' or 'withPriorityVerifier'
-- to apply policies on several handlers at the same time (with 'hoistServer' for instance).
noVerifier_ :: Monad m => m a -> WithVerifier m a
noVerifier_ :: m a -> WithVerifier m a
noVerifier_ = ReaderT Biscuit m a -> WithVerifier m a
forall (m :: * -> *) a.
Monad m =>
ReaderT Biscuit m a -> WithVerifier m a
noVerifier (ReaderT Biscuit m a -> WithVerifier m a)
-> (m a -> ReaderT Biscuit m a) -> m a -> WithVerifier m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ReaderT Biscuit m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | Extracts a biscuit from an http request, assuming:
--
-- - the biscuit is b64-encoded
-- - prefixed with the `Bearer ` string
-- - in the `Authorization` header
extractBiscuit :: Request -> Either String Biscuit
extractBiscuit :: Request -> Either String Biscuit
extractBiscuit Request
req = do
  let note :: a -> Maybe b -> Either a b
note a
e = Either a b -> (b -> Either a b) -> Maybe b -> Either a b
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (a -> Either a b
forall a b. a -> Either a b
Left a
e) b -> Either a b
forall a b. b -> Either a b
Right
  ByteString
authHeader <- String -> Maybe ByteString -> Either String ByteString
forall a b. a -> Maybe b -> Either a b
note String
"Missing Authorization header" (Maybe ByteString -> Either String ByteString)
-> ([(HeaderName, ByteString)] -> Maybe ByteString)
-> [(HeaderName, ByteString)]
-> Either String ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" ([(HeaderName, ByteString)] -> Either String ByteString)
-> [(HeaderName, ByteString)] -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(HeaderName, ByteString)]
requestHeaders Request
req
  ByteString
b64Token   <- String -> Maybe ByteString -> Either String ByteString
forall a b. a -> Maybe b -> Either a b
note String
"Not a Bearer token" (Maybe ByteString -> Either String ByteString)
-> Maybe ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> Maybe ByteString
BS.stripPrefix ByteString
"Bearer " ByteString
authHeader
  (ParseError -> String)
-> Either ParseError Biscuit -> Either String Biscuit
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (String -> ParseError -> String
forall a b. a -> b -> a
const String
"Not a B64-encoded biscuit") (Either ParseError Biscuit -> Either String Biscuit)
-> Either ParseError Biscuit -> Either String Biscuit
forall a b. (a -> b) -> a -> b
$ ByteString -> Either ParseError Biscuit
parseB64 ByteString
b64Token

-- | Servant authorization handler. This extracts the biscuit from the request,
-- checks its signature (but not the datalog part) and returns a 'CheckedBiscuit'
-- upon success.
authHandler :: PublicKey -> AuthHandler Request CheckedBiscuit
authHandler :: PublicKey -> AuthHandler Request CheckedBiscuit
authHandler PublicKey
publicKey = (Request -> Handler CheckedBiscuit)
-> AuthHandler Request CheckedBiscuit
forall r usr. (r -> Handler usr) -> AuthHandler r usr
mkAuthHandler Request -> Handler CheckedBiscuit
handler
  where
    authError :: String -> ServerError
authError String
s = ServerError
err401 { errBody :: ByteString
errBody = ByteString -> ByteString
LBS.fromStrict (String -> ByteString
C8.pack String
s) }
    orError :: Either String a -> Handler a
orError = (String -> Handler a)
-> (a -> Handler a) -> Either String a -> Handler a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> Handler a)
-> (String -> ServerError) -> String -> Handler a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ServerError
authError) a -> Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    handler :: Request -> Handler CheckedBiscuit
handler Request
req = do
      Biscuit
biscuit <- Either String Biscuit -> Handler Biscuit
forall a. Either String a -> Handler a
orError (Either String Biscuit -> Handler Biscuit)
-> Either String Biscuit -> Handler Biscuit
forall a b. (a -> b) -> a -> b
$ Request -> Either String Biscuit
extractBiscuit Request
req
      Bool
result  <- IO Bool -> Handler Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> Handler Bool) -> IO Bool -> Handler Bool
forall a b. (a -> b) -> a -> b
$ Biscuit -> PublicKey -> IO Bool
checkBiscuitSignature Biscuit
biscuit PublicKey
publicKey
      case Bool
result of
        Bool
False -> ServerError -> Handler CheckedBiscuit
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> Handler CheckedBiscuit)
-> ServerError -> Handler CheckedBiscuit
forall a b. (a -> b) -> a -> b
$ String -> ServerError
authError String
"Invalid signature"
        Bool
True  -> CheckedBiscuit -> Handler CheckedBiscuit
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CheckedBiscuit -> Handler CheckedBiscuit)
-> CheckedBiscuit -> Handler CheckedBiscuit
forall a b. (a -> b) -> a -> b
$ PublicKey -> Biscuit -> CheckedBiscuit
CheckedBiscuit PublicKey
publicKey Biscuit
biscuit

-- | Helper function generating a servant context containing the authorization
-- handler.
genBiscuitCtx :: PublicKey -> Context '[AuthHandler Request CheckedBiscuit]
genBiscuitCtx :: PublicKey -> Context '[AuthHandler Request CheckedBiscuit]
genBiscuitCtx PublicKey
pk = PublicKey -> AuthHandler Request CheckedBiscuit
authHandler PublicKey
pk AuthHandler Request CheckedBiscuit
-> Context '[] -> Context '[AuthHandler Request CheckedBiscuit]
forall x (xs :: [*]). x -> Context xs -> Context (x : xs)
:. Context '[]
EmptyContext

-- | Given a 'CheckedBiscuit' (provided by the servant authorization mechanism),
-- verify its validity (with the provided 'Verifier'). If you don't want to pass
-- the biscuit manually to all the endpoints or want to blanket apply verifiers on
-- whole API trees, you can consider using 'withVerifier' (on endpoints), 'withFallbackVerifier' and
-- 'withPriorityVerifier' (on API sub-trees) and 'handleBiscuit' (on the whole API).
checkBiscuit :: (MonadIO m, MonadError ServerError m)
             => CheckedBiscuit
             -> Verifier
             -> m a
             -> m a
checkBiscuit :: CheckedBiscuit -> Verifier -> m a -> m a
checkBiscuit (CheckedBiscuit PublicKey
pk Biscuit
b) Verifier
v m a
h = do
  Either VerificationError Query
res <- IO (Either VerificationError Query)
-> m (Either VerificationError Query)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either VerificationError Query)
 -> m (Either VerificationError Query))
-> IO (Either VerificationError Query)
-> m (Either VerificationError Query)
forall a b. (a -> b) -> a -> b
$ Biscuit
-> Verifier -> PublicKey -> IO (Either VerificationError Query)
verifyBiscuit Biscuit
b Verifier
v PublicKey
pk
  case Either VerificationError Query
res of
    Left VerificationError
e  -> do IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ VerificationError -> IO ()
forall a. Show a => a -> IO ()
print VerificationError
e
                  ServerError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> m a) -> ServerError -> m a
forall a b. (a -> b) -> a -> b
$ ServerError
err401 { errBody :: ByteString
errBody = ByteString
"Biscuit failed checks" }
    Right Query
_ -> m a
h

-- | Given a handler wrapped in a 'WithVerifier', use the attached 'Verifier' to
-- verify the provided biscuit and return an error as needed.
--
-- For simpler use cases, consider using 'checkBiscuit' instead, which works on regular
-- servant handlers.
handleBiscuit :: (MonadIO m, MonadError ServerError m)
              => CheckedBiscuit
              -> WithVerifier m a
              -> m a
handleBiscuit :: CheckedBiscuit -> WithVerifier m a -> m a
handleBiscuit cb :: CheckedBiscuit
cb@(CheckedBiscuit PublicKey
_ Biscuit
b) WithVerifier{Verifier
verifier_ :: Verifier
verifier_ :: forall (m :: * -> *) a. WithVerifier m a -> Verifier
verifier_, ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
handler_ :: forall (m :: * -> *) a. WithVerifier m a -> ReaderT Biscuit m a
handler_} =
  let h :: m a
h = ReaderT Biscuit m a -> Biscuit -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Biscuit m a
handler_ Biscuit
b
  in CheckedBiscuit -> Verifier -> m a -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadError ServerError m) =>
CheckedBiscuit -> Verifier -> m a -> m a
checkBiscuit CheckedBiscuit
cb Verifier
verifier_ m a
h