module Cgm.Data.Array (
ArrayIx,
Array(..),
ImmArray(..),
STArray(..),
STMkArray(..),
STZArray(..),
PinnedArray(..),
Allocation(..),
newAlignedPinnedWord8Array,
withArrayByteString,
arrayFromByteString,
unsafeWithCastArrayByteString,
Free,
Pinned,
primArrayMatchAllocation,
PrimArray,
runSTPrimArray,
unsafeFreezeSTPrimArray,
STPrimArray,
ArrayRange(..),
fullArrayRange,
skipArrayRange,
headArrayRange,
wordConvArrayRange,
unsafePrimArrayCast,
unsafeSTPrimArrayCast,
mapMArray,
mapMArrayCopy,
mapMArrayCopyImm,
mapImmArray,
mkArrayWith,
ioToST,
primSizeOf,
ST(ST),
module Control.Monad.ST,
module Data.Primitive.Types
) where
import Prelude ()
import Cgm.Prelude
import GHC.IO( IO(..), ioToST )
import GHC.ST( ST(..) )
import GHC.Exts
import GHC.Prim
import qualified Data.ByteString as ByteString
import Data.ByteString.Unsafe
import Foreign.Storable
import Foreign.Ptr
import Control.Monad.ST
import Cgm.Data.Word
import Cgm.Data.Len
import Cgm.System.Mem.Alloc
import Data.Primitive.Types
type ArrayIx a = Len (ArrayElem a) Word
type ArrayST a = ST (ArrayState a)
class Array a where
type ArrayElem a :: *
arrayLen :: a -> ArrayIx a
class Array a => ImmArray a where
indexArray :: a -> ArrayIx a -> ArrayElem a
class Array a => STArray a where
type ArrayState a :: *
class STArray a => STMkArray a where
mkArray :: ArrayIx a -> ArrayST a a
class STArray a => STZArray a where
readArray :: a -> ArrayIx a -> ArrayST a (ArrayElem a)
writeArray :: a -> ArrayIx a -> ArrayElem a -> ArrayST a ()
class Array a => PinnedArray a where
withArrayPtr :: (Ptr (ArrayElem a) -> ArrayIx a -> IO b) -> a -> IO b
class Allocation f where
newWord8Array :: Len Word8 Word -> ST s (STPrimArray s f Word8)
onAllocation :: (f ~ Free => z) -> (f ~ Pinned => z) -> Tagged f z
withArrayByteString :: (PinnedArray a, ArrayElem a ~ Word8) => (ByteString.ByteString -> IO b) -> a -> IO b
withArrayByteString f a = flip withArrayPtr a $ \ptr len -> (unsafePackCStringFinalizer ptr (fromIntegral $ getLen len) $ return ()) >>= f
unsafeWithCastArrayByteString :: forall a b. (PinnedArray a, Prim (ArrayElem a)) => (ByteString.ByteString -> IO b) -> a -> IO b
unsafeWithCastArrayByteString f a = flip withArrayPtr a $
\ptr len -> (unsafePackCStringFinalizer (castPtr ptr) (fromIntegral (getLen len) * primSizeOf (undefined :: ArrayElem a)) $ return ()) >>= f
arrayFromByteString :: Allocation f => ByteString.ByteString -> PrimArray f Word8
arrayFromByteString b = mkArrayWith (unsafeLen $ fromIntegral $ ByteString.length b) $ ByteString.index b . fromIntegral . getLen
arrayToList :: ImmArray a => a -> [ArrayElem a]
arrayToList a = fmap (indexArray a) $ counting $ arrayLen a
immArrayEq :: (ImmArray a, Eq (ArrayElem a), Prim (ArrayElem a)) => a -> a -> Bool
immArrayEq a b = (arrayLen a == arrayLen b) && (and $ fmap (\i -> indexArray a i == indexArray b i) (counting $ min (arrayLen a) (arrayLen b)))
st_ :: (State# s -> State# s) -> ST s ()
st_ f = ST $ \s -> (# f s, () #)
data Free
data Pinned
instance Allocation Free where
newWord8Array l = ST $ \s -> case newByteArray# (unI# (fromIntegral $ getLen l)) s of (# s', a #) -> (# s', STPrimArray a #)
onAllocation f p = tag f
instance Allocation Pinned where
newWord8Array l = ST $ \s -> case newPinnedByteArray# (unI# (fromIntegral $ getLen l)) s of (# s', a #) -> (# s', STPrimArray a #)
onAllocation f p = tag p
newAlignedPinnedWord8Array :: Len Word8 Word -> Len Word8 Word -> ST s (STPrimArray s f Word8)
newAlignedPinnedWord8Array al l = ST $ \s ->
case newAlignedPinnedByteArray# (unI# (fromIntegral $ getLen l)) (unI# (fromIntegral $ getLen al)) s of (# s', a #) -> (# s', STPrimArray a #)
primArrayMatchAllocation :: forall f f' w. (Allocation f, Allocation f', Prim w) => ArrayRange (PrimArray f w) -> ArrayRange (PrimArray f' w)
primArrayMatchAllocation = (at :: At f) $ onAllocation
((at :: At f') $ onAllocation id (fullArrayRange . mapImmArray id))
((at :: At f') $ onAllocation (fullArrayRange . mapImmArray id) id)
primSizeOf :: Prim w => w -> Int
primSizeOf w = I# (sizeOf# w)
data STPrimArray s f w = STPrimArray (MutableByteArray# s)
instance Eq (STPrimArray s f w) where (STPrimArray a) == (STPrimArray b) = sameMutableByteArray# a b
instance Prim w => Array (STPrimArray s f w) where
type ArrayElem (STPrimArray s f w) = w
arrayLen (STPrimArray a) = unsafeLen $ fromIntegral $ I# (sizeofMutableByteArray# a `quotInt#` sizeOf# (undefined :: w))
instance Prim w => STArray (STPrimArray s f w) where
type ArrayState (STPrimArray s f w) = s
instance (Allocation f, Prim w) => STMkArray (STPrimArray s f w) where
mkArray l = unsafeSTPrimArrayCast <$> newWord8Array (unsafeLen $ getLen l * fromIntegral (I# (sizeOf# (undefined :: w))))
instance Prim w => STZArray (STPrimArray s f w) where
readArray (STPrimArray a) i = ST $ readByteArray# a (word2Int# (unW# (getLen i)))
writeArray (STPrimArray a) i w = st_ $ writeByteArray# a (word2Int# (unW# (getLen i))) w
instance Prim w => PinnedArray (STPrimArray s Pinned w) where
withArrayPtr f arr@(STPrimArray a) = f (Ptr (byteArrayContents# (unsafeCoerce# a))) (arrayLen arr) >>= (<$ (IO $ \s -> (# touch# a s, () #)))
data PrimArray f w = PrimArray ByteArray#
instance (Prim w, Eq w) => Eq (PrimArray f w) where (==) = immArrayEq
instance (Prim w, Show w) => Show (PrimArray f w) where show = show . arrayToList
instance Prim w => Array (PrimArray f w) where
type ArrayElem (PrimArray f w) = w
arrayLen (PrimArray a) = unsafeLen $ fromIntegral $ I# (sizeofByteArray# a `quotInt#` sizeOf# (undefined :: w))
instance Prim w => ImmArray (PrimArray f w) where
indexArray (PrimArray a) i = indexByteArray# a (word2Int# (unW# (getLen i)))
instance Prim w => PinnedArray (PrimArray Pinned w) where
withArrayPtr f arr@(PrimArray a) = f (Ptr (byteArrayContents# a)) (arrayLen arr) >>= (<$ (IO $ \s -> (# touch# a s, () #)))
runSTPrimArray :: (forall s . ST s (STPrimArray s f w)) -> PrimArray f w
runSTPrimArray st = runST (st >>= unsafeFreezeSTPrimArray)
unsafeFreezeSTPrimArray :: STPrimArray s f w -> ST s (PrimArray f w)
unsafeFreezeSTPrimArray (STPrimArray a) = ST (\s -> case unsafeFreezeByteArray# a s of (# s', a' #) -> (#s', PrimArray a' #) )
data ArrayRange a = ArrayRange a (ArrayIx a) (ArrayIx a)
instance (ImmArray a, Eq (ArrayElem a), Prim (ArrayElem a)) => Eq (ArrayRange a) where (==) = immArrayEq
instance (ImmArray a, Show (ArrayElem a)) => Show (ArrayRange a) where show = show . arrayToList
instance Array a => Array (ArrayRange a) where
type ArrayElem (ArrayRange a) = ArrayElem a
arrayLen (ArrayRange _ _ len) = len
instance ImmArray a => ImmArray (ArrayRange a) where
indexArray (ArrayRange a s _) i = indexArray a $ s + i
instance STArray a => STArray (ArrayRange a) where
type ArrayState (ArrayRange a) = ArrayState a
instance STZArray a => STZArray (ArrayRange a) where
readArray (ArrayRange a s _) i = readArray a $ i + s
writeArray (ArrayRange a s _) i e = writeArray a (i + s) e
instance (Prim (ArrayElem a), PinnedArray a) => PinnedArray (ArrayRange a) where
withArrayPtr f (ArrayRange a start sz) = withArrayPtr (\ptr _ -> f (ptr `plusPtrLen` start) sz) a
plusPtrLen :: forall a b. (Prim a, Integral b) => Ptr a -> Len a b -> Ptr a
plusPtrLen p = plusPtr p . (* primSizeOf (undefined :: a)) . fromIntegral . getLen
fullArrayRange :: Array a => a -> ArrayRange a
fullArrayRange a = ArrayRange a 0 (arrayLen a)
skipArrayRange :: Array a => ArrayIx a -> ArrayRange a -> ArrayRange a
skipArrayRange n (ArrayRange a s l) = ArrayRange a (s + n) (l n)
headArrayRange :: Array a => ArrayIx a -> ArrayRange a -> ArrayRange a
headArrayRange n (ArrayRange a s l) = ArrayRange a s $ min n l
wordConvArrayRange :: WordConv c => Bijection' (ArrayRange (PrimArray f Word)) (ArrayRange (PrimArray f c))
wordConvArrayRange = uncheckedBijection
(\(ArrayRange a s l) -> ArrayRange (apply wordConv1 a) (apply wordLenB s) (apply wordLenB l))
(\(ArrayRange a s l) -> ArrayRange (retract wordConv1 a) (retract wordLenB s) (retract wordLenB l))
unsafeSTPrimArrayCast :: STPrimArray s f w -> STPrimArray s f w'
unsafeSTPrimArrayCast (STPrimArray a) = STPrimArray a
unsafePrimArrayCast :: PrimArray f w -> PrimArray f w'
unsafePrimArrayCast (PrimArray a) = PrimArray a
instance WordConv1 (PrimArray f) where wordConv1 = uncheckedBijection unsafePrimArrayCast unsafePrimArrayCast
instance WordConv1 (STPrimArray s f) where wordConv1 = uncheckedBijection unsafeSTPrimArrayCast unsafeSTPrimArrayCast
mapMArray :: STZArray a => (ArrayElem a -> ArrayST a (ArrayElem a)) -> a -> ArrayST a ()
mapMArray f a = mapM_ (\i -> readArray a i >>= f >>= writeArray a i) $ counting $ arrayLen a
mapMArrayCopy :: (STZArray a, STZArray b, ArrayState a ~ ArrayState b) => (ArrayElem a -> ArrayST a (ArrayElem b)) -> a -> b -> ArrayST a ()
mapMArrayCopy f src dest = mapM_ (\i -> readArray src (unsafeLen i) >>= f >>= writeArray dest (unsafeLen i)) $
counting $ min (getLen $ arrayLen src) (getLen $ arrayLen dest)
mapMArrayCopyImm :: (ImmArray a, STZArray b) => (ArrayElem a -> ArrayST b (ArrayElem b)) -> a -> b -> ArrayST b ()
mapMArrayCopyImm f src dest = mapM_ (\i -> f (indexArray src $ unsafeLen i) >>= writeArray dest (unsafeLen i)) $
counting $ min (getLen $ arrayLen src) (getLen $ arrayLen dest)
mkArrayWith :: (Allocation f, Prim a) => Len a Word -> (Len a Word -> a) -> PrimArray f a
mkArrayWith l f = runSTPrimArray $ do
m <- mkArray l
mapM_ (\i -> writeArray m i $ f i) $ counting l
return m
mapImmArray :: (Allocation f', Prim b, ImmArray r) => (ArrayElem r -> b) -> r -> PrimArray f' b
mapImmArray f src = mkArrayWith (unsafeLen $ getLen $ arrayLen src) (f . indexArray src . unsafeLen . getLen)