{- |
 Common types and functions related to authorization.
-}
module WebGear.Core.Trait.Auth.Common (
  AuthorizationHeader,
  Realm (..),
  AuthToken (..),
  respondUnauthorized,
) where

import Data.ByteString (ByteString, drop)
import Data.ByteString.Char8 (break)
import Data.CaseInsensitive (CI, mk, original)
import Data.Proxy (Proxy (..))
import Data.String (IsString (..))
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import GHC.TypeLits (KnownSymbol, Symbol, symbolVal)
import Web.HttpApiData (FromHttpApiData (..))
import WebGear.Core.Handler (Handler, unwitnessA, (>->))
import WebGear.Core.MIMETypes (PlainText (..))
import WebGear.Core.Modifiers (Existence (..), ParseStyle (..))
import WebGear.Core.Response (Response)
import WebGear.Core.Trait (Sets)
import WebGear.Core.Trait.Body (Body, setBody)
import WebGear.Core.Trait.Header (RequestHeader (..), RequiredResponseHeader, setHeader)
import WebGear.Core.Trait.Status (Status, unauthorized401)
import Prelude hiding (break, drop)

-- | Trait for \"Authorization\" header
type AuthorizationHeader scheme = RequestHeader Optional Lenient "Authorization" (AuthToken scheme)

-- | The protection space for authentication
newtype Realm = Realm ByteString
  deriving newtype (Realm -> Realm -> Bool
(Realm -> Realm -> Bool) -> (Realm -> Realm -> Bool) -> Eq Realm
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Realm -> Realm -> Bool
== :: Realm -> Realm -> Bool
$c/= :: Realm -> Realm -> Bool
/= :: Realm -> Realm -> Bool
Eq, Eq Realm
Eq Realm =>
(Realm -> Realm -> Ordering)
-> (Realm -> Realm -> Bool)
-> (Realm -> Realm -> Bool)
-> (Realm -> Realm -> Bool)
-> (Realm -> Realm -> Bool)
-> (Realm -> Realm -> Realm)
-> (Realm -> Realm -> Realm)
-> Ord Realm
Realm -> Realm -> Bool
Realm -> Realm -> Ordering
Realm -> Realm -> Realm
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Realm -> Realm -> Ordering
compare :: Realm -> Realm -> Ordering
$c< :: Realm -> Realm -> Bool
< :: Realm -> Realm -> Bool
$c<= :: Realm -> Realm -> Bool
<= :: Realm -> Realm -> Bool
$c> :: Realm -> Realm -> Bool
> :: Realm -> Realm -> Bool
$c>= :: Realm -> Realm -> Bool
>= :: Realm -> Realm -> Bool
$cmax :: Realm -> Realm -> Realm
max :: Realm -> Realm -> Realm
$cmin :: Realm -> Realm -> Realm
min :: Realm -> Realm -> Realm
Ord, Int -> Realm -> ShowS
[Realm] -> ShowS
Realm -> String
(Int -> Realm -> ShowS)
-> (Realm -> String) -> ([Realm] -> ShowS) -> Show Realm
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Realm -> ShowS
showsPrec :: Int -> Realm -> ShowS
$cshow :: Realm -> String
show :: Realm -> String
$cshowList :: [Realm] -> ShowS
showList :: [Realm] -> ShowS
Show, ReadPrec [Realm]
ReadPrec Realm
Int -> ReadS Realm
ReadS [Realm]
(Int -> ReadS Realm)
-> ReadS [Realm]
-> ReadPrec Realm
-> ReadPrec [Realm]
-> Read Realm
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Realm
readsPrec :: Int -> ReadS Realm
$creadList :: ReadS [Realm]
readList :: ReadS [Realm]
$creadPrec :: ReadPrec Realm
readPrec :: ReadPrec Realm
$creadListPrec :: ReadPrec [Realm]
readListPrec :: ReadPrec [Realm]
Read, String -> Realm
(String -> Realm) -> IsString Realm
forall a. (String -> a) -> IsString a
$cfromString :: String -> Realm
fromString :: String -> Realm
IsString)

