{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor      #-}
{-# LANGUAGE DeriveGeneric      #-}
{-# LANGUAGE OverloadedStrings  #-}

module Servant.Server.Internal.BasicAuth where

import           Control.Monad
                 (guard)
import           Control.Monad.Trans
                 (liftIO)
import qualified Data.ByteString                     as BS
import           Data.ByteString.Base64
                 (decodeLenient)
import           Data.Typeable
                 (Typeable)
import           Data.Word8
                 (isSpace, toLower, _colon)
import           GHC.Generics
import           Network.HTTP.Types
                 (Header)
import           Network.Wai
                 (Request, requestHeaders)

import           Servant.API.BasicAuth
                 (BasicAuthData (BasicAuthData))
import           Servant.Server.Internal.DelayedIO
import           Servant.Server.Internal.ServerError

-- * Basic Auth

-- | servant-server's current implementation of basic authentication is not
-- immune to certain kinds of timing attacks. Decoding payloads does not take
-- a fixed amount of time.

-- | The result of authentication/authorization
data BasicAuthResult usr
  = Unauthorized
  | BadPassword
  | NoSuchUser
  | Authorized usr
  deriving (BasicAuthResult usr -> BasicAuthResult usr -> Bool
forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BasicAuthResult usr -> BasicAuthResult usr -> Bool
$c/= :: forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
== :: BasicAuthResult usr -> BasicAuthResult usr -> Bool
$c== :: forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
Eq, Int -> BasicAuthResult usr -> ShowS
forall usr. Show usr => Int -> BasicAuthResult usr -> ShowS
forall usr. Show usr => [BasicAuthResult usr] -> ShowS
forall usr. Show usr => BasicAuthResult usr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BasicAuthResult usr] -> ShowS
$cshowList :: forall usr. Show usr => [BasicAuthResult usr] -> ShowS
show :: BasicAuthResult usr -> String
$cshow :: forall usr. Show usr => BasicAuthResult usr -> String
showsPrec :: Int -> BasicAuthResult usr -> ShowS
$cshowsPrec :: forall usr. Show usr => Int -> BasicAuthResult usr -> ShowS
Show, ReadPrec [BasicAuthResult usr]
ReadPrec (BasicAuthResult usr)
ReadS [BasicAuthResult usr]
forall usr. Read usr => ReadPrec [BasicAuthResult usr]
forall usr. Read usr => ReadPrec (BasicAuthResult usr)
forall usr. Read usr => Int -> ReadS (BasicAuthResult usr)
forall usr. Read usr => ReadS [BasicAuthResult usr]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [BasicAuthResult usr]
$creadListPrec :: forall usr. Read usr => ReadPrec [BasicAuthResult usr]
readPrec :: ReadPrec (BasicAuthResult usr)
$creadPrec :: forall usr. Read usr => ReadPrec (BasicAuthResult usr)
readList :: ReadS [BasicAuthResult usr]
$creadList :: forall usr. Read usr => ReadS [BasicAuthResult usr]
readsPrec :: Int -> ReadS (BasicAuthResult usr)
$creadsPrec :: forall usr. Read usr => Int -> ReadS (BasicAuthResult usr)
Read, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall usr x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
forall usr x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
$cto :: forall usr x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
$cfrom :: forall usr x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
Generic, Typeable, forall a b. a -> BasicAuthResult b -> BasicAuthResult a
forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> BasicAuthResult b -> BasicAuthResult a
$c<$ :: forall a b. a -> BasicAuthResult b -> BasicAuthResult a
fmap :: forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b
$cfmap :: forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b
Functor)

-- | Datatype wrapping a function used to check authentication.
newtype BasicAuthCheck usr = BasicAuthCheck
  { forall usr.
BasicAuthCheck usr -> BasicAuthData -> IO (BasicAuthResult usr)
unBasicAuthCheck :: BasicAuthData
                     -> IO (BasicAuthResult usr)
  }
  deriving (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall usr x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
forall usr x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
$cto :: forall usr x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
$cfrom :: forall usr x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
Generic, Typeable, forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a
forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a
$c<$ :: forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a
fmap :: forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
$cfmap :: forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
Functor)

-- | Internal method to make a basic-auth challenge
mkBAChallengerHdr :: BS.ByteString -> Header
mkBAChallengerHdr :: ByteString -> Header
mkBAChallengerHdr ByteString
realm = (HeaderName
"WWW-Authenticate", ByteString
"Basic realm=\"" forall a. Semigroup a => a -> a -> a
<> ByteString
realm forall a. Semigroup a => a -> a -> a
<> ByteString
"\"")

-- | Find and decode an 'Authorization' header from the request as Basic Auth
decodeBAHdr :: Request -> Maybe BasicAuthData
decodeBAHdr :: Request -> Maybe BasicAuthData
decodeBAHdr Request
req = do
    ByteString
ah <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" forall a b. (a -> b) -> a -> b
$ Request -> RequestHeaders
requestHeaders Request
req
    let (ByteString
b, ByteString
rest) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break Word8 -> Bool
isSpace ByteString
ah
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard ((Word8 -> Word8) -> ByteString -> ByteString
BS.map Word8 -> Word8
toLower ByteString
b forall a. Eq a => a -> a -> Bool
== ByteString
"basic")
    let decoded :: ByteString
decoded = ByteString -> ByteString
decodeLenient ((Word8 -> Bool) -> ByteString -> ByteString
BS.dropWhile Word8 -> Bool
isSpace ByteString
rest)
    let (ByteString
username, ByteString
passWithColonAtHead) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break (forall a. Eq a => a -> a -> Bool
== Word8
_colon) ByteString
decoded
    (Word8
_, ByteString
password) <- ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
passWithColonAtHead
    forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString -> BasicAuthData
BasicAuthData ByteString
username ByteString
password)

-- | Run and check basic authentication, returning the appropriate http error per
-- the spec.
runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth :: forall usr.
Request -> ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth Request
req ByteString
realm (BasicAuthCheck BasicAuthData -> IO (BasicAuthResult usr)
ba) =
  case Request -> Maybe BasicAuthData
decodeBAHdr Request
req of
     Maybe BasicAuthData
Nothing -> forall {a}. DelayedIO a
plzAuthenticate
     Just BasicAuthData
e  -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (BasicAuthData -> IO (BasicAuthResult usr)
ba BasicAuthData
e) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \BasicAuthResult usr
res -> case BasicAuthResult usr
res of
       BasicAuthResult usr
BadPassword    -> forall {a}. DelayedIO a
plzAuthenticate
       BasicAuthResult usr
NoSuchUser     -> forall {a}. DelayedIO a
plzAuthenticate
       BasicAuthResult usr
Unauthorized   -> forall a. ServerError -> DelayedIO a
delayedFailFatal ServerError
err403
       Authorized usr
usr -> forall (m :: * -> *) a. Monad m => a -> m a
return usr
usr
  where plzAuthenticate :: DelayedIO a
plzAuthenticate = forall a. ServerError -> DelayedIO a
delayedFailFatal ServerError
err401 { errHeaders :: RequestHeaders
errHeaders = [ByteString -> Header
mkBAChallengerHdr ByteString
realm] }