{-# LANGUAGE StrictData #-}
{-# LANGUAGE NoFieldSelectors #-}

-- | You will need to import this module if you are planning to define an
-- 'Encryption' scheme other than the defaults provided by this library.
module Wai.CryptoCookie.Encryption
   ( Encryption (..)
   , autoKeyFileBase16
   , readKeyFileBase16
   , readKeyFile
   , writeKeyFile
   ) where

import Control.Exception qualified as Ex
import Control.Monad
import Control.Monad.IO.Class
import Crypto.Random qualified as C
import Data.Aeson qualified as Ae
import Data.Bits
import Data.ByteArray qualified as BA
import Data.ByteArray.Encoding qualified as BA
import Data.ByteArray.Sized qualified as BAS
import Data.ByteString.Lazy qualified as BL
import Data.Char qualified as Char
import Data.Kind (Type)
import Data.Text.Encoding qualified as T
import Data.Word
import GHC.TypeNats
import System.IO qualified as IO
import System.IO.Error qualified as IO

-- | Encryption method.
class (KnownNat (KeyLength e), Eq (Key e)) => Encryption (e :: k) where
   -- | Key used for encryption. You can obtain an initial random
   -- 'Key' using 'genKey'. As long as you have access to
   -- said 'Key', you will be able to decrypt data previously
   -- encrypted with it. For this reason, be sure to save and load the key
   -- using 'keyToBytes' and 'keyFromBytes'.
   data Key e :: Type

   -- | Statically known 'Key' length.
   type KeyLength e :: Natural

   -- | Encryption context used by 'encrypt'.
   data Encrypt e :: Type

   -- | Decryption context used by 'decrypt'.
   data Decrypt e :: Type

   -- | Generate a random encryption 'Key'.
   genKey :: (C.MonadRandom m) => m (Key e)

   -- | Load a 'Key' from its bytes representation, if possible.
   keyFromBytes :: (BA.ByteArrayAccess raw) => raw -> Either String (Key e)

   -- | Dump the bytes representation of a 'Key'.
   keyToBytes :: (BAS.ByteArrayN (KeyLength e) raw) => Key e -> raw

   -- | Generate initial 'Encrypt'ion and 'Decrypt'ion context for a 'Key'.
   --
   -- The 'Encrypt'ion context could carry for example the next
   -- __randomly generated nonce__ to use for 'encrypt'ion, the 'Key'
   -- itself or its derivative used during the actual 'encrypt'ion
   -- process, or a deterministic random number generator.
   --
   -- The 'Decrypt'ion context could carry for example the 'Key' itself or its
   -- derivative used during the 'decrypt'ion process.
   initial :: (C.MonadRandom m) => Key e -> m (Encrypt e, Decrypt e)

   -- | After each 'encrypt'ion, the 'Encrypt'ion context will be automatically
   -- 'advance'd through this function. For example, if your 'Encrypt'ion
   -- context carries a nonce or a deterministic random number generator,
   -- this is the place to update them.
   advance :: Encrypt e -> Encrypt e

   -- | Encrypt a plaintext message according to the 'Encrypt'ion context.
   encrypt :: Encrypt e -> BL.ByteString -> BL.ByteString

   -- | Decrypt a message according to the 'Decrypt'ion context.
   --
   -- The 'String' is for internal debugging purposes only.
   decrypt :: Decrypt e -> BL.ByteString -> Either String BL.ByteString

-- | If the 'FilePath' exists, then read the base-16 representation of
-- a 'Key' from it. Ignores trailing newlines.
--
-- Otherwise, generate a random new 'Key' and write its base-16 representation
-- in the 'FilePath'.
--
-- Finally, return the 'Key'.
autoKeyFileBase16
   :: forall e m
    . (Encryption e, MonadIO m)
   => FilePath
   -> m (Key e)
autoKeyFileBase16 :: forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
FilePath -> m (Key e)
autoKeyFileBase16 FilePath
path = IO (Key e) -> m (Key e)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
   (IOError -> Maybe ())
-> IO (Key e) -> (() -> IO (Key e)) -> IO (Key e)
forall e b a.
Exception e =>
(e -> Maybe b) -> IO a -> (b -> IO a) -> IO a
Ex.catchJust
      (Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> (IOError -> Bool) -> IOError -> Maybe ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> Bool
IO.isDoesNotExistError)
      (FilePath -> IO (Key e)
forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
FilePath -> m (Key e)
readKeyFileBase16 FilePath
path)
      \()
_ -> do
         Key e
k0 <- IO (Key e)
forall k (e :: k) (m :: * -> *).
(Encryption e, MonadRandom m) =>
m (Key e)
forall (m :: * -> *). MonadRandom m => m (Key e)
genKey
         (SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes)
-> FilePath -> Key e -> IO ()
forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
(SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes)
-> FilePath -> Key e -> m ()
writeKeyFile (Base -> SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
BA.convertToBase Base
BA.Base16) FilePath
path Key e
k0
         Key e
k1 <- FilePath -> IO (Key e)
forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
FilePath -> m (Key e)
readKeyFileBase16 FilePath
path
         Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Key e
k0 Key e -> Key e -> Bool
forall a. Eq a => a -> a -> Bool
/= Key e
k1) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath -> IO ()
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"autoKeyFile: no roundtrip"
         Key e -> IO (Key e)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Key e
k1

-- | Read a base-16 encoded 'Key' from a file. Ignores trailing newlines.
readKeyFileBase16
   :: forall e m
    . (Encryption e, MonadIO m)
   => FilePath
   -> m (Key e)
readKeyFileBase16 :: forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
FilePath -> m (Key e)
readKeyFileBase16 = (ScrubbedBytes -> Either FilePath ScrubbedBytes)
-> FilePath -> m (Key e)
forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
(ScrubbedBytes -> Either FilePath ScrubbedBytes)
-> FilePath -> m (Key e)
readKeyFile \ScrubbedBytes
a ->
   case (Word8 -> Bool) -> ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes)
