module Network.Wai.Middleware.SpnegoAuth (
spnegoAuth
, SpnegoAuthSettings(..)
, defaultSpnegoSettings
, spnegoAuthKey
, defaultAuthResponse
, defaultAuthError
) where
import Control.Arrow (second)
import Control.Exception (catch)
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Char8 as BS
import qualified Data.CaseInsensitive as CI
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import qualified Data.Vault.Lazy as V
import Network.HTTP.Types (status401)
import Network.HTTP.Types.Header (hAuthorization,
hWWWAuthenticate)
import Network.Wai (Application, Middleware,
Request (..),
mapResponseHeaders,
responseLBS, ResponseReceived,
Response)
import Network.Wai.Middleware.HttpAuth (extractBasicAuth)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe
import Network.Security.GssApi
import Network.Security.Kerberos
data SpnegoAuthSettings = SpnegoAuthSettings {
spnegoRealm :: Maybe BS.ByteString
, spnegoService :: Maybe BS.ByteString
, spnegoUserFull :: Bool
, spnegoBasicFailback :: Bool
, spnegoForceRealm :: Bool
, spnegoOnAuthError :: SpnegoAuthSettings -> Maybe (Either KrbException GssException) -> Application
, spnegoFakeBasicAuth :: Bool
}
defaultSpnegoSettings :: SpnegoAuthSettings
defaultSpnegoSettings = SpnegoAuthSettings {
spnegoRealm = Nothing
, spnegoService = Nothing
, spnegoUserFull = False
, spnegoBasicFailback = True
, spnegoForceRealm = True
, spnegoOnAuthError = defaultAuthError (hPutStrLn stderr)
, spnegoFakeBasicAuth = False
}
defaultAuthResponse :: SpnegoAuthSettings -> (Response -> IO ResponseReceived) -> IO ResponseReceived
defaultAuthResponse settings respond = respond $ responseLBS status401 (authHeaders settings) "Unauthorized"
where
authHeaders SpnegoAuthSettings{spnegoBasicFailback=True, spnegoRealm=Just realm} =
[(hWWWAuthenticate, "Negotiate"), (hWWWAuthenticate, "Basic realm=\"" <> realm <> "\"")]
authHeaders SpnegoAuthSettings{spnegoBasicFailback=True, spnegoRealm=Nothing} =
[(hWWWAuthenticate, "Negotiate"), (hWWWAuthenticate, "Basic realm=\"Auth\"")]
authHeaders SpnegoAuthSettings{spnegoBasicFailback=False} = [(hWWWAuthenticate, "Negotiate")]
defaultAuthError :: (String -> IO ()) -> SpnegoAuthSettings -> Maybe (Either KrbException GssException) -> Application
defaultAuthError _ settings Nothing _ respond = defaultAuthResponse settings respond
defaultAuthError logerr settings (Just (Left (KrbException code err))) _ respond = do
logerr $ "Kerberos error code: " <> show code <> ", error: " <> show err
defaultAuthResponse settings respond
defaultAuthError logerr settings (Just (Right (GssException code err))) _ respond = do
logerr $ "GSSAPI error code: " <> show code <> ", error: " <> show err
defaultAuthResponse settings respond
spnegoAuthKey :: V.Key BS.ByteString
spnegoAuthKey = unsafePerformIO V.newKey
spnegoAuth :: SpnegoAuthSettings -> Middleware
spnegoAuth settings@SpnegoAuthSettings{..} iapp req respond = do
let hdrs = requestHeaders req
case lookup hAuthorization hdrs of
Just val
| Just token <- getSpnegoToken val ->
runSpnegoCheck token `catch` (\exc -> spnegoOnAuthError settings (Just (Right exc)) req respond)
| Just (user, password) <- extractBasicAuth val ->
runKerberosCheck user password `catch` (\exc -> spnegoOnAuthError settings (Just (Left exc)) req respond)
_ -> spnegoOnAuthError settings Nothing req respond
where
insertUserToVault user myreq = myreq{vault = vault'}
where
vault' = V.insert spnegoAuthKey (stripSpnegoRealm user) (vault myreq)
fakeAuth user myreq
| spnegoFakeBasicAuth =
let oldHeaders = requestHeaders myreq
fakeHeader = (hAuthorization, "Basic " <> B64.encode (stripSpnegoRealm user <> ":password"))
in myreq{requestHeaders=fakeHeader : oldHeaders}
| otherwise = myreq
modifyKrbUser orig_user
| spnegoForceRealm = user <> fromMaybe "" (("@" <>) <$> spnegoRealm)
| BS.null realm, Just newrealm <- spnegoRealm = user <> "@" <> newrealm
| otherwise = orig_user
where
(user, realm) = splitPrincipal orig_user
runKerberosCheck origuser password = do
user <- krb5Resolve (modifyKrbUser origuser)
krb5Login user password
iapp (insertUserToVault user req) respond
runSpnegoCheck token = do
let service
| (BS.elem '@' <$> spnegoService) == Just True = spnegoService
| otherwise = (<> fromMaybe "" (("@" <>) <$> spnegoRealm)) <$> spnegoService
(user, output) <- runGssCheck service token
let neghdr = (hWWWAuthenticate, "Negotiate " <> B64.encode output)
iapp (fakeAuth user $ insertUserToVault user req) (respond . mapResponseHeaders (neghdr :))
stripSpnegoRealm user
| not spnegoUserFull, (clservice, clrealm) <- splitPrincipal user,
Just clrealm == spnegoRealm = clservice
| otherwise = user
getSpnegoToken :: BS.ByteString -> Maybe BS.ByteString
getSpnegoToken val
| CI.mk w1 == "negotiate" = either (const Nothing) Just (B64.decode $ BS.drop 1 w2)
| otherwise = Nothing
where
(w1, w2) = BS.break (==' ') val
splitPrincipal :: BS.ByteString -> (BS.ByteString, BS.ByteString)
splitPrincipal = second (BS.drop 1) . BS.break (== '@')