{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Network.Wai.Middleware.SpnegoAuth (
spnegoAuth
, SpnegoAuthSettings(..)
, defaultSpnegoSettings
, spnegoAuthKey
, defaultAuthResponse
, defaultAuthError
) where
import Control.Arrow (second)
import Control.Exception (catch)
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Char8 as BS
import qualified Data.CaseInsensitive as CI
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import qualified Data.Vault.Lazy as V
import Network.HTTP.Types (status401)
import Network.HTTP.Types.Header (hAuthorization,
hWWWAuthenticate)
import Network.Wai (Application, Middleware,
Request (..),
mapResponseHeaders,
responseLBS, ResponseReceived,
Response)
import Network.Wai.Middleware.HttpAuth (extractBasicAuth)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe
import Network.Security.GssApi
import Network.Security.Kerberos
data SpnegoAuthSettings = SpnegoAuthSettings {
spnegoRealm :: Maybe BS.ByteString
, spnegoService :: Maybe BS.ByteString
, spnegoUserFull :: Bool
, spnegoBasicFailback :: Bool
, spnegoForceRealm :: Bool
, spnegoOnAuthError :: SpnegoAuthSettings -> Maybe (Either KrbException GssException) -> Application
, spnegoFakeBasicAuth :: Bool
}
defaultSpnegoSettings :: SpnegoAuthSettings
defaultSpnegoSettings = SpnegoAuthSettings {
spnegoRealm = Nothing
, spnegoService = Nothing
, spnegoUserFull = False
, spnegoBasicFailback = True
, spnegoForceRealm = True
, spnegoOnAuthError = defaultAuthError (hPutStrLn stderr)
, spnegoFakeBasicAuth = False
}
defaultAuthResponse :: SpnegoAuthSettings -> Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
defaultAuthResponse settings request respond = do
flushRequestBody request
respond $ responseLBS status401 (authHeaders settings) "Unauthorized"
where
authHeaders SpnegoAuthSettings{spnegoBasicFailback=True, spnegoRealm=Just realm} =
[(hWWWAuthenticate, "Negotiate"), (hWWWAuthenticate, "Basic realm=\"" <> realm <> "\"")]
authHeaders SpnegoAuthSettings{spnegoBasicFailback=True, spnegoRealm=Nothing} =
[(hWWWAuthenticate, "Negotiate"), (hWWWAuthenticate, "Basic realm=\"Auth\"")]
authHeaders SpnegoAuthSettings{spnegoBasicFailback=False} = [(hWWWAuthenticate, "Negotiate")]
flushRequestBody :: Request -> IO ()
flushRequestBody req = do
res <- requestBody req
case res of
"" -> return ()
_ -> flushRequestBody req
defaultAuthError :: (String -> IO ()) -> SpnegoAuthSettings -> Maybe (Either KrbException GssException) -> Application
defaultAuthError _ settings Nothing req respond = defaultAuthResponse settings req respond
defaultAuthError logerr settings (Just (Left (KrbException code err))) req respond = do
logerr $ "Kerberos error code: " <> show code <> ", error: " <> show err
defaultAuthResponse settings req respond
defaultAuthError logerr settings (Just (Right (GssException major majorTxt minor minorTxt))) req respond = do
logerr $ "GSSAPI major code: " <> show major <> ", error: " <> show majorTxt
<> ", minor code: " <> show minor <> ", error: " <> show minorTxt
defaultAuthResponse settings req respond
spnegoAuthKey :: V.Key BS.ByteString
spnegoAuthKey = unsafePerformIO V.newKey
{-# NOINLINE spnegoAuthKey #-}
spnegoAuth :: SpnegoAuthSettings -> Middleware
spnegoAuth settings@SpnegoAuthSettings{..} iapp req respond = do
let hdrs = requestHeaders req
case lookup hAuthorization hdrs of
Just val
| Just token <- getSpnegoToken val ->
runSpnegoCheck token `catch` (\exc -> spnegoOnAuthError settings (Just (Right exc)) req respond)
| Just (user, password) <- extractBasicAuth val ->
runKerberosCheck user password `catch` (\exc -> spnegoOnAuthError settings (Just (Left exc)) req respond)
_ -> spnegoOnAuthError settings Nothing req respond
where
insertUserToVault user myreq = myreq{vault = vault'}
where
vault' = V.insert spnegoAuthKey (stripSpnegoRealm user) (vault myreq)
fakeAuth user myreq
| spnegoFakeBasicAuth =
let oldHeaders = requestHeaders myreq
fakeHeader = (hAuthorization, "Basic " <> B64.encode (stripSpnegoRealm user <> ":password"))
in myreq{requestHeaders=fakeHeader : oldHeaders}
| otherwise = myreq
modifyKrbUser orig_user
| spnegoForceRealm = user <> fromMaybe "" (("@" <>) <$> spnegoRealm)
| BS.null realm, Just newrealm <- spnegoRealm = user <> "@" <> newrealm
| otherwise = orig_user
where
(user, realm) = splitPrincipal orig_user
runKerberosCheck origuser password = do
user <- krb5Resolve (modifyKrbUser origuser)
krb5Login user password
iapp (insertUserToVault user req) respond
runSpnegoCheck token = do
let service
| (BS.elem '@' <$> spnegoService) == Just True = spnegoService
| otherwise = (<> fromMaybe "" (("@" <>) <$> spnegoRealm)) <$> spnegoService
(user, output) <- runGssCheck service token
let neghdr = (hWWWAuthenticate, "Negotiate " <> B64.encode output)
iapp (fakeAuth user $ insertUserToVault user req) (respond . mapResponseHeaders (neghdr :))
stripSpnegoRealm user
| not spnegoUserFull, (clservice, clrealm) <- splitPrincipal user,
Just clrealm == spnegoRealm = clservice
| otherwise = user
getSpnegoToken :: BS.ByteString -> Maybe BS.ByteString
getSpnegoToken val
| CI.mk w1 == "negotiate" = either (const Nothing) Just (B64.decode $ BS.drop 1 w2)
| otherwise = Nothing
where
(w1, w2) = BS.break (==' ') val
splitPrincipal :: BS.ByteString -> (BS.ByteString, BS.ByteString)
splitPrincipal = second (BS.drop 1) . BS.break (== '@')