{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor      #-}
{-# LANGUAGE DeriveGeneric      #-}
{-# LANGUAGE OverloadedStrings  #-}

module Servant.Server.Internal.BasicAuth where

import           Control.Monad
                 (guard)
import           Control.Monad.Trans
                 (liftIO)
import qualified Data.ByteString                     as BS
import           Data.ByteString.Base64
                 (decodeLenient)
import           Data.Typeable
                 (Typeable)
import           Data.Word8
                 (isSpace, toLower, _colon)
import           GHC.Generics
import           Network.HTTP.Types
                 (Header)
import           Network.Wai
                 (Request, requestHeaders)

import           Servant.API.BasicAuth
                 (BasicAuthData (BasicAuthData))
import           Servant.Server.Internal.DelayedIO
import           Servant.Server.Internal.ServerError

-- * Basic Auth

-- | servant-server's current implementation of basic authentication is not
-- immune to certain kinds of timing attacks. Decoding payloads does not take
-- a fixed amount of time.

-- | The result of authentication/authorization
data BasicAuthResult usr
  = Unauthorized
  | BadPassword
  | NoSuchUser
  | Authorized usr
  deriving (BasicAuthResult usr -> BasicAuthResult usr -> Bool
(BasicAuthResult usr -> BasicAuthResult usr -> Bool)
-> (BasicAuthResult usr -> BasicAuthResult usr -> Bool)
-> Eq (BasicAuthResult usr)
forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BasicAuthResult usr -> BasicAuthResult usr -> Bool
$c/= :: forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
== :: BasicAuthResult usr -> BasicAuthResult usr -> Bool
$c== :: forall usr.
Eq usr =>
BasicAuthResult usr -> BasicAuthResult usr -> Bool
Eq, Int -> BasicAuthResult usr -> ShowS
[BasicAuthResult usr] -> ShowS
BasicAuthResult usr -> String
(Int -> BasicAuthResult usr -> ShowS)
-> (BasicAuthResult usr -> String)
-> ([BasicAuthResult usr] -> ShowS)
-> Show (BasicAuthResult usr)
forall usr. Show usr => Int -> BasicAuthResult usr -> ShowS
forall usr. Show usr => [BasicAuthResult usr] -> ShowS
forall usr. Show usr => BasicAuthResult usr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BasicAuthResult usr] -> ShowS
$cshowList :: forall usr. Show usr => [BasicAuthResult usr] -> ShowS
show :: BasicAuthResult usr -> String
$cshow :: forall usr. Show usr => BasicAuthResult usr -> String
showsPrec :: Int -> BasicAuthResult usr -> ShowS
$cshowsPrec :: forall usr. Show usr => Int -> BasicAuthResult usr -> ShowS
Show, ReadPrec [BasicAuthResult usr]
ReadPrec (BasicAuthResult usr)
Int -> ReadS (BasicAuthResult usr)
ReadS [BasicAuthResult usr]
(Int -> ReadS (BasicAuthResult usr))
-> ReadS [BasicAuthResult usr]
-> ReadPrec (BasicAuthResult usr)
-> ReadPrec [BasicAuthResult usr]
-> Read (BasicAuthResult usr)
forall usr. Read usr => ReadPrec [BasicAuthResult usr]
forall usr. Read usr => ReadPrec (BasicAuthResult usr)
forall usr. Read usr => Int -> ReadS (BasicAuthResult usr)
forall usr. Read usr => ReadS [BasicAuthResult usr]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [BasicAuthResult usr]
$creadListPrec :: forall usr. Read usr => ReadPrec [BasicAuthResult usr]
readPrec :: ReadPrec (BasicAuthResult usr)
$creadPrec :: forall usr. Read usr => ReadPrec (BasicAuthResult usr)
readList :: ReadS [BasicAuthResult usr]
$creadList :: forall usr. Read usr => ReadS [BasicAuthResult usr]
readsPrec :: Int -> ReadS (BasicAuthResult usr)
$creadsPrec :: forall usr. Read usr => Int -> ReadS (BasicAuthResult usr)
Read, (forall x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x)
-> (forall x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr)
-> Generic (BasicAuthResult usr)
forall x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
forall x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall usr x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
forall usr x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
$cto :: forall usr x. Rep (BasicAuthResult usr) x -> BasicAuthResult usr
$cfrom :: forall usr x. BasicAuthResult usr -> Rep (BasicAuthResult usr) x
Generic, Typeable, a -> BasicAuthResult b -> BasicAuthResult a
(a -> b) -> BasicAuthResult a -> BasicAuthResult b
(forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b)
-> (forall a b. a -> BasicAuthResult b -> BasicAuthResult a)
-> Functor BasicAuthResult
forall a b. a -> BasicAuthResult b -> BasicAuthResult a
forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> BasicAuthResult b -> BasicAuthResult a
$c<$ :: forall a b. a -> BasicAuthResult b -> BasicAuthResult a
fmap :: (a -> b) -> BasicAuthResult a -> BasicAuthResult b
$cfmap :: forall a b. (a -> b) -> BasicAuthResult a -> BasicAuthResult b
Functor)

-- | Datatype wrapping a function used to check authentication.
newtype BasicAuthCheck usr = BasicAuthCheck
  { BasicAuthCheck usr -> BasicAuthData -> IO (BasicAuthResult usr)
unBasicAuthCheck :: BasicAuthData
                     -> IO (BasicAuthResult usr)
  }
  deriving ((forall x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x)
-> (forall x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr)
-> Generic (BasicAuthCheck usr)
forall x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
forall x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall usr x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
forall usr x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
$cto :: forall usr x. Rep (BasicAuthCheck usr) x -> BasicAuthCheck usr
$cfrom :: forall usr x. BasicAuthCheck usr -> Rep (BasicAuthCheck usr) x
Generic, Typeable, a -> BasicAuthCheck b -> BasicAuthCheck a
(a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
(forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b)
-> (forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a)
-> Functor BasicAuthCheck
forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a
forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> BasicAuthCheck b -> BasicAuthCheck a
$c<$ :: forall a b. a -> BasicAuthCheck b -> BasicAuthCheck a
fmap :: (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
$cfmap :: forall a b. (a -> b) -> BasicAuthCheck a -> BasicAuthCheck b
Functor)

-- | Internal method to make a basic-auth challenge
mkBAChallengerHdr :: BS.ByteString -> Header
mkBAChallengerHdr :: ByteString -> Header
mkBAChallengerHdr ByteString
realm = (HeaderName
"WWW-Authenticate", ByteString
"Basic realm=\"" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
realm ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\"")

-- | Find and decode an 'Authorization' header from the request as Basic Auth
decodeBAHdr :: Request -> Maybe BasicAuthData
decodeBAHdr :: Request -> Maybe BasicAuthData
decodeBAHdr Request
req = do
    ByteString
ah <- HeaderName -> [Header] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" ([Header] -> Maybe ByteString) -> [Header] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [Header]
requestHeaders Request
req
    let (ByteString
b, ByteString
rest) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break Word8 -> Bool
isSpace ByteString
ah
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ((Word8 -> Word8) -> ByteString -> ByteString
BS.map Word8 -> Word8
toLower ByteString
b ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"basic")
    let decoded :: ByteString
decoded = ByteString -> ByteString
decodeLenient ((Word8 -> Bool) -> ByteString -> ByteString
BS.dropWhile Word8 -> Bool
isSpace ByteString
rest)
    let (ByteString
username, ByteString
passWithColonAtHead) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_colon) ByteString
decoded
    (Word8
_, ByteString
password) <- ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
passWithColonAtHead
    BasicAuthData -> Maybe BasicAuthData
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString -> BasicAuthData
BasicAuthData ByteString
username ByteString
password)

-- | Run and check basic authentication, returning the appropriate http error per
-- the spec.
runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth :: Request -> ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth Request
req ByteString
realm (BasicAuthCheck BasicAuthData -> IO (BasicAuthResult usr)
ba) =
  case Request -> Maybe BasicAuthData
decodeBAHdr Request
req of
     Maybe BasicAuthData
Nothing -> DelayedIO usr
forall a. DelayedIO a
plzAuthenticate
     Just BasicAuthData
e  -> IO (BasicAuthResult usr) -> DelayedIO (BasicAuthResult usr)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (BasicAuthData -> IO (BasicAuthResult usr)
ba BasicAuthData
e) DelayedIO (BasicAuthResult usr)
-> (BasicAuthResult usr -> DelayedIO usr) -> DelayedIO usr
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \BasicAuthResult usr
res -> case BasicAuthResult usr
res of
       BasicAuthResult usr
BadPassword    -> DelayedIO usr
forall a. DelayedIO a
plzAuthenticate
       BasicAuthResult usr
NoSuchUser     -> DelayedIO usr
forall a. DelayedIO a
plzAuthenticate
       BasicAuthResult usr
Unauthorized   -> ServerError -> DelayedIO usr
forall a. ServerError -> DelayedIO a
delayedFailFatal ServerError
err403
       Authorized usr
usr -> usr -> DelayedIO usr
forall (m :: * -> *) a. Monad m => a -> m a
return usr
usr
  where plzAuthenticate :: DelayedIO a
plzAuthenticate = ServerError -> DelayedIO a
forall a. ServerError -> DelayedIO a
delayedFailFatal ServerError
err401 { errHeaders :: [Header]
errHeaders = [ByteString -> Header
mkBAChallengerHdr ByteString
realm] }