{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}

-- | Implements HTTP Basic Authentication.
--
-- This module may add digest authentication in the future.
module Network.Wai.Middleware.HttpAuth (
    -- * Middleware
    basicAuth,
    basicAuth',
    CheckCreds,
    AuthSettings,
    authRealm,
    authOnNoAuth,
    authIsProtected,

    -- * Helping functions
    extractBasicAuth,
    extractBearerAuth,
) where

#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import Data.ByteString.Base64 (decodeLenient)
import Data.String (IsString (..))
import Data.Word8 (isSpace, toLower, _colon)
import Network.HTTP.Types (hAuthorization, hContentType, status401)
import Network.Wai (
    Application,
    Middleware,
    Request (requestHeaders),
    responseLBS,
 )

-- | Check if a given username and password is valid.
type CheckCreds =
    ByteString
    -> ByteString
    -> IO Bool

-- | Perform basic authentication.
--
-- > basicAuth (\u p -> return $ u == "michael" && p == "mypass") "My Realm"
--
-- @since 1.3.4
basicAuth
    :: CheckCreds
    -> AuthSettings
    -> Middleware
basicAuth :: CheckCreds -> AuthSettings -> Middleware
basicAuth = (Request -> CheckCreds) -> AuthSettings -> Middleware
basicAuth' ((Request -> CheckCreds) -> AuthSettings -> Middleware)
-> (CheckCreds -> Request -> CheckCreds)
-> CheckCreds
-> AuthSettings
-> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CheckCreds -> Request -> CheckCreds
forall a b. a -> b -> a
const

-- | Like 'basicAuth', but also passes a request to the authentication function.
--
-- @since 3.0.19
basicAuth'
    :: (Request -> CheckCreds)
    -> AuthSettings
    -> Middleware
basicAuth' :: (Request -> CheckCreds) -> AuthSettings -> Middleware
basicAuth' Request -> CheckCreds
checkCreds AuthSettings{ByteString
ByteString -> Application
Request -> IO Bool
authRealm :: AuthSettings -> ByteString
authOnNoAuth :: AuthSettings -> ByteString -> Application
authIsProtected :: AuthSettings -> Request -> IO Bool
authRealm :: ByteString
authOnNoAuth :: ByteString -> Application
authIsProtected :: Request -> IO Bool
..} Application
app Request
req Response -> IO ResponseReceived
sendResponse = do
    Bool
isProtected <- Request -> IO Bool
authIsProtected Request
req
    Bool
allowed <- if Bool
isProtected then IO Bool
check else Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    if Bool
allowed
        then Application
app Request
req Response -> IO ResponseReceived
sendResponse
        else ByteString -> Application
authOnNoAuth ByteString
authRealm Request
req Response -> IO ResponseReceived
sendResponse
  where
    check :: IO Bool
check =
        case HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hAuthorization (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req)
            Maybe ByteString
-> (ByteString -> Maybe (ByteString, ByteString))
-> Maybe (ByteString, ByteString)
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Maybe (ByteString, ByteString)
extractBasicAuth of
            Maybe (ByteString, ByteString)
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
            Just (ByteString
username, ByteString
password) -> Request -> CheckCreds
checkCreds Request
req ByteString
username ByteString
password

-- | Basic authentication settings. This value is an instance of
-- @IsString@, so the recommended approach to create a value is to
-- provide a string literal (which will be the realm) and then
-- overriding individual fields.
--
-- > "My Realm" { authIsProtected = someFunc } :: AuthSettings
--
-- @since 1.3.4
data AuthSettings = AuthSettings
    { AuthSettings -> ByteString
authRealm :: !ByteString
    -- ^
    --
    -- @since 1.3.4
    , AuthSettings -> ByteString -> Application
authOnNoAuth :: !(ByteString -> Application)
    -- ^ Takes the realm and returns an appropriate 401 response when
    -- authentication is not provided.
    --
    -- @since 1.3.4
    , AuthSettings -> Request -> IO Bool
authIsProtected :: !(Request -> IO Bool)
    -- ^ Determine if access to the requested resource is restricted.
    --
    -- Default: always returns @True@.
    --
    -- @since 1.3.4
    }

instance IsString AuthSettings where
    fromString :: String -> AuthSettings
fromString String
s =
        AuthSettings
            { authRealm :: ByteString
authRealm = String -> ByteString
forall a. IsString a => String -> a
fromString String
s
            , authOnNoAuth :: ByteString -> Application
authOnNoAuth = \ByteString
realm Request
_req Response -> IO ResponseReceived
f ->
                Response -> IO ResponseReceived
f (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
                    Status -> [(HeaderName, ByteString)] -> ByteString -> Response
responseLBS
                        Status
status401
                        [ (HeaderName
hContentType, ByteString
"text/plain")
                        ,
                            ( HeaderName
"WWW-Authenticate"
                            , [ByteString] -> ByteString
S.concat
                                [ ByteString
"Basic realm=\""
                                , ByteString
realm
                                , ByteString
"\""
                                ]
                            )
                        ]
                        ByteString
"Basic authentication is required"
            , authIsProtected :: Request -> IO Bool
authIsProtected = IO Bool -> Request -> IO Bool
forall a b. a -> b -> a
const (IO Bool -> Request -> IO Bool) -> IO Bool -> Request -> IO Bool
forall a b. (a -> b) -> a -> b
$ Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            }

-- | Extract basic authentication data from usually __Authorization__
-- header value. Returns username and password
--
-- @since 3.0.5
extractBasicAuth :: ByteString -> Maybe (ByteString, ByteString)
extractBasicAuth :: ByteString -> Maybe (ByteString, ByteString)
extractBasicAuth ByteString
bs =
    let (ByteString
x, ByteString
y) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break Word8 -> Bool
isSpace ByteString
bs
     in if (Word8 -> Word8) -> ByteString -> ByteString
S.map Word8 -> Word8
toLower ByteString
x ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"basic"
            then ByteString -> Maybe (ByteString, ByteString)
extract (ByteString -> Maybe (ByteString, ByteString))
-> ByteString -> Maybe (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile Word8 -> Bool
isSpace ByteString
y
            else Maybe (ByteString, ByteString)
forall a. Maybe a
Nothing
  where
    extract :: ByteString -> Maybe (ByteString, ByteString)
extract ByteString
encoded =
        let raw :: ByteString
raw = ByteString -> ByteString
decodeLenient ByteString
encoded
            (ByteString
username, ByteString
password') = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_colon) ByteString
raw
         in (ByteString
username,) (ByteString -> (ByteString, ByteString))
-> ((Word8, ByteString) -> ByteString)
-> (Word8, ByteString)
-> (ByteString, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((Word8, ByteString) -> (ByteString, ByteString))
-> Maybe (Word8, ByteString) -> Maybe (ByteString, ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Maybe (Word8, ByteString)
S.uncons ByteString
password'

-- | Extract bearer authentication data from __Authorization__ header
-- value. Returns bearer token
--
-- @since 3.0.5
extractBearerAuth :: ByteString -> Maybe ByteString
extractBearerAuth :: ByteString -> Maybe ByteString
extractBearerAuth ByteString
bs =
    let (ByteString
x, ByteString
y) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break Word8 -> Bool
isSpace ByteString
bs
     in if (Word8 -> Word8) -> ByteString -> ByteString
S.map Word8 -> Word8
toLower ByteString
x ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"bearer"
            then ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile Word8 -> Bool
isSpace ByteString
y
            else Maybe ByteString
forall a. Maybe a
Nothing