-- Filter.hs: OpenPGP (RFC4880) packet filtering
-- Copyright © 2014  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).

{-# LANGUAGE GADTs, RecordWildCards #-}

module Data.Conduit.OpenPGP.Filter (
   conduitPktFilter
 , conduitTKFilter
 , FilterPredicates(..)
 , Expr(..)
 , PKPVar(..)
 , PKPOp(..)
 , PKPValue(..)
 , SPVar(..)
 , SPOp(..)
 , SPValue(..)
 , OVar(..)
 , OOp(..)
 , OValue(..)
 , UPredicate(..)
 , UOp(..)
 , Exp(..)
 , unop
 , binop
) where

import Control.Applicative (Applicative, (<$>), (<*>), pure)
import Control.Error.Util (hush)
import Control.Monad ((>=>))
import Control.Monad.Loops (allM, anyM)
import Control.Monad.Reader (ask, reader, runReader, Reader)
import Control.Monad.Trans.Resource (MonadResource)
import qualified Data.ByteString as B
import Data.Conduit
import qualified Data.Conduit.List as CL
import Data.Maybe (fromMaybe)
import Data.Serialize (runPut, put)

import Codec.Encryption.OpenPGP.Fingerprint (eightOctetKeyID, fingerprint)
import Codec.Encryption.OpenPGP.Internal (sigType, sigPKA, sigHA)
import Codec.Encryption.OpenPGP.KeyInfo (pubkeySize)
import Codec.Encryption.OpenPGP.Types


data FilterPredicates =
    UnifiedFilterPredicate (Expr UPredicate)  -- ^ "old"-style filter predicate, hopefully to be deprecated
  | TransitionalTKFP (Exp (Reader TK) Bool)   -- ^ a more flexible fp for transferable keys
  | TransitionalUFP (Exp (Reader Pkt) Bool)   -- ^ a more flexible fp for context-less packets

data Expr a = EAny
            | E a
            | EAnd (Expr a) (Expr a)
            | EOr (Expr a) (Expr a)
            | ENot (Expr a)

eval :: (a -> v -> Bool) -> Expr a -> v -> Bool
eval t e v = ev e
  where
        ev EAny = True
        ev (EAnd e1 e2) = ev e1 && ev e2
        ev (EOr e1 e2) =  ev e1 || ev e2
        ev (ENot e1) = (not . ev) e1
        ev (E e') = t e' v

data PKPOp = PKEquals | PKLessThan | PKGreaterThan
    deriving Enum

data PKPPredicate = PKPPredicate PKPVar PKPOp PKPValue

data PKPVar = PKPVVersion     -- ^ public key version
            | PKPVPKA         -- ^ public key algorithm
            | PKPVKeysize     -- ^ public key size (in bits)
            | PKPVTimestamp   -- ^ public key creation time
            | PKPVEOKI        -- ^ public key's eight-octet key ID
            | PKPVTOF         -- ^ public key's twenty-octet fingerprint

data PKPValue = PKPInt Int
              | PKPPKA PubKeyAlgorithm
              | PKPEOKI (Either String EightOctetKeyId)
              | PKPTOF TwentyOctetFingerprint
    deriving Eq

instance Ord PKPValue where
    compare i j = compare (pkvToInt i) (pkvToInt j)

pkvToInt (PKPInt i) = i
pkvToInt (PKPPKA i) = fromIntegral (fromFVal i)

data SPOp = SPEquals | SPLessThan | SPGreaterThan
    deriving Enum

data SPPredicate = SPPredicate SPVar SPOp SPValue

data SPVar = SPVVersion       -- ^ signature packet version
           | SPVSigType       -- ^ signature packet tyep
           | SPVPKA           -- ^ signature packet public key algorithm
           | SPVHA            -- ^ signature packet hash algorithm

data SPValue = SPInt Int
             | SPSigType SigType
             | SPPKA PubKeyAlgorithm
             | SPHA HashAlgorithm
    deriving Eq

instance Ord SPValue where
    compare i j = compare (spvToInt i) (spvToInt j)

spvToInt (SPInt i) = i
spvToInt (SPSigType i) = fromIntegral (fromFVal i)
spvToInt (SPPKA i) = fromIntegral (fromFVal i)
spvToInt (SPHA i) = fromIntegral (fromFVal i)

data OOp = OEquals | OLessThan | OGreaterThan
    deriving Enum

data OPredicate = OPredicate OVar OOp OValue

data OVar = OVTag    -- ^ OpenPGP packet tag
          | OVLength -- ^ packet length (length of what, though?)

data OValue = OInt Int
            | OInteger Integer
    deriving Eq

instance Ord OValue where
    compare i j = compare (ovToInteger i) (ovToInteger j)

ovToInteger (OInt i) = fromIntegral i
ovToInteger (OInteger i) = i

data UPredicate = UPKPP PKPVar UOp PKPValue
                | USPP SPVar UOp SPValue
                | UOP OVar UOp OValue

data UOp = UEquals       -- ^ (==)
         | ULessThan     -- ^ (<)
         | UGreaterThan  -- ^ (>)
    deriving Enum

conduitPktFilter :: Monad m => FilterPredicates -> Conduit Pkt m Pkt
conduitPktFilter = CL.filter . superPredicate

superPredicate :: FilterPredicates -> Pkt -> Bool
superPredicate (UnifiedFilterPredicate ufp) p = eval uEval ufp p
superPredicate (TransitionalUFP e) p = runReader (evalM e) p
superPredicate _ _ = False   -- do not match incorrect type of packet

conduitTKFilter :: Monad m => FilterPredicates -> Conduit TK m TK
conduitTKFilter = CL.filter . superTKPredicate

superTKPredicate :: FilterPredicates -> TK -> Bool
superTKPredicate (UnifiedFilterPredicate ufp) p = eval uEval ufp (PublicKeyPkt (fst (_tkKey p)))  -- FIXME: should operate on more than just the pkp
superTKPredicate (TransitionalTKFP e) k = runReader (evalM e) k

pkpEval :: PKPPredicate -> PKPayload -> Bool
pkpEval (PKPPredicate lhs o rhs) pkp = uncurry (opreduce o) (vreduce (lhs,pkp),rhs)
    where
        opreduce PKEquals = (==)
        opreduce PKLessThan = (<)
        opreduce PKGreaterThan = (>)
        vreduce (PKPVVersion, p) = PKPInt (kv (_keyVersion p))
        vreduce (PKPVPKA, p) = PKPPKA (_pkalgo p)
        vreduce (PKPVKeysize, p) = PKPInt (fromMaybe 0 . hush . pubkeySize . _pubkey $ p) -- FIXME: a Left here should invalidate the predicate or something
        vreduce (PKPVTimestamp, p) = PKPInt (fromIntegral (_timestamp p))
        vreduce (PKPVEOKI, p) = PKPEOKI (eightOctetKeyID p)
        vreduce (PKPVTOF, p) = PKPTOF (fingerprint p)
	kv DeprecatedV3 = 3
	kv V4 = 4

spEval :: SPPredicate -> SignaturePayload -> Bool
spEval (SPPredicate lhs o rhs) pkp = case vreduce (lhs, pkp) >>= \x -> return (uncurry (opreduce o) (x,rhs)) of
                                         Just True -> True
                                         _ -> False
    where
        opreduce SPEquals = (==)
        opreduce SPLessThan = (<)
        opreduce SPGreaterThan = (>)
        vreduce (SPVVersion, s) = Just (SPInt (sigVersion s))
        vreduce (SPVSigType, s) = fmap SPSigType (sigType s)
        vreduce (SPVPKA, s) = fmap SPPKA (sigPKA s)
        vreduce (SPVHA, s) = fmap SPHA (sigHA s)
	sigVersion (SigV3 {}) = 3
	sigVersion (SigV4 {}) = 4
	sigVersion (SigVOther v _) = fromIntegral v

oEval :: OPredicate -> Pkt -> Bool
oEval (OPredicate lhs o rhs) pkp = uncurry (opreduce o) (vreduce (lhs,pkp),rhs)
    where
        opreduce OEquals = (==)
        opreduce OLessThan = (<)
        opreduce OGreaterThan = (>)
        vreduce (OVTag, p) = OInteger (fromIntegral (pktTag p))
        vreduce (OVLength, p) = OInteger (fromIntegral (B.length (runPut $ put p)))  -- FIXME: this should be a length that makes sense

uEval :: UPredicate -> Pkt -> Bool
uEval (UPKPP l o r) (PublicKeyPkt p) = pkpEval (PKPPredicate l (toEnum . fromEnum $ o) r)  p
uEval (USPP l o r) (SignaturePkt s) = spEval (SPPredicate l (toEnum . fromEnum $ o) r)  s
uEval (UOP l o r) pkt = oEval (OPredicate l (toEnum . fromEnum $ o) r) pkt
uEval _ _ = False  -- do not match packets of wrong type

--

data Exp m a where
    I       :: Integer -> Exp m Integer
    B       :: Bool -> Exp m Bool
    S       :: String -> Exp m String
    Lift    :: b -> Exp m b
    Ap      :: Exp m (b -> c) -> Exp m b -> Exp m c
    AnyAll  :: ((b -> m Bool) -> [b] -> m Bool) -> (b -> Exp m Bool) -> Exp m [b] -> Exp m Bool
    MA      :: m b -> Exp m b

evalM :: (Functor m, Applicative m, Monad m) => Exp m a -> m a
evalM (I n) = return n
evalM (B b) = return b
evalM (S s) = return s
evalM (Lift l) = return l
evalM (MA a) = a
evalM (Ap f a) = evalM f <*> evalM a
evalM (AnyAll aa f l) = evalM l >>= (aa (evalM . f) >=> return)

unop :: (a -> b) -> Exp m a -> Exp m b
unop = Ap . Lift

binop :: (a -> a -> b) -> Exp m a -> Exp m a -> Exp m b
binop = (Ap .) . Ap . Lift