{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}

module Web.Rails7.Session (
  -- * Decoding
    decode
  , decodeEither
  , DecodingError(..)
  -- * Decrypting
  , decrypt
  ) where

import Control.Applicative ((<$>))
import Control.Monad
import Crypto.PBKDF.ByteString (sha1PBKDF2, sha256PBKDF2)
import Data.Aeson qualified as JSON
import Data.Bifunctor
import Data.ByteArray qualified as BA
import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.ByteString.Base64 qualified as B64
import Data.ByteString.Char8 qualified as C8
import Data.ByteString.Lazy qualified as BL
import Data.Either (Either(..), either)
import Data.Function.Compat ((&))
import Data.Maybe (Maybe(..), fromMaybe)
import Data.Monoid ((<>))
import Data.Ruby.Marshal (RubyObject(..), RubyStringEncoding(..))
import Data.String.Conv (toS)
import Data.Vector qualified as Vec
import Network.HTTP.Types (urlDecode)
import Prelude (Bool(..), Eq, Int, Ord, Show, String, ($!), (.) , (==), const, error, fst, show, snd)
import Prelude hiding (lookup)
import Web.Rails.Session.Types
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.AESGCMSIV qualified as AESGCM
import Crypto.Cipher.Types (cbcDecrypt, cipherInit, makeIV, aeadInit, AEADMode (..), aeadSimpleDecrypt, AuthTag(..))
import Crypto.Error (CryptoFailable(CryptoFailed, CryptoPassed))

data DecodingError
  = InvalidCookieFormat
  | InvalidAuthTagSize Int
  | InvalidIVSize Int
  | InvalidJSON String
  | InvalidBase64 String
  | InvalidCryptoStep String
  | MakeIVFailed String
  | DecryptionIsEmpty
  deriving (Int -> DecodingError -> ShowS
[DecodingError] -> ShowS
DecodingError -> String
(Int -> DecodingError -> ShowS)
-> (DecodingError -> String)
-> ([DecodingError] -> ShowS)
-> Show DecodingError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DecodingError -> ShowS
showsPrec :: Int -> DecodingError -> ShowS
$cshow :: DecodingError -> String
show :: DecodingError -> String
$cshowList :: [DecodingError] -> ShowS
showList :: [DecodingError] -> ShowS
Show, DecodingError -> DecodingError -> Bool
(DecodingError -> DecodingError -> Bool)
-> (DecodingError -> DecodingError -> Bool) -> Eq DecodingError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DecodingError -> DecodingError -> Bool
== :: DecodingError -> DecodingError -> Bool
$c/= :: DecodingError -> DecodingError -> Bool
/= :: DecodingError -> DecodingError -> Bool
Eq)

-- EXPORTS

-- | Decode a cookie encrypted by Rails.
decode :: Maybe Salt
       -> SecretKeyBase
       -> Cookie
       -> Maybe JSON.Value
decode :: Maybe Salt -> SecretKeyBase -> Cookie -> Maybe Value
decode Maybe Salt
mbSalt SecretKeyBase
secretKeyBase Cookie
cookie =
  (DecodingError -> Maybe Value)
-> (Value -> Maybe Value)
-> Either DecodingError Value
-> Maybe Value
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe Value -> DecodingError -> Maybe Value
forall a b. a -> b -> a
const Maybe Value
forall a. Maybe a
Nothing) Value -> Maybe Value
forall a. a -> Maybe a
Just (Maybe Salt -> SecretKeyBase -> Cookie -> Either DecodingError Value
decodeEither Maybe Salt
mbSalt SecretKeyBase
secretKeyBase Cookie
cookie)

-- | Decode a cookie encrypted by Rails and retain some error information on failure.
decodeEither :: Maybe Salt
             -> SecretKeyBase
             -> Cookie
             -> Either DecodingError JSON.Value
decodeEither :: Maybe Salt -> SecretKeyBase -> Cookie -> Either DecodingError Value
decodeEither Maybe Salt
mbSalt SecretKeyBase
secretKeyBase Cookie
cookie = do
  case Maybe Salt
-> SecretKeyBase -> Cookie -> Either DecodingError DecryptedData
decrypt Maybe Salt
mbSalt SecretKeyBase
secretKeyBase Cookie
cookie of
    Left DecodingError
errorMessage ->
      DecodingError -> Either DecodingError Value
forall a b. a -> Either a b
Left DecodingError
errorMessage
    Right (DecryptedData ByteString
deData) ->
      (String -> DecodingError)
