{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} module Matrix ( create, mulw, muly, mulz #ifdef ML_KEM_TESTING , transpose #endif ) where import Basement.Nat import Basement.Types.OffsetSize import Math import Vector (Vector) import qualified Vector create :: (KnownNat m, KnownNat n) => (Offset ty -> Offset (Vector n ty) -> ty) -> Vector m (Vector n ty) create f = Vector.create $ \j -> Vector.create (`f` j) {-# INLINE create #-} index :: Vector m (Vector n ty) -> Offset ty -> Offset (Vector n ty) -> ty index a i j = Vector.index (Vector.index a j) i mulw :: (KnownNat n, BiMulAdd b a) => Vector m (Vector n b) -> Vector m a -> Vector n a -> Vector n a mulw a !u !b = Vector.create $ \(Offset i) -> Vector.foldIndexWith (\c (Offset j) vu -> biMulAdd (index a (Offset i) (Offset j)) vu c) (Vector.index b (Offset i)) u muly :: BiMulAdd b a => Vector m (Vector n b) -> Vector n a -> Vector m a muly a !u = fmap (`mulz` u) a mulz :: BiMulAdd b a => Vector n b -> Vector n a -> a mulz = Vector.fold1ZipWith (\c a b -> biMulAdd a b c) (..*) #ifdef ML_KEM_TESTING transpose :: (KnownNat m, KnownNat n) => Vector m (Vector n ty) -> Vector n (Vector m ty) transpose a = create $ \(Offset j) (Offset i) -> index a (Offset i) (Offset j) #endif