-- Signatures.hs: OpenPGP (RFC4880) signature verification
-- Copyright © 2012-2014  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).

module Codec.Encryption.OpenPGP.Signatures (
   verifySigWith
 , verifyAgainstKeyring
 , verifyAgainstKeys
 , verifyTKWith
) where

import Control.Lens ((^.), _1)
import Control.Monad (liftM2)

import Crypto.PubKey.HashDescr (HashDescr(..))
import qualified Crypto.PubKey.DSA as DSA
import qualified Crypto.PubKey.RSA.PKCS15 as P15

import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Either (lefts, rights)
import Data.IxSet ((@=))
import qualified Data.IxSet as IxSet
import Data.Time.Clock (UTCTime(..), diffUTCTime)
import Data.Time.Clock.POSIX (posixSecondsToUTCTime)
import Data.Serialize.Put (runPut)

import Codec.Encryption.OpenPGP.Fingerprint (eightOctetKeyID, fingerprint)
import Codec.Encryption.OpenPGP.Internal (countBits, integerToBEBS, PktStreamContext(..), issuer, emptyPSC, hashDescr)
import Codec.Encryption.OpenPGP.SerializeForSigs (putPartialSigforSigning, putSigTrailer, payloadForSig)
import Codec.Encryption.OpenPGP.Types
import Data.Conduit.OpenPGP.Keyring.Instances ()

verifySigWith :: (Pkt -> Maybe UTCTime -> ByteString -> Either String Verification) -> Pkt -> PktStreamContext -> Maybe UTCTime -> Either String Verification -- FIXME: check expiration here?
verifySigWith vf sig@(SignaturePkt (SigV4 st _ _ hs _ _ _)) state mt = do
    v <- vf sig mt (payloadForSig st state)
    _ <- mapM_ (checkIssuer (eightOctetKeyID (v^.verificationSigner)) . _sspPayload) hs
    return v
    where
        checkIssuer :: EightOctetKeyId -> SigSubPacketPayload -> Either String Bool
        checkIssuer signer (Issuer i) = if signer == i then Right True else Left "issuer subpacket does not match"
        checkIssuer _ _ = Right True
verifySigWith _ _ _ _ = Left "This should never happen (verifySigWith)."

