{-# LANGUAGE StrictData #-}
{-# LANGUAGE NoFieldSelectors #-}

module Wai.CryptoCookie.Middleware
   ( Config (..)
   , CryptoCookie
   , get
   , set
   , middleware
   ) where

import Control.Concurrent.STM
import Control.Monad.IO.Class
import Data.ByteArray.Encoding qualified as BA
import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as BL
import Data.IORef
import Data.Kind (Type)
import Data.List (find)
import Data.Time.Clock.POSIX qualified as Time
import Network.Wai qualified as Wai
import Web.Cookie
   ( SetCookie (..)
   , parseCookies
   , parseSetCookie
   , renderSetCookieBS
   )

import Wai.CryptoCookie.Encoding (Encoding (..))
import Wai.CryptoCookie.Encryption (Encryption (..))

-- | Configuration for 'middleware'.
--
-- Consider using 'Wai.CryptoCookie.defaultConfig' and
-- updating desired fields only.
data Config (a :: Type) = forall e.
    (Encryption e) =>
   Config
   { ()
key :: Key e
   , forall a. Config a -> Encoding a
encoding :: Encoding a
   , forall a. Config a -> SetCookie
setCookie :: SetCookie
   }

data Env (a :: Type) = Env
   { forall a. Env a -> ByteString -> IO ByteString
encrypt :: BL.ByteString -> IO BL.ByteString
   , forall a. Env a -> ByteString -> Maybe ByteString
decrypt :: BL.ByteString -> Maybe BL.ByteString
   , forall a. Env a -> Encoding a
encoding :: Encoding a
   , forall a. Env a -> SetCookie
setCookie :: SetCookie
   }

encodeEncrypt :: Env a -> a -> IO B.ByteString
encodeEncrypt :: forall a. Env a -> a -> IO ByteString
encodeEncrypt Env a
env a
a = do
   ByteString
cryl <- Env a
env.encrypt (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$! Env a
env.encoding.encode a
a
   ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Base -> ByteString -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
BA.convertToBase Base
BA.Base64URLUnpadded (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict ByteString
cryl

decryptDecode :: Env a -> B.ByteString -> Maybe a
decryptDecode :: forall a. Env a -> ByteString -> Maybe a
decryptDecode Env a
env ByteString
cry64 = do
   ByteString
cry <- (String -> Maybe ByteString)
-> (ByteString -> Maybe ByteString)
-> Either String ByteString
-> Maybe ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe ByteString -> String -> Maybe ByteString
forall a b. a -> b -> a
const Maybe ByteString
forall a. Maybe a
Nothing) ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just do
      Base -> ByteString -> Either String ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
BA.convertFromBase Base
BA.Base64URLUnpadded ByteString
cry64
   Env a
env.encoding.decode (ByteString -> Maybe a) -> Maybe ByteString -> Maybe a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Env a
env.decrypt (ByteString -> ByteString
B.fromStrict ByteString
cry)

newEnv :: Config a -> IO (Env a)
newEnv :: forall a. Config a -> IO (Env a)
newEnv Config{Key e
$sel:key:Config :: ()
key :: Key e
key, Encoding a
$sel:encoding:Config :: forall a. Config a -> Encoding a
encoding :: Encoding a
encoding, SetCookie
$sel:setCookie:Config :: forall a. Config a -> SetCookie
setCookie :: SetCookie
setCookie} = do
   (!Encrypt e
ec0, !Decrypt e
dc) <- Key e -> IO (Encrypt e, Decrypt e)
forall k (e :: k) (m :: * -> *).
(Encryption e, MonadRandom m) =>
Key e -> m (Encrypt e, Decrypt e)
forall (m :: * -> *).
MonadRandom m =>
Key e -> m (Encrypt e, Decrypt e)
initial Key e
key
   IORef (Encrypt e)
ecRef <- Encrypt e -> IO (IORef (Encrypt e))
forall a. a -> IO (IORef a)
newIORef Encrypt e
ec0
   Env a -> IO (Env a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      Env
         { $sel:encrypt:Env :: ByteString -> IO ByteString
encrypt = \ByteString
raw -> do
            Encrypt e
ec <- IORef (Encrypt e)
-> (Encrypt e -> (Encrypt e, Encrypt e)) -> IO (Encrypt e)
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef (Encrypt e)
ecRef \Encrypt e
ec -> (Encrypt e -> Encrypt e
forall k (e :: k). Encryption e => Encrypt e -> Encrypt e
advance Encrypt e
ec, Encrypt e
ec)
            ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Encrypt e -> ByteString -> ByteString
forall k (e :: k).
Encryption e =>
Encrypt e -> ByteString -> ByteString
encrypt Encrypt e
ec ByteString
raw
         , $sel:decrypt:Env :: ByteString -> Maybe ByteString
decrypt = (String -> Maybe ByteString)
-> (ByteString -> Maybe ByteString)
-> Either String ByteString
-> Maybe ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe ByteString -> String -> Maybe ByteString
forall a b. a -> b -> a
const Maybe ByteString
forall a. Maybe a
Nothing) ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (Either String ByteString -> Maybe ByteString)
-> (ByteString -> Either String ByteString)
-> ByteString
-> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Decrypt e -> ByteString -> Either String ByteString
forall k (e :: k).
Encryption e =>
Decrypt e -> ByteString -> Either String ByteString
decrypt Decrypt e
dc
         , Encoding a
$sel:encoding:Env :: Encoding a
encoding :: Encoding a
encoding
         , SetCookie
$sel:setCookie:Env :: SetCookie
setCookie :: SetCookie
setCookie
         }

-- | Read-write access to the "Wai.CryptoCookie" data.
--
-- See 'get' and 'set'.
data CryptoCookie a = CryptoCookie (Maybe a) (TVar (Maybe (Maybe a)))

-- | The data that came through the 'Wai.Request' cookie, if any.
get :: CryptoCookie a -> Maybe a
get :: forall a. CryptoCookie a -> Maybe a
get (CryptoCookie Maybe a
x TVar (Maybe (Maybe a))
_) = Maybe a
x

-- | Cause the eventual 'Wai.Response' corresponding to the current
-- 'Wai.Request' to set the cookie to the specified value if 'Just', or expire
-- (/delete/) the cookie if 'Nothing'.
--
-- Overrides previous uses of 'set'.
set :: CryptoCookie a -> Maybe a -> STM ()
set :: forall a. CryptoCookie a -> Maybe a -> STM ()
set (CryptoCookie Maybe a
_ TVar (Maybe (Maybe a))
x) = TVar (Maybe (Maybe a)) -> Maybe (Maybe a) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe (Maybe a))
x (Maybe (Maybe a) -> STM ())
-> (Maybe a -> Maybe (Maybe a)) -> Maybe a -> STM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> Maybe (Maybe a)
forall a. a -> Maybe a
Just

-- | Obtain a new 'Wai.Application'-transforming function (more or less a
-- 'Wai.Middleware') wherein the 'Wai.Application' being transformed can interact
-- with a 'CryptoCookie'.
--
-- * 'middleware' can be called multiple times as long as the 'setCookieName'
-- for the 'SetCookie' specified in 'Config' is different each time.
--
-- * It is safe to reuse the same 'Key' for multiple 'middleware' calls.  Each
-- time the 'Key' will have a different randomly 'initial'ized 'Encrypt'ion
-- context.
middleware
   :: forall a m
    . (MonadIO m)
   => Config a
   -- ^ Consider using 'Wai.CryptoCookie.defaultConfig'.
   -> m ((CryptoCookie a -> Wai.Application) -> Wai.Application)
   -- ^ Remember that 'Wai.Middleware' is a type-synonym for
   -- @'Wai.Application' -> 'Wai.Application'@.  This type is not too different
   -- from that.
middleware :: forall a (m :: * -> *).
MonadIO m =>
Config a -> m ((CryptoCookie a -> Application) -> Application)
middleware Config a
c = IO ((CryptoCookie a -> Application) -> Application)
-> m ((CryptoCookie a -> Application) -> Application)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
   Env a
env <- Config a -> IO (Env a)
forall a. Config a -> IO (Env a)
newEnv Config a
c
   ((CryptoCookie a -> Application) -> Application)
-> IO ((CryptoCookie a -> Application) -> Application)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure \CryptoCookie a -> Application
fapp -> \Request
req Response -> IO ResponseReceived
respond -> do
      TVar (Maybe (Maybe a))
tv <- Maybe (Maybe a) -> IO (TVar (Maybe (Maybe a)))
forall a. a -> IO (TVar a)
newTVarIO Maybe (Maybe a)
forall a. Maybe a
Nothing
      CryptoCookie a -> Application
fapp (Maybe a -> TVar (Maybe (Maybe a)) -> CryptoCookie a
forall a. Maybe a -> TVar (Maybe (Maybe a)) -> CryptoCookie a
CryptoCookie (Env a -> Request -> Maybe a
forall a. Env a -> Request -> Maybe a
getRequestCookieData Env a
env Request
req) TVar (Maybe (Maybe a))
tv) Request
req \Response
res -> do
         Maybe (Maybe a)
yya1 <- TVar (Maybe (Maybe a)) -> IO (Maybe (Maybe a))
forall a. TVar a -> IO a
readTVarIO TVar (Maybe (Maybe a))
tv
         let f :: Response -> IO Response
f = case Maybe (Maybe a)
yya1 of
               Maybe (Maybe a)
Nothing -> Response -> IO Response
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
               Just Maybe a
Nothing -> Env a -> Response -> IO Response
forall a. Env a -> Response -> IO Response
expireResponseCookieData Env a
env
               Just (Just a
a1) -> Env a -> a -> Response -> IO Response
forall a. Env a -> a -> Response -> IO Response
setResponseCookieData Env a
env a
a1
         Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> IO Response -> IO ResponseReceived
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Response -> IO Response
f Response
res

-- | Find, decrypt and decode the cookie value from the 'Wai.Request'.
--
-- 'Nothing' if the unique cookie couldn't be found
-- or couldn't be decrypted. 'Left' if the 'Encoding' failed.
getRequestCookieData :: Env a -> Wai.Request -> Maybe a
getRequestCookieData :: forall a. Env a -> Request -> Maybe a
getRequestCookieData Env a
env Request
r = do
   let cookieName :: ByteString
cookieName = SetCookie -> ByteString
setCookieName Env a
env.setCookie
   [ByteString
cry] <- [ByteString] -> Maybe [ByteString]
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ByteString] -> Maybe [ByteString])
-> [ByteString] -> Maybe [ByteString]
forall a b. (a -> b) -> a -> b
$ ByteString -> [(ByteString, ByteString)] -> [ByteString]
forall k v. Eq k => k -> [(k, v)] -> [v]
lookupMany ByteString
cookieName ([(ByteString, ByteString)] -> [ByteString])
-> [(ByteString, ByteString)] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Request -> [(ByteString, ByteString)]
requestCookies Request
r
   Env a -> ByteString -> Maybe a