-- | The components of Authorization request header
data AuthToken (scheme :: Symbol) = AuthToken
  { forall (scheme :: Symbol). AuthToken scheme -> CI ByteString
authScheme :: CI ByteString
  -- ^ Authentication scheme
  , forall (scheme :: Symbol). AuthToken scheme -> ByteString
authToken :: ByteString
  -- ^ Authentication token
  }

instance (KnownSymbol scheme) => FromHttpApiData (AuthToken scheme) where
  {-# INLINE parseUrlPiece #-}
  parseUrlPiece :: Text -> Either Text (AuthToken scheme)
  parseUrlPiece :: Text -> Either Text (AuthToken scheme)
parseUrlPiece = ByteString -> Either Text (AuthToken scheme)
forall a. FromHttpApiData a => ByteString -> Either Text a
parseHeader (ByteString -> Either Text (AuthToken scheme))
-> (Text -> ByteString) -> Text -> Either Text (AuthToken scheme)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8

  {-# INLINE parseHeader #-}
  parseHeader :: ByteString -> Either Text (AuthToken scheme)
  parseHeader :: ByteString -> Either Text (AuthToken scheme)
parseHeader ByteString
hdr =
    case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
' ') ByteString
hdr of
      (ByteString
scm, ByteString
tok) ->
        let actualScheme :: CI ByteString
actualScheme = ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
mk ByteString
scm
            expectedScheme :: CI ByteString
expectedScheme = String -> CI ByteString
forall a. IsString a => String -> a
fromString (String -> CI ByteString) -> String -> CI ByteString
forall a b. (a -> b) -> a -> b
$ Proxy scheme -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (Proxy scheme -> String) -> Proxy scheme -> String
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @scheme
         in if CI ByteString
actualScheme CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== CI ByteString
expectedScheme
              then AuthToken scheme -> Either Text (AuthToken scheme)
forall a b. b -> Either a b
Right (CI ByteString -> ByteString -> AuthToken scheme
forall (scheme :: Symbol).
CI ByteString -> ByteString -> AuthToken scheme
AuthToken CI ByteString
actualScheme (Int -> ByteString -> ByteString
drop Int
1 ByteString
tok))
              else Text -> Either Text (AuthToken scheme)
forall a b. a -> Either a b
Left Text
"scheme mismatch"

{- | Create a \"401 Unauthorized\" response.

 The response will have a plain text body and an appropriate
 \"WWW-Authenticate\" header.
-}
respondUnauthorized ::
  ( Handler h m
  , Sets
      h
      [ Status
      , RequiredResponseHeader "Content-Type" Text
      , RequiredResponseHeader "WWW-Authenticate" Text
      , Body PlainText Text
      ]
      Response
  ) =>
  -- | The authentication scheme
  CI ByteString ->
  -- | The authentication realm
  Realm ->
  h a Response
respondUnauthorized :: forall (h :: * -> * -> *) (m :: * -> *) a.
(Handler h m,
 Sets
   h
   '[Status, RequiredResponseHeader "Content-Type" Text,
     RequiredResponseHeader "WWW-Authenticate" Text,
     Body PlainText Text]
   Response) =>
CI ByteString -> Realm -> h a Response
respondUnauthorized CI ByteString
scheme (Realm ByteString
realm) = proc a
_ -> do
  let headerVal :: Text