verifyTKWith :: (Pkt -> PktStreamContext -> Maybe UTCTime -> Either String Verification) -> Maybe UTCTime -> TK -> Either String TK
verifyTKWith vsf mt key = do
    revokers <- checkRevokers key
    revs <- checkKeyRevocations revokers key
    let uids = filter (\(_, sps) -> sps /= []) . checkUidSigs $ key^.tkUIDs -- FIXME: check revocations here?
    let uats = filter (\(_, sps) -> sps /= []) . checkUAtSigs $ key^.tkUAts -- FIXME: check revocations here?
    let subs = concatMap checkSub $ key^.tkSubs -- FIXME: check revocations here?
    return (TK (key^.tkKey) revs uids uats subs)
    where
        checkRevokers = Right . concat . rights . map verifyRevoker . filter isRevokerP . _tkRevs
        checkKeyRevocations :: [(PubKeyAlgorithm, TwentyOctetFingerprint)] -> TK -> Either String [SignaturePayload]
        checkKeyRevocations rs k = Prelude.sequence . concatMap (filterRevs rs) . rights . map (liftM2 fmap (,) vSig) $ k^.tkRevs
        checkUidSigs :: [(String, [SignaturePayload])] -> [(String, [SignaturePayload])]
        checkUidSigs = map (\(uid, sps) -> (uid, (rights . map (\sp -> fmap (const sp) (vUid (uid, sp)))) sps))
        checkUAtSigs :: [([UserAttrSubPacket], [SignaturePayload])] -> [([UserAttrSubPacket], [SignaturePayload])]
        checkUAtSigs = map (\(uat, sps) -> (uat, (rights . map (\sp -> fmap (const sp) (vUAt (uat, sp)))) sps))
        checkSub :: (Pkt, [SignaturePayload]) -> [(Pkt, [SignaturePayload])]
        checkSub (pkt, sps) = if revokedSub pkt sps then [] else checkSub' pkt sps
        revokedSub :: Pkt -> [SignaturePayload] -> Bool
        revokedSub _ [] = False
        revokedSub p sigs = any (vSubSig p) (filter isSubkeyRevocation sigs)
        checkSub' :: Pkt -> [SignaturePayload] -> [(Pkt, [SignaturePayload])]
        checkSub' p sps = let goodsigs = filter (vSubSig p) (filter isSubkeyBindingSig sps) in if null goodsigs then [] else [(p, goodsigs)]
        getHasheds (SigV4 _ _ _ ha _ _ _) = ha
        getHasheds _ = []
	filterRevs :: [(PubKeyAlgorithm, TwentyOctetFingerprint)] -> (SignaturePayload, Verification) -> [Either String SignaturePayload]
	filterRevs vokers spv = case spv of
                                     (s@(SigV4 SignatureDirectlyOnAKey _ _ _ _ _ _), _) -> [Right s]
                                     (s@(SigV4 KeyRevocationSig pka _ _ _ _ _), v) -> if (v^.verificationSigner == key ^. tkKey._1) || any (\(p,f) -> p == pka && f == fingerprint (v^.verificationSigner)) vokers then [Left "Key revoked"] else [Right s]
				     _ -> []
        isKeyRevocation (SigV4 KeyRevocationSig _ _ _ _ _ _) = True
        isKeyRevocation _ = False
        isSubkeyRevocation (SigV4 SubkeyRevocationSig _ _ _ _ _ _) = True
        isSubkeyRevocation _ = False
        isSubkeyBindingSig (SigV4 SubkeyBindingSig _ _ _ _ _ _) = True
        isSubkeyBindingSig _ = False
        isRevokerP (SigV4 SignatureDirectlyOnAKey _ _ h u _ _) = any isRevocationKeySSP h && any isIssuerSSP u
        isRevokerP _ = False
        isRevocationKeySSP (SigSubPacket _ (RevocationKey {})) = True
        isRevocationKeySSP _ = False
        isIssuerSSP (SigSubPacket _ (Issuer _)) = True
        isIssuerSSP _ = False
        vUid :: (String, SignaturePayload) -> Either String Verification
        vUid (uid, sp) = vsf (SignaturePkt sp) emptyPSC { lastPrimaryKey = PublicKeyPkt (key ^. tkKey._1), lastUIDorUAt = UserIdPkt uid } mt
        vUAt :: ([UserAttrSubPacket], SignaturePayload) -> Either String Verification
        vUAt (uat, sp) = vsf (SignaturePkt sp) emptyPSC { lastPrimaryKey = PublicKeyPkt (key ^. tkKey._1), lastUIDorUAt = UserAttributePkt uat } mt
        vSig :: SignaturePayload -> Either String Verification
        vSig sp = vsf (SignaturePkt sp) emptyPSC { lastPrimaryKey = PublicKeyPkt (key ^. tkKey._1) } mt
        vSubSig :: Pkt -> SignaturePayload -> Bool
        vSubSig sk sp = case vsf (SignaturePkt sp) emptyPSC { lastPrimaryKey = PublicKeyPkt (key ^. tkKey._1), lastSubkey = sk} mt of
                                Left _ -> False
				Right _ -> True
        verifyRevoker :: SignaturePayload -> Either String [(PubKeyAlgorithm, TwentyOctetFingerprint)]
        verifyRevoker sp = do
            _ <- vSig sp
            return (map (\(SigSubPacket _ (RevocationKey _ pka fp)) -> (pka, fp)) . filter isRevocationKeySSP $ getHasheds sp)

verifyAgainstKeyring :: Keyring -> Pkt -> Maybe UTCTime -> ByteString -> Either String Verification
verifyAgainstKeyring kr sig mt payload = do
    i <- maybe (Left "issuer not found") Right (issuer sig)
    potentialmatches <- if IxSet.null (kr @= i) then Left "pubkey not found" else Right (kr @= i)
    verifyAgainstKeys (IxSet.toList potentialmatches) sig mt payload

