module Data.Array.Comfort.Boxed (
   Array,
   shape,
   reshape,
   mapShape,
   (!),
   Array.toList,
   Array.fromList,
   Array.vectorFromList,
   toAssociations,
   fromMap,
   toMap,
   fromContainer,
   toContainer,
   indices,
   Array.replicate,

   Array.map,
   zipWith,
   (//),
   accumulate,
   fromAssociations,
   ) where

import qualified Data.Array.Comfort.Boxed.Unchecked as Array
import qualified Data.Array.Comfort.Container as Container
import qualified Data.Array.Comfort.Check as Check
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Boxed.Unchecked (Array(Array))

import qualified Data.Primitive.Array as Prim

import qualified Control.Monad.Primitive as PrimM
import Control.Monad.ST (runST)
import Control.Applicative ((<$>))

import qualified Data.Foldable as Fold
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Map (Map)
import Data.Set (Set)
import Data.Foldable (forM_)

import Prelude hiding (zipWith, replicate)


shape :: Array.Array sh a -> sh
shape :: Array sh a -> sh
shape = Array sh a -> sh
forall sh a. Array sh a -> sh
Array.shape

reshape :: (Shape.C sh0, Shape.C sh1) => sh1 -> Array sh0 a -> Array sh1 a
reshape :: sh1 -> Array sh0 a -> Array sh1 a
reshape = String
-> (Array sh0 a -> sh0)
-> (sh1 -> Array sh0 a -> Array sh1 a)
-> sh1
-> Array sh0 a
-> Array sh1 a
forall sh0 sh1 array0 array1.
(C sh0, C sh1) =>
String
-> (array0 -> sh0)
-> (sh1 -> array0 -> array1)
-> sh1
-> array0
-> array1
Check.reshape String
"Boxed" Array sh0 a -> sh0
forall sh a. Array sh a -> sh
shape sh1 -> Array sh0 a -> Array sh1 a
forall sh1 sh0 a. sh1 -> Array sh0 a -> Array sh1 a
Array.reshape

mapShape ::
   (Shape.C sh0, Shape.C sh1) => (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape :: (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape sh0 -> sh1
f Array sh0 a
arr = sh1 -> Array sh0 a -> Array sh1 a
forall sh0 sh1 a.
(C sh0, C sh1) =>
sh1 -> Array sh0 a -> Array sh1 a
reshape (sh0 -> sh1
f (sh0 -> sh1) -> sh0 -> sh1
forall a b. (a -> b) -> a -> b
$ Array sh0 a -> sh0
forall sh a. Array sh a -> sh
shape Array sh0 a
arr) Array sh0 a
arr


indices :: (Shape.Indexed sh) => sh -> Array.Array sh (Shape.Index sh)
indices :: sh -> Array sh (Index sh)
indices sh
sh = sh -> [Index sh] -> Array sh (Index sh)
forall sh a. C sh => sh -> [a] -> Array sh a
Array.fromList sh
sh ([Index sh] -> Array sh (Index sh))
-> [Index sh] -> Array sh (Index sh)
forall a b. (a -> b) -> a -> b
$ sh -> [Index sh]
forall sh. Indexed sh => sh -> [Index sh]
Shape.indices sh
sh

fromMap :: (Ord k) => Map k a -> Array (Set k) a
fromMap :: Map k a -> Array (Set k) a
fromMap Map k a
m = Set k -> [a] -> Array (Set k) a
forall sh a. C sh => sh -> [a] -> Array sh a
Array.fromList (Map k a -> Set k
forall k a. Map k a -> Set k
Map.keysSet Map k a
m) (Map k a -> [a]
forall k a. Map k a -> [a]
Map.elems Map k a
m)

toMap :: (Ord k) => Array (Set k) a -> Map k a
toMap :: Array (Set k) a -> Map k a
toMap Array (Set k) a
arr = [(k, a)] -> Map k a
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(k, a)] -> Map k a) -> [(k, a)] -> Map k a
forall a b. (a -> b) -> a -> b
$ [k] -> [a] -> [(k, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Set k -> [k]
forall a. Set a -> [a]
Set.toAscList (Set k -> [k]) -> Set k -> [k]
forall a b. (a -> b) -> a -> b
$ Array (Set k) a -> Set k
forall sh a. Array sh a -> sh
shape Array (Set k) a
arr) (Array (Set k) a -> [a]
forall sh a. C sh => Array sh a -> [a]
Array.toList Array (Set k) a
arr)

fromContainer :: (Container.C f) => f a -> Array (Container.Shape f) a
fromContainer :: f a -> Array (Shape f) a
fromContainer f a
xs = Shape f -> [a] -> Array (Shape f) a
forall sh a. C sh => sh -> [a] -> Array sh a
Array.fromList (f a -> Shape f
forall (f :: * -> *) a. C f => f a -> Shape f
Container.toShape f a
xs) (f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Fold.toList f a
xs)

toContainer :: (Container.C f) => Array (Container.Shape f) a -> f a
toContainer :: Array (Shape f) a -> f a
toContainer Array (Shape f) a
arr = Shape f -> [a] -> f a
forall (f :: * -> *) a. C f => Shape f -> [a] -> f a
Container.fromList (Array (Shape f) a -> Shape f
forall sh a. Array sh a -> sh
Array.shape Array (Shape f) a
arr) (Array (Shape f) a -> [a]
forall sh a. C sh => Array sh a -> [a]
Array.toList Array (Shape f) a
arr)


infixl 9 !

(!) :: (Shape.Indexed sh) => Array sh a -> Shape.Index sh -> a
(!) (Array sh
sh Array a
arr) Index sh
ix =
   if sh -> Index sh -> Bool
forall sh. Indexed sh => sh -> Index sh -> Bool
Shape.inBounds sh
sh Index sh
ix
      then Array a -> Int -> a
forall a. Array a -> Int -> a
Prim.indexArray Array a
arr (Int -> a) -> Int -> a
forall a b. (a -> b) -> a -> b
$ sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset sh
sh Index sh
ix
      else String -> a
forall a. HasCallStack => String -> a
error String
"Array.Comfort.Boxed.!: index out of bounds"


zipWith ::
   (Shape.C sh, Eq sh) =>
   (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith :: (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith a -> b -> c
f Array sh a
a Array sh b
b =
   if Array sh a -> sh
forall sh a. Array sh a -> sh
shape Array sh a
a sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
== Array sh b -> sh
forall sh a. Array sh a -> sh
shape Array sh b
b
      then (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
forall sh a b c.
C sh =>
(a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
Array.zipWith a -> b -> c
f Array sh a
a Array sh b
b
      else String -> Array sh c
forall a. HasCallStack => String -> a
error String
"zipWith: shapes mismatch"


(//) ::
   (Shape.Indexed sh) => Array sh a -> [(Shape.Index sh, a)] -> Array sh a
// :: Array sh a -> [(Index sh, a)] -> Array sh a
(//) (Array sh
sh Array a
arr) [(Index sh, a)]
xs = (forall s. ST s (Array sh a)) -> Array sh a
forall a. (forall s. ST s a) -> a
runST (do
   MutableArray s a
marr <- Array a -> Int -> Int -> ST s (MutableArray (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Array a -> Int -> Int -> m (MutableArray (PrimState m) a)
Prim.thawArray Array a
arr Int
0 (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh)
   [(Index sh, a)] -> ((Index sh, a) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, a)]
xs (((Index sh, a) -> ST s ()) -> ST s ())
-> ((Index sh, a) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Index sh
ix,a
a) -> MutableArray (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
Prim.writeArray MutableArray s a
MutableArray (PrimState (ST s)) a
marr (sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset sh
sh Index sh
ix) a
a
   sh -> Array a -> Array sh a
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array a -> Array sh a) -> ST s (Array a) -> ST s (Array sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableArray (PrimState (ST s)) a -> ST s (Array a)
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> m (Array a)
Prim.unsafeFreezeArray MutableArray s a
MutableArray (PrimState (ST s)) a
marr)

accumulate ::
   (Shape.Indexed sh) =>
   (a -> b -> a) -> Array sh a -> [(Shape.Index sh, b)] -> Array sh a
accumulate :: (a -> b -> a) -> Array sh a -> [(Index sh, b)] -> Array sh a
accumulate a -> b -> a
f (Array sh
sh Array a
arr) [(Index sh, b)]
xs = (forall s. ST s (Array sh a)) -> Array sh a
forall a. (forall s. ST s a) -> a
runST (do
   MutableArray s a
marr <- Array a -> Int -> Int -> ST s (MutableArray (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Array a -> Int -> Int -> m (MutableArray (PrimState m) a)
Prim.thawArray Array a
arr Int
0 (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh)
   [(Index sh, b)] -> ((Index sh, b) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, b)]
xs (((Index sh, b) -> ST s ()) -> ST s ())
-> ((Index sh, b) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Index sh
ix,b
b) -> MutableArray (PrimState (ST s)) a -> Int -> (a -> a) -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> (a -> a) -> m ()
updateArray MutableArray s a
MutableArray (PrimState (ST s)) a
marr (sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset sh
sh Index sh
ix) ((a -> a) -> ST s ()) -> (a -> a) -> ST s ()
forall a b. (a -> b) -> a -> b
$ (a -> b -> a) -> b -> a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> b -> a
f b
b
   sh -> Array a -> Array sh a
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array a -> Array sh a) -> ST s (Array a) -> ST s (Array sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableArray (PrimState (ST s)) a -> ST s (Array a)
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> m (Array a)
Prim.unsafeFreezeArray MutableArray s a
MutableArray (PrimState (ST s)) a
marr)

updateArray ::
   PrimM.PrimMonad m =>
   Prim.MutableArray (PrimM.PrimState m) a -> Int -> (a -> a) -> m ()
updateArray :: MutableArray (PrimState m) a -> Int -> (a -> a) -> m ()
updateArray MutableArray (PrimState m) a
marr Int
k a -> a
f = MutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
Prim.writeArray MutableArray (PrimState m) a
marr Int
k (a -> m ()) -> (a -> a) -> a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f (a -> m ()) -> m a -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> m a
Prim.readArray MutableArray (PrimState m) a
marr Int
k

toAssociations :: (Shape.Indexed sh) => Array sh a -> [(Shape.Index sh, a)]
toAssociations :: Array sh a -> [(Index sh, a)]
toAssociations Array sh a
arr = [Index sh] -> [a] -> [(Index sh, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (sh -> [Index sh]
forall sh. Indexed sh => sh -> [Index sh]
Shape.indices (sh -> [Index sh]) -> sh -> [Index sh]
forall a b. (a -> b) -> a -> b
$ Array sh a -> sh
forall sh a. Array sh a -> sh
shape Array sh a
arr) (Array sh a -> [a]
forall sh a. C sh => Array sh a -> [a]
Array.toList Array sh a
arr)

fromAssociations ::
   (Shape.Indexed sh) => a -> sh -> [(Shape.Index sh, a)] -> Array sh a
fromAssociations :: a -> sh -> [(Index sh, a)] -> Array sh a
fromAssociations a
a sh
sh [(Index sh, a)]
xs = (forall s. ST s (Array sh a)) -> Array sh a
forall a. (forall s. ST s a) -> a
runST (do
   MutableArray s a
marr <- Int -> a -> ST s (MutableArray (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MutableArray (PrimState m) a)
Prim.newArray (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) a
a
   [(Index sh, a)] -> ((Index sh, a) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Index sh, a)]
xs (((Index sh, a) -> ST s ()) -> ST s ())
-> ((Index sh, a) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Index sh
ix,a
x) -> MutableArray (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
Prim.writeArray MutableArray s a
MutableArray (PrimState (ST s)) a
marr (sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset sh
sh Index sh
ix) a
x
   sh -> Array a -> Array sh a
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array a -> Array sh a) -> ST s (Array a) -> ST s (Array sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableArray (PrimState (ST s)) a -> ST s (Array a)
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> m (Array a)
Prim.unsafeFreezeArray MutableArray s a
MutableArray (PrimState (ST s)) a
marr)