{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Storable.Mutable.Private where

import qualified Data.Array.Comfort.Shape as Shape

import qualified Foreign.Marshal.Array.Guarded as Alloc
import Foreign.Marshal.Array (copyArray, pokeArray, peekArray)
import Foreign.Storable (Storable, pokeElemOff, peekElemOff)
import Foreign.Ptr (Ptr)

import Control.Monad.Primitive (PrimMonad, unsafeIOToPrim)
import Control.Monad.ST (ST)
import Control.Monad (liftM)
import Control.Applicative ((<$>))

import Data.Either.HT (maybeRight)
import Data.Tuple.HT (mapFst)

import qualified Prelude as P
import Prelude hiding (read, show)


data Array (m :: * -> *) sh a =
   Array {
      forall (m :: * -> *) sh a. Array m sh a -> sh
shape :: sh,
      forall (m :: * -> *) sh a. Array m sh a -> MutablePtr a
buffer :: Alloc.MutablePtr a
   }

type STArray s = Array (ST s)
type IOArray = Array IO


copy ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   Array m sh a -> m (Array m sh a)
copy :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array m sh a)
copy (Array sh
sh MutablePtr a
srcFPtr) =
   forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
dstPtr ->
   forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
srcFPtr forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr ->
      forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
dstPtr Ptr a
srcPtr Int
n


create ::
   (Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> IO (IOArray sh a)
create :: forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> IO (IOArray sh a)
create sh
sh Ptr a -> IO ()
f = forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize sh
sh forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const Ptr a -> IO ()
f

createWithSize ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize :: forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize sh
sh Int -> Ptr a -> IO ()
f =
   forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO ()
f

createWithSizeAndResult ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult :: forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f = do
   let size :: Int
size = forall sh. C sh => sh -> Int
Shape.size sh
sh
   MutablePtr a
mfptr <- forall a. Storable a => Int -> IO (MutablePtr a)
Alloc.new Int
size
   b
b <- forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
mfptr forall a b. (a -> b) -> a -> b
$ Int -> Ptr a -> IO b
f Int
size
   forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
Array sh
sh MutablePtr a
mfptr, b
b)


unsafeCreate ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate sh
sh Ptr a -> IO ()
f = forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const Ptr a -> IO ()
f

unsafeCreateWithSize ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh Int -> Ptr a -> IO ()
f =
   forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) sh a b.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult sh
sh Int -> Ptr a -> IO ()
f

unsafeCreateWithSizeAndResult ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult :: forall (m :: * -> *) sh a b.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f =
   forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim forall a b. (a -> b) -> a -> b
$
   forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a c b. (a -> c) -> (a, b) -> (c, b)
mapFst forall (m :: * -> *) sh a.
PrimMonad m =>
IOArray sh a -> Array m sh a
unsafeArrayIOToPrim) forall a b. (a -> b) -> a -> b
$ forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f

unsafeArrayIOToPrim :: (PrimMonad m) => IOArray sh a -> Array m sh a
unsafeArrayIOToPrim :: forall (m :: * -> *) sh a.
PrimMonad m =>
IOArray sh a -> Array m sh a
unsafeArrayIOToPrim (Array sh
sh MutablePtr a
fptr) = forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
Array sh
sh MutablePtr a
fptr


show ::
   (PrimMonad m, Shape.C sh, Show sh, Storable a, Show a) =>
   Array m sh a -> m String
show :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Show sh, Storable a, Show a) =>
Array m sh a -> m String
show Array m sh a
arr = do
   [a]
xs <- forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m [a]
toList Array m sh a
arr
   forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
      String
"StorableArray.fromList " forall a. [a] -> [a] -> [a]
++ forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (forall (m :: * -> *) sh a. Array m sh a -> sh
shape Array m sh a
arr) (Char
' ' forall a. a -> [a] -> [a]
: forall a. Show a => a -> String
P.show [a]
xs)

withArrayPtr :: (PrimMonad m) => Alloc.MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr :: forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr = forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
fptr

withPtr :: (PrimMonad m) => Array m sh a -> (Ptr a -> IO b) -> m b
withPtr :: forall (m :: * -> *) sh a b.
PrimMonad m =>
Array m sh a -> (Ptr a -> IO b) -> m b
withPtr (Array sh
_sh MutablePtr a
fptr) = forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr

read ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> m a
read :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
read (Array sh
sh MutablePtr a
fptr) Index sh
ix =
   forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix)

readMaybe ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> Maybe (m a)
readMaybe :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> Maybe (m a)
readMaybe (Array sh
sh MutablePtr a
fptr) Index sh
ix =
   forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff) forall a b. (a -> b) -> a -> b
$ forall a b. Either a b -> Maybe b
maybeRight forall a b. (a -> b) -> a -> b
$
   forall a. Result Checked a -> Either String a
Shape.getChecked forall a b. (a -> b) -> a -> b
$ forall sh check.
(Indexed sh, Checking check) =>
sh -> Index sh -> Result check Int
Shape.unifiedOffset sh
sh Index sh
ix

write ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> a -> m ()
write :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
write (Array sh
sh MutablePtr a
fptr) Index sh
ix a
a =
   forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
ptr (forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix) a
a

update ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> (a -> a) -> m ()
update :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> (a -> a) -> m ()
update (Array sh
sh MutablePtr a
fptr) Index sh
ix a -> a
f =
   let k :: Int
k = forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix
   in forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
ptr Int
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr Int
k

new :: (PrimMonad m, Shape.C sh, Storable a) => sh -> a -> m (Array m sh a)
new :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> a -> m (Array m sh a)
new sh
sh a
x =
   forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh forall a b. (a -> b) -> a -> b
$ \Int
size Ptr a
ptr -> forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr a
ptr forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> [a]
replicate Int
size a
x

toList :: (PrimMonad m, Shape.C sh, Storable a) => Array m sh a -> m [a]
toList :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m [a]
toList (Array sh
sh MutablePtr a
fptr) = forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray (forall sh. C sh => sh -> Int
Shape.size sh
sh)

fromList ::
   (PrimMonad m, Shape.C sh, Storable a) => sh -> [a] -> m (Array m sh a)
fromList :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> [a] -> m (Array m sh a)
fromList sh
sh [a]
xs = forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate sh
sh forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr a
ptr [a]
xs

vectorFromList ::
   (PrimMonad m, Storable a) => [a] -> m (Array m (Shape.ZeroBased Int) a)
vectorFromList :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
[a] -> m (Array m (ZeroBased Int) a)
vectorFromList [a]
xs =
   forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate (forall n. n -> ZeroBased n
Shape.ZeroBased forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray [a]
xs