{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{- |
The functions in this module miss any bound checking.
-}
module Data.Array.Comfort.Storable.Unchecked (
   Priv.Array(Array, shape, buffer),
   Priv.reshape,
   mapShape,

   (Priv.!),
   unsafeCreate,
   unsafeCreateWithSize,
   unsafeCreateWithSizeAndResult,
   Priv.toList,
   Priv.fromList,
   Priv.vectorFromList,

   map,
   mapWithIndex,
   zipWith,
   (Priv.//),
   Priv.accumulate,
   Priv.fromAssociations,

   singleton,
   append,
   take, drop,
   takeLeft, takeRight, split,
   takeCenter,

   sum, product,
   foldl,
   ) where

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as Monadic
import qualified Data.Array.Comfort.Storable.Private as Priv
import qualified Data.Array.Comfort.Storable.Memory as Memory
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Private (Array(Array), mapShape)
import Data.Array.Comfort.Shape ((::+)((::+)))

import System.IO.Unsafe (unsafePerformIO)
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.Storable (Storable, poke, peek)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)

import Control.Monad.ST (runST)
import Control.Applicative (liftA2)

import qualified Data.List as List

import Prelude hiding (map, zipWith, foldl, take, drop, sum, product)


{- $setup
>>> import DocTest.Data.Array.Comfort.Storable (ShapeInt, genArray)
>>>
>>> import qualified Data.Array.Comfort.Storable as Array
>>> import qualified Data.Array.Comfort.Shape as Shape
>>> import Data.Array.Comfort.Storable (Array, (!))
>>>
>>> import qualified Test.QuickCheck as QC
>>>
>>> import Data.Word (Word16)
>>>
>>> newtype Array16 = Array16 (Array ShapeInt Word16)
>>>    deriving (Show)
>>>
>>> instance QC.Arbitrary Array16 where
>>>    arbitrary = fmap Array16 genArray
-}


unsafeCreate ::
   (Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate :: forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate sh
sh Ptr a -> IO ()
arr = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array sh a)
Monadic.unsafeCreate sh
sh Ptr a -> IO ()
arr)

unsafeCreateWithSize ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> Array sh a
unsafeCreateWithSize :: forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
unsafeCreateWithSize sh
sh Int -> Ptr a -> IO ()
arr = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array sh a)
Monadic.unsafeCreateWithSize sh
sh Int -> Ptr a -> IO ()
arr)

unsafeCreateWithSizeAndResult ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> (Array sh a, b)
unsafeCreateWithSizeAndResult :: forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> (Array sh a, b)
unsafeCreateWithSizeAndResult sh
sh Int -> Ptr a -> IO b
arr =
   forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) sh a b.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> m (Array sh a, b)
Monadic.unsafeCreateWithSizeAndResult sh
sh Int -> Ptr a -> IO b
arr)


map ::
   (Shape.C sh, Storable a, Storable b) =>
   (a -> b) -> Array sh a -> Array sh b
map :: forall sh a b.
(C sh, Storable a, Storable b) =>
(a -> b) -> Array sh a -> Array sh b
map a -> b
f (Array sh
sh ForeignPtr a
a) =
   forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate sh
sh forall a b. (a -> b) -> a -> b
$ \Ptr b
dstPtr ->
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr ->
   forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
List.take (forall sh. C sh => sh -> Int
Shape.size sh
sh) forall a b. (a -> b) -> a -> b
$
      forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
List.zipWith
         (\Ptr a
src Ptr b
dst -> forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr b
dst forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. Storable a => Ptr a -> IO a
peek Ptr a
src)
         (forall a. (a -> a) -> a -> [a]
iterate (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
1) Ptr a
srcPtr)
         (forall a. (a -> a) -> a -> [a]
iterate (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
1) Ptr b
dstPtr)

mapWithIndex ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a, Storable b) =>
   (ix -> a -> b) -> Array sh a -> Array sh b
mapWithIndex :: forall sh ix a b.
(Indexed sh, Index sh ~ ix, Storable a, Storable b) =>
(ix -> a -> b) -> Array sh a -> Array sh b
mapWithIndex ix -> a -> b
f (Array sh
sh ForeignPtr a
a) =
   forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate sh
sh forall a b. (a -> b) -> a -> b
$ \Ptr b
dstPtr ->
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr ->
   forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ forall a b. (a -> b) -> a -> b
$
      forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
List.zipWith3
         (\ix
ix Ptr a
src Ptr b
dst -> forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr b
dst forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> a -> b
f ix
ix forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. Storable a => Ptr a -> IO a
peek Ptr a
src)
         (forall sh. Indexed sh => sh -> [Index sh]
Shape.indices sh
sh)
         (forall a. (a -> a) -> a -> [a]
iterate (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
1) Ptr a
srcPtr)
         (forall a. (a -> a) -> a -> [a]
iterate (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
1) Ptr b
dstPtr)

