{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Network.Wai.Middleware.Auth
(
AuthSettings
, defaultAuthSettings
, setAuthKey
, setAuthAppRootStatic
, setAuthAppRootGeneric
, setAuthSessionAge
, setAuthPrefix
, setAuthCookieName
, setAuthProviders
, setAuthProvidersTemplate
, mkAuthMiddleware
, smartAppRoot
, waiMiddlewareAuthVersion
, getAuthUser
, getAuthUserFromVault
, getDeleteSessionHeader
, decodeKey
) where
import Blaze.ByteString.Builder (fromByteString)
import Data.Binary (Binary)
import qualified Data.ByteString as S
import Data.ByteString.Builder (Builder)
import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With,
encodeUtf8)
import Data.Text.Encoding.Error (lenientDecode)
import qualified Data.Vault.Lazy as Vault
import Data.Version (Version)
import Foreign.C.Types (CTime (..))
import GHC.Generics (Generic)
import Network.HTTP.Types (Header, status200,
status303, status404,
status501)
import Network.Wai (mapResponseHeaders,
Middleware, Request,
pathInfo, rawPathInfo,
rawQueryString,
responseBuilder,
responseLBS, vault)
import Network.Wai.Auth.AppRoot
import Network.Wai.Auth.ClientSession
import Network.Wai.Middleware.Auth.Provider
import Network.Wai.Auth.Tools (decodeKey)
import qualified Paths_wai_middleware_auth as Paths
import System.IO.Unsafe (unsafePerformIO)
import System.PosixCompat.Time (epochTime)
import Text.Hamlet (Render)
data AuthSettings = AuthSettings
{ AuthSettings -> IO Key
asGetKey :: IO Key
, AuthSettings -> Request -> IO Text
asGetAppRoot :: Request -> IO T.Text
, AuthSettings -> Int
asSessionAge :: Int
, AuthSettings -> Text
asAuthPrefix :: T.Text
, AuthSettings -> ByteString
asStateKey :: S.ByteString
, AuthSettings -> Providers
asProviders :: Providers
, AuthSettings
-> Maybe Text -> Render Provider -> Providers -> Builder
asProvidersTemplate :: Maybe T.Text -> Render Provider -> Providers -> Builder
}
defaultAuthSettings :: AuthSettings
defaultAuthSettings :: AuthSettings
defaultAuthSettings =
AuthSettings :: IO Key
-> (Request -> IO Text)
-> Int
-> Text
-> ByteString
-> Providers
-> (Maybe Text -> Render Provider -> Providers -> Builder)
-> AuthSettings
AuthSettings
{ asGetKey :: IO Key
asGetKey = IO Key
getDefaultKey
, asGetAppRoot :: Request -> IO Text
asGetAppRoot = Text -> IO Text
forall (m :: * -> *) a. Monad m => a -> m a
return (Text -> IO Text) -> (Request -> Text) -> Request -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Request -> Text
smartAppRoot
, asSessionAge :: Int
asSessionAge = Int
3600
, asAuthPrefix :: Text
asAuthPrefix = Text
"_auth_middleware"
, asStateKey :: ByteString
asStateKey = ByteString
"auth_state"
, asProviders :: Providers
asProviders = Providers
forall k v. HashMap k v
HM.empty
, asProvidersTemplate :: Maybe Text -> Render Provider -> Providers -> Builder
asProvidersTemplate = Maybe Text -> Render Provider -> Providers -> Builder
providersTemplate
}
setAuthKey :: IO Key -> AuthSettings -> AuthSettings
setAuthKey :: IO Key -> AuthSettings -> AuthSettings
setAuthKey IO Key
x AuthSettings
as = AuthSettings
as { asGetKey :: IO Key
asGetKey = IO Key
x }
setAuthCookieName :: S.ByteString -> AuthSettings -> AuthSettings
setAuthCookieName :: ByteString -> AuthSettings -> AuthSettings
setAuthCookieName ByteString
x AuthSettings
as = AuthSettings
as { asStateKey :: ByteString
asStateKey = ByteString
x }
setAuthPrefix :: T.Text -> AuthSettings -> AuthSettings
setAuthPrefix :: Text -> AuthSettings -> AuthSettings
setAuthPrefix Text
x AuthSettings
as = AuthSettings
as { asAuthPrefix :: Text
asAuthPrefix = Text
x }
setAuthAppRootStatic :: T.Text -> AuthSettings -> AuthSettings
setAuthAppRootStatic :: Text -> AuthSettings -> AuthSettings
setAuthAppRootStatic = (Request -> IO Text) -> AuthSettings -> AuthSettings
setAuthAppRootGeneric ((Request -> IO Text) -> AuthSettings -> AuthSettings)
-> (Text -> Request -> IO Text)
-> Text
-> AuthSettings
-> AuthSettings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Text -> Request -> IO Text
forall a b. a -> b -> a
const (IO Text -> Request -> IO Text)
-> (Text -> IO Text) -> Text -> Request -> IO Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> IO Text
forall (m :: * -> *) a. Monad m => a -> m a
return
setAuthAppRootGeneric :: (Request -> IO T.Text) -> AuthSettings -> AuthSettings
setAuthAppRootGeneric :: (Request -> IO Text) -> AuthSettings -> AuthSettings
setAuthAppRootGeneric Request -> IO Text
x AuthSettings
as = AuthSettings
as { asGetAppRoot :: Request -> IO Text
asGetAppRoot = Request -> IO Text
x }
setAuthSessionAge :: Int -> AuthSettings -> AuthSettings
setAuthSessionAge :: Int -> AuthSettings -> AuthSettings
setAuthSessionAge Int
x AuthSettings
as = AuthSettings
as { asSessionAge :: Int
asSessionAge = Int
x }
setAuthProviders :: Providers -> AuthSettings -> AuthSettings
setAuthProviders :: Providers -> AuthSettings -> AuthSettings
setAuthProviders !Providers
ps AuthSettings
as = AuthSettings
as { asProviders :: Providers
asProviders = Providers
ps }
setAuthProvidersTemplate :: (Maybe T.Text -> Render Provider -> Providers -> Builder)
-> AuthSettings
-> AuthSettings
setAuthProvidersTemplate :: (Maybe Text -> Render Provider -> Providers -> Builder)
-> AuthSettings -> AuthSettings
setAuthProvidersTemplate Maybe Text -> Render Provider -> Providers -> Builder
t AuthSettings
as = AuthSettings
as { asProvidersTemplate :: Maybe Text -> Render Provider -> Providers -> Builder
asProvidersTemplate = Maybe Text -> Render Provider -> Providers -> Builder
t }
data AuthState = AuthNeedRedirect !S.ByteString
| AuthLoggedIn !AuthUser
deriving ((forall x. AuthState -> Rep AuthState x)
-> (forall x. Rep AuthState x -> AuthState) -> Generic AuthState
forall x. Rep AuthState x -> AuthState
forall x. AuthState -> Rep AuthState x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep AuthState x -> AuthState
$cfrom :: forall x. AuthState -> Rep AuthState x
Generic, Int -> AuthState -> ShowS
[AuthState] -> ShowS
AuthState -> String
(Int -> AuthState -> ShowS)
-> (AuthState -> String)
-> ([AuthState] -> ShowS)
-> Show AuthState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AuthState] -> ShowS
$cshowList :: [AuthState] -> ShowS
show :: AuthState -> String
$cshow :: AuthState -> String
showsPrec :: Int -> AuthState -> ShowS
$cshowsPrec :: Int -> AuthState -> ShowS
Show)
instance Binary AuthState
mkAuthMiddleware :: AuthSettings -> IO Middleware
mkAuthMiddleware :: AuthSettings -> IO Middleware
mkAuthMiddleware AuthSettings {Int
IO Key
ByteString
Text
Providers
Maybe Text -> Render Provider -> Providers -> Builder
Request -> IO Text
asProvidersTemplate :: Maybe Text -> Render Provider -> Providers -> Builder
asProviders :: Providers
asStateKey :: ByteString
asAuthPrefix :: Text
asSessionAge :: Int
asGetAppRoot :: Request -> IO Text
asGetKey :: IO Key
asProvidersTemplate :: AuthSettings
-> Maybe Text -> Render Provider -> Providers -> Builder
asProviders :: AuthSettings -> Providers
asStateKey :: AuthSettings -> ByteString
asAuthPrefix :: AuthSettings -> Text
asSessionAge :: AuthSettings -> Int
asGetAppRoot :: AuthSettings -> Request -> IO Text
asGetKey :: AuthSettings -> IO Key
..} = do
Key
secretKey <- IO Key
asGetKey
let saveAuthState :: AuthState -> IO Header
saveAuthState = Key -> ByteString -> Int -> AuthState -> IO Header
forall value.
Binary value =>
Key -> ByteString -> Int -> value -> IO Header
saveCookieValue Key
secretKey ByteString
asStateKey Int
asSessionAge
authRouteRender :: Render Provider
authRouteRender = Maybe Text -> Text -> [Text] -> Render Provider
mkRouteRender Maybe Text
forall a. Maybe a
Nothing Text
asAuthPrefix []
let enforceLogin :: ByteString -> Request -> (Response -> IO b) -> IO b
enforceLogin ByteString
protectedPath Request
req Response -> IO b
respond =
case Request -> [Text]
pathInfo Request
req of
(Text
prefix:[Text]
rest)
| Text
prefix Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
asAuthPrefix ->
case [Text]
rest of
[] ->
case Providers -> [Provider]
forall k v. HashMap k v -> [v]
HM.elems Providers
asProviders of
[] ->
Response -> IO b
respond (Response -> IO b) -> Response -> IO b
forall a b. (a -> b) -> a -> b
$
Status -> ResponseHeaders -> ByteString -> Response
responseLBS
Status
status501
[]
ByteString
"No Authentication providers available."
[Provider
soleProvider] ->
let loginUrl :: ByteString
loginUrl =
Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Render Provider
authRouteRender Provider
soleProvider []
in Response -> IO b
respond (Response -> IO b) -> Response -> IO b
forall a b. (a -> b) -> a -> b
$
Status -> ResponseHeaders -> ByteString -> Response
responseLBS
Status
status303
[(HeaderName
"Location", ByteString
loginUrl)]
ByteString
"Redirecting to Login page"
[Provider]
_ ->
Response -> IO b
respond (Response -> IO b) -> Response -> IO b
forall a b. (a -> b) -> a -> b
$
Status -> ResponseHeaders -> Builder -> Response
responseBuilder Status
status200 [] (Builder -> Response) -> Builder -> Response
forall a b. (a -> b) -> a -> b
$
Maybe Text -> Render Provider -> Providers -> Builder
asProvidersTemplate Maybe Text
forall a. Maybe a
Nothing Render Provider
authRouteRender Providers
asProviders
(Text
providerName:[Text]
pathSuffix)
| Just Provider
provider <- Text -> Providers -> Maybe Provider
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup Text
providerName Providers
asProviders -> do
Text
appRoot <- Request -> IO Text
asGetAppRoot Request
req
let onFailure :: Status -> ByteString -> m Response
onFailure Status
status ByteString
errMsg =
Response -> m Response
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> m Response) -> Response -> m Response
forall a b. (a -> b) -> a -> b
$
Status -> ResponseHeaders -> Builder -> Response
responseBuilder Status
status [] (Builder -> Response) -> Builder -> Response
forall a b. (a -> b) -> a -> b
$
Maybe Text -> Render Provider -> Providers -> Builder
asProvidersTemplate
(Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode ByteString
errMsg)
Render Provider
authRouteRender
Providers
asProviders
let onSuccess :: ByteString -> IO Response
onSuccess ByteString
"" =
Status -> ByteString -> IO Response
forall (m :: * -> *). Monad m => Status -> ByteString -> m Response
onFailure
Status
status501
ByteString
"Empty user identity is not allowed"
onSuccess ByteString
authLoginState = do
CTime Int64
now <- IO CTime
epochTime
Header
cookie <-
AuthState -> IO Header
saveAuthState (AuthState -> IO Header) -> AuthState -> IO Header
forall a b. (a -> b) -> a -> b
$
AuthUser -> AuthState
AuthLoggedIn (AuthUser -> AuthState) -> AuthUser -> AuthState
forall a b. (a -> b) -> a -> b
$
AuthUser :: ByteString -> ByteString -> Int64 -> AuthUser
AuthUser
{ authLoginState :: ByteString
authLoginState = ByteString
authLoginState
, authProviderName :: ByteString
authProviderName =
Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Provider -> Text
forall ap. AuthProvider ap => ap -> Text
getProviderName Provider
provider
, authLoginTime :: Int64
authLoginTime = Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
now
}
Response -> IO Response
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> IO Response) -> Response -> IO Response
forall a b. (a -> b) -> a -> b
$
Status -> ResponseHeaders -> Builder -> Response
responseBuilder
Status
status303
[(HeaderName
"Location", ByteString
protectedPath), Header
cookie]
(ByteString -> Builder
fromByteString ByteString
"Redirecting to " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<>
ByteString -> Builder
fromByteString ByteString
protectedPath)
let providerUrlRenderer :: ProviderUrl -> [(Text, Text)] -> Text
providerUrlRenderer (ProviderUrl [Text]
suffix) =
Maybe Text -> Text -> [Text] -> Render Provider
mkRouteRender
(Text -> Maybe Text
forall a. a -> Maybe a
Just Text
appRoot)
Text
asAuthPrefix
[Text]
suffix
Provider
provider
Response -> IO b
respond (Response -> IO b) -> IO Response -> IO b
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
Provider
-> Request
-> [Text]
-> (ProviderUrl -> [(Text, Text)] -> Text)
-> (ByteString -> IO Response)
-> (Status -> ByteString -> IO Response)
-> IO Response
forall ap.
AuthProvider ap =>
ap
-> Request
-> [Text]
-> (ProviderUrl -> [(Text, Text)] -> Text)
-> (ByteString -> IO Response)
-> (Status -> ByteString -> IO Response)
-> IO Response
handleLogin
Provider
provider
Request
req
[Text]
pathSuffix
ProviderUrl -> [(Text, Text)] -> Text
providerUrlRenderer
ByteString -> IO Response
onSuccess
Status -> ByteString -> IO Response
forall (m :: * -> *). Monad m => Status -> ByteString -> m Response
onFailure
[Text
"health"] -> Response -> IO b
respond (Response -> IO b) -> Response -> IO b
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status200 [] ByteString
"OK"
[Text]
_ -> Response -> IO b
respond (Response -> IO b) -> Response -> IO b
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status404 [] ByteString
"Unknown URL"
[Text
"favicon.ico"] -> Response -> IO b
respond (Response -> IO b) -> Response -> IO b
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status404 [] ByteString
"No favicon.ico"
[Text]
_ -> do
Header
cookie <-
AuthState -> IO Header
saveAuthState (AuthState -> IO Header) -> AuthState -> IO Header
forall a b. (a -> b) -> a -> b
$
ByteString -> AuthState
AuthNeedRedirect (Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req)
Response -> IO b
respond (Response -> IO b) -> Response -> IO b
forall a b. (a -> b) -> a -> b
$
Status -> ResponseHeaders -> Builder -> Response
responseBuilder
Status
status303
[(HeaderName
"Location", ByteString
"/" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
encodeUtf8 Text
asAuthPrefix), Header
cookie]
Builder
"Redirecting to Login Page"
Middleware -> IO Middleware
forall (m :: * -> *) a. Monad m => a -> m a
return (Middleware -> IO Middleware) -> Middleware -> IO Middleware
forall a b. (a -> b) -> a -> b
$ \Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
app Request
req Response -> IO ResponseReceived
respond -> do
Maybe AuthState
authState <- Key -> ByteString -> Request -> IO (Maybe AuthState)
forall value.
Binary value =>
Key -> ByteString -> Request -> IO (Maybe value)
loadCookieValue Key
secretKey ByteString
asStateKey Request
req
case Maybe AuthState
authState of
Just (AuthLoggedIn AuthUser
user) ->
let providerName :: Text
providerName = OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode (AuthUser -> ByteString
authProviderName AuthUser
user)
in case Text -> Providers -> Maybe Provider
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup Text
providerName Providers
asProviders of
Maybe Provider
Nothing ->
let req' :: Request
req' = Request
req {vault :: Vault
vault = Key AuthUser -> AuthUser -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
Vault.insert Key AuthUser
userKey AuthUser
user (Vault -> Vault) -> Vault -> Vault
forall a b. (a -> b) -> a -> b
$ Request -> Vault
vault Request
req}
in Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
app Request
req' Response -> IO ResponseReceived
respond
Just Provider
provider -> do
Maybe (Request, AuthUser)
refreshResult <- Provider -> Request -> AuthUser -> IO (Maybe (Request, AuthUser))
forall ap.
AuthProvider ap =>
ap -> Request -> AuthUser -> IO (Maybe (Request, AuthUser))
refreshLoginState Provider
provider Request
req AuthUser
user
case Maybe (Request, AuthUser)
refreshResult of
Maybe (Request, AuthUser)
Nothing ->
ByteString
-> Request
-> (Response -> IO ResponseReceived)
-> IO ResponseReceived
forall b. ByteString -> Request -> (Response -> IO b) -> IO b
enforceLogin ByteString
"/" Request
req Response -> IO ResponseReceived
respond
Just (Request
req', AuthUser
user') ->
let req'' :: Request
req'' = Request
req' {vault :: Vault
vault = Key AuthUser -> AuthUser -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
Vault.insert Key AuthUser
userKey AuthUser
user' (Vault -> Vault) -> Vault -> Vault
forall a b. (a -> b) -> a -> b
$ Request -> Vault
vault Request
req'}
respond' :: Response -> IO ResponseReceived
respond' Response
response
| AuthUser
user' AuthUser -> AuthUser -> Bool
forall a. Eq a => a -> a -> Bool
== AuthUser
user = Response -> IO ResponseReceived
respond Response
response
| Bool
otherwise = do
Header
cookieHeader <- AuthState -> IO Header
saveAuthState (AuthUser -> AuthState
AuthLoggedIn AuthUser
user')
Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ (ResponseHeaders -> ResponseHeaders) -> Response -> Response
mapResponseHeaders (Header
cookieHeader Header -> ResponseHeaders -> ResponseHeaders
forall a. a -> [a] -> [a]
:) Response
response
in Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
app Request
req'' Response -> IO ResponseReceived
respond'
Just (AuthNeedRedirect ByteString
url) -> ByteString
-> Request
-> (Response -> IO ResponseReceived)
-> IO ResponseReceived
forall b. ByteString -> Request -> (Response -> IO b) -> IO b
enforceLogin ByteString
url Request
req Response -> IO ResponseReceived
respond
Maybe AuthState
Nothing -> ByteString
-> Request
-> (Response -> IO ResponseReceived)
-> IO ResponseReceived
forall b. ByteString -> Request -> (Response -> IO b) -> IO b
enforceLogin ByteString
"/" Request
req Response -> IO ResponseReceived
respond
userKey :: Vault.Key AuthUser
userKey :: Key AuthUser
userKey = IO (Key AuthUser) -> Key AuthUser
forall a. IO a -> a
unsafePerformIO IO (Key AuthUser)
forall a. IO (Key a)
Vault.newKey
{-# NOINLINE userKey #-}
getAuthUser :: Request -> Maybe AuthUser
getAuthUser :: Request -> Maybe AuthUser
getAuthUser = Key AuthUser -> Vault -> Maybe AuthUser
forall a. Key a -> Vault -> Maybe a
Vault.lookup Key AuthUser
userKey (Vault -> Maybe AuthUser)
-> (Request -> Vault) -> Request -> Maybe AuthUser
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Vault
vault
getAuthUserFromVault :: Vault.Vault -> Maybe AuthUser
getAuthUserFromVault :: Vault -> Maybe AuthUser
getAuthUserFromVault = Key AuthUser -> Vault -> Maybe AuthUser
forall a. Key a -> Vault -> Maybe a
Vault.lookup Key AuthUser
userKey
waiMiddlewareAuthVersion :: Version
waiMiddlewareAuthVersion :: Version
waiMiddlewareAuthVersion = Version
Paths.version
getDeleteSessionHeader :: AuthSettings -> Header
= ByteString -> Header
deleteCookieValue (ByteString -> Header)
-> (AuthSettings -> ByteString) -> AuthSettings -> Header
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AuthSettings -> ByteString
asStateKey