module Network.Wai.Middleware.SpnegoAuth (
spnegoAuth
, SpnegoAuthSettings(..)
, defaultSpnegoSettings
, spnegoAuthKey
) where
import Control.Arrow (second)
import Control.Exception (catch)
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Char8 as BS
import qualified Data.CaseInsensitive as CI
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import qualified Data.Vault.Lazy as V
import Network.HTTP.Types (status401)
import Network.HTTP.Types.Header (hAuthorization,
hWWWAuthenticate)
import Network.Wai (Application, Middleware,
Request (..),
mapResponseHeaders,
responseLBS)
import Network.Wai.Middleware.HttpAuth (extractBasicAuth)
import System.IO.Unsafe
import Network.Security.GssApi
import Network.Security.Kerberos
data SpnegoAuthSettings = SpnegoAuthSettings {
spnegoRealm :: Maybe BS.ByteString
, spnegoService :: Maybe BS.ByteString
, spnegoUserFull :: Bool
, spnegoBasicFailback :: Bool
, spnegoForceRealm :: Bool
, spnegoOnAuthError :: SpnegoAuthSettings -> Maybe (Either KrbException GssException) -> Application
, spnegoFakeBasicAuth :: Bool
}
defaultSpnegoSettings :: SpnegoAuthSettings
defaultSpnegoSettings = SpnegoAuthSettings {
spnegoRealm = Nothing
, spnegoService = Nothing
, spnegoUserFull = False
, spnegoBasicFailback = True
, spnegoForceRealm = True
, spnegoOnAuthError = authError
, spnegoFakeBasicAuth = False
}
where
authHeaders SpnegoAuthSettings{spnegoBasicFailback=True, spnegoRealm=Just realm} =
[(hWWWAuthenticate, "Negotiate"), (hWWWAuthenticate, "Basic realm=\"" <> realm <> "\"")]
authHeaders SpnegoAuthSettings{spnegoBasicFailback=True, spnegoRealm=Nothing} =
[(hWWWAuthenticate, "Negotiate"), (hWWWAuthenticate, "Basic realm=\"Auth\"")]
authHeaders SpnegoAuthSettings{spnegoBasicFailback=False} = [(hWWWAuthenticate, "Negotiate")]
baseResponse settings respond = respond $ responseLBS status401 (authHeaders settings) "Unauthorized"
authError settings Nothing _ respond = baseResponse settings respond
authError settings (Just (Left (KrbException _ err))) _ respond = do
putStrLn $ "Kerberos error: " <> show err
baseResponse settings respond
authError settings (Just (Right (GssException _ err))) _ respond = do
putStrLn $ "GSSAPI error: " <> show err
baseResponse settings respond
spnegoAuthKey :: V.Key BS.ByteString
spnegoAuthKey = unsafePerformIO V.newKey
spnegoAuth :: SpnegoAuthSettings -> Middleware
spnegoAuth settings@SpnegoAuthSettings{..} iapp req respond = do
let hdrs = requestHeaders req
case lookup hAuthorization hdrs of
Just val
| Just token <- getSpnegoToken val ->
runSpnegoCheck token `catch` (\exc -> spnegoOnAuthError settings (Just (Right exc)) req respond)
| Just (user, password) <- extractBasicAuth val ->
runKerberosCheck user password `catch` (\exc -> spnegoOnAuthError settings (Just (Left exc)) req respond)
_ -> spnegoOnAuthError settings Nothing req respond
where
insertUserToVault user myreq = myreq{vault = vault'}
where
vault' = V.insert spnegoAuthKey (stripSpnegoRealm user) (vault myreq)
fakeAuth user myreq
| spnegoFakeBasicAuth =
let oldHeaders = requestHeaders myreq
fakeHeader = (hAuthorization, "Basic " <> B64.encode (user <> ":password"))
in myreq{requestHeaders=fakeHeader : oldHeaders}
| otherwise = myreq
modifyKrbUser orig_user
| spnegoForceRealm = user <> fromMaybe "" (("@" <>) <$> spnegoRealm)
| BS.null realm, Just newrealm <- spnegoRealm = user <> "@" <> newrealm
| otherwise = orig_user
where
(user, realm) = splitPrincipal orig_user
runKerberosCheck origuser password = do
user <- krb5Resolve (modifyKrbUser origuser)
krb5Login user password
iapp (insertUserToVault user req) respond
runSpnegoCheck token = do
let service
| (BS.elem '@' <$> spnegoService) == Just True = spnegoService
| otherwise = (<> fromMaybe "" (("@" <>) <$> spnegoRealm)) <$> spnegoService
(user, output) <- runGssCheck service token
let neghdr = (hWWWAuthenticate, "Negotiate " <> B64.encode output)
iapp (fakeAuth user $ insertUserToVault user req) (respond . mapResponseHeaders (neghdr :))
stripSpnegoRealm user
| not spnegoUserFull, (clservice, clrealm) <- splitPrincipal user,
Just clrealm == spnegoRealm = clservice
| otherwise = user
getSpnegoToken :: BS.ByteString -> Maybe BS.ByteString
getSpnegoToken val
| CI.mk w1 == "negotiate" = either (const Nothing) Just (B64.decode $ BS.drop 1 w2)
| otherwise = Nothing
where
(w1, w2) = BS.break (==' ') val
splitPrincipal :: BS.ByteString -> (BS.ByteString, BS.ByteString)
splitPrincipal = second (BS.drop 1) . BS.break (== '@')