zipWith ::
   (Shape.C sh, Storable a, Storable b, Storable c) =>
   (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith :: forall sh a b c.
(C sh, Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith a -> b -> c
f (Array sh
_sh ForeignPtr a
a) (Array sh
sh ForeignPtr b
b) =
   forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate sh
sh forall a b. (a -> b) -> a -> b
$ \Ptr c
dstPtr ->
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a forall a b. (a -> b) -> a -> b
$ \Ptr a
srcAPtr ->
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr b
b forall a b. (a -> b) -> a -> b
$ \Ptr b
srcBPtr ->
   forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
List.take (forall sh. C sh => sh -> Int
Shape.size sh
sh) forall a b. (a -> b) -> a -> b
$
      forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
List.zipWith3
         (\Ptr a
srcA Ptr b
srcB Ptr c
dst -> forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr c
dst forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> b -> c
f (forall a. Storable a => Ptr a -> IO a
peek Ptr a
srcA) (forall a. Storable a => Ptr a -> IO a
peek Ptr b
srcB))
         (forall a. (a -> a) -> a -> [a]
iterate (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
1) Ptr a
srcAPtr)
         (forall a. (a -> a) -> a -> [a]
iterate (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
1) Ptr b
srcBPtr)
         (forall a. (a -> a) -> a -> [a]
iterate (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
1) Ptr c
dstPtr)


{- |
prop> \x  ->  Array.singleton x ! () == (x::Word16)
-}
singleton :: (Storable a) => a -> Array () a
singleton :: forall a. Storable a => a -> Array () a
singleton a
a = forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate () forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Storable a => Ptr a -> a -> IO ()
poke a
a


infixr 5 `append`

append ::
   (Shape.C shx, Shape.C shy, Storable a) =>
   Array shx a -> Array shy a -> Array (shx::+shy) a
append :: forall shx shy a.
(C shx, C shy, Storable a) =>
Array shx a -> Array shy a -> Array (shx ::+ shy) a
append (Array shx
shX ForeignPtr a
x) (Array shy
shY ForeignPtr a
y) =
   forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate (shx
shXforall sh0 sh1. sh0 -> sh1 -> sh0 ::+ sh1
::+shy
shY) forall a b. (a -> b) -> a -> b
$ \Ptr a
zPtr ->
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
y forall a b. (a -> b) -> a -> b
$ \Ptr a
yPtr -> do
      let sizeX :: Int
sizeX = forall sh. C sh => sh -> Int
Shape.size shx
shX
      let sizeY :: Int
sizeY = forall sh. C sh => sh -> Int
Shape.size shy
shY
      forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
zPtr Ptr a
xPtr Int
sizeX
      forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray (forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
zPtr Int
sizeX) Ptr a
yPtr Int
sizeY

{- |
prop> \(QC.NonNegative n) (Array16 x)  ->  x == Array.mapShape (Shape.ZeroBased . Shape.size) (Array.append (Array.take n x) (Array.drop n x))
-}
take, drop ::
   (Integral n, Storable a) =>
   n -> Array (Shape.ZeroBased n) a -> Array (Shape.ZeroBased n) a
take :: forall n a.
(Integral n, Storable a) =>
n -> Array (ZeroBased n) a -> Array (ZeroBased n) a
take n
n = forall sh0 sh1 a.
(C sh0, C sh1, Storable a) =>
Array (sh0 ::+ sh1) a -> Array sh0 a
takeLeft forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall n a.
(Integral n, Storable a) =>
n -> Array (ZeroBased n) a -> Array (ZeroBased n ::+ ZeroBased n) a
splitN n
n
drop :: forall n a.
(Integral n, Storable a) =>
n -> Array (ZeroBased n) a -> Array (ZeroBased n) a
drop n
n = forall sh0 sh1 a.
(C sh0, C sh1, Storable a) =>
Array (sh0 ::+ sh1) a -> Array sh1 a
takeRight forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall n a.
(Integral n, Storable a) =>
n -> Array (ZeroBased n) a -> Array (ZeroBased n ::+ ZeroBased n) a
splitN n
n

splitN ::
   (Integral n, Storable a) =>
   n -> Array (Shape.ZeroBased n) a ->
   Array (Shape.ZeroBased n ::+ Shape.ZeroBased n) a
splitN :: forall n a.
(Integral n, Storable a) =>
n -> Array (ZeroBased n) a -> Array (ZeroBased n ::+ ZeroBased n) a
splitN n
n = forall sh0 sh1 a. (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape (forall n. Real n => n -> ZeroBased n -> ZeroBased n ::+ ZeroBased n
Shape.zeroBasedSplit n
n)

{- |
prop> \(Array16 x) (Array16 y) -> let xy = Array.append x y in x == Array.takeLeft xy  &&  y == Array.takeRight xy
-}
takeLeft ::
   (Shape.C sh0, Shape.C sh1, Storable a) =>
   Array (sh0::+sh1) a -> Array sh0 a
takeLeft :: forall sh0 sh1 a.
(C sh0, C sh1, Storable a) =>
Array (sh0 ::+ sh1) a -> Array sh0 a
takeLeft =
   forall sh0 sh1 sh2 a.
(C sh0, C sh1, C sh2, Storable a) =>
Array (sh0 ::+ (sh1 ::+ sh2)) a -> Array sh1 a
takeCenter forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall sh0 sh1 a. (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape (\(sh0
sh0 ::+ sh1
sh1) -> (Zero
Shape.Zero forall sh0 sh1. sh0 -> sh1 -> sh0 ::+ sh1
::+ sh0
sh0 forall sh0 sh1. sh0 -> sh1 -> sh0 ::+ sh1
::+ sh1
sh1))

takeRight ::
   (Shape.C sh0, Shape.C sh1, Storable a) =>
   Array (sh0::+sh1) a -> Array sh1 a
takeRight :: forall sh0 sh1 a.
(C sh0, C sh1, Storable a) =>
Array (sh0 ::+ sh1) a -> Array sh1 a
takeRight =
   forall sh0 sh1 sh2 a.
(C sh0, C sh1, C sh2, Storable a) =>
Array (sh0 ::+ (sh1 ::+ sh2)) a -> Array sh1 a
takeCenter forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall sh0 sh1 a. (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape (\(sh0
sh0 ::+ sh1
sh1) -> (sh0
sh0 forall sh0 sh1. sh0 -> sh1 -> sh0 ::+ sh1
::+ sh1
sh1 forall sh0 sh1. sh0 -> sh1 -> sh0 ::+ sh1
::+ Zero
Shape.Zero))

split ::
   (Shape.C sh0, Shape.C sh1, Storable a) =>
   Array (sh0::+sh1) a -> (Array sh0 a, Array sh1 a)
split :: forall sh0 sh1 a.
(C sh0, C sh1, Storable a) =>
Array (sh0 ::+ sh1) a -> (Array sh0 a, Array sh1 a)
split Array (sh0 ::+ sh1) a
x = (forall sh0 sh1 a.
(C sh0, C sh1, Storable a) =>
Array (sh0 ::+ sh1) a -> Array sh0 a
takeLeft Array (sh0 ::+ sh1) a
x, forall sh0 sh1 a.
(C sh0, C sh1, Storable a) =>
Array (sh0 ::+ sh1) a -> Array sh1 a
takeRight Array (sh0 ::+ sh1) a
x)

{- |
prop> \(Array16 x) (Array16 y) (Array16 z) -> let xyz = Array.append x $ Array.append y z in y == Array.takeCenter xyz
-}
takeCenter ::
   (Shape.C sh0, Shape.C sh1, Shape.C sh2, Storable a) =>
   Array (sh0::+sh1::+sh2) a -> Array sh1 a
takeCenter :: forall sh0 sh1 sh2 a.
(C sh0, C sh1, C sh2, Storable a) =>
Array (sh0 ::+ (sh1 ::+ sh2)) a -> Array sh1 a
takeCenter (Array (sh0
sh0::+sh1
sh1::+sh2
_sh2) ForeignPtr a
x) =
   forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
unsafeCreateWithSize sh1
sh1 forall a b. (a -> b) -> a -> b
$ \Int
k Ptr a
yPtr ->
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
      forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
yPtr (forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr (forall sh. C sh => sh -> Int
Shape.size sh0
sh0)) Int
k



{- |
prop> \(Array16 xs)  ->  Array.sum xs == sum (Array.toList xs)
-}
sum :: (Shape.C sh, Storable a, Num a) => Array sh a -> a
sum :: forall sh a. (C sh, Storable a, Num a) => Array sh a -> a
sum = forall sh a b.
(C sh, Storable a) =>
(b -> a -> b) -> b -> Array sh a -> b
foldl forall a. Num a => a -> a -> a
(+) a
0

{- |
prop> \(Array16 xs)  ->  Array.product xs == product (Array.toList xs)
-}
product :: (Shape.C sh, Storable a, Num a) => Array sh a -> a
product :: forall sh a. (C sh, Storable a, Num a) => Array sh a -> a
product = forall sh a b.
(C sh, Storable a) =>
(b -> a -> b) -> b -> Array sh a -> b
foldl forall a. Num a => a -> a -> a
(*) a
1

{-# INLINE foldl #-}
foldl :: (Shape.C sh, Storable a) => (b -> a -> b) -> b -> Array sh a -> b
foldl :: forall sh a b.
(C sh, Storable a) =>
(b -> a -> b) -> b -> Array sh a -> b
foldl b -> a -> b
op b
a (Array sh
sh ForeignPtr a
x) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
   forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr ->
      forall a b.
Storable a =>
(Int -> b -> a -> b) -> b -> Int -> Ptr a -> Int -> IO b
Memory.foldl (forall a b. a -> b -> a
const b -> a -> b
op) b
a (forall sh. C sh => sh -> Int
Shape.size sh
sh) Ptr a
xPtr Int
1