-- |
-- Module:          ConfCrypt.Encryption
-- Copyright:       (c) 2018 Chris Coffey
--                  (c) 2018 CollegeVine
-- License:         MIT
-- Maintainer:      Chris Coffey
-- Stability:       experimental
-- Portability:     portable
--
-- This exposes the interface and instances for handling encryption/decryption. The interface for
-- each operation is intentionally split.

module ConfCrypt.Encryption (

    -- * Working with RSA keys
    KeyProjection,
    project,
    TextKey(..),

    --  * Working with KMS keys
    RemoteKey(..),

    -- * Working with values
    Encrypted,
    renderEncrypted,
    MonadEncrypt,
    encryptValue,
    MonadDecrypt,
    decryptValue,

    -- * Utilities
    loadRSAKey,

    -- ** Exported for Testing
    unpackPrivateRSAKey
    ) where

import ConfCrypt.Types
import ConfCrypt.Providers.AWS (AWSCtx(..), KMSKeyId(..))

import Control.Lens (view)
import Control.Monad.Trans (lift, liftIO, MonadIO)
import Control.Monad.Trans.Class (MonadTrans)
import Control.Monad.Except (MonadError, throwError, Except, ExceptT, runExcept)
import Conduit (MonadResource, MonadThrow)
import Crypto.PubKey.OpenSsh (OpenSshPublicKey(..), OpenSshPrivateKey(..), decodePublic, decodePrivate)
import qualified Crypto.PubKey.RSA.Types as RSA
import Crypto.Types.PubKey.RSA (PrivateKey(..), PublicKey(..))
import Crypto.PubKey.RSA.PKCS15 (encrypt, decrypt)
import Crypto.Random.Types (MonadRandom, getRandomBytes)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC
import qualified Data.ByteString.Base64 as B64
import Data.Text as T
import Data.Text.Encoding as T
import qualified Control.Monad.Trans.AWS as AWS
import qualified Network.AWS.KMS.Encrypt as AWS
import qualified Network.AWS.KMS.Decrypt as AWS

-- | Represents the textual contents of any key stored on the local machine
data TextKey key where
    TextKey :: LocalKey key => key -> TextKey key


-- | This class provides the ability to extract specific parts of a keypair from a given RSA 'KeyPair'
class KeyProjection key where
    project :: RSA.KeyPair -> key

instance KeyProjection RSA.PublicKey where
    project = RSA.toPublicKey

instance LocalKey RSA.PublicKey

instance KeyProjection RSA.PrivateKey where
    project = RSA.toPrivateKey

instance LocalKey RSA.PrivateKey

-- | Given a file on disk that contains the textual representation of an RSA private key (as generated by openssh or ssh-keygen),
-- extract the key from the file and project it into the type of key required.
loadRSAKey :: (MonadIO m, Monad m, MonadError ConfCryptError m, KeyProjection key) =>
    FilePath
    -> m key
loadRSAKey privateKey = do
    prvBytes <- liftIO $ BS.readFile privateKey
    project <$> unpackPrivateRSAKey prvBytes

-- | A private function to actually unpack the RSA key. Only used for testing
unpackPrivateRSAKey :: (MonadError ConfCryptError m) =>
    BS.ByteString
    -> m  RSA.KeyPair
unpackPrivateRSAKey rawPrivateKey =
    case decodePrivate rawPrivateKey of
        Left errMsg -> throwError . KeyUnpackingError $ T.pack errMsg
        Right (OpenSshPrivateKeyDsa _ _ ) -> throwError NonRSAKey
        Right (OpenSshPrivateKeyRsa key ) -> pure $ toKeyPair key
    where
    -- The joys of a needlessly fragmented library ecosystem...
        cryptonitePub key = RSA.PublicKey {
            RSA.public_size = public_size key,
            RSA.public_n = public_n key,
            RSA.public_e = public_e key
            }
        toKeyPair key = RSA.KeyPair $ RSA.PrivateKey {
            RSA.private_pub = cryptonitePub $ private_pub key,
            RSA.private_d = private_d key,
            RSA.private_p = private_p key,
            RSA.private_q = private_q key,
            RSA.private_dP = private_dP key,
            RSA.private_dQ = private_dQ key,
            RSA.private_qinv = private_qinv key
            }

-- TODO use this type in lieu of raw text
newtype Encrypted = Encrypted T.Text
    deriving (Eq, Show)

renderEncrypted :: Encrypted -> T.Text
renderEncrypted (Encrypted encText) = undefined

