{-# LANGUAGE FlexibleContexts #-}
-- | Support for basic access authentication <http://en.wikipedia.org/wiki/Basic_access_authentication>
module Happstack.Server.Auth where

import Data.Foldable (foldl')
import Data.Bits (xor, (.|.))
import Data.Maybe (fromMaybe)
import Control.Monad                             (MonadPlus(mzero, mplus))
import Data.ByteString.Base64                    as Base64
import qualified Data.ByteString                 as BS
import qualified Data.ByteString.Char8           as B
import qualified Data.Map                        as M
import Happstack.Server.Monads                   (Happstack, escape, getHeaderM, setHeaderM)
import Happstack.Server.Response                 (unauthorized, toResponse)

-- | A simple HTTP basic authentication guard.
--
-- If authentication fails, this part will call 'mzero'.
-- 
-- example:
--
-- > main = simpleHTTP nullConf $ 
-- >  msum [ basicAuth "127.0.0.1" (fromList [("happstack","rocks")]) $ ok "You are in the secret club"
-- >       , ok "You are not in the secret club." 
-- >       ]
-- 
basicAuth :: (Happstack m) =>
   String -- ^ the realm name
   -> M.Map String String -- ^ the username password map
   -> m a -- ^ the part to guard
   -> m a
basicAuth :: forall (m :: * -> *) a.
Happstack m =>
[Char] -> Map [Char] [Char] -> m a -> m a
basicAuth [Char]
realmName Map [Char] [Char]
authMap = forall (m :: * -> *) a.
Happstack m =>
(ByteString -> ByteString -> Bool) -> [Char] -> m a -> m a
basicAuthBy (Map [Char] [Char] -> ByteString -> ByteString -> Bool
validLoginPlaintext Map [Char] [Char]
authMap) [Char]
realmName


-- | Generalized version of 'basicAuth'.
--
-- The function that checks the username password combination must be
-- supplied as first argument.
--
-- example:
--
-- > main = simpleHTTP nullConf $
-- >  msum [ basicAuth' (validLoginPlaintext (fromList [("happstack","rocks")])) "127.0.0.1" $ ok "You are in the secret club"
-- >       , ok "You are not in the secret club."
-- >       ]
--
basicAuthBy :: (Happstack m) =>
   (B.ByteString -> B.ByteString -> Bool) -- ^ function that returns true if the name password combination is valid
   -> String -- ^ the realm name
   -> m a -- ^ the part to guard
   -> m a
basicAuthBy :: forall (m :: * -> *) a.
Happstack m =>
(ByteString -> ByteString -> Bool) -> [Char] -> m a -> m a
basicAuthBy ByteString -> ByteString -> Bool
validLogin [Char]
realmName m a
xs = forall {b}. m b
basicAuthImpl forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` m a
xs
  where
    basicAuthImpl :: m b
basicAuthImpl = do
        Maybe ByteString
aHeader <- forall (m :: * -> *).
ServerMonad m =>
[Char] -> m (Maybe ByteString)
getHeaderM [Char]
"authorization"
        case Maybe ByteString
aHeader of
            Maybe ByteString
Nothing -> forall (m :: * -> *) a. Happstack m => m a
err
            Just ByteString
x ->
                do (ByteString
name, ByteString
password) <- forall {m :: * -> *}.
Happstack m =>
ByteString -> m (ByteString, ByteString)
parseHeader ByteString
x
                   if ByteString -> Int
B.length ByteString
password forall a. Ord a => a -> a -> Bool
> Int
0
                      Bool -> Bool -> Bool
&& ByteString -> Char
B.head ByteString
password forall a. Eq a => a -> a -> Bool
== Char
':'
                      Bool -> Bool -> Bool
&& ByteString -> ByteString -> Bool
validLogin ByteString
name (HasCallStack => ByteString -> ByteString
B.tail ByteString
password)
                     then forall (m :: * -> *) a. MonadPlus m => m a
mzero
                     else forall (m :: * -> *) a. Happstack m => m a
err
    parseHeader :: ByteString -> m (ByteString, ByteString)
parseHeader ByteString
h =
      case ByteString -> Either [Char] ByteString
Base64.decode forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> ByteString
B.drop Int
6 forall a b. (a -> b) -> a -> b
$ ByteString
h of
        (Left [Char]
_)   -> forall (m :: * -> *) a. Happstack m => m a
err
        (Right ByteString
bs) -> forall (m :: * -> *) a. Monad m => a -> m a
return ((Char -> Bool) -> ByteString -> (ByteString, ByteString)
B.break (Char
':'forall a. Eq a => a -> a -> Bool
==) ByteString
bs)
    headerName :: [Char]
headerName  = [Char]
"WWW-Authenticate"
    headerValue :: [Char]
headerValue = [Char]
"Basic realm=\"" forall a. [a] -> [a] -> [a]
++ [Char]
realmName forall a. [a] -> [a] -> [a]
++ [Char]
"\""
    err :: (Happstack m) => m a
    err :: forall (m :: * -> *) a. Happstack m => m a
err = forall a (m :: * -> *) b.
(WebMonad a m, FilterMonad a m) =>
m a -> m b
escape forall a b. (a -> b) -> a -> b
$ do
            forall (m :: * -> *).
FilterMonad Response m =>
[Char] -> [Char] -> m ()
setHeaderM [Char]
headerName [Char]
headerValue
            forall (m :: * -> *) a. FilterMonad Response m => a -> m a
unauthorized forall a b. (a -> b) -> a -> b
$ forall a. ToMessage a => a -> Response
toResponse [Char]
"Not authorized"


-- | Function that looks up the plain text password for username in a
-- Map and returns True if it matches with the given password.
--
-- Note: The implementation is hardened against timing attacks but not
-- completely safe. Ideally you should build your own predicate, using
-- a robust constant-time equality comparison from a cryptographic
-- library like sodium.
validLoginPlaintext ::
  M.Map String String -- ^ the username password map
  -> B.ByteString -- ^ the username
  -> B.ByteString -- ^ the password
  -> Bool
validLoginPlaintext :: Map [Char] [Char] -> ByteString -> ByteString -> Bool
validLoginPlaintext Map [Char] [Char]
authMap ByteString
name ByteString
password = forall a. a -> Maybe a -> a
fromMaybe Bool
False forall a b. (a -> b) -> a -> b
$ do
    [Char]
r <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (ByteString -> [Char]
B.unpack ByteString
name) Map [Char] [Char]
authMap
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> ByteString -> Bool
constTimeEq ([Char] -> ByteString
B.pack [Char]
r) ByteString
password)
  where
    -- (Mostly) constant time equality of bytestrings to prevent timing attacks by testing out passwords. This still
    -- allows to extract the length of the configured password via timing attacks. This implementation is still brittle
    -- in the sense that it relies on GHC not unrolling or vectorizing the loop.
    {-# NOINLINE constTimeEq #-}
    constTimeEq :: BS.ByteString -> BS.ByteString -> Bool
    constTimeEq :: ByteString -> ByteString -> Bool
constTimeEq ByteString
x ByteString
y
      | ByteString -> Int
BS.length ByteString
x forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
BS.length ByteString
y
      = Bool
False

      | Bool
otherwise
      = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Bits a => a -> a -> a
(.|.) Word8
0 (forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
BS.zipWith forall a. Bits a => a -> a -> a
xor ByteString
x ByteString
y) forall a. Eq a => a -> a -> Bool
== Word8
0