{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE QuasiQuotes         #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
module Network.Wai.Middleware.Auth.Provider
  ( AuthProvider(..)
  -- * Provider
  , Provider(..)
  , ProviderUrl(..)
  , ProviderInfo(..)
  , Providers
  -- * Provider Parsing
  , ProviderParser
  , mkProviderParser
  , parseProviders
  -- * User
  , AuthUser(..)
  , AuthLoginState
  , UserIdentity
  , authUserIdentity
  -- * Template
  , mkRouteRender
  , providersTemplate
  ) where

import           Blaze.ByteString.Builder      (toByteString)
import           Control.Arrow                 (second)
import           Data.Aeson                    (FromJSON (..), Object,
                                                Result (..), Value)
import           Data.Aeson.Types              (parseEither)

import           Data.Aeson.TH                 (defaultOptions, deriveJSON,
                                                fieldLabelModifier)
import           Data.Aeson.Types              (Parser)
import           Data.Binary                   (Binary)
import qualified Data.ByteString               as S
import qualified Data.ByteString.Builder       as B
import qualified Data.HashMap.Strict           as HM
import           Data.Int
import           Data.Maybe                    (fromMaybe)
import           Data.Proxy                    (Proxy)
import qualified Data.Text                     as T
import           Data.Text.Encoding            (decodeUtf8With)
import           Data.Text.Encoding.Error      (lenientDecode)
import           GHC.Generics                  (Generic)
import           Network.HTTP.Types            (Status, renderQueryText)
import           Network.Wai                   (Request, Response)
import           Network.Wai.Auth.Tools        (toLowerUnderscore)
import           Text.Blaze.Html.Renderer.Utf8 (renderHtmlBuilder)
import           Text.Hamlet                   (Render, hamlet)

-- | Core Authentication class, that allows for extensibility of the Auth
-- middleware created by `Network.Wai.Middleware.Auth.mkAuthMiddleware`. Most
-- important function is `handleLogin`, which implements the actual behavior of a
-- provider. It's function arguments in order:
--
--     * @`ap`@ - Current provider.
--     * @`Request`@ - Request made to the login page
--     * @[`T.Text`]@ - Url suffix, i.e. last part of the Url split by @\'/\'@ character,
--     for instance @["login", "complete"]@ suffix in the example below.
--     * @`Render` `ProviderUrl`@ -
--     Url renderer. It takes desired suffix as first argument and produces an
--     absolute Url renderer. It can further be used to generate provider urls,
--     for instance in Hamlet templates as
--     will result in
--     @"https:\/\/approot.com\/_auth_middleware\/providerName\/login\/complete?user=Hamlet"@
--     or generate Urls for callbacks.
--
--         @
--         \@?{(ProviderUrl ["login", "complete"], [("user", "Hamlet")])}
--         @
--
--     * @(`AuthLoginState` -> `IO` `Response`)@ - Action to call on a successfull login.
--     * @(`Status` -> `S.ByteString` -> `IO` `Response`)@ - Should be called in case of
--     a failure with login process by supplying a
--     status and a short error message.
class AuthProvider ap where

  -- | Return a name for the provider. It will be used as a unique identifier
  -- for this provider. Argument should not be evaluated, as there are many
  -- places were `undefined` value is passed to this function.
  --
  -- @since 0.1.0
  getProviderName :: ap -> T.Text

  -- | Get info about the provider. It will be used in rendering the web page
  -- with a list of providers.
  --
  -- @since 0.1.0
  getProviderInfo :: ap -> ProviderInfo

  -- | Handle a login request in a custom manner. Can be used to render a login
  -- page with a form or redirect to some other authentication service like
  -- OpenID or OAuth2.
  --
  -- @since 0.1.0
  handleLogin
    :: ap
    -> Request
    -> [T.Text]
    -> Render ProviderUrl
    -> (AuthLoginState -> IO Response)
    -> (Status -> S.ByteString -> IO Response)
    -> IO Response

  -- | Check if the login state in a session is still valid, and have the
  -- opportunity to update it. Return `Nothing` to indicate a session has
  -- expired, and the user will be directed to re-authenticate. 
  --
  -- The default implementation never invalidates a session once set.
  --
  -- @since 0.2.3.0
  refreshLoginState 
    :: ap
    -> Request
    -> AuthUser
    -> IO (Maybe (Request, AuthUser))
  refreshLoginState ap
_ Request
req AuthUser
loginState = Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Request, AuthUser) -> Maybe (Request, AuthUser)
forall a. a -> Maybe a
Just (Request
req, AuthUser
loginState))

