{- |
 Common types and functions related to authorization.
-}
module WebGear.Core.Trait.Auth.Common (
  AuthorizationHeader,
  getAuthorizationHeaderTrait,
  Realm (..),
  AuthToken (..),
  respondUnauthorized,
) where

import Control.Arrow (returnA, (<<<))
import Data.ByteString (ByteString, drop)
import Data.ByteString.Char8 (break)
import Data.CaseInsensitive (CI, mk, original)
import Data.Proxy (Proxy (..))
import Data.String (IsString (..))
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Void (absurd)
import GHC.TypeLits (KnownSymbol, Symbol, symbolVal)
import qualified Network.HTTP.Types as HTTP
import Web.HttpApiData (FromHttpApiData (..))
import WebGear.Core.Handler (Handler, unlinkA)
import WebGear.Core.Modifiers (Existence (..), ParseStyle (..))
import WebGear.Core.Request (Request)
import WebGear.Core.Response (Response)
import WebGear.Core.Trait (Get (..), Linked, Sets)
import WebGear.Core.Trait.Body (Body, respondA)
import WebGear.Core.Trait.Header (Header (..), RequiredHeader, setHeader)
import WebGear.Core.Trait.Status (Status)
import Prelude hiding (break, drop)

-- | Trait for \"Authorization\" header
type AuthorizationHeader scheme = Header Optional Lenient "Authorization" (AuthToken scheme)

{- | Extract the \"Authorization\" header from a request by specifying
   an authentication scheme.

  The header is split into the scheme and token parts and returned.
-}
getAuthorizationHeaderTrait ::
  forall scheme h ts.
  Get h (AuthorizationHeader scheme) Request =>
  h (Linked ts Request) (Maybe (Either Text (AuthToken scheme)))
getAuthorizationHeaderTrait :: h (Linked ts Request) (Maybe (Either Text (AuthToken scheme)))
getAuthorizationHeaderTrait = proc Linked ts Request
request -> do
  Either Void (Maybe (Either Text (AuthToken scheme)))
result <- Header 'Optional 'Lenient "Authorization" (AuthToken scheme)
-> h (Linked ts Request)
     (Either
        (Absence
           (Header 'Optional 'Lenient "Authorization" (AuthToken scheme))
           Request)
        (Attribute
           (Header 'Optional 'Lenient "Authorization" (AuthToken scheme))
           Request))
forall (h :: * -> * -> *) t a (ts :: [*]).
Get h t a =>
t -> h (Linked ts a) (Either (Absence t a) (Attribute t a))
getTrait (Header 'Optional 'Lenient "Authorization" (AuthToken scheme)
forall (e :: Existence) (p :: ParseStyle) (name :: Symbol) val.
Header e p name val
Header :: Header Optional Lenient "Authorization" (AuthToken scheme)) -< Linked ts Request
request
  h (Maybe (Either Text (AuthToken scheme)))
  (Maybe (Either Text (AuthToken scheme)))
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< (Void -> Maybe (Either Text (AuthToken scheme)))
-> (Maybe (Either Text (AuthToken scheme))
    -> Maybe (Either Text (AuthToken scheme)))
-> Either Void (Maybe (Either Text (AuthToken scheme)))
-> Maybe (Either Text (AuthToken scheme))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Void -> Maybe (Either Text (AuthToken scheme))
forall a. Void -> a
absurd Maybe (Either Text (AuthToken scheme))
-> Maybe (Either Text (AuthToken scheme))
forall a. a -> a
id Either Void (Maybe (Either Text (AuthToken scheme)))
result

-- | The protection space for authentication
newtype Realm = Realm ByteString
  deriving newtype (Realm -> Realm -> Bool
(Realm -> Realm -> Bool) -> (Realm -> Realm -> Bool) -> Eq Realm
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Realm -> Realm -> Bool
$c/= :: Realm -> Realm -> Bool
== :: Realm -> Realm -> Bool
$c== :: Realm -> Realm -> Bool
Eq, Eq Realm
Eq Realm
-> (Realm -> Realm -> Ordering)
-> (Realm -> Realm -> Bool)
-> (Realm -> Realm -> Bool)
-> (Realm -> Realm -> Bool)
-> (Realm -> Realm -> Bool)
-> (Realm -> Realm -> Realm)
-> (Realm -> Realm -> Realm)
-> Ord Realm
Realm -> Realm -> Bool
Realm -> Realm -> Ordering
Realm -> Realm -> Realm
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Realm -> Realm -> Realm
$cmin :: Realm -> Realm -> Realm
max :: Realm -> Realm -> Realm
$cmax :: Realm -> Realm -> Realm
>= :: Realm -> Realm -> Bool
$c>= :: Realm -> Realm -> Bool
> :: Realm -> Realm -> Bool
$c> :: Realm -> Realm -> Bool
<= :: Realm -> Realm -> Bool
$c<= :: Realm -> Realm -> Bool
< :: Realm -> Realm -> Bool
$c< :: Realm -> Realm -> Bool
compare :: Realm -> Realm -> Ordering
$ccompare :: Realm -> Realm -> Ordering
$cp1Ord :: Eq Realm
Ord, Int -> Realm -> ShowS
[Realm] -> ShowS
Realm -> String
(Int -> Realm -> ShowS)
-> (Realm -> String) -> ([Realm] -> ShowS) -> Show Realm
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Realm] -> ShowS
$cshowList :: [Realm] -> ShowS
show :: Realm -> String
$cshow :: Realm -> String
showsPrec :: Int -> Realm -> ShowS
$cshowsPrec :: Int -> Realm -> ShowS
Show, ReadPrec [Realm]
ReadPrec Realm
Int -> ReadS Realm
ReadS [Realm]
(Int -> ReadS Realm)
-> ReadS [Realm]
-> ReadPrec Realm
-> ReadPrec [Realm]
-> Read Realm
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Realm]
$creadListPrec :: ReadPrec [Realm]
readPrec :: ReadPrec Realm
$creadPrec :: ReadPrec Realm
readList :: ReadS [Realm]
$creadList :: ReadS [Realm]
readsPrec :: Int -> ReadS Realm
$creadsPrec :: Int -> ReadS Realm
Read, String -> Realm
(String -> Realm) -> IsString Realm
forall a. (String -> a) -> IsString a
fromString :: String -> Realm
$cfromString :: String -> Realm
IsString)