-> Either String Value -> Either DecodingError Value
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> DecodingError
InvalidJSON (Either String Value -> Either DecodingError Value)
-> Either String Value -> Either DecodingError Value
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String Value
forall a. FromJSON a => ByteString -> Either String a
JSON.eitherDecode (ByteString -> ByteString
BL.fromStrict ByteString
deData)

-- | Decrypts a cookie encrypted by Rails. It returns the encrypted
-- data as a 'ByteString' blob, which is your responsibility to deserialise.
decrypt :: Maybe Salt
        -> SecretKeyBase
        -> Cookie
        -> Either DecodingError DecryptedData
decrypt :: Maybe Salt
-> SecretKeyBase -> Cookie -> Either DecodingError DecryptedData
decrypt Maybe Salt
mbSalt SecretKeyBase
secretKeyBase Cookie
cookie = do
  (EncryptedData ByteString
encData, InitVector ByteString
ivVec, AuthTag
autTag) <- Cookie -> Either DecodingError (EncryptedData, InitVector, AuthTag)
prepare Cookie
cookie
  Nonce
nonce <- CryptoFailable Nonce -> Either DecodingError Nonce
forall a. CryptoFailable a -> Either DecodingError a
doCryptoStep (CryptoFailable Nonce -> Either DecodingError Nonce)
-> CryptoFailable Nonce -> Either DecodingError Nonce
forall a b. (a -> b) -> a -> b
$ ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
AESGCM.nonce ByteString
ivVec
  AES256
cipher <- CryptoFailable AES256 -> Either DecodingError AES256
forall a. CryptoFailable a -> Either DecodingError a
doCryptoStep (ByteString -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
forall key. ByteArray key => key -> CryptoFailable AES256
cipherInit ByteString
cipherKey :: CryptoFailable AES256)
  AEAD AES256
aad <- CryptoFailable (AEAD AES256) -> Either DecodingError (AEAD AES256)
forall a. CryptoFailable a -> Either DecodingError a
doCryptoStep (AEADMode -> AES256 -> ByteString -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
forall iv.
ByteArrayAccess iv =>
AEADMode -> AES256 -> iv -> CryptoFailable (AEAD AES256)
aeadInit AEADMode
AEAD_GCM AES256
cipher ByteString
ivVec)
  case AEAD AES256
-> ByteString -> ByteString -> AuthTag -> Maybe ByteString
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> AuthTag -> Maybe ba
aeadSimpleDecrypt AEAD AES256
aad (ByteString
forall a. Monoid a => a
mempty :: ByteString) ByteString
encData AuthTag
autTag of
    Maybe ByteString
Nothing -> DecodingError -> Either DecodingError DecryptedData
forall a b. a -> Either a b
Left DecodingError
DecryptionIsEmpty
    Just ByteString
dt -> DecryptedData -> Either DecodingError DecryptedData
forall a b. b -> Either a b
Right (DecryptedData -> Either DecodingError DecryptedData)
-> DecryptedData -> Either DecodingError DecryptedData
forall a b. (a -> b) -> a -> b
$ ByteString -> DecryptedData
DecryptedData ByteString
dt

  where

    cipherKey :: ByteString
    (SecretKey ByteString
cipherKey) = Salt -> SecretKeyBase -> SecretKey
generateSecret Salt
salt SecretKeyBase
secretKeyBase

    salt :: Salt
    salt :: Salt
salt = Salt -> Maybe Salt -> Salt
forall a. a -> Maybe a -> a
fromMaybe Salt
defaultSalt Maybe Salt
mbSalt

    defaultSalt :: Salt
    defaultSalt :: Salt
defaultSalt = ByteString -> Salt
Salt ByteString
"authenticated encrypted cookie"

doCryptoStep :: CryptoFailable a -> Either DecodingError a
doCryptoStep :: forall a. CryptoFailable a -> Either DecodingError a
doCryptoStep = \case
  CryptoFailed CryptoError
errorMessage ->
    DecodingError -> Either DecodingError a
forall a b. a -> Either a b
Left (String -> DecodingError
InvalidCryptoStep (String -> DecodingError) -> String -> DecodingError
forall a b. (a -> b) -> a -> b
$ CryptoError -> String
forall a. Show a => a -> String
show CryptoError
errorMessage)
  CryptoPassed a
a -> a -> Either DecodingError a
forall a b. b -> Either a b
Right a
a

-- PRIVATE

-- | Generate secret key using same cryptographic routines as Rails.
generateSecret :: Salt -> SecretKeyBase -> SecretKey
generateSecret :: Salt -> SecretKeyBase -> SecretKey
generateSecret (Salt ByteString
salt) (SecretKeyBase ByteString
secret) =
  ByteString -> SecretKey