forall bs. ByteArray bs => (Word8 -> Bool) -> bs -> (bs, bs)
BA.span (Bool -> Bool
not (Bool -> Bool) -> (Word8 -> Bool) -> Word8 -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Bool
rn) ScrubbedBytes
a of
      (ScrubbedBytes
pre, ScrubbedBytes
pos)
         | (Word8 -> Bool) -> ScrubbedBytes -> Bool
forall ba. ByteArrayAccess ba => (Word8 -> Bool) -> ba -> Bool
BA.all Word8 -> Bool
rn ScrubbedBytes
pos -> Base -> ScrubbedBytes -> Either FilePath ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either FilePath bout
BA.convertFromBase Base
BA.Base16 ScrubbedBytes
pre
         | Bool
otherwise -> FilePath -> Either FilePath ScrubbedBytes
forall a b. a -> Either a b
Left FilePath
"invalid format"
  where
   Word8
_r :: Word8 = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Char -> Int
Char.ord Char
'\r')
   Word8
_n :: Word8 = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Char -> Int
Char.ord Char
'\n')
   Word8 -> Bool
rn :: Word8 -> Bool = \Word8
w -> Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_r Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_n

-- | Read a 'Key' from a file.
readKeyFile
   :: forall e m
    . (Encryption e, MonadIO m)
   => (BA.ScrubbedBytes -> Either String BA.ScrubbedBytes)
   -- ^ Convert the raw content of the file into input suitable
   -- for 'keyFromBytes'.
   -> FilePath
   -> m (Key e)
readKeyFile :: forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
(ScrubbedBytes -> Either FilePath ScrubbedBytes)
-> FilePath -> m (Key e)
readKeyFile ScrubbedBytes -> Either FilePath ScrubbedBytes
g FilePath
path = IO (Key e) -> m (Key e)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
   FilePath -> IOMode -> (Handle -> IO (Key e)) -> IO (Key e)
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile FilePath
path IOMode
IO.ReadMode \Handle
h -> do
      Int
flen :: Int <- do
         Integer
a <- Handle -> IO Integer
IO.hFileSize Handle
h
         case Integer -> Maybe Int
forall a b.
(Integral a, Integral b, Bits a, Bits b) =>
a -> Maybe b
toIntegralSized Integer
a of
            Just Int
b | Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 -> Int -> IO Int
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
b
            Maybe Int
_ -> FilePath -> IO Int
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"readKeyFile: invalid key file size"
      (Int
rlen, ScrubbedBytes
fraw) <- Int -> (Ptr Any -> IO Int) -> IO (Int, ScrubbedBytes)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
forall p a. Int -> (Ptr p -> IO a) -> IO (a, ScrubbedBytes)
BA.allocRet Int
flen \Ptr Any
p -> Handle -> Ptr Any -> Int -> IO Int
forall a. Handle -> Ptr a -> Int -> IO Int
IO.hGetBuf Handle
h Ptr Any
p Int
flen
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rlen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
flen) do
         -- This shouldn't happen, but we are being extra careful.
         FilePath -> IO ()
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"readKeyFile: could not read key file"
      case ScrubbedBytes -> Either FilePath ScrubbedBytes
