{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module Main (main) where

import Criterion.Main

import Crypto.Random
import Crypto.PubKey.ML_KEM

import Data.ByteArray (Bytes)
import Data.Proxy

data KEM = forall a . ParamSet a => KEM (Proxy a)

kems :: [(String, KEM)]
kems =
    [ ("ML-KEM-512",  KEM (Proxy :: Proxy ML_KEM_512))
    , ("ML-KEM-768",  KEM (Proxy :: Proxy ML_KEM_768))
    , ("ML-KEM-1024", KEM (Proxy :: Proxy ML_KEM_1024))
    ]

doBench :: (String, KEM) -> Benchmark
doBench (name, KEM p) = bgroup name
    [ bench "generate" $ perRunEnv setupGenerate (return . runGenerate)
    , bench "encapsulate" $ perRunEnv setupEncap (return . runEncap)
    , bench "encapsulate (batch)" $ perBatchEnv (const setupEncap) (return . runEncap)
    , bench "decapsulate" $ perRunEnv setupDecap (return . runDecap)
    , bench "decapsulate (batch)" $ perBatchEnv (const setupDecap) (return . runDecap)
    ]
  where
    gen32 = getRandomBytes 32 :: IO Bytes

    setupGenerate = (,) <$> gen32 <*> gen32
    runGenerate = uncurry (generateWith p)

    setupEncap = do
        (ek, _) <- generate p
        m <- gen32
        return (ek, m)
    runEncap (ek, m) = encapsulateWith ek m

    runDecap (dk, c) = decapsulate dk c
    setupDecap = do
        (ek, dk) <- generate p
        (_, c) <- encapsulate ek
        return (dk, c)

main :: IO ()
main = defaultMain
    [ bgroup "mlkem" $ map doBench kems
    ]