SecretKey (ByteString -> SecretKey) -> ByteString -> SecretKey
forall a b. (a -> b) -> a -> b
$! ByteString -> ByteString -> Int -> Int -> ByteString
sha256PBKDF2 ByteString
secret ByteString
salt Int
1000 Int
32

-- | Prepare a cookie for decryption.
-- /NOTE/: Unlike Rails4, Rails7 cookies contains a final auth tag at the end.
prepare :: Cookie -> Either DecodingError (EncryptedData, InitVector, AuthTag)
prepare :: Cookie -> Either DecodingError (EncryptedData, InitVector, AuthTag)
prepare (Cookie ByteString
cookie) =
  case ByteString -> ByteString -> [ByteString]
tokenise ByteString
"--" ByteString
cookie of
    [ByteString
encDataB64, ByteString
ivVectorB64, ByteString
autTagB64]
     -> do
       ByteString
encData  <- ByteString -> Either DecodingError ByteString
base64decode ByteString
encDataB64
       ByteString
ivVector <- ByteString -> Either DecodingError ByteString
enforceIVLen     (ByteString -> Either DecodingError ByteString)
-> Either DecodingError ByteString
-> Either DecodingError ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ByteString -> Either DecodingError ByteString
base64decode ByteString
ivVectorB64
       ByteString
autTag   <- ByteString -> Either DecodingError ByteString
enforceAutTagLen (ByteString -> Either DecodingError ByteString)
-> Either DecodingError ByteString
-> Either DecodingError ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ByteString -> Either DecodingError ByteString
base64decode ByteString
autTagB64
       (EncryptedData, InitVector, AuthTag)
-> Either DecodingError (EncryptedData, InitVector, AuthTag)
forall a b. b -> Either a b
Right (ByteString -> EncryptedData
EncryptedData ByteString
encData, ByteString -> InitVector
InitVector ByteString
ivVector, Bytes -> AuthTag
AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
autTag)
    [ByteString]
_ -> DecodingError
-> Either DecodingError (EncryptedData, InitVector, AuthTag)
forall a b. a -> Either a b
Left DecodingError
InvalidCookieFormat
  where
    base64decode :: ByteString -> Either DecodingError ByteString
    base64decode :: ByteString -> Either DecodingError ByteString
base64decode = ByteString -> Either DecodingError ByteString
forall a b. b -> Either a b
Right (ByteString -> Either DecodingError ByteString)
-> (ByteString -> ByteString)
-> ByteString
-> Either DecodingError ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B64.decodeLenient (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> ByteString -> ByteString
C8.filter (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'\n')

    enforceAutTagLen :: ByteString -> Either DecodingError ByteString
    enforceAutTagLen :: ByteString -> Either DecodingError ByteString
enforceAutTagLen ByteString
b
      | ByteString -> Int
BS.length ByteString
b Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
16
      = ByteString -> Either DecodingError ByteString
forall a b. b -> Either a b
Right ByteString
b
      | Bool
otherwise
      = DecodingError -> Either DecodingError ByteString
forall a b. a -> Either a b
Left (DecodingError -> Either DecodingError ByteString)
-> DecodingError -> Either DecodingError ByteString
forall a b. (a -> b) -> a -> b
$ Int -> DecodingError
InvalidAuthTagSize (ByteString -> Int
BS.length ByteString
b)

    enforceIVLen :: ByteString -> Either DecodingError ByteString
    enforceIVLen :: ByteString -> Either DecodingError ByteString
enforceIVLen ByteString
b
      | ByteString -> Int
BS.length ByteString
b Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
12
      = ByteString -> Either DecodingError ByteString
forall a b. b -> Either a b
Right ByteString
b
      | Bool
otherwise
      = DecodingError -> Either DecodingError ByteString
forall a b. a -> Either a b
Left (DecodingError -> Either DecodingError ByteString)
-> DecodingError -> Either DecodingError ByteString
forall a b. (a -> b) -> a -> b
$ Int -> DecodingError
InvalidIVSize (ByteString -> Int
BS.length ByteString
b)


separator :: ByteString
separator :: ByteString
separator = ByteString
"--"

tokenise :: ByteString -> ByteString -> [ByteString]
tokenise :: ByteString -> ByteString -> [ByteString]
tokenise ByteString
x ByteString
y = ByteString
h ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: if ByteString -> Bool
BS.null ByteString
t then [] else ByteString -> ByteString -> [ByteString]
tokenise ByteString
x (Int -> ByteString -> ByteString
BS.drop (ByteString -> Int
BS.length ByteString
x) ByteString
t)
    where (ByteString
h,ByteString
t) = ByteString -> ByteString -> (ByteString, ByteString)
BS.breakSubstring ByteString
x ByteString
y