verifyAgainstKeys :: [TK] -> Pkt -> Maybe UTCTime -> ByteString -> Either String Verification
verifyAgainstKeys ks sig mt payload = do
    let allrelevantpkps = filter (\x -> issuer sig == Just (eightOctetKeyID x)) (concatMap (\x -> (x ^. tkKey._1):map subPKP (_tkSubs x)) ks)
    let results = map (\pkp -> verify' sig pkp (hashalgo sig) (finalPayload sig payload)) allrelevantpkps
    case rights results of
        [] -> Left (concatMap (++"/") (lefts results))
        [r] -> do _ <- isSignatureExpired sig mt
	          return (Verification r ((_signaturePayload . fromPkt) sig)) -- FIXME: this should also check expiration time and flags of the signing key
        _ -> Left "multiple successes; unexpected condition"
    where
        subPKP (pack, _) = subPKP' pack
        subPKP' (PublicSubkeyPkt p) = p
        subPKP' (SecretSubkeyPkt p _) = p
        verify' (SignaturePkt s) (pub@(PKPayload V4 _ _ _ pkey)) ha pl = verify'' (pkaAndMPIs s) ha pub pkey pl
        verify' _ _ _ _ = error "This should never happen (verify')."
        verify'' (DSA,mpis) ha pub (DSAPubKey pkey) bs = verify''' (dsaVerify mpis (hashDescr ha) pkey bs) pub
        verify'' (RSA,mpis) ha pub (RSAPubKey pkey) bs = verify''' (rsaVerify mpis (hashDescr ha) pkey bs) pub
        verify'' _ _ _ _ _ = Left "unimplemented key type"
	verify''' f pub = if f then Right pub else Left "verification failed"
	dsaVerify (r:s:[]) (Right hd) pkey = DSA.verify (dsaTruncate pkey . hashFunction hd) pkey (dsaMPIsToSig r s)
	dsaVerify _ (Right _) _ = const False -- FIXME: this should be some sort of Either chain
	dsaVerify _ (Left _) _ = const False  -- FIXME: this should be some sort of Either chain
	rsaVerify mpis (Right hd) pkey bs = P15.verify hd pkey bs (rsaMPItoSig mpis)
	rsaVerify _ (Left _) _ _ = False  -- FIXME: this should be some sort of Either chain
        dsaMPIsToSig r s = DSA.Signature (unMPI r) (unMPI s)
        rsaMPItoSig mpis = integerToBEBS (unMPI (head mpis))
        hashalgo :: Pkt -> HashAlgorithm
        hashalgo (SignaturePkt (SigV4 _ _ ha _ _ _ _)) = ha
        hashalgo _ = error "This should never happen (hashalgo)."
        dsaTruncate pkey bs = if countBits bs > dsaQLen pkey then B.take (fromIntegral (dsaQLen pkey) `div` 8) bs else bs -- FIXME: uneven bits
        dsaQLen = countBits . integerToBEBS . DSA.params_q . DSA.public_params
	pkaAndMPIs (SigV4 _ pka _ _ _ _ mpis) = (pka,mpis)
	pkaAndMPIs _ = error "This should never happen (pkaAndMPIs)."
        isSignatureExpired :: Pkt -> Maybe UTCTime -> Either String Bool
        isSignatureExpired s Nothing = return False
        isSignatureExpired s (Just t) = if any (expiredBefore t) ((\(SigV4 _ _ _ h _ _ _) -> h) . _signaturePayload . fromPkt $ s) then Left "signature expired" else return True
        expiredBefore :: UTCTime -> SigSubPacket -> Bool
        expiredBefore ct (SigSubPacket _ (SigExpirationTime et)) = fromEnum ((posixSecondsToUTCTime . toEnum . fromEnum) et `diffUTCTime` ct) < 0
        expiredBefore _ _ = False

finalPayload :: Pkt -> ByteString -> ByteString
finalPayload s pl = B.concat [pl, sigbit s, trailer s]
    where
        sigbit s = runPut $ putPartialSigforSigning s
        trailer :: Pkt -> ByteString
        trailer s@(SignaturePkt (SigV4 {})) = runPut $ putSigTrailer s
        trailer _ = B.empty