module Network.Wai.Middleware.Crowd
(
CrowdSettings
, defaultCrowdSettings
, setCrowdKey
, setCrowdRoot
, setCrowdApprootStatic
, setCrowdApprootGeneric
, setCrowdManager
, setCrowdAge
, mkCrowdMiddleware
, smartApproot
, waiMiddlewareCrowdVersion
, getUserName
) where
import Blaze.ByteString.Builder (fromByteString, toByteString)
import Data.Binary (Binary)
import qualified Data.ByteString as S
import Data.Monoid ((<>))
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With, encodeUtf8)
import Data.Text.Encoding.Error (lenientDecode)
import qualified Data.Vault.Lazy as Vault
import Data.Version (Version)
import GHC.Generics (Generic)
import Network.HTTP.Client (Manager, newManager)
import Network.HTTP.Client.TLS (tlsManagerSettings)
import Network.HTTP.Types (Header, status200, status303)
import Network.Wai (Middleware, Request, pathInfo,
rawPathInfo, rawQueryString,
responseBuilder, responseLBS,
vault)
import Network.Wai.Approot
import Network.Wai.ClientSession
import Network.Wai.OpenId
import qualified Paths_wai_middleware_crowd as Paths
import System.IO.Unsafe (unsafePerformIO)
data CrowdSettings = CrowdSettings
{ csGetKey :: IO Key
, csCrowdRoot :: T.Text
, csGetApproot :: IO (Request -> IO T.Text)
, csGetManager :: IO Manager
, csAge :: Int
}
setCrowdKey :: IO Key -> CrowdSettings -> CrowdSettings
setCrowdKey x cs = cs { csGetKey = x }
setCrowdRoot :: T.Text -> CrowdSettings -> CrowdSettings
setCrowdRoot x cs = cs { csCrowdRoot = x }
setCrowdApprootStatic :: T.Text -> CrowdSettings -> CrowdSettings
setCrowdApprootStatic x = setCrowdApprootGeneric $ return $ const $ return x
setCrowdApprootGeneric :: IO (Request -> IO T.Text) -> CrowdSettings -> CrowdSettings
setCrowdApprootGeneric x cs = cs { csGetApproot = x }
setCrowdManager :: IO Manager -> CrowdSettings -> CrowdSettings
setCrowdManager x cs = cs { csGetManager = x }
setCrowdAge :: Int -> CrowdSettings -> CrowdSettings
setCrowdAge x cs = cs { csAge = x }
defaultCrowdSettings :: CrowdSettings
defaultCrowdSettings = CrowdSettings
{ csGetKey = getDefaultKey
, csCrowdRoot = "http://localhost:8095/openidserver"
, csGetApproot = smartApproot
, csGetManager = newManager tlsManagerSettings
, csAge = 3600
}
data CrowdState = CSNeedRedirect S.ByteString
| CSLoggedIn S.ByteString
deriving (Generic, Show)
instance Binary CrowdState
csKey :: S.ByteString
csKey = "crowd_state"
userKey :: Vault.Key S.ByteString
userKey = unsafePerformIO Vault.newKey
getUserName :: Request -> Maybe S.ByteString
getUserName = Vault.lookup userKey . vault
saveCrowdState :: Key -> Int -> CrowdState -> IO Header
saveCrowdState key age cs = saveCookieValue key csKey age cs
mkCrowdMiddleware :: CrowdSettings -> IO Middleware
mkCrowdMiddleware CrowdSettings {..} = do
key <- csGetKey
getApproot <- csGetApproot
man <- csGetManager
let prefix = csCrowdRoot <> "/users/"
return $ \app req respond -> do
cs <- loadCookieValue key csKey req
case cs of
Just (CSLoggedIn ident) ->
let req' = req
{ vault = Vault.insert userKey ident $ vault req
}
in app req' respond
_ -> case pathInfo req of
["_crowd_middleware", "complete"] -> do
eres <- openIdComplete req man
case eres of
Left e -> respond $ responseLBS status200 [] "Login failed"
Right res ->
case T.stripPrefix prefix $ identifier $ oirOpLocal res of
Just username -> do
cookie <- saveCrowdState key csAge $ CSLoggedIn $ encodeUtf8 username
let dest =
case cs of
Just (CSNeedRedirect bs) -> bs
_ -> "/"
respond $ responseBuilder status303
[ ("Location", dest)
, cookie
]
(fromByteString "Redirecting to " <> fromByteString dest)
_ -> do
approot <- getApproot req
loc <- runResourceT $ getForwardUrl
(csCrowdRoot <> "/op")
(approot <> "/_crowd_middleware/complete")
Nothing
[]
man
cookie <- saveCrowdState key csAge $ CSNeedRedirect
$ rawPathInfo req <> rawQueryString req
respond $ responseLBS status303
[ ("Location", encodeUtf8 loc)
, cookie
]
"Logging in"
waiMiddlewareCrowdVersion :: Version
waiMiddlewareCrowdVersion = Paths.version