-- | Generic authentication provider wrapper.
data Provider where
  Provider :: AuthProvider p => p -> Provider


instance AuthProvider Provider where

  getProviderName :: Provider -> Text
getProviderName (Provider p
p) = p -> Text
forall ap. AuthProvider ap => ap -> Text
getProviderName p
p

  getProviderInfo :: Provider -> ProviderInfo
getProviderInfo (Provider p
p) = p -> ProviderInfo
forall ap. AuthProvider ap => ap -> ProviderInfo
getProviderInfo p
p

  handleLogin :: Provider
-> Request
-> [Text]
-> Render ProviderUrl
-> (AuthLoginState -> IO Response)
-> (Status -> AuthLoginState -> IO Response)
-> IO Response
handleLogin (Provider p
p) = p
-> Request
-> [Text]
-> Render ProviderUrl
-> (AuthLoginState -> IO Response)
-> (Status -> AuthLoginState -> IO Response)
-> IO Response
forall ap.
AuthProvider ap =>
ap
-> Request
-> [Text]
-> Render ProviderUrl
-> (AuthLoginState -> IO Response)
-> (Status -> AuthLoginState -> IO Response)
-> IO Response
handleLogin p
p

  refreshLoginState :: Provider -> Request -> AuthUser -> IO (Maybe (Request, AuthUser))
refreshLoginState (Provider p
p) = p -> Request -> AuthUser -> IO (Maybe (Request, AuthUser))
forall ap.
AuthProvider ap =>
ap -> Request -> AuthUser -> IO (Maybe (Request, AuthUser))
refreshLoginState p
p 

-- | Collection of supported providers.
type Providers = HM.HashMap T.Text Provider

-- | Aeson parser for a provider with unique provider name (same as returned by
-- `getProviderName`)
type ProviderParser = (T.Text, Value -> Parser Provider)

-- | Data type for rendering Provider specific urls.
newtype ProviderUrl = ProviderUrl [T.Text]

-- | Provider information used for rendering a page with list of supported providers.
data ProviderInfo = ProviderInfo
  { ProviderInfo -> Text
providerTitle   :: T.Text
  , ProviderInfo -> Text
providerLogoUrl :: T.Text
  , ProviderInfo -> Text
providerDescr   :: T.Text
  } deriving (Int -> ProviderInfo -> ShowS
[ProviderInfo] -> ShowS
ProviderInfo -> String
(Int -> ProviderInfo -> ShowS)
-> (ProviderInfo -> String)
-> ([ProviderInfo] -> ShowS)
-> Show ProviderInfo
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ProviderInfo] -> ShowS
$cshowList :: [ProviderInfo] -> ShowS
show :: ProviderInfo -> String
$cshow :: ProviderInfo -> String
showsPrec :: Int -> ProviderInfo -> ShowS
$cshowsPrec :: Int -> ProviderInfo -> ShowS
Show)


-- | An arbitrary state that comes with logged in user, eg. a username, token or an email address.
type AuthLoginState = S.ByteString

