-- Verify.hs: OpenPGP (RFC4880) signature verification
-- Copyright Ⓒ 2012  Clint Adams
-- This software is released under the terms of the Expat (MIT) license.
-- (See the LICENSE file).

module Data.Conduit.OpenPGP.Verify (
   conduitVerify
) where

import qualified Crypto.Cipher.DSA as DSA
import qualified Crypto.Cipher.RSA as RSA

import qualified Crypto.Hash.RIPEMD160 as RIPEMD160
import qualified Crypto.Hash.SHA1 as SHA1
import qualified Crypto.Hash.SHA224 as SHA224
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Crypto.Hash.SHA384 as SHA384
import qualified Crypto.Hash.SHA512 as SHA512

import qualified Data.ASN1.DER as DER
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC8
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import Data.List (find)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Serialize.Put (runPut)

import Codec.Encryption.OpenPGP.Fingerprint (eightOctetKeyID)
import Codec.Encryption.OpenPGP.Internal (countBits, integerToBEBS)
import Codec.Encryption.OpenPGP.SerializeForSigs (putPartialSigforSigning, putSigTrailer, putKeyforSigning, putSigforSigning)
import Codec.Encryption.OpenPGP.Types

data StreamState = StreamState { lastLD :: Packet
                               , lastUIDorUAt :: Packet
                               , lastSig :: Packet
                               , lastPrimaryKey :: Packet
                               , lastSubkey :: Packet
                               }

conduitVerify :: MonadResource m => Keyring -> Conduit Packet m (Either String Bool)
conduitVerify kr = conduitState (StreamState (Marker B.empty) (Marker B.empty) (Marker B.empty) (Marker B.empty) (Marker B.empty)) push close
    where
        push state ld@(LiteralData _ _ _ _) = return $ StateProducing (state { lastLD = ld }) []
        push state uid@(UserId _) = return $ StateProducing (state { lastUIDorUAt = uid }) []
        push state uat@(UserAttribute _) = return $ StateProducing (state { lastUIDorUAt = uat }) []
        push state pk@(PublicKey _) = return $ StateProducing (state { lastPrimaryKey = pk }) []
        push state pk@(PublicSubkey _) = return $ StateProducing (state { lastSubkey = pk }) []
        push state sk@(SecretKey _ _) = return $ StateProducing (state { lastPrimaryKey = sk }) []
        push state sk@(SecretSubkey _ _) = return $ StateProducing (state { lastSubkey = sk }) []
        push state sig@(Signature (SigV4 _ _ _ _ _ _ _)) = return $ StateProducing state { lastSig = sig } [verifySig kr sig state]
        push state (OnePassSignature pv st ha pka eok False) = return $ StateProducing state []
        push state input = return $ StateProducing state []
        close state = return []
        normLineEndings = id  -- FIXME

verifySig kr sig@(Signature (SigV4 BinarySig pka ha hsubs usubs left16 mpis)) state = verify kr sig (lastLD state)
verifySig kr sig@(Signature (SigV4 CanonicalTextSig pka ha hsubs usubs left16 mpis)) state = verify kr sig (lastLD state)
verifySig kr sig@(Signature (SigV4 StandaloneSig pka ha hsubs usubs left16 mpis)) state = verify kr sig (Marker B.empty)
verifySig kr sig@(Signature (SigV4 GenericCert pka ha hsubs usubs left16 mpis)) state = verify kr sig (lastUIDorUAt state)
verifySig kr sig@(Signature (SigV4 PersonaCert pka ha hsubs usubs left16 mpis)) state = verify kr sig (lastUIDorUAt state)
verifySig kr sig@(Signature (SigV4 CasualCert pka ha hsubs usubs left16 mpis)) state = verify kr sig (lastUIDorUAt state)
verifySig kr sig@(Signature (SigV4 PositiveCert pka ha hsubs usubs left16 mpis)) state = verify kr sig (lastUIDorUAt state)
verifySig kr sig@(Signature (SigV4 SubkeyBindingSig pka ha hsubs usubs left16 mpis)) state = Left "FIXME: SubkeyBindingSig"
verifySig kr sig@(Signature (SigV4 PrimaryKeyBindingSig pka ha hsubs usubs left16 mpis)) state = Left "FIXME: PrimaryKeyBindingSig"
verifySig kr sig@(Signature (SigV4 SignatureDirectlyOnAKey pka ha hsubs usubs left16 mpis)) state = verify kr sig (lastPrimaryKey state)
verifySig kr sig@(Signature (SigV4 KeyRevocationSig pka ha hsubs usubs left16 mpis)) state = Left "FIXME: KeyRevocationSig"
verifySig kr sig@(Signature (SigV4 SubkeyRevocationSig pka ha hsubs usubs left16 mpis)) state = Left "FIXME: SubkeyRevocationSig"
verifySig kr sig@(Signature (SigV4 CertRevocationSig pka ha hsubs usubs left16 mpis)) state = verify kr sig (lastSig state)
verifySig kr sig@(Signature (SigV4 st pka ha hsubs usubs left16 mpis)) state = Left ("I dunno how to " ++ show st)

