-- SecretKey.hs: OpenPGP (RFC4880) secret key decryption
-- Copyright © 2013  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).

module Codec.Encryption.OpenPGP.SecretKey (
   decryptPrivateKey
 , encryptPrivateKey
 , encryptPrivateKeyIO
 , reencryptSecretKeyIO
) where

import Codec.Encryption.OpenPGP.Types
import Codec.Encryption.OpenPGP.BlockCipher (saBlockSize, keySize)
import Codec.Encryption.OpenPGP.CFB (decryptNoNonce, encryptNoNonce)
import Codec.Encryption.OpenPGP.Serialize (getSecretKey)
import Codec.Encryption.OpenPGP.S2K (skesk2Key, string2Key)
import Control.Monad ((>=>))
import qualified Crypto.Hash.SHA1 as SHA1
import Crypto.Random (createEntropyPool, cprgCreate, cprgGenerateWithEntropy, SystemRNG)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Serialize (runGet, runPut, put)
import Data.Serialize.Get (getBytes, remaining, getWord16be)
import qualified Crypto.PubKey.RSA as R

decryptPrivateKey :: (PKPayload, SKAddendum) -> BL.ByteString -> SKAddendum
decryptPrivateKey (pkp, ska@(SUS16bit {})) pp = either (error "could not decrypt SUS16bit") id (decryptSKA (pkp, ska) pp)
decryptPrivateKey (pkp, ska@(SUSSHA1 {})) pp = either (error "could not decrypt SUSSHA1") id (decryptSKA (pkp, ska) pp)
decryptPrivateKey (_, SUSym {}) _ = error "SUSym key decryption not implemented"
decryptPrivateKey (_, ska@(SUUnencrypted {})) _ = ska

decryptSKA :: (PKPayload, SKAddendum) -> BL.ByteString -> Either String SKAddendum
decryptSKA (pkp, SUS16bit sa s2k iv payload) pp = do
    let key = skesk2Key (SKESK 4 sa s2k Nothing) pp
    p <- decryptNoNonce sa iv payload key
    (s, cksum) <- runGet (getSecretKey pkp >>= \sk -> getWord16be >>= \csum -> return (sk, csum)) p  -- FIXME: check the 16bit hash
    let checksum = cksum
    return $ SUUnencrypted s checksum  -- FIXME: is this the correct checksum?
decryptSKA (pkp, SUSSHA1 sa s2k iv payload) pp = do
    let key = skesk2Key (SKESK 4 sa s2k Nothing) pp
    p <- decryptNoNonce sa iv payload key
    (s, cksum) <- runGet (getSecretKey pkp >>= \sk -> remaining >>= (getBytes >=> \csum -> return (sk, csum))) p  -- FIXME: check the SHA1 hash
    let checksum = sum . map fromIntegral . B.unpack . B.take (B.length p - 20) $ p
    return $ SUUnencrypted s checksum  -- FIXME: is this the correct checksum?
decryptSKA _ _ = fail "Unexpected codepath"

-- |generates pseudo-random salt and IV
encryptPrivateKeyIO :: SKAddendum -> BL.ByteString -> IO SKAddendum
encryptPrivateKeyIO ska pp = saltiv >>= \(s,i) -> return (encryptPrivateKey s i ska pp)
    where
        saltiv = do
                    ep <- createEntropyPool
                    let gen = cprgCreate ep :: SystemRNG
		        bb = fst (cprgGenerateWithEntropy (8 + saBlockSize AES256) gen)
		    return $ B.splitAt 8 bb

-- |8-octet salt, IV must be length of cipher blocksize
encryptPrivateKey :: B.ByteString -> IV -> SKAddendum -> BL.ByteString -> SKAddendum
encryptPrivateKey _ _ ska@(SUS16bit {}) _ = ska
encryptPrivateKey _ _ ska@(SUSSHA1 {}) _ = ska
encryptPrivateKey _ _ ska@(SUSym {}) _ = ska
encryptPrivateKey salt iv (SUUnencrypted skey _) pp = SUSSHA1 AES256 s2k iv (encryptSKey skey s2k iv pp)
    where
       s2k = IteratedSalted SHA512 salt 12058624

encryptSKey :: SKey -> S2K -> IV -> BL.ByteString -> B.ByteString
encryptSKey (RSAPrivateKey (R.PrivateKey _ d p q _ _ _)) s2k iv pp = either error id (encryptNoNonce AES256 s2k iv payload key)
    where
        key = string2Key s2k (keySize AES256) pp
        algospecific = runPut $ put (MPI d) >> put (MPI p) >> put (MPI q) >> put (MPI u)
	cksum = SHA1.hash algospecific
	payload = algospecific `B.append` cksum
	u = inverse q p
encryptSKey _ _ _ _ = error "Non-RSA keytypes not handled yet" -- FIXME: do DSA and ElGamal

inverse :: Integral a => a -> a -> a
inverse _ 1 = 1
inverse q p = (n * q + 1) `div` p
    where n = p - inverse p (q `mod` p)

reencryptSecretKeyIO :: SecretKey -> BL.ByteString -> IO SecretKey
reencryptSecretKeyIO sk pp = encryptPrivateKeyIO (_secretKeySKAddendum sk) pp >>= \n -> return sk { _secretKeySKAddendum = n }