-- |
-- Module      : Vector
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- A vector of lifted elements with the vector dimension at type level.
-- Currently backed by type t'Array' from basement.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Vector
    ( Vector, Vector.concatMap
    , Vector.fold1ZipWith, Vector.foldIndexWith, Vector.toNormalForm
    , Vector.create, Vector.index
#ifdef ML_KEM_TESTING
    , Vector.replicateM, Vector.zipWith
#endif
    ) where

import Basement.BoxedArray (Array)
import qualified Basement.BoxedArray as Array
#ifdef ML_KEM_TESTING
import Basement.Compat.IsList
#endif
import Basement.Nat
import Basement.NormalForm
import Basement.Types.OffsetSize

import Control.DeepSeq (NFData(..))
#ifdef ML_KEM_TESTING
import Control.Monad
#endif

#if !(MIN_VERSION_base(4,20,0))
import Data.List (foldl')
#endif
import Data.Proxy

import Iterate
import Math

newtype Vector (n :: Nat) a = Vector { unVector :: Array a }
    deriving (Eq, Show, NormalForm)

instance Functor (Vector n) where
    fmap = mapVector
    {-# INLINE fmap #-}

instance (Add a, KnownNat n) => Add (Vector n a) where
    zero = create (const zero)
    {-# INLINE zero #-}
    (.+) = Vector.zipWith (.+)
    {-# INLINE (.+) #-}
    (.-) = Vector.zipWith (.-)
    {-# INLINE (.-) #-}
    neg = mapVector neg
    {-# INLINE neg #-}

create :: forall n a. KnownNat n => (Offset a -> a) -> Vector n a
create = genericCreate
{-# INLINE create #-}

genericCreate :: forall n a a'. KnownNat n => (Offset a' -> a) -> Vector n a
genericCreate f = Vector $ Array.create (CountOf sz) (\(Offset !i) -> f (Offset i))
  where !sz = fromIntegral $ natVal (Proxy :: Proxy n)
{-# INLINE [1] genericCreate #-}

genericCreateZipLeft :: KnownNat n => (a -> b -> c) -> (Offset a' -> a) -> Vector n b -> Vector n c
genericCreateZipLeft f g a = genericCreate $ \off@(Offset i) -> f (g off) (index a (Offset i))
{-# INLINE [1] genericCreateZipLeft #-}

genericCreateZipRight :: KnownNat n => (a -> b -> c) -> (Offset b' -> b) -> Vector n a -> Vector n c
genericCreateZipRight f g a = genericCreate $ \off@(Offset i) -> f (index a (Offset i)) (g off)
{-# INLINE [1] genericCreateZipRight #-}

mapVector :: (a -> b) -> Vector n a -> Vector n b
mapVector f = Vector <$> fmap f . unVector
{-# INLINE [1] mapVector #-}

arrayIndex :: Array a -> Offset a -> a
#ifdef ML_KEM_TESTING
arrayIndex = Array.index

replicateM :: forall n m a. (KnownNat n, Applicative m) => m a -> m (Vector n a)
replicateM f = Vector . fromList <$> Control.Monad.replicateM sz f
  where !sz = fromIntegral $ natVal (Proxy :: Proxy n)
#else
arrayIndex = Array.unsafeIndex
#endif

index :: Vector n a -> Offset a -> a
index = arrayIndex . unVector

concatMap :: Monoid b => (a -> b) -> Vector n a -> b
concatMap f = mconcat . mapToList f
{-# INLINE concatMap #-}

mapToList :: (a -> b) -> Vector n a -> [b]
mapToList f (Vector a) = Prelude.map (f . arrayIndex a . Offset) (offsets sa)
  where CountOf sa = Array.length a

zipWith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
zipWith f (Vector a) (Vector !b) = Vector $
    Array.create (CountOf sa) $ \(Offset i) ->
        f (arrayIndex a (Offset i)) (arrayIndex b (Offset i))
  where
    CountOf sa = Array.length a
{-# INLINE [1] zipWith #-}

fold1ZipWith :: (c -> a -> b -> c) -> (a -> b -> c) -> Vector n a -> Vector n b -> c
fold1ZipWith f g (Vector a) (Vector !b) =
    foldl' ff gg (offsetsFrom 1 sa)
  where
    ff x i = f x (arrayIndex a (Offset i)) (arrayIndex b (Offset i))
    gg = g (arrayIndex a 0) (arrayIndex b 0)
    CountOf !sa = Array.length a
{-# INLINE fold1ZipWith #-}

foldIndexWith :: (c -> Offset a -> a -> c) -> c -> Vector n a -> c
foldIndexWith f c (Vector a) = foldl' g c (offsets sa)
  where
    g x i = f x (Offset i) (arrayIndex a (Offset i))
    CountOf !sa = Array.length a
{-# INLINE foldIndexWith #-}

toNormalForm :: NFData a => Vector n a -> ()
toNormalForm = Array.foldl' (\acc x -> acc `seq` rnf x) () . unVector

{-# RULES
"mapVector/mapVector" [2] forall f g a. mapVector f (mapVector g a) = mapVector (f . g) a
"mapVector/genericCreate" [2] forall f g. mapVector f (genericCreate g) = genericCreate (f . g)
"zipWith/genericCreate left" [2] forall f g a. Vector.zipWith f (genericCreate g) a = genericCreateZipLeft f g a
"zipWith/genericCreate right" [2] forall f g a. Vector.zipWith f a (genericCreate g) = genericCreateZipRight f g a
"genericCreateZipLeft/genericCreate" [2] forall f g h. genericCreateZipLeft f g (genericCreate h) = genericCreate $ \(Offset i) -> f (g (Offset i)) (h (Offset i))
"genericCreateZipRight/genericCreate" [2] forall f g h. genericCreateZipRight f g (genericCreate h) = genericCreate $ \(Offset i) -> f (h (Offset i)) (g (Offset i))
"zipWith/mapVector left" [2] forall f g a. Vector.zipWith f (mapVector g a) = Vector.zipWith (f . g) a
"zipWith/mapVector right" [2] forall f g a b. Vector.zipWith f a (mapVector g b) = Vector.zipWith (\aa bb -> f aa (g bb)) a b
  #-}
