{-# LANGUAGE OverloadedStrings #-}
module Web.Simple.Auth
( AuthRouter
, basicAuthRoute, basicAuth, authRewriteReq
) where
import Control.Monad
import Data.ByteString.Base64
import qualified Data.ByteString.Char8 as S8
import Data.Maybe
import Network.HTTP.Types
import Network.Wai
import Web.Simple.Responses
import Web.Simple.Controller
type AuthRouter r a = (Request -> S8.ByteString
-> S8.ByteString
-> Controller r (Maybe Request))
-> Controller r a
-> Controller r a
basicAuthRoute :: String -> AuthRouter r a
basicAuthRoute :: forall r a. String -> AuthRouter r a
basicAuthRoute String
realm Request -> ByteString -> ByteString -> Controller r (Maybe Request)
testAuth Controller r a
next = do
Request
req <- forall s. Controller s Request
request
let authStr :: ByteString
authStr = forall a. a -> Maybe a -> a
fromMaybe ByteString
"" forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hAuthorization (Request -> RequestHeaders
requestHeaders Request
req)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int -> ByteString -> ByteString
S8.take Int
5 ByteString
authStr forall a. Eq a => a -> a -> Bool
/= ByteString
"Basic") forall {s} {a}. Controller s a
requireAuth
case forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Char -> ByteString -> [ByteString]
S8.split Char
':') forall a b. (a -> b) -> a -> b
$ ByteString -> Either String ByteString
decode forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
S8.drop Int
6 ByteString
authStr of
Right (ByteString
user:ByteString
pwd:[]) -> do
Maybe Request
mfin <- Request -> ByteString -> ByteString -> Controller r (Maybe Request)
testAuth Request
req ByteString
user ByteString
pwd
forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall {s} {a}. Controller s a
requireAuth (\Request
finReq -> forall s a.
(Request -> Request) -> Controller s a -> Controller s a
localRequest (forall a b. a -> b -> a
const Request
finReq) Controller r a
next) Maybe Request
mfin
Either String [ByteString]
_ -> forall {s} {a}. Controller s a
requireAuth
where requireAuth :: Controller s a
requireAuth = forall s a. Response -> Controller s a
respond forall a b. (a -> b) -> a -> b
$ String -> Response
requireBasicAuth String
realm
authRewriteReq :: AuthRouter r a
-> (S8.ByteString -> S8.ByteString -> Controller r Bool)
-> Controller r a
-> Controller r a
authRewriteReq :: forall r a.
AuthRouter r a
-> (ByteString -> ByteString -> Controller r Bool)
-> Controller r a
-> Controller r a
authRewriteReq AuthRouter r a
authRouter ByteString -> ByteString -> Controller r Bool
testAuth Controller r a
rt =
AuthRouter r a
authRouter (\Request
req ByteString
user ByteString
pwd -> do
Bool
success <- ByteString -> ByteString -> Controller r Bool
testAuth ByteString
user ByteString
pwd
if Bool
success then
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Request -> ByteString -> Request
transReq Request
req ByteString
user
else forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing) Controller r a
rt
where transReq :: Request -> ByteString -> Request
transReq Request
req ByteString
user = Request
req
{ requestHeaders :: RequestHeaders
requestHeaders = (HeaderName
"X-User", ByteString
user)forall a. a -> [a] -> [a]
:(Request -> RequestHeaders
requestHeaders Request
req)}
basicAuth :: String
-> S8.ByteString
-> S8.ByteString
-> Controller r a -> Controller r a
basicAuth :: forall r a.
String
-> ByteString -> ByteString -> Controller r a -> Controller r a
basicAuth String
realm ByteString
user ByteString
pwd = forall r a.
AuthRouter r a
-> (ByteString -> ByteString -> Controller r Bool)
-> Controller r a
-> Controller r a
authRewriteReq (forall r a. String -> AuthRouter r a
basicAuthRoute String
realm)
(\ByteString
u ByteString
p -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString
u forall a. Eq a => a -> a -> Bool
== ByteString
user Bool -> Bool -> Bool
&& ByteString
p forall a. Eq a => a -> a -> Bool
== ByteString
pwd)