forall a. Env a -> ByteString -> Maybe a
decryptDecode Env a
env ByteString
cry

-- | Adds the @Set-Cookie@ header to the 'Wai.Response'.
setResponseCookieData :: Env a -> a -> Wai.Response -> IO Wai.Response
setResponseCookieData :: forall a. Env a -> a -> Response -> IO Response
setResponseCookieData Env a
env a
a = \Response
res ->
   case (SetCookie -> Bool) -> [SetCookie] -> Maybe SetCookie
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find SetCookie -> Bool
predicate (Response -> [SetCookie]
responseCookies Response
res) of
      Maybe SetCookie
Nothing -> do
         ByteString
cry <- Env a -> a -> IO ByteString
forall a. Env a -> a -> IO ByteString
encodeEncrypt Env a
env a
a
         let hval :: ByteString
hval = SetCookie -> ByteString
renderSetCookieBS (SetCookie -> ByteString) -> SetCookie -> ByteString
forall a b. (a -> b) -> a -> b
$ Env a
env.setCookie{setCookieValue = cry}
         Response -> IO Response
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Response -> IO Response) -> Response -> IO Response
forall a b. (a -> b) -> a -> b
$ (ResponseHeaders -> ResponseHeaders) -> Response -> Response
Wai.mapResponseHeaders ((HeaderName
"Set-Cookie", ByteString
hval) :) Response
res
      Maybe SetCookie
