-- Verify.hs: OpenPGP (RFC4880) signature verification
-- Copyright © 2012  Clint Adams
-- This software is released under the terms of the ISC license.
-- (See the LICENSE file).

module Codec.Encryption.OpenPGP.Verify (
   verifySig
 , verify
 , verifyTK
) where

import Control.Monad (guard, liftM2)

import qualified Crypto.Cipher.DSA as DSA
import qualified Crypto.Cipher.RSA as RSA

import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Either (lefts, rights)
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Time.Clock (UTCTime(..), diffUTCTime)
import Data.Time.Clock.POSIX (posixSecondsToUTCTime)
import Data.Serialize.Put (runPut)

import Codec.Encryption.OpenPGP.Fingerprint (eightOctetKeyID, fingerprint)
import Codec.Encryption.OpenPGP.Internal (countBits, integerToBEBS, PktStreamContext(..), hash, issuer, emptyPSC, asn1Prefix)
import Codec.Encryption.OpenPGP.SerializeForSigs (putPartialSigforSigning, putSigTrailer, payloadForSig)
import Codec.Encryption.OpenPGP.Types

verifySig :: Keyring -> Pkt -> PktStreamContext -> Maybe UTCTime -> Either String Verification -- FIXME: check expiration here?
verifySig kr sig@(SignaturePkt (SigV4 st _ _ hs _ _ _)) state mt = do
    v <- verify kr sig mt (payloadForSig st state)
    _ <- mapM_ (checkIssuer (eightOctetKeyID (verificationSigner v)) . sspPayload) hs
    return v
    where
        checkIssuer :: EightOctetKeyId -> SigSubPacketPayload -> Either String Bool
        checkIssuer signer (Issuer i) = if signer == i then Right True else Left "issuer subpacket does not match"
        checkIssuer _ _ = Right True
verifySig _ _ _ _ = Left "This should never happen."

verifyTK :: Keyring -> Maybe UTCTime -> TK -> Either String TK
verifyTK kr mt key = do
    revokers <- checkRevokers key
    revs <- checkKeyRevocations revokers key
    let uids = filter (\(_, sps) -> sps == []) . checkUidSigs $ tkUIDs key -- FIXME: check revocations here?
    let uats = filter (\(_, sps) -> sps == []) . checkUAtSigs $ tkUAts key -- FIXME: check revocations here?
    let subs = concatMap checkSub $ tkSubs key -- FIXME: check revocations here?
    return (TK (tkPKP key) (tkmSKA key) revs uids uats subs)
    where
        checkRevokers = Right . concat . rights . map verifyRevoker . filter isRevokerP . tkRevs
        checkKeyRevocations :: [(PubKeyAlgorithm, TwentyOctetFingerprint)] -> TK -> Either String [SignaturePayload]
        checkKeyRevocations rs k = Prelude.sequence . concatMap (filterRevs rs) . rights . map (liftM2 fmap (,) (vSig kr)) . tkRevs $ k
        checkUidSigs :: [(String, [SignaturePayload])] -> [(String, [SignaturePayload])]
        checkUidSigs = map (\(uid, sps) -> (uid, (rights . map (\sp -> fmap (const sp) (vUid kr (uid, sp)))) sps))
        checkUAtSigs :: [([UserAttrSubPacket], [SignaturePayload])] -> [([UserAttrSubPacket], [SignaturePayload])]
        checkUAtSigs = map (\(uat, sps) -> (uat, (rights . map (\sp -> fmap (const sp) (vUAt kr (uat, sp)))) sps))
        checkSub :: (Pkt, SignaturePayload, Maybe SignaturePayload) -> [(Pkt, SignaturePayload, Maybe SignaturePayload)]
        checkSub (pkt, sp, mrp) = if revokedSub pkt mrp then [] else checkSub' pkt sp
        revokedSub :: Pkt -> Maybe SignaturePayload -> Bool
        revokedSub _ Nothing = False
        revokedSub p (Just rp) = vSubSig kr p rp
        checkSub' :: Pkt -> SignaturePayload -> [(Pkt, SignaturePayload, Maybe SignaturePayload)]
        checkSub' p sp = guard (vSubSig kr p sp) >> return (p, sp, Nothing)
        getHasheds (SigV4 _ _ _ ha _ _ _) = ha
        getHasheds _ = []
	filterRevs :: [(PubKeyAlgorithm, TwentyOctetFingerprint)] -> (SignaturePayload, Verification) -> [Either String SignaturePayload]
	filterRevs vokers spv = case spv of
                                     (s@(SigV4 SignatureDirectlyOnAKey _ _ _ _ _ _), _) -> [Right s]
                                     (s@(SigV4 KeyRevocationSig pka _ _ _ _ _), v) -> if any (\(p,f) -> p == pka && f == fingerprint (verificationSigner v)) vokers then [Left "Key revoked"] else [Right s]
				     _ -> []
        isKeyRevocation (SigV4 KeyRevocationSig _ _ _ _ _ _) = True
        isKeyRevocation _ = False
        isRevokerP (SigV4 SignatureDirectlyOnAKey _ _ h u _ _) = any isRevocationKeySSP h && any isIssuerSSP u
        isRevokerP _ = False
        isRevocationKeySSP (SigSubPacket _ (RevocationKey {})) = True
        isRevocationKeySSP _ = False
        isIssuerSSP (SigSubPacket _ (Issuer _)) = True
        isIssuerSSP _ = False
        vUid :: Keyring -> (String, SignaturePayload) -> Either String Verification
        vUid keyring (uid, sp) = verifySig keyring (SignaturePkt sp) emptyPSC { lastPrimaryKey = PublicKeyPkt (tkPKP key), lastUIDorUAt = UserIdPkt uid } mt
        vUAt :: Keyring -> ([UserAttrSubPacket], SignaturePayload) -> Either String Verification
        vUAt keyring (uat, sp) = verifySig keyring (SignaturePkt sp) emptyPSC { lastPrimaryKey = PublicKeyPkt (tkPKP key), lastUIDorUAt = UserAttributePkt uat } mt
        vSig :: Keyring -> SignaturePayload -> Either String Verification
        vSig keyring sp = verifySig keyring (SignaturePkt sp) emptyPSC { lastPrimaryKey = PublicKeyPkt (tkPKP key) } mt
        vSubSig :: Keyring -> Pkt -> SignaturePayload -> Bool
        vSubSig keyring sk sp = case verifySig keyring (SignaturePkt sp) emptyPSC { lastPrimaryKey = PublicKeyPkt (tkPKP key), lastSubkey = sk} mt of
                                Left _ -> False
				Right _ -> True
        verifyRevoker :: SignaturePayload -> Either String [(PubKeyAlgorithm, TwentyOctetFingerprint)]
        verifyRevoker sp = do
            _ <- vSig kr sp
            return (map (\(SigSubPacket _ (RevocationKey _ pka fp)) -> (pka, fp)) . filter isRevocationKeySSP $ getHasheds sp)

