{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE KindSignatures #-}
module Main (main) where

import Data.ByteArray (Bytes)
import qualified Data.ByteArray as B

import Control.Monad

import Data.Maybe (isJust, fromJust)
import Data.Proxy

import GHC.IO.Exception (IOErrorType(..))

import System.Directory (doesFileExist)
import System.IO.Error (catchIOError, mkIOError)
import System.Process (readProcess)

#ifdef ML_KEM_TESTING

import Data.Bits
import Data.Word

import GHC.TypeNats

import Auxiliary
import BlockN (BlockN)
import Builder (Builder)
import Marking (Leak(..), SecurityMarking(..))
import Math
import Matrix
import Vector
import qualified BlockN
import qualified Builder
#endif

import Crypto.PubKey.ML_KEM as Lib

import Test.Tasty
import Test.Tasty.HUnit
import Test.Tasty.QuickCheck

import qualified EncapDecap
import qualified KeyGen
import qualified Vectors

#ifdef ML_KEM_TESTING

newtype Bit7 = Bit7 Word8 deriving Show

instance Arbitrary Bit7 where
#if (MIN_VERSION_tasty_quickcheck(0,10,2))
    arbitrary = Bit7 <$> chooseBoundedIntegral (0, 127)
#else
    arbitrary = Bit7 <$> choose (0, 127)
#endif

newtype FE = FE { unFE :: Zq} deriving Show

instance Arbitrary FE where
#if (MIN_VERSION_tasty_quickcheck(0,10,2))
    arbitrary = FE . toZq <$> chooseBoundedIntegral (0, 3328)
#else
    arbitrary = FE . toZq <$> choose (0, 3328)
#endif

newtype Poly = Poly (Rq Sec) deriving Show

instance Arbitrary Poly where
    arbitrary = do
      coeffs <- map unFE <$> vectorOf 256 arbitrary
      let a = fromJust (fromCoeffs coeffs)
      return (Poly a)

newtype PolyNTT = PolyNTT (Tq Sec) deriving Show

instance Arbitrary PolyNTT where
    arbitrary = (\(Poly f) -> PolyNTT (ntt f)) <$> arbitrary

newtype D = D Int deriving Show

instance Arbitrary D where
    arbitrary = D <$> choose (0, 12)

data Dim = forall (n :: Nat). KnownNat n => Dim (Proxy n)

instance Show Dim where
    show (Dim n) = show n

instance Arbitrary Dim where
    arbitrary = toDim <$> choose (1, 9)

toDim :: Int -> Dim
toDim n = case someNatVal (fromIntegral n) of SomeNat p -> Dim p

type VElem = Zq  -- test with any ring but Tq would also work here

arbitraryVector :: KnownNat n => proxy n -> Gen (Vector n VElem)
arbitraryVector _ = Vector.replicateM (unFE <$> arbitrary)

arbitraryMatrix :: (KnownNat m, KnownNat n) => proxy n -> proxy m -> Gen (Vector n (Vector m VElem))
arbitraryMatrix _ m = Vector.replicateM (arbitraryVector m)

arbitraryBytes :: Int -> Gen Bytes
arbitraryBytes n = B.pack <$> vectorOf n arbitrary

byteDecodeBytes :: Int -> Bytes -> BlockN Sec 256 Word16
byteDecodeBytes = byteDecode

byteEncodeBytes :: Int -> BlockN Sec 256 Word16 -> Bytes
byteEncodeBytes d = runBytes . byteEncode d

byteEncodeBytes1 :: BlockN Sec 256 Word16 -> Bytes
byteEncodeBytes1 = runBytes . byteEncode1

byteEncodeBytes12 :: Tq Sec -> Bytes
byteEncodeBytes12 = runBytes . byteEncode12

runBytes :: Builder Sec -> Bytes
runBytes = Builder.run . leak

#endif

data P = forall a. (ParamSet a, Show a) => P (Proxy a)

instance Show P where
    show (P p) = show p

instance Arbitrary P where
    arbitrary = elements
        [ P (Proxy :: Proxy ML_KEM_512)
        , P (Proxy :: Proxy ML_KEM_768)
        , P (Proxy :: Proxy ML_KEM_1024)
        ]

toP :: String -> P
toP "ML-KEM-512" =  P (Proxy :: Proxy ML_KEM_512)
toP "ML-KEM-768" =  P (Proxy :: Proxy ML_KEM_768)
toP "ML-KEM-1024" = P (Proxy :: Proxy ML_KEM_1024)
toP paramSet      = error ("unknown parameter set " ++ paramSet)

withVectors :: (IO () -> TestTree) -> TestTree
withVectors = withResource alloc free
  where
    scriptPath = "tests/get-vectors.sh"
    free _ = return ()
    alloc = do
        keyGenExists <- doesFileExist "tests/keyGen.json.gz"
        encapDecapExists <- doesFileExist "tests/encapDecap.json.gz"
        unless (keyGenExists && encapDecapExists) $ catchIOError
            (void $ readProcess "/bin/sh" [scriptPath] "")
            (\e ->
                let msg = "Could not download test vectors, you will need to run the script `" ++
                            scriptPath ++ "' manually. Script failure was: " ++ show e
                 in ioError (mkIOError OtherError msg Nothing Nothing)
            )

keyGenVectors :: (String -> IO ()) -> Assertion
keyGenVectors step = do
    step "Reading test vectors ..."
    file <- Vectors.readJson "tests/keyGen.json.gz"
    forM_ (Vectors.testGroups file) $ \group -> do
        let paramSet = KeyGen.parameterSet group
        step paramSet
        case toP paramSet of
            P p -> forM_ (KeyGen.tests group) $ \t -> do
                let tcId = KeyGen.tcId t
                    eks = Lib.encode ek
                    dks = Lib.encode dk
                    (ek, dk) = fromJust $ Lib.generateWith p (KeyGen.d t) (KeyGen.z t)
                assertEqual ("ek mismatch for tcId=" ++ show tcId) (KeyGen.ek t) eks
                assertEqual ("dk mismatch for tcId=" ++ show tcId) (KeyGen.dk t) dks

encapDecapVectors :: (String -> IO ()) -> Assertion
encapDecapVectors step = do
    step "Reading test vectors ..."
    file <- Vectors.readJson "tests/encapDecap.json.gz"
    forM_ (Vectors.testGroups file) $ \group -> do
        let paramSet = EncapDecap.parameterSet group
        step (paramSet ++ " (" ++ EncapDecap.function group ++ ")")
        case toP paramSet of
            P p -> case EncapDecap.payload group of
                EncapDecap.FunctionEncapsulation tests ->
                    forM_ tests (testEncapsulation p)
                EncapDecap.FunctionDecapsulation tests ->
                    forM_ tests (testDecapsulation p)
                EncapDecap.FunctionEncapsulationKeyCheck tests ->
                    forM_ tests (testEncapsulationKeyCheck p)
                EncapDecap.FunctionDecapsulationKeyCheck tests ->
                    forM_ tests (testDecapsulationKeyCheck p)
  where
    ensureEk = id :: f (EncapsulationKey a) -> f (EncapsulationKey a)
    ensureDk = id :: f (DecapsulationKey a) -> f (DecapsulationKey a)
    testEncapsulation p test = do
        let tcId = EncapDecap.tcId test
            ext = EncapDecap.tcExt test
            ek = fromJust $ Lib.decode p (EncapDecap.ekEnc ext)
            k' = fromJust $ Lib.decode p (EncapDecap.kEnc ext)
            c' = fromJust $ Lib.decode p (EncapDecap.cEnc ext)
            (k, c) = fromJust $ Lib.encapsulateWith ek (EncapDecap.mEnc ext)
        assertEqual ("k mismatch for tcId=" ++ show tcId) k' k
        assertEqual ("c mismatch for tcId=" ++ show tcId) c' c
    testDecapsulation p test = do
        let tcId = EncapDecap.tcId test
            ext = EncapDecap.tcExt test
            dk = fromJust $ Lib.decode p (EncapDecap.dkDec ext)
            c  = fromJust $ Lib.decode p (EncapDecap.cDec ext)
            k' = fromJust $ Lib.decode p (EncapDecap.kDec ext)
            k = Lib.decapsulate dk c
        assertEqual ("k mismatch for tcId=" ++ show tcId) k' k
    testEncapsulationKeyCheck p test = do
        let tcId = EncapDecap.tcId test
            ext = EncapDecap.tcExt test
            mek = ensureEk $ Lib.decode p (EncapDecap.ekEkc ext)
        assertBool ("opposite outcome for tcId=" ++ show tcId)
            (EncapDecap.passedEkc ext == isJust mek)
    testDecapsulationKeyCheck p test = do
        let tcId = EncapDecap.tcId test
            ext = EncapDecap.tcExt test
            mdk = ensureDk $ Lib.decode p (EncapDecap.dkDkc ext)
        assertBool ("opposite outcome for tcId=" ++ show tcId)
            (EncapDecap.passedDkc ext == isJust mdk)

main :: IO ()
main = defaultMain $ testGroup "mlkem"
    [ withVectors $ \_ -> testGroup "vectors"
        [ testCaseSteps "keyGen" keyGenVectors
        , testCaseSteps "encapDecap" encapDecapVectors
        ]
    , testGroup "properties"
        [ testGroup "ML-KEM"
            [ testProperty "encapsulate/decapsulate" $ \(P p) -> ioProperty $ do
                (ek, dk) <- Lib.generate p
                (kk, c) <- Lib.encapsulate ek
                let kk' = Lib.decapsulate dk c
                return (kk === kk')
            , testProperty "encode/decode keys" $ \(P p) -> ioProperty $ do
                (ek, dk) <- Lib.generate p
                return $ conjoin
                    [ Just ek === Lib.decode p (Lib.encode ek :: Bytes)
                    , Just dk === Lib.decode p (Lib.encode dk :: Bytes)
                    ]
            , testProperty "convert/decode ciphertext and shared secret" $ \(P p) -> ioProperty $ do
                (ek, _) <- Lib.generate p
                (kk, c) <- Lib.encapsulate ek
                return $ conjoin
                    [ Just c === Lib.decode p (B.convert c :: Bytes)
                    , Just kk === Lib.decode p (B.convert kk :: Bytes)
                    ]
            , testProperty "toPublic" $ \(P p) -> ioProperty $ do
                (ek, dk) <- Lib.generate p
                return (ek === toPublic dk)
            , testProperty "checkKeyPair" $ \(P p) -> ioProperty $
                Lib.generate p >>= checkKeyPair
            ]
#ifdef ML_KEM_TESTING
        , testGroup "bitRev7"
            [ testCase "powers of two" $
                let powers = [1, 2, 4, 8, 16, 32, 64]
                 in reverse powers @=? map bitRev7 powers
            , testProperty "or" $ \(Bit7 a) (Bit7 b) ->
                bitRev7 (a .|. b) === bitRev7 a .|. bitRev7 b
            , testProperty "not" $ \(Bit7 a) ->
                let comp = xor 127
                 in bitRev7 (comp a) === comp (bitRev7 a)
            , testProperty "involutive" $ \(Bit7 a) ->
                a === bitRev7 (bitRev7 a)
            , testProperty "preserves bit count" $ \(Bit7 a) ->
                popCount a === popCount (bitRev7 a)
            ]
        , testGroup "compression"
            [ testProperty "compress . decompress == id" $ \(D d) -> do
                y <- choose (0, 2^d - 1)
                return (d < 12 ==> y === compress d (decompress d y))
            ]
        , testGroup "encoding"
            [ testProperty "byteEncode . byteDecode == id" $ \(D d) -> do
                b <- arbitraryBytes (32 * d)
                return (b === byteEncodeBytes d (byteDecode d b))
            , testProperty "byteEncode1 . byteDecode1 == id" $ do
                b <- arbitraryBytes 32
                return (b === byteEncodeBytes1 (byteDecode1 b))
            , testProperty "byteDecode12 . byteEncode12 == id" $ \(PolyNTT p) ->
                p === byteDecode12 (byteEncodeBytes12 p)
            , testProperty "byteEncode 8" $ \x ->
                B.replicate 256 x === byteEncodeBytes 8 (BlockN.replicate $ fromIntegral x)
            , testCase "byteEncode 1 (zeros)" $
                B.replicate 32 0 @=? byteEncodeBytes 1 (BlockN.replicate 0)
            , testCase "byteEncode 1 (ones)" $
                B.replicate 32 255 @=? byteEncodeBytes 1 (BlockN.replicate 1)
            , testProperty "byteDecode1 == byteDecode 1" $ do
                b <- arbitraryBytes 32
                return (byteDecodeBytes 1 b === byteDecode1 b)
            ]
        , testGroup "Zq"
            [ testProperty "toZq . fromZq == id " $ \(FE a) ->
                a === toZq (fromZq a)
            , testProperty "fromZq . toZq == id " $ \a ->
                mod a 3329 === fromZq (toZq a)
            , testCase "field order" $ zero @=? toZq 3329
            , testProperty "addition with zero" $ \(FE a) ->
                conjoin [ a === zero .+ a
                        , a === a .+ zero
                        ]
            , testProperty "addition associative" $ \(FE a) (FE b) (FE c) ->
                a .+ (b .+ c) === (a .+ b) .+ c
            , testProperty "addition commutative" $ \(FE a) (FE b) ->
                a .+ b === b .+ a
            , testProperty "substraction with zero" $ \(FE a) ->
                a === a .- zero
            , testProperty "substraction non-associative" $ \(FE a) (FE b) (FE c) ->
                a .- (b .- c) === (a .- b) .+ c
            , testProperty "substraction anti-commutative" $ \(FE a) (FE b) ->
                a .- b === neg (b .- a)
            , testProperty "negation" $ \(FE a) ->
                neg a === zero .- a
            , testProperty "double negation" $ \(FE a) ->
                a === neg (neg a)
            , testProperty "multiplication with zero" $ \(FE a) ->
                conjoin [ zero === zero .* a
                        , zero === a .* zero
                        ]
            , testProperty "multiplication with one" $ \(FE a) ->
                conjoin [ a === one .* a
                        , a === a .* one
                        ]
            , testProperty "multiplication associative" $ \(FE a) (FE b) (FE c) ->
                a .* (b .* c) === (a .* b) .* c
            , testProperty "multiplication commutative" $ \(FE a) (FE b) ->
                a .* b === b .* a
            , testProperty "multiplication distributive" $ \(FE a) (FE b) (FE c) ->
                conjoin [ (a .* b) .+ (a .* c) === a .* (b .+ c)
                        , (b .* a) .+ (c .* a) === (b .+ c) .* a
                        ]
            , testProperty "mulAdd" $ \(FE a) (FE b) (FE c) ->
                a .* b .+ c === mulAdd a b c
            ]
        , testGroup "Rq"
            [ testProperty "fromCoeffs . toCoeffs == id " $ \(Poly a) ->
                Just a === fromCoeffs (toCoeffs a)
            , testProperty "addition with zero" $ \(Poly a) ->
                conjoin [ a === zero .+ a
                        , a === a .+ zero
                        ]
            , testProperty "addition associative" $ \(Poly a) (Poly b) (Poly c) ->
                a .+ (b .+ c) === (a .+ b) .+ c
            , testProperty "addition commutative" $ \(Poly a) (Poly b) ->
                a .+ b === b .+ a
            , testProperty "substraction with zero" $ \(Poly a) ->
                a === a .- zero
            , testProperty "substraction non-associative" $ \(Poly a) (Poly b) (Poly c) ->
                a .- (b .- c) === (a .- b) .+ c
            , testProperty "substraction anti-commutative" $ \(Poly a) (Poly b) ->
                a .- b === neg (b .- a)
            , testProperty "negation" $ \(Poly a) ->
                neg a === zero .- a
            , testProperty "double negation" $ \(Poly a) ->
                a === neg (neg a)
            ]
        , testGroup "Tq"
            [ testProperty "nttInv . ntt == id" $ \(Poly a) ->
                a === nttInv (ntt a)
            , testProperty "addition with zero" $ \(PolyNTT a) ->
                conjoin [ a === zero .+ a
                        , a === a .+ zero
                        ]
            , testProperty "addition associative" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) ->
                a .+ (b .+ c) === (a .+ b) .+ c
            , testProperty "addition commutative" $ \(PolyNTT a) (PolyNTT b) ->
                a .+ b === b .+ a
            , testProperty "substraction with zero" $ \(PolyNTT a) ->
                a === a .- zero
            , testProperty "substraction non-associative" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) ->
                a .- (b .- c) === (a .- b) .+ c
            , testProperty "substraction anti-commutative" $ \(PolyNTT a) (PolyNTT b) ->
                a .- b === neg (b .- a)
            , testProperty "negation" $ \(PolyNTT a) ->
                neg a === zero .- a
            , testProperty "double negation" $ \(PolyNTT a) ->
                a === neg (neg a)
            , testProperty "multiplication with zero" $ \(PolyNTT a) ->
                conjoin [ zero === zero .* a
                        , zero === a .* zero
                        ]
            , testProperty "multiplication with one" $ \(PolyNTT a) ->
                conjoin [ a === one .* a
                        , a === a .* one
                        ]
            , testProperty "multiplication associative" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) ->
                a .* (b .* c) === (a .* b) .* c
            , testProperty "multiplication commutative" $ \(PolyNTT a) (PolyNTT b) ->
                a .* b === b .* a
            , testProperty "multiplication distributive" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) ->
                conjoin [ (a .* b) .+ (a .* c) === a .* (b .+ c)
                        , (b .* a) .+ (c .* a) === (b .+ c) .* a
                        ]
            , testProperty "mulAdd" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) ->
                a .* b .+ c === mulAdd a b c
            ]
        , testGroup "Vector"
            [ testProperty "addition with zero" $ \(Dim n) -> do
                a <- arbitraryVector n
                return $ conjoin
                    [ a === zero .+ a
                    , a === a .+ zero
                    ]
            , testProperty "addition associative" $ \(Dim n) -> do
                (a, b, c) <- (,,) <$> arbitraryVector n <*> arbitraryVector n <*> arbitraryVector n
                return (a .+ (b .+ c) === (a .+ b) .+ c)
            , testProperty "addition commutative" $ \(Dim n) -> do
                (a, b) <- (,) <$> arbitraryVector n <*> arbitraryVector n
                return (a .+ b === b .+ a)
            , testProperty "substraction with zero" $ \(Dim n) -> do
                a <- arbitraryVector n
                return (a === a .- zero)
            , testProperty "substraction non-associative" $ \(Dim n) -> do
                (a, b, c) <- (,,) <$> arbitraryVector n <*> arbitraryVector n <*> arbitraryVector n
                return (a .- (b .- c) === (a .- b) .+ c)
            , testProperty "substraction anti-commutative" $ \(Dim n) -> do
                (a, b) <- (,) <$> arbitraryVector n <*> arbitraryVector n
                return (a .- b === neg (b .- a))
            , testProperty "negation" $ \(Dim n) -> do
                a <- arbitraryVector n
                return (neg a === zero .- a)
            , testProperty "double negation" $ \(Dim n) -> do
                a <- arbitraryVector n
                return (a === neg (neg a))
            ]
        , testGroup "Matrix"
            [ testProperty "mulw distributive left" $ \(Dim n) (Dim m) -> do
                (a, b, u, v) <- (,,,) <$> arbitraryMatrix m n <*> arbitraryMatrix m n <*> arbitraryVector m <*> arbitraryVector n
                return (mulw (a .+ b) u v === mulw a u (mulw b u zero) .+ v)
            , testProperty "mulw distributive right" $ \(Dim n) (Dim m) -> do
                (a, u, v, w) <- (,,,) <$> arbitraryMatrix m n <*> arbitraryVector m <*> arbitraryVector m <*> arbitraryVector n
                return (mulw a (u .+ v) w === mulw a u (mulw a v w))
            , testProperty "muly definition" $ \(Dim n) (Dim m) -> do
                (a, u, v) <- (,,) <$> arbitraryMatrix n m <*> arbitraryVector m <*> arbitraryVector n
                return ((a `muly` u) .+ v == mulw (transpose a) u v)
            , testProperty "muly distributive left" $ \(Dim n) (Dim m) -> do
                (a, b, u) <- (,,) <$> arbitraryMatrix n m <*> arbitraryMatrix n m <*> arbitraryVector m
                return ((a .+ b) `muly` u === (a `muly` u) .+ (b `muly` u))
            , testProperty "muly distributive right" $ \(Dim n) (Dim m) -> do
                (a, u, v) <- (,,) <$> arbitraryMatrix n m <*> arbitraryVector m <*> arbitraryVector m
                return (a `muly` (u .+ v) === (a `muly` u) .+ (a `muly` v))
            , testProperty "mulz commutative" $ \(Dim n) -> do
                (u, v) <- (,) <$> arbitraryVector n <*> arbitraryVector n
                return (u `mulz` v === v `mulz` u)
            , testProperty "mulz distributive" $ \(Dim n) -> do
                (u, v, w) <- (,,) <$> arbitraryVector n <*> arbitraryVector n <*> arbitraryVector n
                return $ conjoin
                    [ u `mulz` (v .+ w) === (u `mulz` v) .+ (u `mulz` w)
                    , (u .+ v) `mulz` w === (u `mulz` w) .+ (v `mulz` w)
                    ]
            ]
#endif
        ]
    ]