_ -> String -> IO Response
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO Response) -> String -> IO Response
forall a b. (a -> b) -> a -> b
$ String
"Duplicate cookie name: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
forall a. Show a => a -> String
show ByteString
cookieName
  where
   cookieName :: ByteString
cookieName = SetCookie -> ByteString
setCookieName Env a
env.setCookie
   predicate :: SetCookie -> Bool
predicate = \SetCookie
x -> SetCookie -> ByteString
setCookieName SetCookie
x ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
cookieName

-- | Adds the @Set-Cookie@ header to the 'Wai.Response'.
expireResponseCookieData :: Env a -> Wai.Response -> IO Wai.Response
expireResponseCookieData :: forall a. Env a -> Response -> IO Response
expireResponseCookieData Env a
env = \Response
res ->
   case (SetCookie -> Bool) -> [SetCookie] -> Maybe SetCookie
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find SetCookie -> Bool
predicate (Response -> [SetCookie]
responseCookies Response
res) of
      Maybe SetCookie
Nothing -> Response -> IO Response
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Response -> IO Response) -> Response -> IO Response
forall a b. (a -> b) -> a -> b
$ (ResponseHeaders -> ResponseHeaders) -> Response -> Response
Wai.mapResponseHeaders ((HeaderName
"Set-Cookie", ByteString
hval) :) Response
res
      Maybe SetCookie
