-- | -- Module : Crypto -- License : BSD-3-Clause -- Copyright : (c) 2025 Olivier Chéron -- -- Crypto-related utilities like the ML-KEM hash and PRF functions, or more -- general concerns like constant-time equality and selection. -- {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ScopedTypeVariables #-} module Crypto ( ConstEqW(..), BoolW, andW, toBool, constSelectBytes, snoc, append, eq , prf, h, j, g, BlockDigest, unBlockDigest, hashToBlock ) where import Crypto.Hash (Context) import Crypto.Hash.Algorithms import Crypto.Hash.IO import Control.Exception (assert) import Control.Monad import Control.Monad.ST import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes) import qualified Data.ByteArray as B import Data.Bits import Data.Word import GHC.TypeNats import Foreign.Ptr (Ptr, castPtr, plusPtr) import Foreign.Storable (pokeByteOff) import Block (Block) import Builder (Builder) import Machine import ScrubbedBlock (ScrubbedBlock) import Vector (Vector) import qualified Block import qualified Builder import qualified ByteArrayST as ST import qualified ScrubbedBlock import qualified Vector newtype BoolW = BoolW Word #ifdef ML_KEM_TESTING instance Show BoolW where showsPrec d = showsPrec d . toBool #endif toBool :: BoolW -> Bool toBool (BoolW mask) = mask /= 0 falseW, trueW :: BoolW falseW = BoolW 0 trueW = BoolW maxBound andW :: BoolW -> BoolW -> BoolW andW (BoolW a) (BoolW b) = BoolW (a .&. b) bitsW :: Int bitsW = let BoolW x = falseW in finiteBitSize x bytesW :: Int bytesW = div bitsW 8 eqW :: Word -> Word -> BoolW eqW a b = isZeroW (a `xor` b) where isZeroW x = BoolW $ msbW (complement x .&. (x - 1)) msbW x = negate (x `unsafeShiftR` (bitsW - 1)) assertMultW :: Int -> a -> a assertMultW n = assert (n .&. mask == 0) where mask = bytesW - 1 class ConstEqW a where constEqW :: a -> a -> BoolW instance ConstEqW a => ConstEqW (Vector n a) where constEqW = Vector.fold1ZipWith (\mask x y -> mask `andW` constEqW x y) constEqW instance ConstEqW (Block Word) where constEqW a b | Block.length a /= Block.length b = falseW | otherwise = Block.foldZipWith (\mask x y -> mask `andW` eqW x y) trueW a b instance ConstEqW (ScrubbedBlock Word) where constEqW a b | ScrubbedBlock.length a /= ScrubbedBlock.length b = falseW | otherwise = ScrubbedBlock.foldZipWith (\mask x y -> mask `andW` eqW x y) trueW a b instance ConstEqW Bytes where constEqW = bytesConstEqW instance ConstEqW ScrubbedBytes where constEqW = bytesConstEqW bytesConstEqW :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> BoolW bytesConstEqW a b | B.length a /= B.length b = falseW | otherwise = foldZipWith (\mask x y -> mask `andW` eqW x y) trueW a b foldZipWith :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => (c -> Word -> Word -> c) -> c -> bs1 -> bs2 -> c foldZipWith f c a b = assert (sa == sb) $ assertMultW sa $ assertMultW sb $ runST $ ST.withByteArray a $ \pa -> ST.withByteArray b $ \pb -> loop (pa :: Ptr Word) (pb :: Ptr Word) c 0 where !sa = B.length a !sb = B.length b loop !pa !pb !acc i | i == sa = return acc | otherwise = do va <- ST.peek pa vb <- ST.peek pb loop (pa `plusPtr` bytesW) (pb `plusPtr` bytesW) (f acc va vb) (i + bytesW) {-# INLINE foldZipWith #-} zipWith :: (Word -> Word -> Word) -> ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes zipWith f a b = assert (sa == sb) $ assertMultW sa $ assertMultW sb $ ST.unsafeCreate sa $ \out -> ST.withByteArray a $ \pa -> ST.withByteArray b $ \pb -> loop out pa pb 0 where !sa = B.length a !sb = B.length b loop :: Ptr Word -> Ptr Word -> Ptr Word -> Int -> ST s () loop !out !pa !pb i = when (i < sa) $ do va <- ST.peek pa vb <- ST.peek pb ST.pokeByteOff out i $ f va vb loop out (pa `plusPtr` bytesW) (pb `plusPtr` bytesW) (i + bytesW) {-# INLINE zipWith #-} constSelectBytes :: BoolW -> ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes constSelectBytes (BoolW !mask) = Crypto.zipWith f where f yes no = (mask .&. yes) .|. (complement mask .&. no) -- This version of snoc accepts a more general input and uses internally a call -- to copyByteArrayToPtr, so it does not need a trampoline when the input is -- backed by Block Word8 snoc :: ByteArrayAccess a => a -> Word8 -> ScrubbedBytes snoc a b = B.allocAndFreeze (na + 1) $ \p -> do B.copyByteArrayToPtr a p pokeByteOff p na b where na = B.length a {-# INLINE snoc #-} -- This version of append is more polymorphic and requires no trampoline when -- fed with an input backed by Block Word8. append :: (ByteArrayAccess a, ByteArrayAccess b) => a -> b -> ScrubbedBytes append a b = B.allocAndFreeze (na + nb) $ \p -> do B.copyByteArrayToPtr a p B.copyByteArrayToPtr b (p `plusPtr` na) where na = B.length a nb = B.length b {-# INLINE append #-} eq :: (ByteArrayAccess a, ByteArrayAccess b) => a -> b -> Bool eq a b = assert (sa == sb) $ assertMultM sa $ assertMultM sb $ runST $ ST.withByteArray a $ \pa -> ST.withByteArray b $ \pb -> loop (pa :: Ptr WordM) (pb :: Ptr WordM) 0 where !sa = B.length a !sb = B.length b loop !pa !pb i | i == sa = return True | otherwise = do va <- ST.peek pa vb <- ST.peek pb if va == vb then loop (pa `plusPtr` wordBytes) (pb `plusPtr` wordBytes) (i + wordBytes) else return False prf :: ByteArrayAccess s => Word -> s -> Word8 -> ScrubbedBytes prf !eta s !b = case someNatVal (fromIntegral (8 * 64 * eta)) of SomeNat proxy -> unDigest (doHash proxy) where doHash :: KnownNat bitlen => proxy bitlen -> Digest (SHAKE256 bitlen) doHash _ = hash (snoc s b) h :: ByteArrayAccess s => s -> Bytes h = Builder.run . hashWith SHA3_256 j :: ScrubbedBytes -> ScrubbedBytes j = Builder.run . hashWith (SHAKE256 :: SHAKE256 256) g :: ByteArray ba => ScrubbedBytes -> (ba, B.View ScrubbedBytes) g c = (B.convert $ B.takeView ab 32, B.dropView ab 32) where ab = Builder.run $ hashWith SHA3_512 c -- Override cryptonite types and hashing functions. -- -- Standard type Digest is a newtype over an unpinned Block Word8, which -- requires a trampoline to implement most Ptr access to the underlying byte -- array. Instead we re-implement here the Digest type over ScrubbedBytes as -- well as pinned Block backends, to avoid trampoline costs. Additionnally -- we use the mutable API to avoid copying the hashing Context in between -- steps init/update/finalize and then clear the content. newtype Digest a = Digest { unDigest :: ScrubbedBytes } newtype BlockDigest a = BlockDigest { unBlockDigest :: Block Word8 } hash :: forall a ba. (HashAlgorithm a, ByteArrayAccess ba) => ba -> Digest a hash = Digest . Builder.run . hashWith (undefined :: a) hashToBlock :: forall a. HashAlgorithm a => Bytes -> BlockDigest a hashToBlock = BlockDigest . Builder.runToBlock . hashWith (undefined :: a) hashWith :: forall marking a ba. (HashAlgorithm a, ByteArrayAccess ba) => a -> ba -> Builder marking hashWith a ba = Builder.unsafeCreate (hashDigestSize a) $ \dig -> do ctx <- hashMutableInit hashMutableUpdate (ctx :: MutableContext a) ba B.withByteArray ctx $ \pctx -> do hashInternalFinalize (castPtr pctx :: Ptr (Context a)) dig ScrubbedBlock.erasePtr (B.length ctx) pctx