{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} module Vector ( Vector, Vector.concatMap , Vector.fold1ZipWith, Vector.foldIndexWith, Vector.toNormalForm , Vector.create, Vector.index #ifdef ML_KEM_TESTING , Vector.replicateM, Vector.zipWith #endif ) where import Basement.BoxedArray (Array) import qualified Basement.BoxedArray as Array import Basement.Compat.IsList import Basement.Nat import Basement.NormalForm import Basement.Types.OffsetSize import Control.DeepSeq (NFData(..)) #ifdef ML_KEM_TESTING import Control.Monad #endif #if !(MIN_VERSION_base(4,20,0)) import Data.List (foldl') #endif import Data.Proxy import Math newtype Vector (n :: Nat) a = Vector { unVector :: Array a } deriving (Eq, Show, Functor, NormalForm) instance (Add a, KnownNat n) => Add (Vector n a) where zero = create (const zero) (.+) = Vector.zipWith (.+) (.-) = Vector.zipWith (.-) neg (Vector a) = Vector (fmap neg a) create :: forall n a. KnownNat n => (Offset a -> a) -> Vector n a create = Vector . Array.create (CountOf sz) where !sz = fromIntegral $ natVal (Proxy :: Proxy n) {-# INLINE create #-} arrayIndex :: Array a -> Offset a -> a #ifdef ML_KEM_TESTING arrayIndex = Array.index replicateM :: forall n m a. (KnownNat n, Applicative m) => m a -> m (Vector n a) replicateM f = Vector . fromList <$> Control.Monad.replicateM sz f where !sz = fromIntegral $ natVal (Proxy :: Proxy n) #else arrayIndex = Array.unsafeIndex #endif index :: Vector n a -> Offset a -> a index = arrayIndex . unVector concatMap :: Monoid b => (a -> b) -> Vector n a -> b concatMap f = mconcat . Prelude.map f . toList . unVector {-# INLINE concatMap #-} zipWith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c zipWith f (Vector a) (Vector !b) = Vector $ Array.create (CountOf sa) $ \(Offset i) -> f (arrayIndex a (Offset i)) (arrayIndex b (Offset i)) where CountOf sa = Array.length a {-# INLINE zipWith #-} fold1ZipWith :: (c -> a -> b -> c) -> (a -> b -> c) -> Vector n a -> Vector n b -> c fold1ZipWith f g (Vector a) (Vector !b) = foldl' ff gg [1 .. sa - 1] where ff x i = f x (arrayIndex a (Offset i)) (arrayIndex b (Offset i)) gg = g (arrayIndex a 0) (arrayIndex b 0) CountOf !sa = Array.length a {-# INLINE fold1ZipWith #-} foldIndexWith :: (c -> Offset a -> a -> c) -> c -> Vector n a -> c foldIndexWith f c (Vector a) = foldl' g c [0 .. sa - 1] where g x i = f x (Offset i) (arrayIndex a (Offset i)) CountOf !sa = Array.length a {-# INLINE foldIndexWith #-} toNormalForm :: NFData a => Vector n a -> () toNormalForm = Array.foldl' (\acc x -> acc `seq` rnf x) () . unVector