_ -> String -> IO Response
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO Response) -> String -> IO Response
forall a b. (a -> b) -> a -> b
$ String
"Duplicate cookie name: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
forall a. Show a => a -> String
show ByteString
cookieName
  where
   cookieName :: ByteString
cookieName = SetCookie -> ByteString
setCookieName Env a
env.setCookie
   predicate :: SetCookie -> Bool
predicate = \SetCookie
x -> SetCookie -> ByteString
setCookieName SetCookie
x ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
cookieName
   hval :: ByteString
hval =
      SetCookie -> ByteString
renderSetCookieBS (SetCookie -> ByteString) -> SetCookie -> ByteString
forall a b. (a -> b) -> a -> b
$
         Env a
env.setCookie
            { setCookieValue = mempty
            , setCookieExpires = Just (Time.posixSecondsToUTCTime 0)
            , setCookieMaxAge = Just (negate 1)
            }

--------------------------------------------------------------------------------

requestCookies :: Wai.Request -> [(B.ByteString, B.ByteString)]
requestCookies :: Request -> [(ByteString, ByteString)]
requestCookies Request
r = ByteString -> [(ByteString, ByteString)]
parseCookies (ByteString -> [(ByteString, ByteString)])
-> [ByteString] -> [(ByteString, ByteString)]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< HeaderName -> ResponseHeaders -> [ByteString]
forall k v. Eq k => k -> [(k, v)] -> [v]
lookupMany HeaderName
"Cookie" (Request -> ResponseHeaders
Wai.requestHeaders Request
r)

responseCookies :: Wai.Response -> [SetCookie]
responseCookies :: Response -> [SetCookie]
responseCookies =
   (ByteString -> SetCookie) -> [ByteString] -> [SetCookie]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> SetCookie
parseSetCookie ([ByteString] -> [SetCookie])
-> (Response -> [ByteString]) -> Response -> [SetCookie]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> ResponseHeaders -> [ByteString]
forall k v. Eq k => k -> [(k, v)] -> [v]
lookupMany HeaderName
"Set-Cookie" (ResponseHeaders -> [ByteString])
-> (Response -> ResponseHeaders) -> Response -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> ResponseHeaders
Wai.responseHeaders

--------------------------------------------------------------------------------

lookupMany :: (Eq k) => k -> [(k, v)] -> [v]
lookupMany :: forall k v. Eq k => k -> [(k, v)] -> [v]
lookupMany k
k = (k -> Bool) -> [(k, v)] -> [v]
forall k v. Eq k => (k -> Bool) -> [(k, v)] -> [v]
findMany (k -> k -> Bool
forall a. Eq a => a -> a -> Bool
== k
k)

findMany :: (Eq k) => (k -> Bool) -> [(k, v)] -> [v]
findMany :: forall k v. Eq k => (k -> Bool) -> [(k, v)] -> [v]
findMany k -> Bool
f = ((k, v) -> v) -> [(k, v)] -> [v]
forall a b. (a -> b) -> [a] -> [b]
map (k, v) -> v
forall a b. (a, b) -> b
snd ([(k, v)] -> [v]) -> ([(k, v)] -> [(k, v)]) -> [(k, v)] -> [v]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((k, v) -> Bool) -> [(k, v)] -> [(k, v)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(k
a, v
_) -> k -> Bool
f k
a)