type UserIdentity = S.ByteString
{-# DEPRECATED UserIdentity "In favor of `AuthLoginState`" #-}

authUserIdentity :: AuthUser -> UserIdentity
authUserIdentity :: AuthUser -> AuthLoginState
authUserIdentity = AuthUser -> AuthLoginState
authLoginState
{-# DEPRECATED authUserIdentity "In favor of `authLoginState`" #-}

-- | Representation of a user for a particular `Provider`.
data AuthUser = AuthUser
  { AuthUser -> AuthLoginState
authLoginState   :: !UserIdentity
  , AuthUser -> AuthLoginState
authProviderName :: !S.ByteString
  , AuthUser -> Int64
authLoginTime    :: !Int64
  } deriving (AuthUser -> AuthUser -> Bool
(AuthUser -> AuthUser -> Bool)
-> (AuthUser -> AuthUser -> Bool) -> Eq AuthUser
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuthUser -> AuthUser -> Bool
$c/= :: AuthUser -> AuthUser -> Bool
== :: AuthUser -> AuthUser -> Bool
$c== :: AuthUser -> AuthUser -> Bool
Eq, (forall x. AuthUser -> Rep AuthUser x)
-> (forall x. Rep AuthUser x -> AuthUser) -> Generic AuthUser
forall x. Rep AuthUser x -> AuthUser
forall x. AuthUser -> Rep AuthUser x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep AuthUser x -> AuthUser
$cfrom :: forall x. AuthUser -> Rep AuthUser x
Generic, Int -> AuthUser -> ShowS
[AuthUser] -> ShowS
AuthUser -> String
(Int -> AuthUser -> ShowS)
-> (AuthUser -> String) -> ([AuthUser] -> ShowS) -> Show AuthUser
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AuthUser] -> ShowS
$cshowList :: [AuthUser] -> ShowS
show :: AuthUser -> String
$cshow :: AuthUser -> String
showsPrec :: Int -> AuthUser -> ShowS
$cshowsPrec :: Int -> AuthUser -> ShowS
Show)

instance Binary AuthUser



-- | First argument is not evaluated and is only needed for restricting the type.
mkProviderParser :: forall ap . (FromJSON ap, AuthProvider ap) => Proxy ap -> ProviderParser
mkProviderParser :: Proxy ap -> ProviderParser
mkProviderParser Proxy ap
_ =
  ( ap -> Text
forall ap. AuthProvider ap => ap -> Text
getProviderName ap
nameProxyError
  , (ap -> Provider) -> Parser ap -> Parser Provider
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ap -> Provider
forall p. AuthProvider p => p -> Provider
Provider (Parser ap -> Parser Provider)
-> (Value -> Parser ap) -> Value -> Parser Provider
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Value -> Parser ap
forall a. FromJSON a => Value -> Parser a
parseJSON :: Value -> Parser ap))
  where
    nameProxyError :: ap
    nameProxyError :: ap
nameProxyError = String -> ap
forall a. HasCallStack => String -> a
error String
"AuthProvider.getProviderName should not evaluate it's argument."

-- | Parse configuration for providers from an `Object`.
parseProviders :: Object -> [ProviderParser] -> Result Providers
parseProviders :: Object -> [ProviderParser] -> Result Providers
parseProviders Object
unparsedProvidersHM [ProviderParser]
providerParsers =
  if Object -> Bool
forall k v. HashMap k v -> Bool
HM.null Object
unrecognized
    then HashMap Text (Result Provider) -> Result Providers
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence (HashMap Text (Result Provider) -> Result Providers)
-> HashMap Text (Result Provider) -> Result Providers
forall a b. (a -> b) -> a -> b
$ (Value -> (Value -> Parser Provider) -> Result Provider)
-> Object
-> HashMap Text (Value -> Parser Provider)
-> HashMap Text (Result Provider)
forall k v1 v2 v3.
(Eq k, Hashable k) =>
(v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
HM.intersectionWith Value -> (Value -> Parser Provider) -> Result Provider
forall a a. a -> (a -> Parser a) -> Result a
parseProvider Object
unparsedProvidersHM HashMap Text (Value -> Parser Provider)
parsersHM
    else String -> Result Providers
forall a. String -> Result a
Error (String -> Result Providers) -> String -> Result Providers
forall a b. (a -> b) -> a -> b
$
         String
"Provider name(s) are not recognized: " String -> ShowS
forall a. [a] -> [a] -> [a]
++
         Text -> String
T.unpack (Text -> [Text] -> Text
T.intercalate Text
", " ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Object -> [Text]
forall k v. HashMap k v -> [k]
HM.keys Object
unrecognized)
  where
    parsersHM :: HashMap Text (Value -> Parser Provider)
parsersHM = [ProviderParser] -> HashMap Text (Value -> Parser Provider)
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList [ProviderParser]
providerParsers
    unrecognized :: Object
unrecognized = Object -> HashMap Text (Value -> Parser Provider) -> Object
forall k v w.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k w -> HashMap k v
HM.difference Object
unparsedProvidersHM HashMap Text (Value -> Parser Provider)
parsersHM
    parseProvider :: a -> (a -> Parser a) -> Result a
parseProvider a
v a -> Parser a
p = (String -> Result a)
-> (a -> Result a) -> Either String a -> Result a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> Result a
forall a. String -> Result a
Error a -> Result a
forall a. a -> Result a
Success (Either String a -> Result a) -> Either String a -> Result a
forall a b. (a -> b) -> a -> b
$ (a -> Parser a) -> a -> Either String a
forall a b. (a -> Parser b) -> a -> Either String b
parseEither a -> Parser a
p a
v

-- | Create a url renderer for a provider.
mkRouteRender :: Maybe T.Text -> T.Text -> [T.Text] -> Render Provider
mkRouteRender :: Maybe Text -> Text -> [Text] -> Render Provider
mkRouteRender Maybe Text
appRoot Text
authPrefix [Text]
authSuffix (Provider p
p) [(Text, Text)]
params =
  (Text -> [Text] -> Text
T.intercalate Text
"/" ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ [Text
root, Text
authPrefix, p -> Text
forall ap. AuthProvider ap => ap -> Text
getProviderName p
p] [Text] -> [Text] -> [Text]
forall a. [a] -> [a] -> [a]
++ [Text]
authSuffix) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>
  OnDecodeError -> AuthLoginState -> Text
decodeUtf8With
    OnDecodeError
lenientDecode
    (Builder -> AuthLoginState
toByteString (Builder -> AuthLoginState) -> Builder -> AuthLoginState
forall a b. (a -> b) -> a -> b
$ Bool -> QueryText -> Builder
renderQueryText Bool
True (((Text, Text) -> (Text, Maybe Text)) -> [(Text, Text)] -> QueryText
forall a b. (a -> b) -> [a] -> [b]
map ((Text -> Maybe Text) -> (Text, Text) -> (Text, Maybe Text)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second Text -> Maybe Text
forall a. a -> Maybe a
Just) [(Text, Text)]
params))
  where
    root :: Text
root = Text -> Maybe Text -> Text
forall a. a -> Maybe a -> a
fromMaybe Text
"" Maybe Text
appRoot


$(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 8} ''ProviderInfo)