toEncrypted :: T.Text -> Encrypted
toEncrypted = Encrypted

-- | Decrypts an encrypted block of text
class (Monad m, MonadError ConfCryptError m) => MonadDecrypt m k where
    -- | Given a key and some encrypted ciphertext, returns either the decrypted plaintext or
    -- raises a 'ConfCryptError'
    decryptValue :: k -> T.Text -> m T.Text

instance (Monad m, MonadError ConfCryptError m) => MonadDecrypt m RSA.PrivateKey where
    decryptValue _ "" = pure ""
    decryptValue privateKey encryptedValue =
        either (throwError . DecryptionError)
               (pure . T.decodeUtf8)
               (lMap (T.pack . show) . decrypt Nothing privateKey =<< unwrapBytes encryptedValue)

instance (MonadError ConfCryptError m, Monad m) => MonadDecrypt m (TextKey RSA.PrivateKey) where
    decryptValue (TextKey key) = decryptValue key


--
-- Encryption
--

-- | The interface for encrypting a value is simply a function from a key + plaintext -> ciphertext.
class (Monad m, MonadError ConfCryptError m) => MonadEncrypt m k where
    -- | Encrypts a value and either returns the ciphertext or throws a 'ConfCryptError'
    encryptValue :: k -> T.Text -> m T.Text

instance (Monad m, MonadRandom m, MonadError ConfCryptError m) => MonadEncrypt m RSA.PublicKey where
    encryptValue _ "" = pure ""
    encryptValue publicKey nakedValue = do
        res <- encrypt publicKey $ T.encodeUtf8 nakedValue
        either (throwError . EncryptionError)
               (pure . wrapBytes)
               res

instance (MonadRandom m, MonadError ConfCryptError m, Monad m) =>
    MonadEncrypt m (TextKey RSA.PublicKey) where
    encryptValue (TextKey key) = encryptValue key

instance (MonadRandom m) => MonadRandom (ConfCryptM m k) where
    getRandomBytes = lift . lift . lift . getRandomBytes

instance (MonadRandom m) => MonadRandom (ExceptT e m) where
    getRandomBytes = lift . getRandomBytes


--
-- KMS Support
--

-- | Represents a KMS key remotely managed by a third party service provider.
data RemoteKey key where
    RemoteKey :: KMSKey key => key -> RemoteKey key


instance KMSKey AWSCtx
-- TODO can this constraint be cleaner? Duplicating 'key' is ugly
instance MonadDecrypt (ConfCryptM IO (RemoteKey AWSCtx)) (RemoteKey AWSCtx) where
    decryptValue (RemoteKey AWSCtx {env}) rawValue = AWS.runAWST env $ do
        -- Unwrap bytes
        let decoded = unwrapBytes rawValue
        rawBytes <- either (throwError . AWSDecryptionError) pure decoded
        -- Decrypt them
        decryptResponse <- AWS.send $ AWS.decrypt rawBytes
        let status = view AWS.drsResponseStatus decryptResponse
            plaintext = view AWS.drsPlaintext decryptResponse
            decodedResult = T.decodeUtf8 <$> plaintext
        -- TODO look into AWS status codes and fail on the failure cases
        -- when (status)
        maybe (throwError $ AWSDecryptionError "Unable to decrypt value") pure decodedResult

instance MonadEncrypt (ConfCryptM IO (RemoteKey AWSCtx)) (RemoteKey AWSCtx) where
    encryptValue (RemoteKey AWSCtx {env, kmsKey}) rawValue = AWS.runAWST env $ do
        -- Encode bytes
        let encryptRequest = AWS.encrypt (keyId kmsKey) $ T.encodeUtf8 rawValue
        encryptResponse <- AWS.send encryptRequest
        let status = view AWS.ersResponseStatus encryptResponse
            plaintext = view AWS.ersCiphertextBlob encryptResponse
            -- Wrap them up in B64
            decodedResult = wrapBytes <$> plaintext
        maybe (throwError $ AWSEncryptionError "Unable to encrypt value") pure decodedResult

unwrapBytes :: T.Text -> Either T.Text BS.ByteString
unwrapBytes = lMap T.pack . B64.decode . T.encodeUtf8

lMap :: (a -> b) -> Either a r -> Either b r
lMap f (Left v) = Left (f v)
lMap _ (Right v) = Right v

wrapBytes :: BS.ByteString -> T.Text
wrapBytes = T.decodeUtf8 . B64.encode