-- | The components of Authorization request header
data AuthToken (scheme :: Symbol) = AuthToken
  { -- | Authentication scheme
    AuthToken scheme -> CI ByteString
authScheme :: CI ByteString
  , -- | Authentication token
    AuthToken scheme -> ByteString
authToken :: ByteString
  }

instance KnownSymbol scheme => FromHttpApiData (AuthToken scheme) where
  parseUrlPiece :: Text -> Either Text (AuthToken scheme)
parseUrlPiece = ByteString -> Either Text (AuthToken scheme)
forall a. FromHttpApiData a => ByteString -> Either Text a
parseHeader (ByteString -> Either Text (AuthToken scheme))
-> (Text -> ByteString) -> Text -> Either Text (AuthToken scheme)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8

  parseHeader :: ByteString -> Either Text (AuthToken scheme)
parseHeader ByteString
hdr =
    case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
' ') ByteString
hdr of
      (ByteString
scm, ByteString
tok) ->
        let actualScheme :: CI ByteString
actualScheme = ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
mk ByteString
scm
            expectedScheme :: CI ByteString
expectedScheme = String -> CI ByteString
forall a. IsString a => String -> a
fromString (String -> CI ByteString) -> String -> CI ByteString
forall a b. (a -> b) -> a -> b
$ Proxy scheme -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (Proxy scheme -> String) -> Proxy scheme -> String
forall a b. (a -> b) -> a -> b
$ Proxy scheme
forall k (t :: k). Proxy t
Proxy @scheme
         in if CI ByteString
actualScheme CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== CI ByteString
expectedScheme
              then AuthToken scheme -> Either Text (AuthToken scheme)
forall a b. b -> Either a b
Right (CI ByteString -> ByteString -> AuthToken scheme
forall (scheme :: Symbol).
CI ByteString -> ByteString -> AuthToken scheme
AuthToken CI ByteString
actualScheme (Int -> ByteString -> ByteString
drop Int
1 ByteString
tok))
              else Text -> Either Text (AuthToken scheme)
forall a b. a -> Either a b
Left Text
"scheme mismatch"

{- | Create a \"401 Unauthorized\" response.

 The response will have a plain text body and an appropriate
 \"WWW-Authenticate\" header.
-}
respondUnauthorized ::
  ( Handler h m
  , Sets
      h
      [ Status
      , RequiredHeader "Content-Type" Text
      , RequiredHeader "WWW-Authenticate" Text
      , Body Text
      ]
      Response
  ) =>
  -- | The authentication scheme
  CI ByteString ->
  -- | The authentication realm
  Realm ->
  h a Response
respondUnauthorized :: CI ByteString -> Realm -> h a Response
respondUnauthorized CI ByteString
scheme (Realm ByteString
realm) = proc a
_ -> do
  let headerVal :: Text
headerVal = ByteString -> Text
decodeUtf8 (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ CI ByteString -> ByteString
forall s. CI s -> s
original CI ByteString
scheme ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" realm=\"" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
realm ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\""
  h (Linked
     '[Header 'Required 'Strict "WWW-Authenticate" Text,
       RequiredHeader "Content-Type" Text, Body Text, Status]
     Response)
  Response
forall (h :: * -> * -> *) (m :: * -> *) (ts :: [*]).
Handler h m =>
h (Linked ts Response) Response
unlinkA
    h (Linked
     '[Header 'Required 'Strict "WWW-Authenticate" Text,
       RequiredHeader "Content-Type" Text, Body Text, Status]
     Response)
  Response
-> h (Text, Text)
     (Linked
        '[Header 'Required 'Strict "WWW-Authenticate" Text,
          RequiredHeader "Content-Type" Text, Body Text, Status]
        Response)
-> h (Text, Text) Response
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
<<< h Text
  (Linked
     '[RequiredHeader "Content-Type" Text, Body Text, Status] Response)
-> h (Text, Text)
     (Linked
        '[Header 'Required 'Strict "WWW-Authenticate" Text,
          RequiredHeader "Content-Type" Text, Body Text, Status]
        Response)
forall (name :: Symbol) val a (h :: * -> * -> *) (res :: [*]).
Set h (Header 'Required 'Strict name val) Response =>
h a (Linked res Response)
-> h (val, a)
     (Linked (Header 'Required 'Strict name val : res) Response)
setHeader @"WWW-Authenticate" (Status
-> MediaType
-> h Text
     (Linked
        '[RequiredHeader "Content-Type" Text, Body Text, Status] Response)
forall body (h :: * -> * -> *).
Sets
  h
  '[Status, Body body, RequiredHeader "Content-Type" Text]
  Response =>
Status
-> MediaType
-> h body
     (Linked
        '[RequiredHeader "Content-Type" Text, Body body, Status] Response)
respondA Status
HTTP.unauthorized401 MediaType
"text/plain")
    -<
      (Text
headerVal, Text
"Unauthorized" :: Text)