verify :: Keyring -> Pkt -> Maybe UTCTime -> ByteString -> Either String Verification
verify kr sig mt payload = do
    i <- maybe (Left "issuer not found") Right (issuer sig)
    tpkset <- maybe (Left "pubkey not found") Right (Map.lookup i kr)
    let allrelevantpkps = filter (\x -> issuer sig == Just (eightOctetKeyID x)) (concatMap (\x -> tkPKP x:map subPKP (tkSubs x)) (Set.toAscList tpkset))
    let results = map (\pkp -> verify' sig pkp (hashalgo sig) (finalPayload sig payload)) allrelevantpkps
    case rights results of
        [] -> Left (concatMap (++"/") (lefts results))
        [r] -> do _ <- isSignatureExpired sig mt
	          return (Verification r ((signaturePayload . fromPkt) sig)) -- FIXME: this should also check expiration time and flags of the signing key
        _ -> Left "multiple successes; unexpected condition"
    where
        subPKP (pack, _, _) = subPKP' pack
        subPKP' (PublicSubkeyPkt p) = p
        subPKP' (SecretSubkeyPkt p _) = p
        verify' (SignaturePkt s) (pub@(PubV4 _ _ pkey)) ha pl = verify'' (pkaAndMPIs s) ha pub pkey pl
        verify' _ _ _ _ = error "This should never happen."
        verify'' (DSA,mpis) ha pub (DSAPubKey pkey) bs = verify''' (dsaVerify mpis ha pkey bs) pub
        verify'' (RSA,mpis) ha pub (RSAPubKey pkey) bs = verify''' (rsaVerify mpis ha pkey bs) pub
        verify'' _ _ _ _ _ = Left "unimplemented key type"
	verify''' f pub = case f of
                               Left _ -> Left "invalid signature"
                               Right False -> Left "verification failed"
                               Right True -> Right pub
	dsaVerify mpis ha pkey bs = DSA.verify (dsaMPIsToSig mpis) (dsaTruncate pkey . hash ha) pkey bs
	rsaVerify mpis ha pkey bs = RSA.verify (hash ha) (asn1Prefix ha) pkey bs (rsaMPItoSig mpis)
        dsaMPIsToSig mpis = (unMPI (mpis !! 0), unMPI (mpis !! 1))
        rsaMPItoSig mpis = integerToBEBS (unMPI (head mpis))
        finalPayload s pl = B.concat [pl, sigbit s, trailer s]
        sigbit s = runPut $ putPartialSigforSigning s
        hashalgo :: Pkt -> HashAlgorithm
        hashalgo (SignaturePkt (SigV4 _ _ ha _ _ _ _)) = ha
        hashalgo _ = error "This should never happen."
        trailer :: Pkt -> ByteString
        trailer s@(SignaturePkt (SigV4 {})) = runPut $ putSigTrailer s
        trailer _ = B.empty
        dsaTruncate pkey bs = if countBits bs > dsaQLen pkey then B.take (fromIntegral (dsaQLen pkey) `div` 8) bs else bs -- FIXME: uneven bits
        dsaQLen pk = (\(_,_,z) -> countBits (integerToBEBS z)) (DSA.public_params pk)
	pkaAndMPIs (SigV4 _ pka _ _ _ _ mpis) = (pka,mpis)
	pkaAndMPIs _ = error "This should never happen."
        isSignatureExpired :: Pkt -> Maybe UTCTime -> Either String Bool
        isSignatureExpired s Nothing = return False
        isSignatureExpired s (Just t) = if any (expiredBefore t) ((\(SigV4 _ _ _ h _ _ _) -> h) . signaturePayload . fromPkt $ s) then Left "signature expired" else return True
        expiredBefore :: UTCTime -> SigSubPacket -> Bool
        expiredBefore ct (SigSubPacket _ (SigExpirationTime et)) = fromEnum ((posixSecondsToUTCTime . toEnum . fromEnum) et `diffUTCTime` ct) < 0
        expiredBefore ct _ = False