g ScrubbedBytes
fraw of
         Left FilePath
e -> FilePath -> IO (Key e)
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail (FilePath -> IO (Key e)) -> FilePath -> IO (Key e)
forall a b. (a -> b) -> a -> b
$ FilePath
"readKeyFile: " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
e
         Right ScrubbedBytes
kraw -> case ScrubbedBytes -> Either FilePath (Key e)
forall raw. ByteArrayAccess raw => raw -> Either FilePath (Key e)
forall k (e :: k) raw.
(Encryption e, ByteArrayAccess raw) =>
raw -> Either FilePath (Key e)
keyFromBytes ScrubbedBytes
kraw of
            Right Key e
key -> Key e -> IO (Key e)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Key e
key
            Left FilePath
err -> FilePath -> IO (Key e)
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail (FilePath -> IO (Key e)) -> FilePath -> IO (Key e)
forall a b. (a -> b) -> a -> b
$ FilePath
"readKeyFile: " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
err

-- | Save a key to a file.
writeKeyFile
   :: forall e m
    . (Encryption e, MonadIO m)
   => (BAS.SizedByteArray (KeyLength e) BA.ScrubbedBytes -> BA.ScrubbedBytes)
   -- ^ Convert the raw 'keyToBytes' bytes to file contents.
   -> FilePath
   -> Key e
   -> m ()
writeKeyFile :: forall {k} (e :: k) (m :: * -> *).
(Encryption e, MonadIO m) =>
(SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes)
-> FilePath -> Key e -> m ()
writeKeyFile SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes
g FilePath
path Key e
key = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
   ScrubbedBytes
kout <- ScrubbedBytes -> IO ScrubbedBytes
forall a. a -> IO a
Ex.evaluate (ScrubbedBytes -> IO ScrubbedBytes)
-> ScrubbedBytes -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes
g (SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes)
-> SizedByteArray (KeyLength e) ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ Key e -> SizedByteArray (KeyLength e) ScrubbedBytes
forall raw. ByteArrayN (KeyLength e) raw => Key e -> raw
forall k (e :: k) raw.
(Encryption e, ByteArrayN (KeyLength e) raw) =>
Key e -> raw
keyToBytes Key e
key
   FilePath -> IOMode -> (Handle -> IO ()) -> IO ()
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile FilePath
path IOMode
IO.WriteMode \Handle
h ->
      ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
BA.withByteArray ScrubbedBytes
kout \Ptr Any
p ->
         Handle -> Ptr Any -> Int -> IO ()
forall a. Handle -> Ptr a -> Int -> IO ()
IO.hPutBuf Handle
h Ptr Any
p (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
kout

-- | Base-16 encoded.
instance (Encryption e) => Ae.FromJSON (Key e) where
   parseJSON :: Value -> Parser (Key e)
parseJSON = FilePath -> (Text -> Parser (Key e)) -> Value -> Parser (Key e)
forall a. FilePath -> (Text -> Parser a) -> Value -> Parser a
Ae.withText FilePath
"Key" \Text
t ->
      -- Note that un-scrubbable bytes will continue to exist in @t@.
      case Base -> ByteString -> Either FilePath ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either FilePath bout
BA.convertFromBase Base
BA.Base16 (Text -> ByteString
T.encodeUtf8 Text
t) of
         Right (ScrubbedBytes
kraw :: BA.ScrubbedBytes) ->
            case ScrubbedBytes -> Either FilePath (Key e)
forall raw. ByteArrayAccess raw => raw -> Either FilePath (Key e)
forall k (e :: k) raw.
(Encryption e, ByteArrayAccess raw) =>
raw -> Either FilePath (Key e)
keyFromBytes ScrubbedBytes
kraw of
               Right Key e
key -> Key e -> Parser (Key e)
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Key e
key
               Left FilePath
err -> FilePath -> Parser (Key e)
forall a. FilePath -> Parser a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
err
         Either FilePath ScrubbedBytes
_ -> FilePath -> Parser (Key e)
forall a. FilePath -> Parser a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"Invalid key"