{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module CLaSH.Sized.Vector
(
Vec(..), (<:), singleton
, head, tail, last, init
, take, takeI, drop, dropI, at, select, selectI
, (++), (+>>), (<<+), concat, zip, unzip, zip3, unzip3, shiftInAt0, shiftInAtN
, shiftOutFrom0, shiftOutFromN
, splitAt, splitAtI, unconcat, unconcatI, merge
, map, zipWith, zipWith3
, foldr, foldl, foldr1, foldl1, fold
, scanl, scanr, sscanl, sscanr
, mapAccumL, mapAccumR
, dfold, vfold
, (!!), replace, maxIndex, length
, replicate, repeat, iterate, iterateI, generate, generateI
, reverse, toList, v, lazyV, asNatProxy
, concatBitVector#
, unconcatBitVector#
)
where
import Control.Lens (Index, Ixed (..), IxValue)
import Data.Default (Default (..))
import qualified Data.Foldable as F
import Data.Proxy (Proxy (..))
import Data.Singletons.Prelude (TyFun,Apply,type ($))
import GHC.TypeLits (CmpNat, KnownNat, Nat, type (+), type (*),
natVal)
import GHC.Base (Int(I#),Int#,isTrue#)
import GHC.Prim ((==#),(<#),(-#))
import Language.Haskell.TH (ExpQ)
import Language.Haskell.TH.Syntax (Lift(..))
import Prelude hiding ((++), (!!), concat, drop, foldl,
foldl1, foldr, foldr1, head, init,
iterate, last, length, map, repeat,
replicate, reverse, scanl, scanr,
splitAt, tail, take, unzip, unzip3,
zip, zip3, zipWith, zipWith3)
import qualified Prelude as P
import Test.QuickCheck (Arbitrary (..), CoArbitrary (..))
import Unsafe.Coerce (unsafeCoerce)
import CLaSH.Promoted.Nat (SNat (..), UNat (..), withSNat, toUNat)
import CLaSH.Sized.Internal.BitVector (BitVector, (++#), split#)
import CLaSH.Class.BitPack (BitPack (..))
data Vec :: Nat -> * -> * where
Nil :: Vec 0 a
(:>) :: a -> Vec n a -> Vec (n + 1) a
infixr 5 :>
instance Show a => Show (Vec n a) where
show vs = "<" P.++ punc vs P.++ ">"
where
punc :: Vec m a -> String
punc Nil = ""
punc (x :> Nil) = show x
punc (x :> xs) = show x P.++ "," P.++ punc xs
instance Eq a => Eq (Vec n a) where
(==) = eq#
(/=) = neq#
{-# NOINLINE eq# #-}
eq# :: Eq a => Vec n a -> Vec n a -> Bool
eq# v1 v2 = foldr (&&) True (zipWith (==) v1 v2)
{-# NOINLINE neq# #-}
neq# :: Eq a => Vec n a -> Vec n a -> Bool
neq# v1 v2 = not (eq# v1 v2)
instance Ord a => Ord (Vec n a) where
compare x y = foldr f EQ $ zipWith compare x y
where f EQ keepGoing = keepGoing
f done _ = done
instance KnownNat n => Applicative (Vec n) where
pure = repeat
fs <*> xs = zipWith ($) fs xs
instance F.Foldable (Vec n) where
foldr = foldr
instance Functor (Vec n) where
fmap = map
instance Traversable (Vec n) where
traverse = traverse#
{-# NOINLINE traverse# #-}
traverse# :: Applicative f => (a -> f b) -> Vec n a -> f (Vec n b)
traverse# _ Nil = pure Nil
traverse# f (x :> xs) = (:>) <$> f x <*> traverse# f xs
instance (Default a, KnownNat n) => Default (Vec n a) where
def = repeat def
{-# INLINE singleton #-}
singleton :: a -> Vec 1 a
singleton = (:> Nil)
{-# NOINLINE head #-}
head :: Vec (n + 1) a -> a
head (x :> _) = x
{-# NOINLINE tail #-}
tail :: Vec (n + 1) a -> Vec n a
tail (_ :> xs) = xs
{-# NOINLINE last #-}
last :: Vec (n + 1) a -> a
last (x :> Nil) = x
last (_ :> y :> ys) = last (y :> ys)
{-# NOINLINE init #-}
init :: Vec (n + 1) a -> Vec n a
init (_ :> Nil) = Nil
init (x :> y :> ys) = x :> init (y :> ys)
{-# INLINE shiftInAt0 #-}
shiftInAt0 :: KnownNat n
=> Vec n a
-> Vec m a
-> (Vec n a, Vec m a)
shiftInAt0 xs ys = splitAtI zs
where
zs = ys ++ xs
{-# INLINE shiftInAtN #-}
shiftInAtN :: KnownNat m
=> Vec n a
-> Vec m a
-> (Vec n a,Vec m a)
shiftInAtN xs ys = (zsR, zsL)
where
zs = xs ++ ys
(zsL,zsR) = splitAtI zs
infixl 5 <:
{-# INLINE (<:) #-}
(<:) :: Vec n a -> a -> Vec (n + 1) a
xs <: x = xs ++ singleton x
infixr 4 +>>
{-# INLINE (+>>) #-}
(+>>) :: KnownNat n => a -> Vec n a -> Vec n a
s +>> xs = fst (shiftInAt0 xs (singleton s))
infixl 4 <<+
{-# INLINE (<<+) #-}
(<<+) :: Vec n a -> a -> Vec n a
xs <<+ s = fst (shiftInAtN xs (singleton s))
{-# INLINE shiftOutFrom0 #-}
shiftOutFrom0 :: (Default a, KnownNat m)
=> SNat m
-> Vec (m + n) a
-> (Vec (m + n) a, Vec m a)
shiftOutFrom0 m xs = shiftInAtN xs (replicate m def)
{-# INLINE shiftOutFromN #-}
shiftOutFromN :: (Default a, KnownNat (m + n))
=> SNat m
-> Vec (m + n) a
-> (Vec (m + n) a, Vec m a)
shiftOutFromN m xs = shiftInAt0 xs (replicate m def)
infixr 5 ++
{-# NOINLINE (++) #-}
(++) :: Vec n a -> Vec m a -> Vec (n + m) a
Nil ++ ys = ys
(x :> xs) ++ ys = x :> xs ++ ys
{-# NOINLINE splitAt #-}
splitAt :: SNat m -> Vec (m + n) a -> (Vec m a, Vec n a)
splitAt n xs = splitAtU (toUNat n) xs
splitAtU :: UNat m -> Vec (m + n) a -> (Vec m a, Vec n a)
splitAtU UZero ys = (Nil,ys)
splitAtU (USucc s) (y :> ys) = let (as,bs) = splitAtU s ys
in (y :> as, bs)
{-# INLINE splitAtI #-}
splitAtI :: KnownNat m => Vec (m + n) a -> (Vec m a, Vec n a)
splitAtI = withSNat splitAt
{-# NOINLINE concat #-}
concat :: Vec n (Vec m a) -> Vec (n * m) a
concat Nil = Nil
concat (x :> xs) = x ++ concat xs
{-# NOINLINE unconcat #-}
unconcat :: KnownNat n => SNat m -> Vec (n * m) a -> Vec n (Vec m a)
unconcat n xs = unconcatU (withSNat toUNat) (toUNat n) xs
unconcatU :: UNat n -> UNat m -> Vec (n * m) a -> Vec n (Vec m a)
unconcatU UZero _ _ = Nil
unconcatU (USucc n') m ys = let (as,bs) = splitAtU m ys
in as :> unconcatU n' m bs
{-# INLINE unconcatI #-}
unconcatI :: (KnownNat n, KnownNat m) => Vec (n * m) a -> Vec n (Vec m a)
unconcatI = withSNat unconcat
{-# NOINLINE merge #-}
merge :: Vec n a -> Vec n a -> Vec (n + n) a
merge Nil Nil = Nil
merge (x :> xs) (y :> ys) = x :> y :> merge xs ys
{-# NOINLINE reverse #-}
reverse :: Vec n a -> Vec n a
reverse Nil = Nil
reverse (x :> xs) = reverse xs <: x
{-# NOINLINE map #-}
map :: (a -> b) -> Vec n a -> Vec n b
map _ Nil = Nil
map f (x :> xs) = f x :> map f xs
{-# NOINLINE zipWith #-}
zipWith :: (a -> b -> c) -> Vec n a -> Vec n b -> Vec n c
zipWith _ Nil _ = Nil
zipWith f (x :> xs) ys = f x (head ys) :> zipWith f xs (tail ys)
{-# INLINE zipWith3 #-}
zipWith3 :: (a -> b -> c -> d) -> Vec n a -> Vec n b -> Vec n c -> Vec n d
zipWith3 f us vs ws = zipWith (\a (b,c) -> f a b c) us (zip vs ws)
{-# INLINABLE foldr #-}
foldr :: (a -> b -> b) -> b -> Vec n a -> b
foldr f z xs = head (scanr f z xs)
{-# INLINABLE foldl #-}
foldl :: (b -> a -> b) -> b -> Vec n a -> b
foldl f z xs = last (scanl f z xs)
{-# INLINABLE foldr1 #-}
foldr1 :: (a -> a -> a) -> Vec (n + 1) a -> a
foldr1 f xs = foldr f (last xs) (init xs)
{-# INLINE foldl1 #-}
foldl1 :: (a -> a -> a) -> Vec (n + 1) a -> a
foldl1 f xs = foldl f (head xs) (tail xs)
{-# NOINLINE fold #-}
fold :: (a -> a -> a) -> Vec (n + 1) a -> a
fold f vs = fold' (toList vs)
where
fold' [x] = x
fold' xs = fold' ys `f` fold' zs
where
(ys,zs) = P.splitAt (P.length xs `div` 2) xs
{-# INLINE scanl #-}
scanl :: (b -> a -> b) -> b -> Vec n a -> Vec (n + 1) b
scanl f z xs = ws
where
ws = z :> zipWith (flip f) xs (init ws)
{-# INLINE sscanl #-}
sscanl :: (b -> a -> b) -> b -> Vec n a -> Vec n b
sscanl f z xs = tail (scanl f z xs)
{-# INLINE scanr #-}
scanr :: (a -> b -> b) -> b -> Vec n a -> Vec (n + 1) b
scanr f z xs = ws
where
ws = zipWith f xs ((tail ws)) <: z
{-# INLINE sscanr #-}
sscanr :: (a -> b -> b) -> b -> Vec n a -> Vec n b
sscanr f z xs = init (scanr f z xs)
{-# INLINE mapAccumL #-}
mapAccumL :: (acc -> x -> (acc,y)) -> acc -> Vec n x -> (acc,Vec n y)
mapAccumL f acc xs = (acc',ys)
where
accs = acc :> accs'
ws = zipWith (flip f) xs (init accs)
accs' = map fst ws
ys = map snd ws
acc' = last accs
{-# INLINE mapAccumR #-}
mapAccumR :: (acc -> x -> (acc,y)) -> acc -> Vec n x -> (acc, Vec n y)
mapAccumR f acc xs = (acc',ys)
where
accs = accs' <: acc
ws = zipWith (flip f) xs (tail accs)
accs' = map fst ws
ys = map snd ws
acc' = head accs
{-# INLINE zip #-}
zip :: Vec n a -> Vec n b -> Vec n (a,b)
zip = zipWith (,)
{-# INLINE zip3 #-}
zip3 :: Vec n a -> Vec n b -> Vec n c -> Vec n (a,b,c)
zip3 = zipWith3 (,,)
{-# INLINE unzip #-}
unzip :: Vec n (a,b) -> (Vec n a, Vec n b)
unzip xs = (map fst xs, map snd xs)
{-# INLINE unzip3 #-}
unzip3 :: Vec n (a,b,c) -> (Vec n a, Vec n b, Vec n c)
unzip3 xs = ( map (\(x,_,_) -> x) xs
, map (\(_,y,_) -> y) xs
, map (\(_,_,z) -> z) xs
)
{-# NOINLINE index_int #-}
index_int :: KnownNat n => Vec n a -> Int -> a
index_int xs i@(I# n0)
| isTrue# (n0 <# 0#) = error "CLaSH.Sized.Vector.(!!): negative index"
| otherwise = sub xs n0
where
sub :: Vec m a -> Int# -> a
sub Nil _ = error (P.concat [ "CLaSH.Sized.Vector.(!!): index "
, show i
, " is larger than maximum index "
, show (maxIndex xs)
])
sub (y:>(!ys)) n = if isTrue# (n ==# 0#)
then y
else sub ys (n -# 1#)
{-# INLINE (!!) #-}
(!!) :: (KnownNat n, Enum i) => Vec n a -> i -> a
xs !! i = index_int xs (fromEnum i)
{-# NOINLINE maxIndex #-}
maxIndex :: KnownNat n => Vec n a -> Integer
maxIndex = subtract 1 . length
{-# NOINLINE length #-}
length :: KnownNat n => Vec n a -> Integer
length = natVal . asNatProxy
{-# NOINLINE replace_int #-}
replace_int :: KnownNat n => Vec n a -> Int -> a -> Vec n a
replace_int xs i@(I# n0) a
| isTrue# (n0 <# 0#) = error "CLaSH.Sized.Vector.replace: negative index"
| otherwise = sub xs n0 a
where
sub :: Vec m b -> Int# -> b -> Vec m b
sub Nil _ _ = error (P.concat [ "CLaSH.Sized.Vector.replace: index "
, show i
, " is larger than maximum index "
, show (maxIndex xs)
])
sub (y:>(!ys)) n b = if isTrue# (n ==# 0#)
then b :> ys
else y :> sub ys (n -# 1#) b
{-# INLINE replace #-}
replace :: (KnownNat n, Enum i) => i -> a -> Vec n a -> Vec n a
replace i y xs = replace_int xs (fromEnum i) y
{-# INLINABLE take #-}
take :: SNat m -> Vec (m + n) a -> Vec m a
take n = fst . splitAt n
{-# INLINE takeI #-}
takeI :: KnownNat m => Vec (m + n) a -> Vec m a
takeI = withSNat take
{-# INLINE drop #-}
drop :: SNat m -> Vec (m + n) a -> Vec n a
drop n = snd . splitAt n
{-# INLINE dropI #-}
dropI :: KnownNat m => Vec (m + n) a -> Vec n a
dropI = withSNat drop
{-# INLINE at #-}
at :: SNat m -> Vec (m + (n + 1)) a -> a
at n xs = head $ snd $ splitAt n xs
{-# NOINLINE select #-}
select :: (CmpNat (i + s) (s * n) ~ 'GT)
=> SNat f
-> SNat s
-> SNat n
-> Vec (f + i) a
-> Vec n a
select f s n xs = select' (toUNat n) $ drop f xs
where
select' :: UNat n -> Vec i a -> Vec n a
select' UZero _ = Nil
select' (USucc n') vs@(x :> _) = x :> select' n' (drop s (unsafeCoerce vs))
{-# INLINE selectI #-}
-- | 'selectI' @f s xs@ selects as many elements as demanded by the context
-- with stepsize @s@ and offset @f@ from @xs@
--
-- >>> selectI d1 d2 (1:>2:>3:>4:>5:>6:>7:>8:>Nil) :: Vec 2 Int
-- <2,4>
selectI :: (CmpNat (i + s) (s * n) ~ 'GT, KnownNat n)
=> SNat f
-> SNat s
-> Vec (f + i) a
-> Vec n a
selectI f s xs = withSNat (\n -> select f s n xs)
{-# NOINLINE replicate #-}
-- | 'replicate' @n a@ returns a vector that has @n@ copies of @a@
--
-- >>> replicate (snat :: SNat 3) 6
-- <6,6,6>
-- >>> replicate d3 6
-- <6,6,6>
replicate :: SNat n -> a -> Vec n a
replicate n a = replicateU (toUNat n) a
replicateU :: UNat n -> a -> Vec n a
replicateU UZero _ = Nil
replicateU (USucc s) x = x :> replicateU s x
{-# INLINE repeat #-}
-- | 'repeat' @a@ creates a vector with as many copies of @a@ as demanded by the
-- context
--
-- >>> repeat 6 :: Vec 5 Int
-- <6,6,6,6,6>
repeat :: KnownNat n => a -> Vec n a
repeat = withSNat replicate
{-# INLINE iterate #-}
-- | 'iterate' @n f x@ returns a vector starting with @x@ followed by @n@
-- repeated applications of @f@ to @x@
--
-- > iterate (snat :: SNat 4) f x == (x :> f x :> f (f x) :> f (f (f x)) :> Nil)
-- > iterate d4 f x == (x :> f x :> f (f x) :> f (f (f x)) :> Nil)
--
-- >>> iterate d4 (+1) 1
-- <1,2,3,4>
--
-- @interate n f z@ corresponds to the following circuit layout:
--
-- <<doc/iterate.svg>>
iterate :: SNat n -> (a -> a) -> a -> Vec n a
iterate (SNat _) = iterateI
{-# INLINE iterateI #-}
-- | 'iterate' @f x@ returns a vector starting with @x@ followed by @n@
-- repeated applications of @f@ to @x@, where @n@ is determined by the context
--
-- > iterateI f x :: Vec 3 a == (x :> f x :> f (f x) :> Nil)
--
-- >>> iterateI (+1) 1 :: Vec 3 Int
-- <1,2,3>
--
-- @interateI f z@ corresponds to the following circuit layout:
--
-- <<doc/iterate.svg>>
iterateI :: KnownNat n => (a -> a) -> a -> Vec n a
iterateI f a = xs
where
xs = init (a :> ws)
ws = map f (lazyV xs)
{-# INLINE generate #-}
-- | 'generate' @n f x@ returns a vector with @n@ repeated applications of @f@
-- to @x@
--
-- > generate (snat :: SNat 4) f x == (f x :> f (f x) :> f (f (f x)) :> f (f (f (f x))) :> Nil)
-- > generate d4 f x == (f x :> f (f x) :> f (f (f x)) :> f (f (f (f x))) :> Nil)
--
-- >>> generate d4 (+1) 1
-- <2,3,4,5>
--
-- @generate n f z@ corresponds to the following circuit layout:
--
-- <<doc/generate.svg>>
generate :: SNat n -> (a -> a) -> a -> Vec n a
generate (SNat _) f a = iterateI f (f a)
{-# INLINE generateI #-}
-- | 'generate' @f x@ returns a vector with @n@ repeated applications of @f@
-- to @x@, where @n@ is determined by the context
--
-- > generateI f x :: Vec 3 a == (f x :> f (f x) :> f (f (f x)) :> Nil)
--
-- >>> generateI (+1) 1 :: Vec 3 Int
-- <2,3,4>
--
-- @generateI f z@ corresponds to the following circuit layout:
--
-- <<doc/generate.svg>>
generateI :: KnownNat n => (a -> a) -> a -> Vec n a
generateI f a = iterateI f (f a)
{-# INLINE toList #-}
-- | Convert a vector to a list
--
-- >>> toList (1:>2:>3:>Nil)
-- [1,2,3]
--
-- __NB__: Not synthesisable
toList :: Vec n a -> [a]
toList = foldr (:) []
-- | Create a vector literal from a list literal
--
-- > $(v [1::Signed 8,2,3,4,5]) == (8:>2:>3:>4:>5:>Nil) :: Vec 5 (Signed 8)
--
-- >>> [1 :: Signed 8,2,3,4,5]
-- [1,2,3,4,5]
-- >>> $(v [1::Signed 8,2,3,4,5])
-- <1,2,3,4,5>
v :: Lift a => [a] -> ExpQ
v [] = [| Nil |]
v (x:xs) = [| x :> $(v xs) |]
-- | 'Vec'tor as a 'Proxy' for 'Nat'
asNatProxy :: Vec n a -> Proxy n
asNatProxy _ = Proxy
{-# NOINLINE lazyV #-}
-- | For when your vector functions are too strict in their arguments
--
-- For example:
--
-- @
-- -- Bubble sort for 1 iteration
-- sortV xs = 'map' fst sorted '<:' (snd ('last' sorted))
-- where
-- lefts = 'head' xs :> 'map' snd ('init' sorted)
-- rights = 'tail' xs
-- sorted = 'zipWith' compareSwapL lefts rights
--
-- -- Compare and swap
-- compareSwapL a b = if a < b then (a,b)
-- else (b,a)
-- @
--
-- Will not terminate because 'zipWith' is too strict in its second argument:
--
-- >>> sortV (4 :> 1 :> 2 :> 3 :> Nil)
-- <*** Exception: <<loop>>
--
-- In this case, adding 'lazyV' on 'zipWith's second argument:
--
-- @
-- sortVL xs = 'map' fst sorted '<:' (snd ('last' sorted))
-- where
-- lefts = 'head' xs :> map snd ('init' sorted)
-- rights = 'tail' xs
-- sorted = 'zipWith' compareSwapL ('lazyV' lefts) rights
-- @
--
-- Results in a successful computation:
--
-- >>> sortVL (4 :> 1 :> 2 :> 3 :> Nil)
-- <1,2,3,4>
--
-- __NB__: There is also a solution using 'flip', but it slightly obfuscates the
-- meaning of the code:
--
-- @
-- sortV_flip xs = 'map' fst sorted '<:' (snd ('last' sorted))
-- where
-- lefts = 'head' xs :> 'map' snd ('init' sorted)
-- rights = 'tail' xs
-- sorted = 'zipWith' ('flip' compareSwapL) rights lefts
-- @
--
-- >>> sortV_flip (4 :> 1 :> 2 :> 3 :> Nil)
-- <1,2,3,4>
lazyV :: KnownNat n
=> Vec n a
-> Vec n a
lazyV = lazyV' (repeat undefined)
where
lazyV' :: Vec n a -> Vec n a -> Vec n a
lazyV' Nil _ = Nil
lazyV' (_ :> xs) ys = head ys :> lazyV' xs (tail ys)
{-# NOINLINE dfold #-}
-- | A /dependently/ typed fold.
--
-- __NB__: Not synthesisable
--
-- Using lists, we can define @append@ ('Prelude.++') using 'Prelude.foldr':
--
-- >>> import qualified Prelude
-- >>> let append xs ys = Prelude.foldr (:) ys xs
-- >>> append [1,2] [3,4]
-- [1,2,3,4]
--
-- However, when we try to do the same for 'Vec':
--
-- @
-- append' xs ys = 'foldr' (:>) ys xs
-- @
--
-- We get a type error
--
-- >>> let append' xs ys = foldr (:>) ys xs
-- <BLANKLINE>
-- <interactive>:...
-- Occurs check: cannot construct the infinite type: ... ~ ... + 1
-- Expected type: a -> Vec ... a -> Vec ... a
-- Actual type: a -> Vec ... a -> Vec (... + 1) a
-- Relevant bindings include
-- ys :: Vec ... a (bound at ...)
-- append' :: Vec n a -> Vec ... a -> Vec ... a
-- (bound at ...)
-- In the first argument of ‘foldr’, namely ‘(:>)’
-- In the expression: foldr (:>) ys xs
--
-- The reason is that the type of 'foldr' is:
--
-- >>> :t foldr
-- foldr :: (a -> b -> b) -> b -> Vec n a -> b
--
-- While the type of (':>') is:
--
-- >>> :t (:>)
-- (:>) :: a -> Vec n a -> Vec (n + 1) a
--
-- We thus need a @fold@ function that can handle the growing vector type:
-- 'dfold'. Compared to 'foldr', 'dfold' takes an extra parameter, called the
-- /motive/, that allows the folded function to have an argument and result type
-- that /depends/ on the current index into the vector. Using 'dfold', we can
-- now correctly define ('++'):
--
-- @
-- import Data.Singletons.Prelude
-- import Data.Proxy
--
-- data Append (m :: Nat) (a :: *) (f :: 'TyFun' Nat *) :: *
-- type instance 'Apply' (Append m a) l = 'Vec' (l + m) a
--
-- append' xs ys = 'dfold' (Proxy :: Proxy (Append m a)) (const (':>')) ys xs
-- @
--
-- We now see that @append'@ has the appropriate type:
--
-- >>> :t append'
-- append' :: Vec k a -> Vec m a -> Vec (k + m) a
--
-- And that it works:
--
-- >>> append' (1 :> 2 :> Nil) (3 :> 4 :> Nil)
-- <1,2,3,4>
dfold :: Proxy (p :: TyFun Nat * -> *) -- ^ The /motive/
-> (forall l . Proxy l -> a -> p $ l -> p $ (l + 1)) -- ^ Function to fold
-> (p $ 0) -- ^ Initial element
-> Vec k a -- ^ Vector to fold over
-> p $ k
dfold _ _ z Nil = z
dfold p f z (x :> (xs :: Vec l a)) = f (Proxy :: Proxy l) x (dfold p f z xs)
data V (a :: *) (f :: TyFun Nat *) :: *
type instance Apply (V a) l = Vec l a
{-# NOINLINE vfold #-}
-- | Specialised version of 'dfold' that builds a triangular computational
-- structure.
--
-- __NB__: Not synthesisable
--
-- Example:
--
-- @
-- cs a b = if a > b then (a,b) else (b,a)
-- csRow y xs = let (y',xs') = 'mapAccumL' cs y xs in xs' '<:' y'
-- csSort = 'vfold' csRow
-- @
--
-- Builds a triangular structure of compare and swaps to sort a row.
--
-- >>> csSort (7 :> 3 :> 9 :> 1 :> Nil)
-- <1,3,7,9>
vfold :: (forall l . a -> Vec l b -> Vec (l + 1) b)
-> Vec k a
-> Vec k b
vfold f xs = dfold (Proxy :: Proxy (V a)) (const f) Nil xs
instance (KnownNat n, KnownNat (BitSize a), BitPack a) => BitPack (Vec n a) where
type BitSize (Vec n a) = n * (BitSize a)
pack = concatBitVector# . map pack
unpack = map unpack . unconcatBitVector#
{-# NOINLINE concatBitVector# #-}
concatBitVector# :: KnownNat m
=> Vec n (BitVector m)
-> BitVector (n * m)
concatBitVector# = concatBitVector' . reverse
where
concatBitVector' :: KnownNat m
=> Vec n (BitVector m)
-> BitVector (n * m)
concatBitVector' Nil = 0
concatBitVector' (x :> xs) = concatBitVector' xs ++# x
{-# NOINLINE unconcatBitVector# #-}
unconcatBitVector# :: (KnownNat n, KnownNat m)
=> BitVector (n * m)
-> Vec n (BitVector m)
unconcatBitVector# bv = withSNat (\s -> ucBV (toUNat s) bv)
{-# INLINE ucBV #-}
ucBV :: forall n m . KnownNat m
=> UNat n -> BitVector (n * m) -> Vec n (BitVector m)
ucBV UZero _ = Nil
ucBV (USucc n) bv = let (bv',x :: BitVector m) = split# bv
in ucBV n bv' <: x
instance Lift a => Lift (Vec n a) where
lift Nil = [| Nil |]
lift (x:>xs) = [| x :> $(lift xs) |]
instance (KnownNat n, Arbitrary a) => Arbitrary (Vec n a) where
arbitrary = sequence $ repeat arbitrary
shrink = sequence . fmap shrink
instance CoArbitrary a => CoArbitrary (Vec n a) where
coarbitrary = coarbitrary . toList
type instance Index (Vec n a) = Int
type instance IxValue (Vec n a) = a
instance KnownNat n => Ixed (Vec n a) where
ix i f xs = replace_int xs i <$> f (index_int xs i)