{- |
If you have a 'Trav.Traversable' instance of a record,
you can load and store all elements,
that are accessible by @Traversable@ methods.
We treat the record like an array,
that is we assume, that all elements have the same size and alignment.

Example:

> import Foreign.Storable.Traversable as Store
>
> data Stereo a = Stereo {left, right :: a}
>
> instance Functor Stereo where
>    fmap = Trav.fmapDefault
>
> instance Foldable Stereo where
>    foldMap = Trav.foldMapDefault
>
> instance Traversable Stereo where
>    sequenceA ~(Stereo l r) = liftA2 Stereo l r
>
> instance (Storable a) => Storable (Stereo a) where
>    sizeOf = Store.sizeOf
>    alignment = Store.alignment
>    peek = Store.peek (error "instance Traversable Stereo is lazy, so we do not provide a real value here")
>    poke = Store.poke

You would certainly not define the 'Trav.Traversable' and according instances
just for the implementation of the 'Storable' instance,
but there are usually similar applications
where the @Traversable@ instance is useful.
-}
module Foreign.Storable.Traversable (
   alignment, sizeOf,
   peek, poke,
   peekApplicative,
   ) where

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Control.Applicative as App

-- ToDo: Maybe we should use State.Strict instead?
import Control.Monad.Trans.State
          (StateT, evalStateT, get, put, modify, )
import Control.Monad.IO.Class (liftIO, )

import Foreign.Storable.FixedArray (roundUp, )
import qualified Foreign.Storable as St

import Foreign.Ptr (Ptr, castPtr, )
import Foreign.Storable (Storable, )
import Foreign.Marshal.Array (advancePtr, )


{-# INLINE elementType #-}
elementType :: f a -> a
elementType :: forall (f :: * -> *) a. f a -> a
elementType f a
_ =
   forall a. HasCallStack => [Char] -> a
error [Char]
"Storable.Traversable.alignment and sizeOf may not depend on element values"

{-# INLINE alignment #-}
alignment ::
   (Fold.Foldable f, Storable a) =>
   f a -> Int
alignment :: forall (f :: * -> *) a. (Foldable f, Storable a) => f a -> Int
alignment = forall a. Storable a => a -> Int
St.alignment forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. f a -> a
elementType

{-# INLINE sizeOf #-}
{- |
Warning:
It uses Foldable class and will certainly access the data structure
and thus will fail on 'undefined'.

You may call it on the record, after re-constructing it lazily:

> sizeOf . lazy
>
> lazy :: Complex a -> Complex a
> lazy ~(r:+i) = r:+i
-}
sizeOf ::
   (Fold.Foldable f, Storable a) =>
   f a -> Int
sizeOf :: forall (f :: * -> *) a. (Foldable f, Storable a) => f a -> Int
sizeOf f a
f =
   forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Fold.foldl' (\Int
s a
_ -> Int
s forall a. Num a => a -> a -> a
+ Int
1) Int
0 f a
f forall a. Num a => a -> a -> a
*
   Int -> Int -> Int
roundUp (forall (f :: * -> *) a. (Foldable f, Storable a) => f a -> Int
alignment f a
f) (forall a. Storable a => a -> Int
St.sizeOf (forall (f :: * -> *) a. f a -> a
elementType f a
f))


{- |
@peek skeleton ptr@ fills the @skeleton@ with data read from memory beginning at @ptr@.
The skeleton is needed formally for using 'Trav.Traversable'.
For instance when reading a list, it is not clear,
how many elements shall be read.
Using the skeleton you can give this information
and you also provide information that is not contained in the element type @a@.
For example you can call

> peek (replicate 10 ()) ptr

for reading 10 elements from memory starting at @ptr@.
-}
{-# INLINE peek #-}
peek ::
   (Trav.Traversable f, Storable a) =>
   f () -> Ptr (f a) -> IO (f a)
peek :: forall (f :: * -> *) a.
(Traversable f, Storable a) =>
f () -> Ptr (f a) -> IO (f a)
peek f ()
skeleton =
   forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
Trav.mapM (forall a b. a -> b -> a
const forall a. Storable a => StateT (Ptr a) IO a
peekState) f ()
skeleton) forall b c a. (b -> c) -> (a -> b) -> a -> c
.
   forall a b. Ptr a -> Ptr b
castPtr

{- |
Like 'peek' but uses 'pure' for construction of the result.
'pure' would be in class @Pointed@ if that would exist.
Thus we use the closest approximate 'Applicative'.
-}
{-# INLINE peekApplicative #-}
peekApplicative ::
   (App.Applicative f, Trav.Traversable f, Storable a) =>
   Ptr (f a) -> IO (f a)
peekApplicative :: forall (f :: * -> *) a.
(Applicative f, Traversable f, Storable a) =>
Ptr (f a) -> IO (f a)
peekApplicative =
   forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
Trav.sequence (forall (f :: * -> *) a. Applicative f => a -> f a
App.pure forall a. Storable a => StateT (Ptr a) IO a
peekState)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Ptr a -> Ptr b
castPtr

{-# INLINE peekState #-}
peekState ::
   (Storable a) =>
   StateT (Ptr a) IO a
peekState :: forall a. Storable a => StateT (Ptr a) IO a
peekState =
   forall (m :: * -> *) s. Monad m => StateT s m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Ptr a
p -> forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
p Int
1) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall a. Storable a => Ptr a -> IO a
St.peek Ptr a
p)

{-# INLINE poke #-}
poke ::
   (Fold.Foldable f, Storable a) =>
   Ptr (f a) -> f a -> IO ()
poke :: forall (f :: * -> *) a.
(Foldable f, Storable a) =>
Ptr (f a) -> f a -> IO ()
poke Ptr (f a)
ptr f a
x =
   forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
Fold.traverse_ forall a. Storable a => a -> StateT (Ptr a) IO ()
pokeState f a
x) forall a b. (a -> b) -> a -> b
$
   forall a b. Ptr a -> Ptr b
castPtr Ptr (f a)
ptr

{-# INLINE pokeState #-}
pokeState ::
   (Storable a) =>
   a -> StateT (Ptr a) IO ()
pokeState :: forall a. Storable a => a -> StateT (Ptr a) IO ()
pokeState a
x = do
   forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> a -> IO ()
St.poke a
x forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) s. Monad m => StateT s m s
get
   forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
1)