{-# Language GADTs, DataKinds, TypeOperators, BangPatterns #-} {-# Language PatternGuards #-} {-# Language TypeApplications, ScopedTypeVariables #-} {-# Language Rank2Types, RoleAnnotations #-} {-# Language CPP #-} #if __GLASGOW_HASKELL__ >= 805 {-# Language NoStarIsType #-} #endif {-| Copyright : (c) Galois, Inc 2014-2019 A fixed-size vector of typed elements. NB: This module contains an orphan instance. It will be included in GHC 8.10, see https://gitlab.haskell.org/ghc/ghc/merge_requests/273. -} module Data.Parameterized.Vector ( Vector -- * Lists , fromList , toList -- * Length , length , nonEmpty , lengthInt -- * Indexing , elemAt , elemAtMaybe , elemAtUnsafe -- * Update , insertAt , insertAtMaybe -- * Sub sequences , uncons , slice , Data.Parameterized.Vector.take , replace , mapAt , mapAtM -- * Zipping , zipWith , zipWithM , zipWithM_ , interleave -- * Reorder , shuffle , reverse , rotateL , rotateR , shiftL , shiftR -- * Construction , singleton , cons , snoc , generate , generateM -- * Splitting and joining -- ** General , joinWithM , joinWith , splitWith , splitWithA -- ** Vectors , split , join , append ) where import qualified Data.Vector as Vector import Data.Functor.Compose import Data.Coerce import Data.Vector.Mutable (MVector) import qualified Data.Vector.Mutable as MVector import Control.Monad.ST import Data.Functor.Identity import Data.Parameterized.NatRepr import Data.Parameterized.NatRepr.Internal import Data.Proxy import Prelude hiding (length,reverse,zipWith) import Numeric.Natural import Data.Parameterized.Utils.Endian -- | Fixed-size non-empty vectors. data Vector n a where Vector :: (1 <= n) => !(Vector.Vector a) -> Vector n a type role Vector nominal representational instance Eq a => Eq (Vector n a) where (Vector x) == (Vector y) = x == y instance Show a => Show (Vector n a) where show (Vector x) = show x -- | Get the elements of the vector as a list, lowest index first. toList :: Vector n a -> [a] toList (Vector v) = Vector.toList v {-# Inline toList #-} -- NOTE: We are using the raw 'NatRepr' constructor here, which is unsafe. -- | Length of the vector. -- @O(1)@ length :: Vector n a -> NatRepr n length (Vector xs) = NatRepr (fromIntegral (Vector.length xs) :: Natural) {-# INLINE length #-} -- | The length of the vector as an "Int". lengthInt :: Vector n a -> Int lengthInt (Vector xs) = Vector.length xs {-# Inline lengthInt #-} elemAt :: ((i+1) <= n) => NatRepr i -> Vector n a -> a elemAt n (Vector xs) = xs Vector.! widthVal n -- | Get the element at the given index. -- @O(1)@ elemAtMaybe :: Int -> Vector n a -> Maybe a elemAtMaybe n (Vector xs) = xs Vector.!? n {-# INLINE elemAt #-} -- | Get the element at the given index. -- Raises an exception if the element is not in the vector's domain. -- @O(1)@ elemAtUnsafe :: Int -> Vector n a -> a elemAtUnsafe n (Vector xs) = xs Vector.! n {-# INLINE elemAtUnsafe #-} -- | Insert an element at the given index. -- @O(n)@. insertAt :: ((i + 1) <= n) => NatRepr i -> a -> Vector n a -> Vector n a insertAt n a (Vector xs) = Vector (Vector.unsafeUpd xs [(widthVal n,a)]) -- | Insert an element at the given index. -- Return 'Nothing' if the element is outside the vector bounds. -- @O(n)@. insertAtMaybe :: Int -> a -> Vector n a -> Maybe (Vector n a) insertAtMaybe n a (Vector xs) | 0 <= n && n < Vector.length xs = Just (Vector (Vector.unsafeUpd xs [(n,a)])) | otherwise = Nothing -- | Proof that the length of this vector is not 0. nonEmpty :: Vector n a -> LeqProof 1 n nonEmpty (Vector _) = LeqProof {-# Inline nonEmpty #-} -- | Remove the first element of the vector, and return the rest, if any. uncons :: forall n a. Vector n a -> (a, Either (n :~: 1) (Vector (n-1) a)) uncons v@(Vector xs) = (Vector.head xs, mbTail) where mbTail :: Either (n :~: 1) (Vector (n - 1) a) mbTail = case testStrictLeq (knownNat @1) (length v) of Left n2_leq_n -> do LeqProof <- return (leqSub2 n2_leq_n (leqRefl (knownNat @1))) return (Vector (Vector.tail xs)) Right Refl -> Left Refl {-# Inline uncons #-} -------------------------------------------------------------------------------- -- | Make a vector of the given length and element type. -- Returns "Nothing" if the input list does not have the right number of -- elements. -- @O(n)@. fromList :: (1 <= n) => NatRepr n -> [a] -> Maybe (Vector n a) fromList n xs | widthVal n == Vector.length v = Just (Vector v) | otherwise = Nothing where v = Vector.fromList xs {-# INLINE fromList #-} -- | Extract a subvector of the given vector. slice :: (i + w <= n, 1 <= w) => NatRepr i {- ^ Start index -} -> NatRepr w {- ^ Width of sub-vector -} -> Vector n a -> Vector w a slice i w (Vector xs) = Vector (Vector.slice (widthVal i) (widthVal w) xs) {-# INLINE slice #-} -- | Take the front (lower-indexes) part of the vector. take :: forall n x a. (1 <= n) => NatRepr n -> Vector (n + x) a -> Vector n a take | LeqProof <- prf = slice (knownNat @0) where prf = leqAdd (leqRefl (Proxy @n)) (Proxy @x) -- | Scope a monadic function to a sub-section of the given vector. mapAtM :: Monad m => (i + w <= n, 1 <= w) => NatRepr i {- ^ Start index -} -> NatRepr w {- ^ Section width -} -> (Vector w a -> m (Vector w a)) {-^ map for the sub-vector -} -> Vector n a -> m (Vector n a) mapAtM i w f (Vector vn) = let (vhead, vtail) = Vector.splitAt (widthVal i) vn (vsect, vend) = Vector.splitAt (widthVal w) vtail in do Vector vsect' <- f (Vector vsect) return $ Vector $ vhead Vector.++ vsect' Vector.++ vend -- | Scope a function to a sub-section of the given vector. mapAt :: (i + w <= n, 1 <= w) => NatRepr i {- ^ Start index -} -> NatRepr w {- ^ Section width -} -> (Vector w a -> Vector w a) {-^ map for the sub-vector -} -> Vector n a -> Vector n a mapAt i w f vn = runIdentity $ mapAtM i w (pure . f) vn -- | Replace a sub-section of a vector with the given sub-vector. replace :: (i + w <= n, 1 <= w) => NatRepr i {- ^ Start index -} -> Vector w a {- ^ sub-vector -} -> Vector n a -> Vector n a replace i vw vn = mapAt i (length vw) (const vw) vn -------------------------------------------------------------------------------- instance Functor (Vector n) where fmap f (Vector xs) = Vector (Vector.map f xs) {-# Inline fmap #-} instance Foldable (Vector n) where foldMap f (Vector xs) = foldMap f xs instance Traversable (Vector n) where traverse f (Vector xs) = Vector <$> traverse f xs {-# Inline traverse #-} -- | Zip two vectors, potentially changing types. -- @O(n)@ zipWith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c zipWith f (Vector xs) (Vector ys) = Vector (Vector.zipWith f xs ys) {-# Inline zipWith #-} zipWithM :: Monad m => (a -> b -> m c) -> Vector n a -> Vector n b -> m (Vector n c) zipWithM f (Vector xs) (Vector ys) = Vector <$> Vector.zipWithM f xs ys {-# Inline zipWithM #-} zipWithM_ :: Monad m => (a -> b -> m ()) -> Vector n a -> Vector n b -> m () zipWithM_ f (Vector xs) (Vector ys) = Vector.zipWithM_ f xs ys {-# Inline zipWithM_ #-} {- | Interleave two vectors. The elements of the first vector are at even indexes in the result, the elements of the second are at odd indexes. -} interleave :: forall n a. (1 <= n) => Vector n a -> Vector n a -> Vector (2 * n) a interleave (Vector xs) (Vector ys) | LeqProof <- leqMulPos (Proxy @2) (Proxy @n) = Vector zs where len = Vector.length xs + Vector.length ys zs = Vector.generate len (\i -> let v = if even i then xs else ys in v Vector.! (i `div` 2)) -------------------------------------------------------------------------------- {- | Move the elements around, as specified by the given function. * Note: the reindexing function says where each of the elements in the new vector come from. * Note: it is OK for the same input element to end up in mulitple places in the result. @O(n)@ -} shuffle :: (Int -> Int) -> Vector n a -> Vector n a shuffle f (Vector xs) = Vector ys where ys = Vector.generate (Vector.length xs) (\i -> xs Vector.! f i) {-# Inline shuffle #-} -- | Reverse the vector. reverse :: forall a n. (1 <= n) => Vector n a -> Vector n a reverse x = shuffle (\i -> lengthInt x - i - 1) x -- | Rotate "left". The first element of the vector is on the "left", so -- rotate left moves all elemnts toward the corresponding smaller index. -- Elements that fall off the beginning end up at the end. rotateL :: Int -> Vector n a -> Vector n a rotateL !n xs = shuffle rotL xs where !len = lengthInt xs rotL i = (i + n) `mod` len -- `len` is known to be >= 1 {-# Inline rotateL #-} -- | Rotate "right". The first element of the vector is on the "left", so -- rotate right moves all elemnts toward the corresponding larger index. -- Elements that fall off the end, end up at the beginning. rotateR :: Int -> Vector n a -> Vector n a rotateR !n xs = shuffle rotR xs where !len = lengthInt xs rotR i = (i - n) `mod` len -- `len` is known to be >= 1 {-# Inline rotateR #-} {- | Move all elements towards smaller indexes. Elements that fall off the front are ignored. Empty slots are filled in with the given element. @O(n)@. -} shiftL :: Int -> a -> Vector n a -> Vector n a shiftL !x a (Vector xs) = Vector ys where !len = Vector.length xs ys = Vector.generate len (\i -> let j = i + x in if j >= len then a else xs Vector.! j) {-# Inline shiftL #-} {- | Move all elements towards the larger indexes. Elements that "fall" off the end are ignored. Empty slots are filled in with the given element. @O(n)@. -} shiftR :: Int -> a -> Vector n a -> Vector n a shiftR !x a (Vector xs) = Vector ys where !len = Vector.length xs ys = Vector.generate len (\i -> let j = i - x in if j < 0 then a else xs Vector.! j) {-# Inline shiftR #-} -------------------------------------------------------------------------------i -- | Append two vectors. The first one is at lower indexes in the result. append :: Vector m a -> Vector n a -> Vector (m + n) a append v1@(Vector xs) v2@(Vector ys) = case leqAddPos (length v1) (length v2) of { LeqProof -> Vector (xs Vector.++ ys) } {-# Inline append #-} -------------------------------------------------------------------------------- -- Constructing Vectors -- | Vector with exactly one element singleton :: forall a. a -> Vector 1 a singleton a = Vector (Vector.singleton a) leqLen :: forall n a. Vector n a -> LeqProof 1 (n + 1) leqLen v = let leqSucc :: forall f z. f z -> LeqProof z (z + 1) leqSucc fz = leqAdd (leqRefl fz :: LeqProof z z) (knownNat @1) in leqTrans (nonEmpty v :: LeqProof 1 n) (leqSucc (length v)) -- | Add an element to the head of a vector cons :: forall n a. a -> Vector n a -> Vector (n+1) a cons a v@(Vector x) = case leqLen v of LeqProof -> (Vector (Vector.cons a x)) -- | Add an element to the tail of a vector snoc :: forall n a. Vector n a -> a -> Vector (n+1) a snoc v@(Vector x) a = case leqLen v of LeqProof -> (Vector (Vector.snoc x a)) -- | This newtype wraps Vector so that we can curry it in the call to -- @natRecBounded@. It adds 1 to the length so that the base case is -- a @Vector@ of non-zero length. newtype Vector' a n = MkVector' (Vector (n+1) a) unVector' :: Vector' a n -> Vector (n+1) a unVector' (MkVector' v) = v snoc' :: forall a m. Vector' a m -> a -> Vector' a (m+1) snoc' v = MkVector' . snoc (unVector' v) generate' :: forall h a . NatRepr h -> (forall n. (n <= h) => NatRepr n -> a) -> Vector' a h generate' h gen = case isZeroOrGT1 h of Left Refl -> base Right LeqProof -> case (minusPlusCancel h (knownNat @1) :: h - 1 + 1 :~: h) of { Refl -> natRecBounded (decNat h) (decNat h) base step } where base :: Vector' a 0 base = MkVector' $ singleton (gen (knownNat @0)) step :: forall m. (1 <= h, m <= h - 1) => NatRepr m -> Vector' a m -> Vector' a (m + 1) step m v = case minusPlusCancel h (knownNat @1) :: h - 1 + 1 :~: h of { Refl -> case (leqAdd2 (LeqProof :: LeqProof m (h-1)) (LeqProof :: LeqProof 1 1) :: LeqProof (m+1) h) of { LeqProof -> snoc' v (gen (incNat m)) }} -- | Apply a function to each element in a range starting at zero; -- return the a vector of values obtained. -- cf. both @natFromZero@ and @Data.Vector.generate@ generate :: forall h a . NatRepr h -> (forall n. (n <= h) => NatRepr n -> a) -> Vector (h + 1) a generate h gen = unVector' (generate' h gen) -- | Since @Vector@ is traversable, we can pretty trivially sequence -- @natFromZeroVec@ inside a monad. generateM :: forall m h a. (Monad m) => NatRepr h -> (forall n. (n <= h) => NatRepr n -> m a) -> m (Vector (h + 1) a) generateM h gen = sequence $ generate h gen -------------------------------------------------------------------------------- coerceVec :: Coercible a b => Vector n a -> Vector n b coerceVec = coerce -- | Monadically join a vector of values, using the given function. -- This functionality can sometimes be reproduced by creating a newtype -- wrapper and using @joinWith@, this implementation is provided for -- convenience. joinWithM :: forall m f n w. (1 <= w, Monad m) => (forall l. (1 <= l) => NatRepr l -> f w -> f l -> m (f (w + l))) {- ^ A function for joining contained elements. The first argument is the size of the accumulated third term, and the second argument is the element to join to the accumulated term. The function can use any join strategy desired (prepending/"BigEndian", appending/"LittleEndian", etc.). -} -> NatRepr w -> Vector n (f w) -> m (f (n * w)) joinWithM jn w = fmap fst . go where go :: forall l. Vector l (f w) -> m (f (l * w), NatRepr (l * w)) go exprs = case uncons exprs of (a, Left Refl) -> return (a, w) (a, Right rest) -> case nonEmpty rest of { LeqProof -> case leqMulPos (length rest) w of { LeqProof -> case nonEmpty exprs of { LeqProof -> case lemmaMul w (length exprs) of { Refl -> do -- @siddharthist: This could probably be written applicatively? (res, sz) <- go rest joined <- jn sz a res return (joined, addNat w sz) }}}} -- | Join a vector of vectors, using the given function to combine the -- sub-vectors. joinWith :: forall f n w. (1 <= w) => (forall l. (1 <= l) => NatRepr l -> f w -> f l -> f (w + l)) {- ^ A function for joining contained elements. The first argument is the size of the accumulated third term, and the second argument is the element to join to the accumulated term. The function can use any join strategy desired (prepending/"BigEndian", appending/"LittleEndian", etc.). -} -> NatRepr w -> Vector n (f w) -> f (n * w) joinWith jn w v = runIdentity $ joinWithM (\n x -> pure . (jn n x)) w v {-# Inline joinWith #-} -- | Split a vector into a vector of vectors. -- -- The "Endian" parameter determines the ordering of the inner -- vectors. If "LittleEndian", then less significant bits go into -- smaller indexes. If "BigEndian", then less significant bits go -- into larger indexes. See the documentation for 'split' for more -- details. splitWith :: forall f w n. (1 <= w, 1 <= n) => Endian -> (forall i. (i + w <= n * w) => NatRepr (n * w) -> NatRepr i -> f (n * w) -> f w) {- ^ A function for slicing out a chunk of length @w@, starting at @i@ -} -> NatRepr n -> NatRepr w -> f (n * w) -> Vector n (f w) splitWith endian select n w val = Vector (Vector.create initializer) where len = widthVal n start :: Int next :: Int -> Int (start,next) = case endian of LittleEndian -> (0, succ) BigEndian -> (len - 1, pred) initializer :: forall s. ST s (MVector s (f w)) initializer = do LeqProof <- return (leqMulPos n w) LeqProof <- return (leqMulMono n w) v <- MVector.new len let fill :: Int -> NatRepr i -> ST s () fill loc i = let end = addNat i w in case testLeq end inLen of Just LeqProof -> do MVector.write v loc (select inLen i val) fill (next loc) end Nothing -> return () fill start (knownNat @0) return v inLen :: NatRepr (n * w) inLen = natMultiply n w {-# Inline splitWith #-} -- We can sneakily put our functor in the parameter "f" of @splitWith@ using the -- @Compose@ newtype. -- | An applicative version of @splitWith@. splitWithA :: forall f g w n. (Applicative f, 1 <= w, 1 <= n) => Endian -> (forall i. (i + w <= n * w) => NatRepr (n * w) -> NatRepr i -> g (n * w) -> f (g w)) {- ^ f function for slicing out f chunk of length @w@, starting at @i@ -} -> NatRepr n -> NatRepr w -> g (n * w) -> f (Vector n (g w)) splitWithA e select n w val = traverse getCompose $ splitWith @(Compose f g) e select' n w $ Compose (pure val) where -- Wrap everything in Compose select' :: (forall i. (i + w <= n * w) => NatRepr (n * w) -> NatRepr i -> Compose f g (n * w) -> Compose f g w) -- Whatever we pass in as "val" is what's passed to select anyway, -- so there's no need to examine the argument. Just use "val" directly here. select' nw i _ = Compose $ select nw i val newtype Vec a n = Vec (Vector n a) vSlice :: (i + w <= l, 1 <= w) => NatRepr w -> NatRepr l -> NatRepr i -> Vec a l -> Vec a w vSlice w _ i (Vec xs) = Vec (slice i w xs) {-# Inline vSlice #-} -- | Append the two bit vectors. The first argument is -- at the lower indexes of the resulting vector. vAppend :: NatRepr n -> Vec a m -> Vec a n -> Vec a (m + n) vAppend _ (Vec xs) (Vec ys) = Vec (append xs ys) {-# Inline vAppend #-} -- | Split a vector into a vector of vectors. The default ordering of -- the outer result vector is "LittleEndian". -- -- For example: -- @ -- let wordsize = knownNat :: NatRepr 3 -- vecsize = knownNat :: NatRepr 12 -- numwords = knownNat :: NatRepr 4 (12 / 3) -- Just inpvec = fromList vecsize [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ] -- in show (split numwords wordsize inpvec) == "[ [1,2,3], [4,5,6], [7,8,9], [10,11,12] ]" -- @ -- whereas a BigEndian result would have been -- @ -- [ [10,11,12], [7,8,9], [4,5,6], [1,2,3] ] -- @ split :: (1 <= w, 1 <= n) => NatRepr n -- ^ Inner vector size -> NatRepr w -- ^ Outer vector size -> Vector (n * w) a -- ^ Input vector -> Vector n (Vector w a) split n w xs = coerceVec (splitWith LittleEndian (vSlice w) n w (Vec xs)) {-# Inline split #-} -- | Join a vector of vectors into a single vector. Assumes an -- append/"LittleEndian" join strategy: the order of the inner vectors -- is preserved in the result vector. -- -- @ -- let innersize = knownNat :: NatRepr 4 -- Just inner1 = fromList innersize [ 1, 2, 3, 4 ] -- Just inner2 = fromList innersize [ 5, 6, 7, 8 ] -- Just inner3 = fromList innersize [ 9, 10, 11, 12 ] -- outersize = knownNat :: NatRepr 3 -- Just outer = fromList outersize [ inner1, inner2, inner3 ] -- in show (join innersize outer) = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ] -- @ -- a prepend/"BigEndian" join strategy would have the result: -- @ -- [ 9, 10, 11, 12, 5, 6, 7, 8, 1, 2, 3, 4 ] -- @ join :: (1 <= w) => NatRepr w -> Vector n (Vector w a) -> Vector (n * w) a join w xs = ys where Vec ys = joinWith vAppend w (coerceVec xs) {-# Inline join #-}