{-# Language GADTs, DataKinds, TypeOperators, BangPatterns #-} {-# Language PatternGuards #-} {-# Language TypeApplications, ScopedTypeVariables #-} {-# Language Rank2Types, RoleAnnotations #-} {-# Language CPP #-} #if __GLASGOW_HASKELL__ >= 805 {-# Language NoStarIsType #-} #endif -- | A vector fixed-size vector of typed elements. module Data.Parameterized.Vector ( Vector -- * Lists , fromList , toList -- * Length , length , nonEmpty , lengthInt -- * Indexing , elemAt , elemAtMaybe , elemAtUnsafe -- * Update , insertAt , insertAtMaybe -- * Sub sequences , uncons , slice , Data.Parameterized.Vector.take -- * Zipping , zipWith , zipWithM , zipWithM_ , interleave -- * Reorder , shuffle , reverse , rotateL , rotateR , shiftL , shiftR -- * Construction , singleton , cons , snoc , generate , generateM -- * Splitting and joining -- ** General , joinWithM , joinWith , splitWith , splitWithA -- ** Vectors , split , join , append ) where import qualified Data.Vector as Vector import Data.Functor.Compose import Data.Coerce import Data.Vector.Mutable (MVector) import qualified Data.Vector.Mutable as MVector import Control.Monad.ST import Data.Functor.Identity import Data.Parameterized.NatRepr import Data.Proxy import Prelude hiding (length,reverse,zipWith) import Numeric.Natural import Data.Parameterized.Utils.Endian -- | Fixed-size non-empty vectors. data Vector n a where Vector :: (1 <= n) => !(Vector.Vector a) -> Vector n a type role Vector nominal representational instance Eq a => Eq (Vector n a) where (Vector x) == (Vector y) = x == y instance Show a => Show (Vector n a) where show (Vector x) = show x -- | Get the elements of the vector as a list, lowest index first. toList :: Vector n a -> [a] toList (Vector v) = Vector.toList v {-# Inline toList #-} -- | Length of the vector. -- @O(1)@ length :: Vector n a -> NatRepr n length (Vector xs) = activateNatReprCoercionBackdoor_IPromiseIKnowWhatIAmDoing $ \mk -> mk (fromIntegral (Vector.length xs) :: Natural) {-# INLINE length #-} -- | The length of the vector as an "Int". lengthInt :: Vector n a -> Int lengthInt (Vector xs) = Vector.length xs {-# Inline lengthInt #-} elemAt :: ((i+1) <= n) => NatRepr i -> Vector n a -> a elemAt n (Vector xs) = xs Vector.! widthVal n -- | Get the element at the given index. -- @O(1)@ elemAtMaybe :: Int -> Vector n a -> Maybe a elemAtMaybe n (Vector xs) = xs Vector.!? n {-# INLINE elemAt #-} -- | Get the element at the given index. -- Raises an exception if the element is not in the vector's domain. -- @O(1)@ elemAtUnsafe :: Int -> Vector n a -> a elemAtUnsafe n (Vector xs) = xs Vector.! n {-# INLINE elemAtUnsafe #-} -- | Insert an element at the given index. -- @O(n)@. insertAt :: ((i + 1) <= n) => NatRepr i -> a -> Vector n a -> Vector n a insertAt n a (Vector xs) = Vector (Vector.unsafeUpd xs [(widthVal n,a)]) -- | Insert an element at the given index. -- Return 'Nothing' if the element is outside the vector bounds. -- @O(n)@. insertAtMaybe :: Int -> a -> Vector n a -> Maybe (Vector n a) insertAtMaybe n a (Vector xs) | 0 <= n && n < Vector.length xs = Just (Vector (Vector.unsafeUpd xs [(n,a)])) | otherwise = Nothing -- | Proof that the length of this vector is not 0. nonEmpty :: Vector n a -> LeqProof 1 n nonEmpty (Vector _) = LeqProof {-# Inline nonEmpty #-} -- | Remove the first element of the vector, and return the rest, if any. uncons :: forall n a. Vector n a -> (a, Either (n :~: 1) (Vector (n-1) a)) uncons v@(Vector xs) = (Vector.head xs, mbTail) where mbTail :: Either (n :~: 1) (Vector (n - 1) a) mbTail = case testStrictLeq (knownNat @1) (length v) of Left n2_leq_n -> do LeqProof <- return (leqSub2 n2_leq_n (leqRefl (knownNat @1))) return (Vector (Vector.tail xs)) Right Refl -> Left Refl {-# Inline uncons #-} -------------------------------------------------------------------------------- -- | Make a vector of the given length and element type. -- Returns "Nothing" if the input list does not have the right number of -- elements. -- @O(n)@. fromList :: (1 <= n) => NatRepr n -> [a] -> Maybe (Vector n a) fromList n xs | widthVal n == Vector.length v = Just (Vector v) | otherwise = Nothing where v = Vector.fromList xs {-# INLINE fromList #-} -- | Extract a subvector of the given vector. slice :: (i + w <= n, 1 <= w) => NatRepr i {- ^ Start index -} -> NatRepr w {- ^ Width of sub-vector -} -> Vector n a -> Vector w a slice i w (Vector xs) = Vector (Vector.slice (widthVal i) (widthVal w) xs) {-# INLINE slice #-} -- | Take the front (lower-indexes) part of the vector. take :: forall n x a. (1 <= n) => NatRepr n -> Vector (n + x) a -> Vector n a take | LeqProof <- prf = slice (knownNat @0) where prf = leqAdd (leqRefl (Proxy @n)) (Proxy @x) -------------------------------------------------------------------------------- instance Functor (Vector n) where fmap f (Vector xs) = Vector (Vector.map f xs) {-# Inline fmap #-} instance Foldable (Vector n) where foldMap f (Vector xs) = foldMap f xs instance Traversable (Vector n) where traverse f (Vector xs) = Vector <$> traverse f xs {-# Inline traverse #-} -- | Zip two vectors, potentially changing types. -- @O(n)@ zipWith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c zipWith f (Vector xs) (Vector ys) = Vector (Vector.zipWith f xs ys) {-# Inline zipWith #-} zipWithM :: Monad m => (a -> b -> m c) -> Vector n a -> Vector n b -> m (Vector n c) zipWithM f (Vector xs) (Vector ys) = Vector <$> Vector.zipWithM f xs ys {-# Inline zipWithM #-} zipWithM_ :: Monad m => (a -> b -> m ()) -> Vector n a -> Vector n b -> m () zipWithM_ f (Vector xs) (Vector ys) = Vector.zipWithM_ f xs ys {-# Inline zipWithM_ #-} {- | Interleave two vectors. The elements of the first vector are at even indexes in the result, the elements of the second are at odd indexes. -} interleave :: forall n a. (1 <= n) => Vector n a -> Vector n a -> Vector (2 * n) a interleave (Vector xs) (Vector ys) | LeqProof <- leqMulPos (Proxy @2) (Proxy @n) = Vector zs where len = Vector.length xs + Vector.length ys zs = Vector.generate len (\i -> let v = if even i then xs else ys in v Vector.! (i `div` 2)) -------------------------------------------------------------------------------- {- | Move the elements around, as specified by the given function. * Note: the reindexing function says where each of the elements in the new vector come from. * Note: it is OK for the same input element to end up in mulitple places in the result. @O(n)@ -} shuffle :: (Int -> Int) -> Vector n a -> Vector n a shuffle f (Vector xs) = Vector ys where ys = Vector.generate (Vector.length xs) (\i -> xs Vector.! f i) {-# Inline shuffle #-} -- | Reverse the vector. reverse :: forall a n. (1 <= n) => Vector n a -> Vector n a reverse x = shuffle (\i -> lengthInt x - i - 1) x -- | Rotate "left". The first element of the vector is on the "left", so -- rotate left moves all elemnts toward the corresponding smaller index. -- Elements that fall off the beginning end up at the end. rotateL :: Int -> Vector n a -> Vector n a rotateL !n xs = shuffle rotL xs where !len = lengthInt xs rotL i = (i + n) `mod` len -- `len` is known to be >= 1 {-# Inline rotateL #-} -- | Rotate "right". The first element of the vector is on the "left", so -- rotate right moves all elemnts toward the corresponding larger index. -- Elements that fall off the end, end up at the beginning. rotateR :: Int -> Vector n a -> Vector n a rotateR !n xs = shuffle rotR xs where !len = lengthInt xs rotR i = (i - n) `mod` len -- `len` is known to be >= 1 {-# Inline rotateR #-} {- | Move all elements towards smaller indexes. Elements that fall off the front are ignored. Empty slots are filled in with the given element. @O(n)@. -} shiftL :: Int -> a -> Vector n a -> Vector n a shiftL !x a (Vector xs) = Vector ys where !len = Vector.length xs ys = Vector.generate len (\i -> let j = i + x in if j >= len then a else xs Vector.! j) {-# Inline shiftL #-} {- | Move all elements towards the larger indexes. Elements that "fall" off the end are ignored. Empty slots are filled in with the given element. @O(n)@. -} shiftR :: Int -> a -> Vector n a -> Vector n a shiftR !x a (Vector xs) = Vector ys where !len = Vector.length xs ys = Vector.generate len (\i -> let j = i - x in if j < 0 then a else xs Vector.! j) {-# Inline shiftR #-} -------------------------------------------------------------------------------i -- | Append two vectors. The first one is at lower indexes in the result. append :: Vector m a -> Vector n a -> Vector (m + n) a append v1@(Vector xs) v2@(Vector ys) = case leqAddPos (length v1) (length v2) of { LeqProof -> Vector (xs Vector.++ ys) } {-# Inline append #-} -------------------------------------------------------------------------------- -- Constructing Vectors -- | Vector with exactly one element singleton :: forall a. a -> Vector 1 a singleton a = Vector (Vector.singleton a) leqLen :: forall n a. Vector n a -> LeqProof 1 (n + 1) leqLen v = let leqSucc :: forall f z. f z -> LeqProof z (z + 1) leqSucc fz = leqAdd (leqRefl fz :: LeqProof z z) (knownNat @1) in leqTrans (nonEmpty v :: LeqProof 1 n) (leqSucc (length v)) -- | Add an element to the head of a vector cons :: forall n a. a -> Vector n a -> Vector (n+1) a cons a v@(Vector x) = case leqLen v of LeqProof -> (Vector (Vector.cons a x)) -- | Add an element to the tail of a vector snoc :: forall n a. Vector n a -> a -> Vector (n+1) a snoc v@(Vector x) a = case leqLen v of LeqProof -> (Vector (Vector.snoc x a)) -- | This newtype wraps Vector so that we can curry it in the call to -- @natRecBounded@. It adds 1 to the length so that the base case is -- a @Vector@ of non-zero length. newtype Vector' a n = MkVector' (Vector (n+1) a) unVector' :: Vector' a n -> Vector (n+1) a unVector' (MkVector' v) = v snoc' :: forall a m. Vector' a m -> a -> Vector' a (m+1) snoc' v = MkVector' . snoc (unVector' v) generate' :: forall h a . NatRepr h -> (forall n. (n <= h) => NatRepr n -> a) -> Vector' a h generate' h gen = case isZeroOrGT1 h of Left Refl -> base Right LeqProof -> case (minusPlusCancel h (knownNat @1) :: h - 1 + 1 :~: h) of { Refl -> natRecBounded (decNat h) (decNat h) base step } where base :: Vector' a 0 base = MkVector' $ singleton (gen (knownNat @0)) step :: forall m. (1 <= h, m <= h - 1) => NatRepr m -> Vector' a m -> Vector' a (m + 1) step m v = case minusPlusCancel h (knownNat @1) :: h - 1 + 1 :~: h of { Refl -> case (leqAdd2 (LeqProof :: LeqProof m (h-1)) (LeqProof :: LeqProof 1 1) :: LeqProof (m+1) h) of { LeqProof -> snoc' v (gen (incNat m)) }} -- | Apply a function to each element in a range starting at zero; -- return the a vector of values obtained. -- cf. both @natFromZero@ and @Data.Vector.generate@ generate :: forall h a . NatRepr h -> (forall n. (n <= h) => NatRepr n -> a) -> Vector (h + 1) a generate h gen = unVector' (generate' h gen) -- | Since @Vector@ is traversable, we can pretty trivially sequence -- @natFromZeroVec@ inside a monad. generateM :: forall m h a. (Monad m) => NatRepr h -> (forall n. (n <= h) => NatRepr n -> m a) -> m (Vector (h + 1) a) generateM h gen = sequence $ generate h gen -------------------------------------------------------------------------------- coerceVec :: Coercible a b => Vector n a -> Vector n b coerceVec = coerce -- | Monadically join a vector of values, using the given function. -- This functionality can sometimes be reproduced by creating a newtype -- wrapper and using @joinWith@, this implementation is provided for -- convenience. joinWithM :: forall m f n w. (1 <= w, Monad m) => (forall l. (1 <= l) => NatRepr l -> f w -> f l -> m (f (w + l))) {- ^ A function for appending contained elements. Earlier vector indexes are the first argument of the join function. Pass a different function to implmenet little/big endian behaviors -} -> NatRepr w -> Vector n (f w) -> m (f (n * w)) joinWithM jn w = fmap fst . go where go :: forall l. Vector l (f w) -> m (f (l * w), NatRepr (l * w)) go exprs = case uncons exprs of (a, Left Refl) -> return (a, w) (a, Right rest) -> case nonEmpty rest of { LeqProof -> case leqMulPos (length rest) w of { LeqProof -> case nonEmpty exprs of { LeqProof -> case lemmaMul w (length exprs) of { Refl -> do -- @siddharthist: This could probably be written applicatively? (res, sz) <- go rest joined <- jn sz a res return (joined, addNat w sz) }}}} -- | Join a vector of values, using the given function. joinWith :: forall f n w. (1 <= w) => (forall l. (1 <= l) => NatRepr l -> f w -> f l -> f (w + l)) {- ^ A function for appending contained elements. Earlier vector indexes are the first argument of the join function. Pass a different function to implmenet little/big endian behaviors -} -> NatRepr w -> Vector n (f w) -> f (n * w) joinWith jn w v = runIdentity $ joinWithM (\n x -> pure . (jn n x)) w v {-# Inline joinWith #-} -- | Split a bit-vector into a vector of bit-vectors. -- If "LittleEndian", then less significant bits go into smaller indexes. -- If "BigEndian", then less significant bits go into larger indexes. splitWith :: forall f w n. (1 <= w, 1 <= n) => Endian -> (forall i. (i + w <= n * w) => NatRepr (n * w) -> NatRepr i -> f (n * w) -> f w) {- ^ A function for slicing out a chunk of length @w@, starting at @i@ -} -> NatRepr n -> NatRepr w -> f (n * w) -> Vector n (f w) splitWith endian select n w val = Vector (Vector.create initializer) where len = widthVal n start :: Int next :: Int -> Int (start,next) = case endian of LittleEndian -> (0, succ) BigEndian -> (len - 1, pred) initializer :: forall s. ST s (MVector s (f w)) initializer = do LeqProof <- return (leqMulPos n w) LeqProof <- return (leqMulMono n w) v <- MVector.new len let fill :: Int -> NatRepr i -> ST s () fill loc i = let end = addNat i w in case testLeq end inLen of Just LeqProof -> do MVector.write v loc (select inLen i val) fill (next loc) end Nothing -> return () fill start (knownNat @0) return v inLen :: NatRepr (n * w) inLen = natMultiply n w {-# Inline splitWith #-} -- We can sneakily put our functor in the parameter "f" of @splitWith@ using the -- @Compose@ newtype. -- | An applicative version of @splitWith@. splitWithA :: forall f g w n. (Applicative f, 1 <= w, 1 <= n) => Endian -> (forall i. (i + w <= n * w) => NatRepr (n * w) -> NatRepr i -> g (n * w) -> f (g w)) {- ^ f function for slicing out f chunk of length @w@, starting at @i@ -} -> NatRepr n -> NatRepr w -> g (n * w) -> f (Vector n (g w)) splitWithA e select n w val = traverse getCompose $ splitWith @(Compose f g) e select' n w $ Compose (pure val) where -- Wrap everything in Compose select' :: (forall i. (i + w <= n * w) => NatRepr (n * w) -> NatRepr i -> Compose f g (n * w) -> Compose f g w) -- Whatever we pass in as "val" is what's passed to select anyway, -- so there's no need to examine the argument. Just use "val" directly here. select' nw i _ = Compose $ select nw i val newtype Vec a n = Vec (Vector n a) vSlice :: (i + w <= l, 1 <= w) => NatRepr w -> NatRepr l -> NatRepr i -> Vec a l -> Vec a w vSlice w _ i (Vec xs) = Vec (slice i w xs) {-# Inline vSlice #-} -- | Append the two bit vectors. The first argument is -- at the lower indexes of the resulting vector. vAppend :: NatRepr n -> Vec a m -> Vec a n -> Vec a (m + n) vAppend _ (Vec xs) (Vec ys) = Vec (append xs ys) {-# Inline vAppend #-} -- | Split a vector into a vector of vectors. split :: (1 <= w, 1 <= n) => NatRepr n -> NatRepr w -> Vector (n * w) a -> Vector n (Vector w a) split n w xs = coerceVec (splitWith LittleEndian (vSlice w) n w (Vec xs)) {-# Inline split #-} -- | Join a vector of vectors into a single vector. join :: (1 <= w) => NatRepr w -> Vector n (Vector w a) -> Vector (n * w) a join w xs = ys where Vec ys = joinWith vAppend w (coerceVec xs) {-# Inline join #-}