-- | Template for the providers page
providersTemplate :: Maybe T.Text -- ^ Error message to display, if any.
                  -> Render Provider -- ^ Renderer function for provider urls.
                  -> Providers -- ^ List of available providers.
                  -> B.Builder
providersTemplate :: Maybe Text -> Render Provider -> Providers -> Builder
providersTemplate Maybe Text
merrMsg Render Provider
render Providers
providers =
  Html -> Builder
renderHtmlBuilder (Html -> Builder) -> Html -> Builder
forall a b. (a -> b) -> a -> b
$ [hamlet|
$doctype 5
<html>
  <head>
    <title>WAI Auth Middleware - Authentication Providers.
    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous">
    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap-theme.min.css" integrity="sha384-rHyoN1iRsVXV4nD0JutlnGaslCJuC7uwjduW9SVrLvRYooPp2bWYgmgJQIXwl/Sp" crossorigin="anonymous">
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.1.1/jquery.min.js">
    <script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js" integrity="sha384-Tc5IQib027qvyjSMfHjOMaLkfuWVxZxUPnCJA7l2mCWNIpG9mGCD8wGNIcPD7Txa" crossorigin="anonymous">
    <style>
      .provider-logo {
        max-height: 64px;
        max-width: 64px;
        padding: 5px;
        margin: auto;
        position: absolute;
        top: 0;
        bottom: 0;
        left: 0;
        right: 0;
      }
      .media-container {
        width: 600px;
        position: absolute;
        top: 100px;
        bottom: 0;
        left: 0;
        right: 0;
        margin: auto;
      }
      .provider.media {
        border: 1px solid #e1e1e8;
        padding: 5px;
        height: 82px;
        text-overflow: ellipsis;
        margin-top: 5px;
      }
      .provider.media:hover {
        background-color: #f5f5f5;
        border: 1px solid #337ab7;
      }
      .provider .media-left {
        height: 70px;
        width: 0px;
        padding-right: 70px;
        position: relative;
      }
      a:hover {
        text-decoration: none;
      }
  <body>
    <div .media-container>
      <h3>Select one of available authentication methods:
      $maybe errMsg <- merrMsg
        <div .alert .alert-danger role="alert">
          #{errMsg}
      $forall provider <- providers
        $with info <- getProviderInfo provider
          <div .media.provider>
            <a href=@{provider}>
              <div .media-left .container>
                <img .provider-logo src=#{providerLogoUrl info}>
              <div .media-body>
                <h3 .media-heading>
                  #{providerTitle info}
                #{providerDescr info}
|] Render Provider
render