{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE KindSignatures #-} module Main (main) where import Data.ByteArray (Bytes) import qualified Data.ByteArray as B import Control.Monad import Data.Maybe (isJust, fromJust) import Data.Proxy import GHC.IO.Exception (IOErrorType(..)) import System.Directory (doesFileExist) import System.IO.Error (catchIOError, mkIOError) import System.Process (readProcess) #ifdef ML_KEM_TESTING import Data.Bits import Data.Word import GHC.TypeNats import Auxiliary import BlockN (BlockN) import Builder (Builder) import Marking (Leak(..), SecurityMarking(..)) import Math import Matrix import Vector import qualified BlockN import qualified Builder #endif import Crypto.PubKey.ML_KEM as Lib import Test.Tasty import Test.Tasty.HUnit import Test.Tasty.QuickCheck import qualified EncapDecap import qualified KeyGen import qualified Vectors #ifdef ML_KEM_TESTING newtype Bit7 = Bit7 Word8 deriving Show instance Arbitrary Bit7 where #if (MIN_VERSION_tasty_quickcheck(0,10,2)) arbitrary = Bit7 <$> chooseBoundedIntegral (0, 127) #else arbitrary = Bit7 <$> choose (0, 127) #endif newtype FE = FE { unFE :: Zq} deriving Show instance Arbitrary FE where #if (MIN_VERSION_tasty_quickcheck(0,10,2)) arbitrary = FE . toZq <$> chooseBoundedIntegral (0, 3328) #else arbitrary = FE . toZq <$> choose (0, 3328) #endif newtype Poly = Poly (Rq Sec) deriving Show instance Arbitrary Poly where arbitrary = do coeffs <- map unFE <$> vectorOf 256 arbitrary let a = fromJust (fromCoeffs coeffs) return (Poly a) newtype PolyNTT = PolyNTT (Tq Sec) deriving Show instance Arbitrary PolyNTT where arbitrary = (\(Poly f) -> PolyNTT (ntt f)) <$> arbitrary newtype D = D Int deriving Show instance Arbitrary D where arbitrary = D <$> choose (0, 12) data Dim = forall (n :: Nat). KnownNat n => Dim (Proxy n) instance Show Dim where show (Dim n) = show n instance Arbitrary Dim where arbitrary = toDim <$> choose (1, 9) toDim :: Int -> Dim toDim n = case someNatVal (fromIntegral n) of SomeNat p -> Dim p type VElem = Zq -- test with any ring but Tq would also work here arbitraryVector :: KnownNat n => proxy n -> Gen (Vector n VElem) arbitraryVector _ = Vector.replicateM (unFE <$> arbitrary) arbitraryMatrix :: (KnownNat m, KnownNat n) => proxy n -> proxy m -> Gen (Vector n (Vector m VElem)) arbitraryMatrix _ m = Vector.replicateM (arbitraryVector m) arbitraryBytes :: Int -> Gen Bytes arbitraryBytes n = B.pack <$> vectorOf n arbitrary byteDecodeBytes :: Int -> Bytes -> BlockN Sec 256 Word16 byteDecodeBytes = byteDecode byteEncodeBytes :: Int -> BlockN Sec 256 Word16 -> Bytes byteEncodeBytes d = runBytes . byteEncode d byteEncodeBytes1 :: BlockN Sec 256 Word16 -> Bytes byteEncodeBytes1 = runBytes . byteEncode1 byteEncodeBytes12 :: Tq Sec -> Bytes byteEncodeBytes12 = runBytes . byteEncode12 runBytes :: Builder Sec -> Bytes runBytes = Builder.run . leak #endif data P = forall a. (ParamSet a, Show a) => P (Proxy a) instance Show P where show (P p) = show p instance Arbitrary P where arbitrary = elements [ P (Proxy :: Proxy ML_KEM_512) , P (Proxy :: Proxy ML_KEM_768) , P (Proxy :: Proxy ML_KEM_1024) ] toP :: String -> P toP "ML-KEM-512" = P (Proxy :: Proxy ML_KEM_512) toP "ML-KEM-768" = P (Proxy :: Proxy ML_KEM_768) toP "ML-KEM-1024" = P (Proxy :: Proxy ML_KEM_1024) toP paramSet = error ("unknown parameter set " ++ paramSet) withVectors :: (IO () -> TestTree) -> TestTree withVectors = withResource alloc free where scriptPath = "tests/get-vectors.sh" free _ = return () alloc = do keyGenExists <- doesFileExist "tests/keyGen.json.gz" encapDecapExists <- doesFileExist "tests/encapDecap.json.gz" unless (keyGenExists && encapDecapExists) $ catchIOError (void $ readProcess "/bin/sh" [scriptPath] "") (\e -> let msg = "Could not download test vectors, you will need to run the script `" ++ scriptPath ++ "' manually. Script failure was: " ++ show e in ioError (mkIOError OtherError msg Nothing Nothing) ) keyGenVectors :: (String -> IO ()) -> Assertion keyGenVectors step = do step "Reading test vectors ..." file <- Vectors.readJson "tests/keyGen.json.gz" forM_ (Vectors.testGroups file) $ \group -> do let paramSet = KeyGen.parameterSet group step paramSet case toP paramSet of P p -> forM_ (KeyGen.tests group) $ \t -> do let tcId = KeyGen.tcId t eks = Lib.encode ek dks = Lib.encode dk (ek, dk) = fromJust $ Lib.generateWith p (KeyGen.d t) (KeyGen.z t) assertEqual ("ek mismatch for tcId=" ++ show tcId) (KeyGen.ek t) eks assertEqual ("dk mismatch for tcId=" ++ show tcId) (KeyGen.dk t) dks encapDecapVectors :: (String -> IO ()) -> Assertion encapDecapVectors step = do step "Reading test vectors ..." file <- Vectors.readJson "tests/encapDecap.json.gz" forM_ (Vectors.testGroups file) $ \group -> do let paramSet = EncapDecap.parameterSet group step (paramSet ++ " (" ++ EncapDecap.function group ++ ")") case toP paramSet of P p -> case EncapDecap.payload group of EncapDecap.FunctionEncapsulation tests -> forM_ tests (testEncapsulation p) EncapDecap.FunctionDecapsulation tests -> forM_ tests (testDecapsulation p) EncapDecap.FunctionEncapsulationKeyCheck tests -> forM_ tests (testEncapsulationKeyCheck p) EncapDecap.FunctionDecapsulationKeyCheck tests -> forM_ tests (testDecapsulationKeyCheck p) where ensureEk = id :: f (EncapsulationKey a) -> f (EncapsulationKey a) ensureDk = id :: f (DecapsulationKey a) -> f (DecapsulationKey a) testEncapsulation p test = do let tcId = EncapDecap.tcId test ext = EncapDecap.tcExt test ek = fromJust $ Lib.decode p (EncapDecap.ekEnc ext) k' = fromJust $ Lib.decode p (EncapDecap.kEnc ext) c' = fromJust $ Lib.decode p (EncapDecap.cEnc ext) (k, c) = fromJust $ Lib.encapsulateWith ek (EncapDecap.mEnc ext) assertEqual ("k mismatch for tcId=" ++ show tcId) k' k assertEqual ("c mismatch for tcId=" ++ show tcId) c' c testDecapsulation p test = do let tcId = EncapDecap.tcId test ext = EncapDecap.tcExt test dk = fromJust $ Lib.decode p (EncapDecap.dkDec ext) c = fromJust $ Lib.decode p (EncapDecap.cDec ext) k' = fromJust $ Lib.decode p (EncapDecap.kDec ext) k = Lib.decapsulate dk c assertEqual ("k mismatch for tcId=" ++ show tcId) k' k testEncapsulationKeyCheck p test = do let tcId = EncapDecap.tcId test ext = EncapDecap.tcExt test mek = ensureEk $ Lib.decode p (EncapDecap.ekEkc ext) assertBool ("opposite outcome for tcId=" ++ show tcId) (EncapDecap.passedEkc ext == isJust mek) testDecapsulationKeyCheck p test = do let tcId = EncapDecap.tcId test ext = EncapDecap.tcExt test mdk = ensureDk $ Lib.decode p (EncapDecap.dkDkc ext) assertBool ("opposite outcome for tcId=" ++ show tcId) (EncapDecap.passedDkc ext == isJust mdk) main :: IO () main = defaultMain $ testGroup "mlkem" [ withVectors $ \_ -> testGroup "vectors" [ testCaseSteps "keyGen" keyGenVectors , testCaseSteps "encapDecap" encapDecapVectors ] , testGroup "properties" [ testGroup "ML-KEM" [ testProperty "encapsulate/decapsulate" $ \(P p) -> ioProperty $ do (ek, dk) <- Lib.generate p (kk, c) <- Lib.encapsulate ek let kk' = Lib.decapsulate dk c return (kk === kk') , testProperty "encode/decode keys" $ \(P p) -> ioProperty $ do (ek, dk) <- Lib.generate p return $ conjoin [ Just ek === Lib.decode p (Lib.encode ek :: Bytes) , Just dk === Lib.decode p (Lib.encode dk :: Bytes) ] , testProperty "convert/decode ciphertext and shared secret" $ \(P p) -> ioProperty $ do (ek, _) <- Lib.generate p (kk, c) <- Lib.encapsulate ek return $ conjoin [ Just c === Lib.decode p (B.convert c :: Bytes) , Just kk === Lib.decode p (B.convert kk :: Bytes) ] , testProperty "toPublic" $ \(P p) -> ioProperty $ do (ek, dk) <- Lib.generate p return (ek === toPublic dk) , testProperty "checkKeyPair" $ \(P p) -> ioProperty $ Lib.generate p >>= checkKeyPair ] #ifdef ML_KEM_TESTING , testGroup "bitRev7" [ testCase "powers of two" $ let powers = [1, 2, 4, 8, 16, 32, 64] in reverse powers @=? map bitRev7 powers , testProperty "or" $ \(Bit7 a) (Bit7 b) -> bitRev7 (a .|. b) === bitRev7 a .|. bitRev7 b , testProperty "not" $ \(Bit7 a) -> let comp = xor 127 in bitRev7 (comp a) === comp (bitRev7 a) , testProperty "involutive" $ \(Bit7 a) -> a === bitRev7 (bitRev7 a) , testProperty "preserves bit count" $ \(Bit7 a) -> popCount a === popCount (bitRev7 a) ] , testGroup "compression" [ testProperty "compress . decompress == id" $ \(D d) -> do y <- choose (0, 2^d - 1) return (d < 12 ==> y === compress d (decompress d y)) ] , testGroup "encoding" [ testProperty "byteEncode . byteDecode == id" $ \(D d) -> do b <- arbitraryBytes (32 * d) return (b === byteEncodeBytes d (byteDecode d b)) , testProperty "byteEncode1 . byteDecode1 == id" $ do b <- arbitraryBytes 32 return (b === byteEncodeBytes1 (byteDecode1 b)) , testProperty "byteDecode12 . byteEncode12 == id" $ \(PolyNTT p) -> p === byteDecode12 (byteEncodeBytes12 p) , testProperty "byteEncode 8" $ \x -> B.replicate 256 x === byteEncodeBytes 8 (BlockN.replicate $ fromIntegral x) , testCase "byteEncode 1 (zeros)" $ B.replicate 32 0 @=? byteEncodeBytes 1 (BlockN.replicate 0) , testCase "byteEncode 1 (ones)" $ B.replicate 32 255 @=? byteEncodeBytes 1 (BlockN.replicate 1) , testProperty "byteDecode1 == byteDecode 1" $ do b <- arbitraryBytes 32 return (byteDecodeBytes 1 b === byteDecode1 b) ] , testGroup "Zq" [ testProperty "toZq . fromZq == id " $ \(FE a) -> a === toZq (fromZq a) , testProperty "fromZq . toZq == id " $ \a -> mod a 3329 === fromZq (toZq a) , testCase "field order" $ zero @=? toZq 3329 , testProperty "addition with zero" $ \(FE a) -> conjoin [ a === zero .+ a , a === a .+ zero ] , testProperty "addition associative" $ \(FE a) (FE b) (FE c) -> a .+ (b .+ c) === (a .+ b) .+ c , testProperty "addition commutative" $ \(FE a) (FE b) -> a .+ b === b .+ a , testProperty "substraction with zero" $ \(FE a) -> a === a .- zero , testProperty "substraction non-associative" $ \(FE a) (FE b) (FE c) -> a .- (b .- c) === (a .- b) .+ c , testProperty "substraction anti-commutative" $ \(FE a) (FE b) -> a .- b === neg (b .- a) , testProperty "negation" $ \(FE a) -> neg a === zero .- a , testProperty "double negation" $ \(FE a) -> a === neg (neg a) , testProperty "multiplication with zero" $ \(FE a) -> conjoin [ zero === zero .* a , zero === a .* zero ] , testProperty "multiplication with one" $ \(FE a) -> conjoin [ a === one .* a , a === a .* one ] , testProperty "multiplication associative" $ \(FE a) (FE b) (FE c) -> a .* (b .* c) === (a .* b) .* c , testProperty "multiplication commutative" $ \(FE a) (FE b) -> a .* b === b .* a , testProperty "multiplication distributive" $ \(FE a) (FE b) (FE c) -> conjoin [ (a .* b) .+ (a .* c) === a .* (b .+ c) , (b .* a) .+ (c .* a) === (b .+ c) .* a ] , testProperty "mulAdd" $ \(FE a) (FE b) (FE c) -> a .* b .+ c === mulAdd a b c ] , testGroup "Rq" [ testProperty "fromCoeffs . toCoeffs == id " $ \(Poly a) -> Just a === fromCoeffs (toCoeffs a) , testProperty "addition with zero" $ \(Poly a) -> conjoin [ a === zero .+ a , a === a .+ zero ] , testProperty "addition associative" $ \(Poly a) (Poly b) (Poly c) -> a .+ (b .+ c) === (a .+ b) .+ c , testProperty "addition commutative" $ \(Poly a) (Poly b) -> a .+ b === b .+ a , testProperty "substraction with zero" $ \(Poly a) -> a === a .- zero , testProperty "substraction non-associative" $ \(Poly a) (Poly b) (Poly c) -> a .- (b .- c) === (a .- b) .+ c , testProperty "substraction anti-commutative" $ \(Poly a) (Poly b) -> a .- b === neg (b .- a) , testProperty "negation" $ \(Poly a) -> neg a === zero .- a , testProperty "double negation" $ \(Poly a) -> a === neg (neg a) ] , testGroup "Tq" [ testProperty "nttInv . ntt == id" $ \(Poly a) -> a === nttInv (ntt a) , testProperty "addition with zero" $ \(PolyNTT a) -> conjoin [ a === zero .+ a , a === a .+ zero ] , testProperty "addition associative" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) -> a .+ (b .+ c) === (a .+ b) .+ c , testProperty "addition commutative" $ \(PolyNTT a) (PolyNTT b) -> a .+ b === b .+ a , testProperty "substraction with zero" $ \(PolyNTT a) -> a === a .- zero , testProperty "substraction non-associative" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) -> a .- (b .- c) === (a .- b) .+ c , testProperty "substraction anti-commutative" $ \(PolyNTT a) (PolyNTT b) -> a .- b === neg (b .- a) , testProperty "negation" $ \(PolyNTT a) -> neg a === zero .- a , testProperty "double negation" $ \(PolyNTT a) -> a === neg (neg a) , testProperty "multiplication with zero" $ \(PolyNTT a) -> conjoin [ zero === zero .* a , zero === a .* zero ] , testProperty "multiplication with one" $ \(PolyNTT a) -> conjoin [ a === one .* a , a === a .* one ] , testProperty "multiplication associative" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) -> a .* (b .* c) === (a .* b) .* c , testProperty "multiplication commutative" $ \(PolyNTT a) (PolyNTT b) -> a .* b === b .* a , testProperty "multiplication distributive" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) -> conjoin [ (a .* b) .+ (a .* c) === a .* (b .+ c) , (b .* a) .+ (c .* a) === (b .+ c) .* a ] , testProperty "mulAdd" $ \(PolyNTT a) (PolyNTT b) (PolyNTT c) -> a .* b .+ c === mulAdd a b c ] , testGroup "Vector" [ testProperty "addition with zero" $ \(Dim n) -> do a <- arbitraryVector n return $ conjoin [ a === zero .+ a , a === a .+ zero ] , testProperty "addition associative" $ \(Dim n) -> do (a, b, c) <- (,,) <$> arbitraryVector n <*> arbitraryVector n <*> arbitraryVector n return (a .+ (b .+ c) === (a .+ b) .+ c) , testProperty "addition commutative" $ \(Dim n) -> do (a, b) <- (,) <$> arbitraryVector n <*> arbitraryVector n return (a .+ b === b .+ a) , testProperty "substraction with zero" $ \(Dim n) -> do a <- arbitraryVector n return (a === a .- zero) , testProperty "substraction non-associative" $ \(Dim n) -> do (a, b, c) <- (,,) <$> arbitraryVector n <*> arbitraryVector n <*> arbitraryVector n return (a .- (b .- c) === (a .- b) .+ c) , testProperty "substraction anti-commutative" $ \(Dim n) -> do (a, b) <- (,) <$> arbitraryVector n <*> arbitraryVector n return (a .- b === neg (b .- a)) , testProperty "negation" $ \(Dim n) -> do a <- arbitraryVector n return (neg a === zero .- a) , testProperty "double negation" $ \(Dim n) -> do a <- arbitraryVector n return (a === neg (neg a)) ] , testGroup "Matrix" [ testProperty "mulw distributive left" $ \(Dim n) (Dim m) -> do (a, b, u, v) <- (,,,) <$> arbitraryMatrix m n <*> arbitraryMatrix m n <*> arbitraryVector m <*> arbitraryVector n return (mulw (a .+ b) u v === mulw a u (mulw b u zero) .+ v) , testProperty "mulw distributive right" $ \(Dim n) (Dim m) -> do (a, u, v, w) <- (,,,) <$> arbitraryMatrix m n <*> arbitraryVector m <*> arbitraryVector m <*> arbitraryVector n return (mulw a (u .+ v) w === mulw a u (mulw a v w)) , testProperty "muly definition" $ \(Dim n) (Dim m) -> do (a, u, v) <- (,,) <$> arbitraryMatrix n m <*> arbitraryVector m <*> arbitraryVector n return ((a `muly` u) .+ v == mulw (transpose a) u v) , testProperty "muly distributive left" $ \(Dim n) (Dim m) -> do (a, b, u) <- (,,) <$> arbitraryMatrix n m <*> arbitraryMatrix n m <*> arbitraryVector m return ((a .+ b) `muly` u === (a `muly` u) .+ (b `muly` u)) , testProperty "muly distributive right" $ \(Dim n) (Dim m) -> do (a, u, v) <- (,,) <$> arbitraryMatrix n m <*> arbitraryVector m <*> arbitraryVector m return (a `muly` (u .+ v) === (a `muly` u) .+ (a `muly` v)) , testProperty "mulz commutative" $ \(Dim n) -> do (u, v) <- (,) <$> arbitraryVector n <*> arbitraryVector n return (u `mulz` v === v `mulz` u) , testProperty "mulz distributive" $ \(Dim n) -> do (u, v, w) <- (,,) <$> arbitraryVector n <*> arbitraryVector n <*> arbitraryVector n return $ conjoin [ u `mulz` (v .+ w) === (u `mulz` v) .+ (u `mulz` w) , (u .+ v) `mulz` w === (u `mulz` w) .+ (v `mulz` w) ] ] #endif ] ]