{-# LANGUAGE OverloadedStrings #-}

module Cookie.Secure.Middleware (secureCookies) where

import Network.Wai (Middleware
                  , Request
                  , ResponseReceived
                  , responseLBS
                  , requestHeaders
                  , responseHeaders)
import Network.Wai.Internal (Response(..))
import Network.HTTP.Types.Header (Header
                                , RequestHeaders
                                , ResponseHeaders)
import Network.HTTP.Types.Status (status200)
import qualified Data.ByteString.Char8 as BS
import Data.Maybe (catMaybes)
import Cookie.Secure (encryptAndSignIO, verifyAndDecryptIO)
import Data.List.Split (splitOn)

secureCookies :: Middleware
secureCookies :: Middleware
secureCookies Application
app Request
request Response -> IO ResponseReceived
respondWith =
  Request -> IO Request
verifyAndDecryptCookies Request
request
  IO Request
-> (Request -> IO ResponseReceived) -> IO ResponseReceived
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Application
-> (Response -> IO ResponseReceived)
-> Request
-> IO ResponseReceived
forall a b c. (a -> b -> c) -> b -> a -> c
flip Application
app ((Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
encryptAndSignCookies Response -> IO ResponseReceived
respondWith)

verifyAndDecryptCookies :: Request -> IO Request
verifyAndDecryptCookies :: Request -> IO Request
verifyAndDecryptCookies Request
request =
  Request -> RequestHeaders -> Request
replaceRequestHeaders Request
request
  (RequestHeaders -> Request) -> IO RequestHeaders -> IO Request
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Header -> IO Header) -> RequestHeaders -> IO RequestHeaders
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Header -> IO Header
verifyAndDecryptIfCookieHeader (Request -> RequestHeaders
requestHeaders Request
request)

encryptAndSignCookies
  :: (Response -> IO ResponseReceived)
  -> Response -> IO ResponseReceived
encryptAndSignCookies :: (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
encryptAndSignCookies Response -> IO ResponseReceived
respondWith Response
response = do
  (Header -> IO Header) -> RequestHeaders -> IO RequestHeaders
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Header -> IO Header
encryptAndSignIfSetCookieHeader (Response -> RequestHeaders
responseHeaders Response
response)
  IO RequestHeaders
-> (RequestHeaders -> IO ResponseReceived) -> IO ResponseReceived
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Response -> IO ResponseReceived
respondWith (Response -> IO ResponseReceived)
-> (RequestHeaders -> Response)
-> RequestHeaders
-> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> RequestHeaders -> Response
replaceResponseHeaders Response
response

encryptAndSignIfSetCookieHeader :: Header -> IO Header
encryptAndSignIfSetCookieHeader :: Header -> IO Header
encryptAndSignIfSetCookieHeader Header
header =
  if Header -> HeaderName
forall a b. (a, b) -> a
fst Header
header HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
"Set-Cookie"
  then Header -> IO Header
encryptAndSignCookieHeader Header
header
  else Header -> IO Header
forall (m :: * -> *) a. Monad m => a -> m a
return Header
header

encryptAndSignCookieHeader :: Header -> IO Header
encryptAndSignCookieHeader :: Header -> IO Header
encryptAndSignCookieHeader (HeaderName
name, ByteString
value) = (,)
  (HeaderName -> ByteString -> Header)
-> IO HeaderName -> IO (ByteString -> Header)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HeaderName -> IO HeaderName
forall (m :: * -> *) a. Monad m => a -> m a
return HeaderName
name
  IO (ByteString -> Header) -> IO ByteString -> IO Header
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO ByteString
encryptedSignedCookieHeaderValue
    where
      (ByteString
cookie, ByteString
metadata) = (Char -> Bool) -> ByteString -> (ByteString, ByteString)
BS.break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
';') ByteString
value
      encryptedSignedCookieHeaderValue :: IO ByteString
encryptedSignedCookieHeaderValue =
        (ByteString -> ByteString -> ByteString)
-> ByteString -> ByteString -> ByteString
forall a b c. (a -> b -> c) -> b -> a -> c
flip ByteString -> ByteString -> ByteString
BS.append ByteString
metadata (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO ByteString
encryptAndSignCookie ByteString
cookie
      encryptAndSignCookie :: ByteString -> IO ByteString
encryptAndSignCookie ByteString
c = do
        let cookieNameValueList :: [ByteString]
cookieNameValueList = (String -> ByteString) -> [String] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map String -> ByteString
BS.pack ([String] -> [ByteString])
-> (String -> [String]) -> String -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String -> [String]
forall a. Eq a => [a] -> [a] -> [[a]]
splitOn String
"=" (String -> [ByteString]) -> String -> [ByteString]
forall a b. (a -> b) -> a -> b
$ ByteString -> String
BS.unpack ByteString
c
        let cName :: ByteString
cName = [ByteString] -> ByteString
forall a. [a] -> a
head [ByteString]
cookieNameValueList
        let cValue :: ByteString
cValue = [ByteString] -> ByteString
forall a. [a] -> a
last [ByteString]
cookieNameValueList

        ByteString
encryptedValue <- ByteString -> IO ByteString
encryptAndSignIO ByteString
cValue

        ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"=" [ByteString
cName, ByteString
encryptedValue]

replaceRequestHeaders :: Request -> RequestHeaders -> Request
replaceRequestHeaders :: Request -> RequestHeaders -> Request
replaceRequestHeaders Request
request RequestHeaders
newHeaders =
  Request
request { requestHeaders :: RequestHeaders
requestHeaders = RequestHeaders
newHeaders }

-- OPTIMIZE: Response is imported from Network.Wai.Internal, which
-- interface is not guaranteed to be stable.
replaceResponseHeaders :: Response -> ResponseHeaders -> Response
replaceResponseHeaders :: Response -> RequestHeaders -> Response
replaceResponseHeaders
  (ResponseFile Status
status RequestHeaders
headers String
filepath Maybe FilePart
possibleFilepart) RequestHeaders
newHeaders =
    Status -> RequestHeaders -> String -> Maybe FilePart -> Response
ResponseFile Status
status RequestHeaders
newHeaders String
filepath Maybe FilePart
possibleFilepart
replaceResponseHeaders (ResponseBuilder Status
status RequestHeaders
headers Builder
builder) RequestHeaders
newHeaders =
  Status -> RequestHeaders -> Builder -> Response
ResponseBuilder Status
status RequestHeaders
newHeaders Builder
builder
replaceResponseHeaders (ResponseStream Status
status RequestHeaders
headers StreamingBody
body) RequestHeaders
newHeaders =
  Status -> RequestHeaders -> StreamingBody -> Response
ResponseStream Status
status RequestHeaders
newHeaders StreamingBody
body
replaceResponseHeaders (ResponseRaw IO ByteString -> (ByteString -> IO ()) -> IO ()
toStreaming Response
response) RequestHeaders
newHeaders =
  (IO ByteString -> (ByteString -> IO ()) -> IO ())
-> Response -> Response
ResponseRaw IO ByteString -> (ByteString -> IO ()) -> IO ()
toStreaming (Response -> RequestHeaders -> Response
replaceResponseHeaders Response
response RequestHeaders
newHeaders)

verifyAndDecryptIfCookieHeader :: Header -> IO Header
verifyAndDecryptIfCookieHeader :: Header -> IO Header
verifyAndDecryptIfCookieHeader Header
header =
  if Header -> HeaderName
forall a b. (a, b) -> a
fst Header
header HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
"Cookie"
  then Header -> IO Header
verifyAndDecryptCookieHeader Header
header
  else Header -> IO Header
forall (m :: * -> *) a. Monad m => a -> m a
return Header
header

verifyAndDecryptCookieHeader :: Header -> IO Header
verifyAndDecryptCookieHeader :: Header -> IO Header
verifyAndDecryptCookieHeader (HeaderName
name, ByteString
value) = (,)
  (HeaderName -> ByteString -> Header)
-> IO HeaderName -> IO (ByteString -> Header)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HeaderName -> IO HeaderName
forall (m :: * -> *) a. Monad m => a -> m a
return HeaderName
name
  IO (ByteString -> Header) -> IO ByteString -> IO Header
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ByteString -> IO ByteString
verifyAndDecryptCookieHeaderValue ByteString
value
    where
      verifyAndDecryptCookieHeaderValue :: ByteString -> IO ByteString
verifyAndDecryptCookieHeaderValue ByteString
value =
        ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"; " ([ByteString] -> ByteString)
-> ([Maybe ByteString] -> [ByteString])
-> [Maybe ByteString]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe ByteString] -> [ByteString]
forall a. [Maybe a] -> [a]
catMaybes
        ([Maybe ByteString] -> ByteString)
-> IO [Maybe ByteString] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> IO (Maybe ByteString))
-> [String] -> IO [Maybe ByteString]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM String -> IO (Maybe ByteString)
verifyAndDecryptCookie
        (String -> String -> [String]
forall a. Eq a => [a] -> [a] -> [[a]]
splitOn String
"; " (ByteString -> String
BS.unpack ByteString
value))
      verifyAndDecryptCookie :: String -> IO (Maybe ByteString)
verifyAndDecryptCookie String
cookie = do
        let cookieNameValueList :: [ByteString]
cookieNameValueList = (String -> ByteString) -> [String] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map String -> ByteString
BS.pack ([String] -> [ByteString]) -> [String] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ String -> String -> [String]
forall a. Eq a => [a] -> [a] -> [[a]]
splitOn String
"=" String
cookie
        let cName :: ByteString
cName = [ByteString] -> ByteString
forall a. [a] -> a
head [ByteString]
cookieNameValueList
        let cValue :: ByteString
cValue = [ByteString] -> ByteString
forall a. [a] -> a
last [ByteString]
cookieNameValueList

        Maybe ByteString
encryptedValue <- ByteString -> IO (Maybe ByteString)
verifyAndDecryptIO ByteString
cValue

        -- OPTIMIZE: maybe silently dropping cookies which fail to verify
        -- or decrypt isn't the best idea?
        case Maybe ByteString
encryptedValue of
          Maybe ByteString
Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ByteString
forall a. Maybe a
Nothing
          Just ByteString
encryptedValue' ->
            Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> (ByteString -> Maybe ByteString)
-> ByteString
-> IO (Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> IO (Maybe ByteString))
-> ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"=" [ByteString
cName, ByteString
encryptedValue']