module Haskus.Format.Binary.Vector
( Vector (..)
, vectorBuffer
, take
, drop
, index
, fromList
, fromFilledList
, fromFilledListZ
, toList
, replicate
, concat
)
where
import Prelude hiding (replicate, head, last,
tail, init, map, length, drop, take, concat)
import System.IO.Unsafe (unsafePerformIO)
import qualified Haskus.Utils.List as List
import Haskus.Utils.Types
import Haskus.Utils.HList
import Haskus.Format.Binary.Storable
import Haskus.Format.Binary.Ptr
import Haskus.Format.Binary.Buffer
data Vector (n :: Nat) a = Vector Buffer
instance (Storable a, Show a, KnownNat n) => Show (Vector n a) where
show v = "fromList " ++ show (toList v)
vectorBuffer :: Vector n a -> Buffer
vectorBuffer (Vector b) = b
type family ElemOffset a i n where
ElemOffset a i n = IfNat (i+1 <=? n)
(i * (SizeOf a))
(TypeError ('Text "Invalid vector index: " ':<>: 'ShowType i
':$$: 'Text "Vector size: " ':<>: 'ShowType n))
instance forall a n.
( KnownNat (SizeOf a * n)
) => StaticStorable (Vector n a) where
type SizeOf (Vector n a) = SizeOf a * n
type Alignment (Vector n a) = Alignment a
staticPeekIO ptr =
Vector <$> bufferPackPtr (natValue @(SizeOf a * n)) (castPtr ptr)
staticPokeIO ptr (Vector b) = bufferPoke ptr b
instance forall a n.
( KnownNat n
, Storable a
) => Storable (Vector n a) where
sizeOf _ = natValue @n * sizeOfT @a
alignment _ = alignmentT @a
peekIO ptr =
Vector <$> bufferPackPtr (sizeOfT' @(Vector n a)) (castPtr ptr)
pokeIO ptr (Vector b) = bufferPoke ptr b
take :: forall n m a.
( KnownNat (SizeOf a * n)
) => Vector (m+n) a -> Vector n a
take (Vector b) = Vector (bufferTake (natValue @(SizeOf a * n)) b)
drop :: forall n m a.
( KnownNat (SizeOf a * n)
) => Vector (m+n) a -> Vector m a
drop (Vector b) = Vector (bufferDrop (natValue @(SizeOf a * n)) b)
index :: forall i a n.
( KnownNat (ElemOffset a i n)
, Storable a
) => Vector n a -> a
index (Vector b) = bufferPeekStorableAt b (natValue @(ElemOffset a i n))
fromList :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => [a] -> Maybe (Vector n a)
fromList v
| n' /= n = Nothing
| n' == 0 = Just $ Vector $ emptyBuffer
| otherwise = Just $ Vector $ bufferPackStorableList v
where
n' = natValue' @n
n = fromIntegral (List.length v)
fromFilledList :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => a -> [a] -> Vector n a
fromFilledList z v = Vector $ bufferPackStorableList v'
where
v' = List.take (natValue @n) (v ++ repeat z)
fromFilledListZ :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => a -> [a] -> Vector n a
fromFilledListZ z v = fromFilledList z v'
where
v' = List.take (natValue @n 1) v
toList :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => Vector n a -> [a]
toList (Vector b)
| n == 0 = []
| otherwise = fmap (bufferPeekStorableAt b . (sza*)) [0..n1]
where
n = natValue @n
sza = sizeOfT' @a
replicate :: forall a (n :: Nat) .
( KnownNat n
, Storable a
) => a -> Vector n a
replicate v = fromFilledList v []
data StoreVector = StoreVector
instance forall n v a r.
( v ~ Vector n a
, r ~ IO (Ptr a)
, KnownNat n
, KnownNat (SizeOf a)
, StaticStorable a
, Storable a
) => Apply StoreVector (v, IO (Ptr a)) r where
apply _ (v, getP) = do
p <- getP
let
vsz = natValue @n
p' = p `indexPtr'` (1 * vsz * sizeOfT @a)
poke (castPtr p') v
return p'
type family WholeSize fs :: Nat where
WholeSize '[] = 0
WholeSize (Vector n s ': xs) = n + WholeSize xs
concat :: forall l (n :: Nat) a .
( n ~ WholeSize l
, KnownNat n
, Storable a
, StaticStorable a
, HFoldr StoreVector (IO (Ptr a)) l (IO (Ptr a))
)
=> HList l -> Vector n a
concat vs = unsafePerformIO $ do
let sz = sizeOfT @a * natValue @n
p <- mallocBytes (fromIntegral sz) :: IO (Ptr ())
_ <- hFoldr StoreVector (return (castPtr p `indexPtr'` sz) :: IO (Ptr a)) vs :: IO (Ptr a)
Vector <$> bufferUnsafePackPtr (fromIntegral sz) p