{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
module Network.Wai.Middleware.Auth.OAuth2
( OAuth2(..)
, oAuth2Parser
, URIParseException(..)
, parseAbsoluteURI
, getAccessToken
) where
import Control.Monad.Catch
import Data.Aeson.TH (defaultOptions,
deriveJSON,
fieldLabelModifier)
import Data.Functor ((<&>))
import Data.Int (Int64)
import Data.Proxy (Proxy (..))
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Foreign.C.Types (CTime (..))
import Network.HTTP.Client.TLS (getGlobalManager)
import qualified Network.OAuth.OAuth2 as OA2
import Network.Wai (Request)
import Network.Wai.Auth.Internal (decodeToken, encodeToken,
oauth2Login,
refreshTokens)
import Network.Wai.Auth.Tools (toLowerUnderscore)
import qualified Network.Wai.Middleware.Auth as MA
import Network.Wai.Middleware.Auth.Provider
import System.PosixCompat.Time (epochTime)
import qualified URI.ByteString as U
data OAuth2 = OAuth2
{ OAuth2 -> Text
oa2ClientId :: T.Text
, OAuth2 -> Text
oa2ClientSecret :: T.Text
, OAuth2 -> Text
oa2AuthorizeEndpoint :: T.Text
, OAuth2 -> Text
oa2AccessTokenEndpoint :: T.Text
, OAuth2 -> Maybe [Text]
oa2Scope :: Maybe [T.Text]
, OAuth2 -> ProviderInfo
oa2ProviderInfo :: ProviderInfo
}
data URIParseException = URIParseException U.URIParseError deriving Int -> URIParseException -> ShowS
[URIParseException] -> ShowS
URIParseException -> String
(Int -> URIParseException -> ShowS)
-> (URIParseException -> String)
-> ([URIParseException] -> ShowS)
-> Show URIParseException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [URIParseException] -> ShowS
$cshowList :: [URIParseException] -> ShowS
show :: URIParseException -> String
$cshow :: URIParseException -> String
showsPrec :: Int -> URIParseException -> ShowS
$cshowsPrec :: Int -> URIParseException -> ShowS
Show
instance Exception URIParseException
parseAbsoluteURI :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI :: Text -> m URI
parseAbsoluteURI Text
urlTxt = do
case URIParserOptions -> ByteString -> Either URIParseError URI
U.parseURI URIParserOptions
U.strictURIParserOptions (Text -> ByteString
encodeUtf8 Text
urlTxt) of
Left URIParseError
err -> URIParseException -> m URI
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (URIParseException -> m URI) -> URIParseException -> m URI
forall a b. (a -> b) -> a -> b
$ URIParseError -> URIParseException
URIParseException URIParseError
err
Right URI
url -> URI -> m URI
forall (m :: * -> *) a. Monad m => a -> m a
return URI
url
getClientId :: T.Text -> T.Text
getClientId :: Text -> Text
getClientId = Text -> Text
forall a. a -> a
id
getClientSecret :: T.Text -> T.Text
getClientSecret :: Text -> Text
getClientSecret = Text -> Text
forall a. a -> a
id
$(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2)
oAuth2Parser :: ProviderParser
oAuth2Parser :: ProviderParser
oAuth2Parser = Proxy OAuth2 -> ProviderParser
forall ap.
(FromJSON ap, AuthProvider ap) =>
Proxy ap -> ProviderParser
mkProviderParser (Proxy OAuth2
forall k (t :: k). Proxy t
Proxy :: Proxy OAuth2)
instance AuthProvider OAuth2 where
getProviderName :: OAuth2 -> Text
getProviderName OAuth2
_ = Text
"oauth2"
getProviderInfo :: OAuth2 -> ProviderInfo
getProviderInfo = OAuth2 -> ProviderInfo
oa2ProviderInfo
handleLogin :: OAuth2
-> Request
-> [Text]
-> Render ProviderUrl
-> (ByteString -> IO Response)
-> (Status -> ByteString -> IO Response)
-> IO Response
handleLogin oa2 :: OAuth2
oa2@OAuth2 {Maybe [Text]
Text
ProviderInfo
oa2ProviderInfo :: ProviderInfo
oa2Scope :: Maybe [Text]
oa2AccessTokenEndpoint :: Text
oa2AuthorizeEndpoint :: Text
oa2ClientSecret :: Text
oa2ClientId :: Text
oa2ProviderInfo :: OAuth2 -> ProviderInfo
oa2Scope :: OAuth2 -> Maybe [Text]
oa2AccessTokenEndpoint :: OAuth2 -> Text
oa2AuthorizeEndpoint :: OAuth2 -> Text
oa2ClientSecret :: OAuth2 -> Text
oa2ClientId :: OAuth2 -> Text
..} Request
req [Text]
suffix Render ProviderUrl
renderUrl ByteString -> IO Response
onSuccess Status -> ByteString -> IO Response
onFailure = do
URI
authEndpointURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
oa2AuthorizeEndpoint
URI
accessTokenEndpointURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
oa2AccessTokenEndpoint
URI
callbackURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI (Text -> IO URI) -> Text -> IO URI
forall a b. (a -> b) -> a -> b
$ Render ProviderUrl
renderUrl ([Text] -> ProviderUrl
ProviderUrl [Text
"complete"]) []
let oauth2 :: OAuth2
oauth2 =
OAuth2 :: Text -> Maybe Text -> URI -> URI -> Maybe URI -> OAuth2
OA2.OAuth2
{ oauthClientId :: Text
oauthClientId = Text -> Text
getClientId Text
oa2ClientId
, oauthClientSecret :: Maybe Text
oauthClientSecret = Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
getClientSecret Text
oa2ClientSecret
, oauthOAuthorizeEndpoint :: URI
oauthOAuthorizeEndpoint = URI
authEndpointURI
, oauthAccessTokenEndpoint :: URI
oauthAccessTokenEndpoint = URI
accessTokenEndpointURI
, oauthCallback :: Maybe URI
oauthCallback = URI -> Maybe URI
forall a. a -> Maybe a
Just URI
callbackURI
}
Manager
man <- IO Manager
getGlobalManager
OAuth2
-> Manager
-> Maybe [Text]
-> Text
-> Request
-> [Text]
-> (ByteString -> IO Response)
-> (Status -> ByteString -> IO Response)
-> IO Response
oauth2Login
OAuth2
oauth2
Manager
man
Maybe [Text]
oa2Scope
(OAuth2 -> Text
forall ap. AuthProvider ap => ap -> Text
getProviderName OAuth2
oa2)
Request
req
[Text]
suffix
ByteString -> IO Response
onSuccess
Status -> ByteString -> IO Response
onFailure
refreshLoginState :: OAuth2 -> Request -> AuthUser -> IO (Maybe (Request, AuthUser))
refreshLoginState OAuth2 {Maybe [Text]
Text
ProviderInfo
oa2ProviderInfo :: ProviderInfo
oa2Scope :: Maybe [Text]
oa2AccessTokenEndpoint :: Text
oa2AuthorizeEndpoint :: Text
oa2ClientSecret :: Text
oa2ClientId :: Text
oa2ProviderInfo :: OAuth2 -> ProviderInfo
oa2Scope :: OAuth2 -> Maybe [Text]
oa2AccessTokenEndpoint :: OAuth2 -> Text
oa2AuthorizeEndpoint :: OAuth2 -> Text
oa2ClientSecret :: OAuth2 -> Text
oa2ClientId :: OAuth2 -> Text
..} Request
req AuthUser
user = do
URI
authEndpointURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
oa2AuthorizeEndpoint
URI
accessTokenEndpointURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
oa2AccessTokenEndpoint
let loginState :: ByteString
loginState = AuthUser -> ByteString
authLoginState AuthUser
user
case ByteString -> Either String OAuth2Token
decodeToken ByteString
loginState of
Left String
_ -> Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Request, AuthUser)
forall a. Maybe a
Nothing
Right OAuth2Token
tokens -> do
CTime Int64
now <- IO CTime
epochTime
if AuthUser -> Int64 -> OAuth2Token -> Bool
tokenExpired AuthUser
user Int64
now OAuth2Token
tokens then do
let oauth2 :: OAuth2
oauth2 =
OAuth2 :: Text -> Maybe Text -> URI -> URI -> Maybe URI -> OAuth2
OA2.OAuth2
{ oauthClientId :: Text
oauthClientId = Text -> Text
getClientId Text
oa2ClientId
, oauthClientSecret :: Maybe Text
oauthClientSecret = Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Text
getClientSecret Text
oa2ClientSecret)
, oauthOAuthorizeEndpoint :: URI
oauthOAuthorizeEndpoint = URI
authEndpointURI
, oauthAccessTokenEndpoint :: URI
oauthAccessTokenEndpoint = URI
accessTokenEndpointURI
, oauthCallback :: Maybe URI
oauthCallback = Maybe URI
forall a. Maybe a
Nothing
}
Manager
man <- IO Manager
getGlobalManager
Maybe OAuth2Token
rRes <- OAuth2Token -> Manager -> OAuth2 -> IO (Maybe OAuth2Token)
refreshTokens OAuth2Token
tokens Manager
man OAuth2
oauth2
Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe OAuth2Token
rRes Maybe OAuth2Token
-> (OAuth2Token -> (Request, AuthUser))
-> Maybe (Request, AuthUser)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \OAuth2Token
newTokens -> (Request
req, AuthUser
user {
authLoginState :: ByteString
authLoginState = OAuth2Token -> ByteString
encodeToken OAuth2Token
newTokens,
authLoginTime :: Int64
authLoginTime = Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
now
}))
else
Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Request, AuthUser) -> Maybe (Request, AuthUser)
forall a. a -> Maybe a
Just (Request
req, AuthUser
user))
tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool
tokenExpired :: AuthUser -> Int64 -> OAuth2Token -> Bool
tokenExpired AuthUser
user Int64
now OAuth2Token
tokens =
case OAuth2Token -> Maybe Int
OA2.expiresIn OAuth2Token
tokens of
Maybe Int
Nothing -> Bool
False
Just Int
expiresIn -> AuthUser -> Int64
authLoginTime AuthUser
user Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
expiresIn) Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
now
getAccessToken :: Request -> Maybe OA2.OAuth2Token
getAccessToken :: Request -> Maybe OAuth2Token
getAccessToken Request
req = do
AuthUser
user <- Request -> Maybe AuthUser
MA.getAuthUser Request
req
(String -> Maybe OAuth2Token)
-> (OAuth2Token -> Maybe OAuth2Token)
-> Either String OAuth2Token
-> Maybe OAuth2Token
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe OAuth2Token -> String -> Maybe OAuth2Token
forall a b. a -> b -> a
const Maybe OAuth2Token
forall a. Maybe a
Nothing) OAuth2Token -> Maybe OAuth2Token
forall a. a -> Maybe a
Just (Either String OAuth2Token -> Maybe OAuth2Token)
-> Either String OAuth2Token -> Maybe OAuth2Token
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String OAuth2Token
decodeToken (AuthUser -> ByteString
authLoginState AuthUser
user)