headerVal = ByteString -> Text
decodeUtf8 (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ CI ByteString -> ByteString
forall s. CI s -> s
original CI ByteString
scheme ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" realm=\"" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
realm ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\""
  (h () (With Response '[Status])
forall (h :: * -> * -> *).
Set h Status Response =>
h () (With Response '[Status])
unauthorized401 -< ())
    h (a, ()) (With Response '[Status])
-> h (a, (With Response '[Status], ()))
     (With
        Response
        '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
-> h (a, ())
     (With
        Response
        '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
forall {a}.
h (a, ()) (With Response '[Status])
-> h (a, (With Response '[Status], ()))
     (With
        Response
        '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
-> h (a, ())
     (With
        Response
        '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
forall (h :: * -> * -> *) env stack a b.
Arrow h =>
h (env, stack) a -> h (env, (a, stack)) b -> h (env, stack) b
>-> (\With Response '[Status]
resp -> PlainText
-> h (With Response '[Status], Text)
     (With
        Response
        '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
forall body mt (h :: * -> * -> *) (ts :: [*]).
(Sets
   h
   '[Body mt body, RequiredResponseHeader "Content-Type" Text]
   Response,
 MIMEType mt) =>
mt
-> h (With Response ts, body)
     (With
        Response
        (Body mt body : RequiredResponseHeader "Content-Type" Text : ts))
setBody PlainText
PlainText -< (With Response '[Status]
resp, Text
"Unauthorized" :: Text))
    h (a, ())
  (With
     Response
     '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
       Status])
-> h (a,
      (With
         Response
         '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
           Status],
       ()))
     (With
        Response
        '[RequiredResponseHeader "WWW-Authenticate" Text,
          Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
-> h (a, ())
     (With
        Response
        '[RequiredResponseHeader "WWW-Authenticate" Text,
          Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
forall {a}.
h (a, ())
  (With
     Response
     '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
       Status])
-> h (a,
      (With
         Response
         '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
           Status],
       ()))
     (With
        Response
        '[RequiredResponseHeader "WWW-Authenticate" Text,
          Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
-> h (a, ())
     (With
        Response
        '[RequiredResponseHeader "WWW-Authenticate" Text,
          Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
          Status])
forall (h :: * -> * -> *) env stack a b.
Arrow h =>
h (env, stack) a -> h (env, (a, stack)) b -> h (env, stack) b
>-> (\With
  Response
  '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
    Status]
resp -> forall (name :: Symbol) val (h :: * -> * -> *) (ts :: [*]).
Set h (ResponseHeader 'Required name val) Response =>
h (With Response ts, val)
  (With Response (ResponseHeader 'Required name val : ts))
setHeader @"WWW-Authenticate" -< (With
  Response
  '[Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
    Status]
resp, Text
headerVal))
    h (a, ())
  (With
     Response
     '[RequiredResponseHeader "WWW-Authenticate" Text,
       Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
       Status])
-> h (a,
      (With
         Response
         '[RequiredResponseHeader "WWW-Authenticate" Text,
           Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
           Status],
       ()))
     Response
-> h (a, ()) Response
forall {a}.
h (a, ())
  (With
     Response
     '[RequiredResponseHeader "WWW-Authenticate" Text,
       Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
       Status])
-> h (a,
      (With
         Response
         '[RequiredResponseHeader "WWW-Authenticate" Text,
           Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
           Status],
       ()))
     Response
-> h (a, ()) Response
forall (h :: * -> * -> *) env stack a b.
Arrow h =>
h (env, stack) a -> h (env, (a, stack)) b -> h (env, stack) b
>-> (\With
  Response
  '[RequiredResponseHeader "WWW-Authenticate" Text,
    Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
    Status]
resp -> h (With
     Response
     '[RequiredResponseHeader "WWW-Authenticate" Text,
       Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
       Status])
  Response
forall (h :: * -> * -> *) (m :: * -> *) (ts :: [*]).
Handler h m =>
h (With Response ts) Response
unwitnessA -< With
  Response
  '[RequiredResponseHeader "WWW-Authenticate" Text,
    Body PlainText Text, RequiredResponseHeader "Content-Type" Text,
    Status]
resp)
{-# INLINE respondUnauthorized #-}