{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} #if MIN_VERSION_GLASGOW_HASKELL (8,6,0,0) {-# LANGUAGE NoStarIsType #-} #endif -- | Vector with size in the type module Haskus.Binary.Vector ( Vector (..) , vectorBuffer , vectorReverse , take , drop , index , fromList , fromFilledList , fromFilledListZ , toList , replicate , concat , zipWith ) where import Prelude hiding ( replicate, head, last , tail, init, map, length, drop, take, concat , zipWith ) import System.IO.Unsafe (unsafePerformIO) import qualified Haskus.Utils.List as List import Haskus.Utils.Types import Haskus.Utils.HList import Haskus.Utils.Maybe import Haskus.Utils.Flow import Haskus.Binary.Storable import Haskus.Binary.Buffer import Haskus.Binary.Bits import Foreign.Ptr import Foreign.Marshal.Alloc (mallocBytes) -- | Vector with type-checked size data Vector (n :: Nat) a = Vector Buffer instance (Storable a, Show a, KnownNat n) => Show (Vector n a) where show v = "fromList " ++ show (toList v) -- | Return the buffer backing the vector vectorBuffer :: Vector n a -> Buffer vectorBuffer (Vector b) = b -- | Reverse a vector vectorReverse :: (KnownNat n, Storable a) => Vector n a -> Vector n a vectorReverse = fromJust . fromList . reverse . toList -- | Offset of the i-th element in a stored vector type family ElemOffset a i n where ElemOffset a i n = Assert (i+1 <=? n) (i * (SizeOf a)) (('Text "Invalid vector index: " ':<>: 'ShowType i ':$$: 'Text "Vector size: " ':<>: 'ShowType n)) instance forall a n. ( KnownNat (SizeOf a * n) ) => StaticStorable (Vector n a) where type SizeOf (Vector n a) = SizeOf a * n type Alignment (Vector n a) = Alignment a staticPeekIO ptr = Vector <$> bufferPackPtr (natValue @(SizeOf a * n)) (castPtr ptr) staticPokeIO ptr (Vector b) = bufferPoke ptr b instance forall a n. ( KnownNat n , Storable a ) => Storable (Vector n a) where sizeOf _ = natValue @n * sizeOfT @a alignment _ = alignmentT @a peekIO ptr = Vector <$> bufferPackPtr (sizeOfT' @(Vector n a)) (castPtr ptr) pokeIO ptr (Vector b) = bufferPoke ptr b -- | Yield the first n elements take :: forall n m a. ( KnownNat (SizeOf a * n) ) => Vector (m+n) a -> Vector n a {-# INLINABLE take #-} take (Vector b) = Vector (bufferTake (natValue @(SizeOf a * n)) b) -- | Drop the first n elements drop :: forall n m a. ( KnownNat (SizeOf a * n) ) => Vector (m+n) a -> Vector m a {-# INLINABLE drop #-} drop (Vector b) = Vector (bufferDrop (natValue @(SizeOf a * n)) b) -- | /O(1)/ Index safely into the vector using a type level index. index :: forall i a n. ( KnownNat (ElemOffset a i n) , Storable a ) => Vector n a -> a {-# INLINABLE index #-} index (Vector b) = bufferPeekStorableAt b (natValue @(ElemOffset a i n)) -- | Convert a list into a vector if the number of elements matches fromList :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => [a] -> Maybe (Vector n a) {-# INLINABLE fromList #-} fromList v | n' /= n = Nothing | n' == 0 = Just $ Vector $ emptyBuffer | otherwise = Just $ Vector $ bufferPackStorableList v where n' = natValue' @n n = fromIntegral (List.length v) -- | Take at most n element from the list, then use z fromFilledList :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => a -> [a] -> Vector n a {-# INLINABLE fromFilledList #-} fromFilledList z v = Vector $ bufferPackStorableList v' where v' = List.take (natValue @n) (v ++ repeat z) -- | Take at most (n-1) element from the list, then use z fromFilledListZ :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => a -> [a] -> Vector n a {-# INLINABLE fromFilledListZ #-} fromFilledListZ z v = fromFilledList z v' where v' = List.take (natValue @n - 1) v -- | Convert a vector into a list toList :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => Vector n a -> [a] {-# INLINABLE toList #-} toList (Vector b) | n == 0 = [] | otherwise = fmap (bufferPeekStorableAt b . (sza*)) [0..n-1] where n = natValue @n sza = sizeOfT' @a -- | Create a vector by replicating a value replicate :: forall a (n :: Nat) . ( KnownNat n , Storable a ) => a -> Vector n a {-# INLINABLE replicate #-} replicate v = fromFilledList v [] data StoreVector = StoreVector -- Store a vector at the right offset instance forall n v a r. ( v ~ Vector n a , r ~ IO (Ptr a) , KnownNat n , KnownNat (SizeOf a) , StaticStorable a , Storable a ) => Apply StoreVector (v, IO (Ptr a)) r where apply _ (v, getP) = do p <- getP let vsz = natValue @n p' = p `plusPtr` (-1 * vsz * fromIntegral (sizeOfT @a)) poke (castPtr p') v return p' type family WholeSize fs :: Nat where WholeSize '[] = 0 WholeSize (Vector n s ': xs) = n + WholeSize xs -- | Concat several vectors into a single one concat :: forall l (n :: Nat) a . ( n ~ WholeSize l , KnownNat n , Storable a , StaticStorable a , HFoldr StoreVector (IO (Ptr a)) l (IO (Ptr a)) ) => HList l -> Vector n a concat vs = unsafePerformIO $ do let sz = sizeOfT @a * natValue @n p <- mallocBytes (fromIntegral sz) :: IO (Ptr ()) _ <- hFoldr StoreVector (return (castPtr p `plusPtr` fromIntegral sz) :: IO (Ptr a)) vs :: IO (Ptr a) Vector <$> bufferUnsafePackPtr (fromIntegral sz) p -- | Zip two vectors zipWith :: ( KnownNat n , Storable a , Storable b , Storable c ) => (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c zipWith f u v = fromJust . fromList <| List.zipWith f (toList u) (toList v) -- | map map :: ( KnownNat n , Storable a , Storable b ) => (a -> b) -> Vector n a -> Vector n b map f = fromJust . fromList . fmap f . toList instance ( KnownNat n , Storable a , Eq a ) => Eq (Vector n a) where u == v = toList u == toList v instance ( KnownNat n , Bitwise a , Storable a ) => Bitwise (Vector n a) where u .&. v = zipWith (.&.) u v u .|. v = zipWith (.|.) u v u `xor` v = zipWith xor u v instance ( KnownNat (BitSize a) , FiniteBits a , KnownNat n , Storable a ) => FiniteBits (Vector n a) where type BitSize (Vector n a) = n * BitSize a zeroBits = fromJust (fromList (List.replicate (natValue @n) zeroBits)) oneBits = fromJust (fromList (List.replicate (natValue @n) oneBits)) complement u = map complement u countLeadingZeros = go 0 . toList where go !n [] = n go !n (x:xs) = let c = countLeadingZeros x in if c == natValue @(BitSize a) then go (n+c) xs else n+c countTrailingZeros = go 0 . reverse . toList where go !n [] = n go !n (x:xs) = let c = countTrailingZeros x in if c == natValue @(BitSize a) then go (n+c) xs else n+c instance ( Storable a , ShiftableBits a , Bitwise a , FiniteBits a , KnownNat (BitSize a) , KnownNat (n * BitSize a) , KnownNat n ) => ShiftableBits (Vector n a) where shiftL u c = uncheckedShiftL u (c `mod` natValue @(BitSize (Vector n a))) shiftR u c = uncheckedShiftR u (c `mod` natValue @(BitSize (Vector n a))) uncheckedShiftL u c = let n = natValue @n sa = natValue @(BitSize a) go _ 0 _ = [] go 0 k xs = List.take k xs go s k xs | s >= sa = go (s-sa) k (List.tail xs) | otherwise = let (x:y:zs) = xs in ((x `shiftL` s) .|. (y `shiftR` (sa-s))) : go s (k-1) (y:zs) in fromJust (fromList (go c n (toList u ++ List.repeat zeroBits))) uncheckedShiftR u c = let n = natValue @n sa = natValue @(BitSize a) go _ 0 _ = [] go 0 k xs = List.take k (List.tail xs) go s k xs | s >= sa = zeroBits : go (s-sa) (k-1) xs | otherwise = let (x:y:zs) = xs in ((x `shiftL` (sa-s)) .|. (y `shiftR` s)) : go s (k-1) (y:zs) in fromJust (fromList (go c n (zeroBits : toList u))) instance ( Storable a , IndexableBits a , FiniteBits a , KnownNat (BitSize a) , KnownNat n , Bitwise a ) => IndexableBits (Vector n a) where popCount = sum . fmap popCount . toList bit i = let n = natValue @n sa = natValue @(BitSize a) (f,r) = i `divMod` sa toRep = fromIntegral (n - f - 1) xs = List.replicate toRep zeroBits ++ [bit r] ++ List.replicate (fromIntegral f) zeroBits in fromJust <| fromList <| if i >= n * sa then List.replicate (fromIntegral n) zeroBits else xs testBit u i = let n = natValue @n sa = natValue @(BitSize a) (f,r) = i `divMod` sa toDrop = fromIntegral (n - f - 1) in if i >= n * sa then False else testBit (List.head (List.drop toDrop (toList u))) r instance ( Storable a , Bits a , KnownNat n , KnownNat (n * BitSize a) ) => RotatableBits (Vector n a)