verify kr sig (LiteralData dt fn ts bs) = go kr sig bs
verify kr sig (UserId uid) = go kr sig (B.singleton 0xB4 `B.append` BC8.pack uid)
verify kr sig (UserAttribute uats) = go kr sig (B.empty) -- FIXME: proper serialization of uats
verify kr sig (Marker _) = go kr sig (B.empty) -- fake for standalone sig
verify kr sig p@(PublicKey _) = go kr sig (runPut $ putKeyforSigning p)
verify kr sig p@(PublicSubkey _) = go kr sig (runPut $ putKeyforSigning p)
verify kr sig p@(SecretKey _ _) = go kr sig (runPut $ putKeyforSigning p)
verify kr sig p@(SecretSubkey _ _) = go kr sig (runPut $ putKeyforSigning p)
verify kr sig s@(Signature _) = go kr sig (runPut $ putSigforSigning s)
verify kr sig p = Left $ "So confused..." ++ show p

go kr sig payload = case issuer sig of
                            Nothing -> Left "pubkey not found"
                            Just i -> verify' (Map.lookup i kr) sig payload
    where
        verify' Nothing _ _ = Left "pubkey not found"
        verify' (Just tpk) sig payload = verify'' sig tpk (hashalgo sig) (finalPayload sig payload)
        verify'' (Signature (SigV4 _ pka _ _ _ _ mpis)) (TPK (PubV4 _ _ pkey) _ _ _ _) hashalgo pl = verify''' pka hashalgo mpis pkey pl
        verify''' DSA hashalgo mpis (DSAPubKey pkey) bs = case DSA.verify (dsaMPIsToSig mpis) (dsaTruncate pkey . (hash hashalgo)) pkey bs of
                                                             Left err -> Left "invalid signature"
                                                             Right False -> Left $ "verification failed"
                                                             Right True -> Right True
        verify''' RSA hashalgo mpis (RSAPubKey pkey) bs = case RSA.verify (hash hashalgo) (asn1Prefix hashalgo) pkey bs (rsaMPItoSig mpis) of
                                                             Left err -> Left "invalid signature"
                                                             Right False -> Left $ "verification failed"
                                                             Right True -> Right True
        dsaMPIsToSig mpis = (unMPI (mpis !! 0), unMPI (mpis !! 1))
        rsaMPItoSig mpis = integerToBEBS (unMPI (mpis !! 0))
        finalPayload sig payload = payload `B.append` sigbit sig `B.append` trailer sig
        sigbit sig = runPut $ putPartialSigforSigning sig
        hashalgo (Signature (SigV4 _ _ ha _ _ _ _)) = ha
        trailer sig@(Signature (SigV4 _ _ _ hs _ _ _)) = runPut $ putSigTrailer sig
        dsaTruncate pkey bs = if countBits bs > dsaQLen pkey then B.take (fromIntegral (dsaQLen pkey) `div` 8) bs else bs -- FIXME: uneven bits
        dsaQLen pk = (\(x,y,z) -> countBits (integerToBEBS z)) (DSA.public_params pk)

issuer (Signature (SigV4 st pka ha hsubs usubs left16 mpis)) = fmap (\(Issuer _ i) -> i) (find (isIssuer) usubs)
    where
        isIssuer (Issuer _ _) = True
        isIssuer _ = False

hash SHA1 = SHA1.hash
hash RIPEMD160 = RIPEMD160.hash
hash SHA256 = SHA256.hash
hash SHA384 = SHA384.hash
hash SHA512 = SHA512.hash
hash SHA224 = SHA224.hash

--emsa_pkcs1_v1_5_encode :: HashAlgorithm -> ByteString -> Int -> Either String ByteString
--emsa_pkcs1_v1_5_encode ha m emLen = if emLen < tLen + 11 then
--                                        Left "intended encoded message length too short"
--                                    else
--                                        Right $ B.concat [header, ps, numpty, t]
--    where
--        t = asn1DigestInfo ha (hash ha m)
--        tLen = B.length t
--        header = B.pack [0,1]
--        ps = B.pack $ replicate (emLen - tLen - 3) 0xff
--        numpty = B.singleton 0

asn1Prefix :: HashAlgorithm -> ByteString
asn1Prefix ha = do
    let start = DER.Start DER.Sequence
    let (blen, oid) = (bitLength ha, hashOid ha)
    let numpty = DER.Null
    let end = DER.End DER.Sequence
    let fakeint = DER.OctetString (BL.pack (replicate ((blen `div` 8) - 1) 0 ++ [1]))
    case DER.encodeASN1Stream [start,start,oid,numpty,end,fakeint,end] of
        Left err -> error "encodeASN1 failure"
        Right l -> B.concat . BL.toChunks $ getPrefix l
    where
        getPrefix = BL.reverse . BL.drop 1 . BL.dropWhile (==0) . BL.reverse
        bitLength DeprecatedMD5 = 128
        bitLength SHA1 = 160
        bitLength RIPEMD160 = 160
        bitLength SHA256 = 256
        bitLength SHA384 = 384
        bitLength SHA512 = 512
        bitLength SHA224 = 224
        hashOid DeprecatedMD5 = DER.OID [1,2,840,113549,2,5]
        hashOid RIPEMD160 = DER.OID [1,3,36,3,2,1]
        hashOid SHA1 = DER.OID [1,3,14,3,2,26]
        hashOid SHA224 = DER.OID [2,16,840,1,101,3,4,2,4]
        hashOid SHA256 = DER.OID [2,16,840,1,101,3,4,2,1]
        hashOid SHA384 = DER.OID [2,16,840,1,101,3,4,2,2]
        hashOid SHA512 = DER.OID [2,16,840,1,101,3,4,2,3]